In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [3]:
with open("data.txt","r",encoding='utf-8') as f:
    text=f.read()

In [4]:
cleaned_text=[]
text=text.lower().split("\n")
for sentence in text:
    if(sentence.strip()):
        cleaned_text.append(sentence)



In [None]:
len(cleaned_text)


In [None]:
vocab={"UNK":0}


In [None]:
import nltk
nltk.download("punkt")
nltk.download("punkt_tab")

In [None]:
import re


In [None]:
for line in cleaned_text:
    line=line.replace('"',"")
    line = re.sub(r"([,.:;!?\"'])", r" \1 ", line)
    line = re.sub(r"\s+", " ", line).strip()
    line=nltk.word_tokenize(line)
    for word in line:
        if word not in vocab:
            vocab[word]=len(vocab)

In [None]:
len(vocab)


In [None]:
vocab

In [None]:
def text_to_sequence(sentence,vocab):
    numerical=[]
    sentence=sentence.lower()
    sentence=sentence.replace('"',"")
    sentence = re.sub(r"([,.:;!?\"'])", r" \1 ", sentence)
    sentence = re.sub(r"\s+", " ", sentence).strip()
    sentence=nltk.word_tokenize(sentence)
    for word in sentence:
        if word in vocab:
            numerical.append(vocab[word])
        else:
            numerical.append(0)
    return numerical

In [None]:
sequence=[]
for sentence in cleaned_text:
    sequence.append(text_to_sequence(sentence,vocab))

In [None]:
sequence


In [None]:
input=[]
for el in sequence:
    i=2
    while(i<len(el)):
        input.append(el[0:i])
        i+=1


In [None]:
input


In [None]:
import numpy as np


In [None]:
def padding(sentence,maxlen):
    list=[]
    length=len(sentence)
    to_pad=maxlen-length
    list=[0]*to_pad+sentence
    return list

In [None]:
padded_input=[]
for line in input:
    padded_input.append(padding(line,28))

In [None]:
len(padded_input[10])

In [None]:
padded_input=torch.tensor(padded_input,dtype=torch.long)

In [None]:
X=padded_input[: , :-1]
y=padded_input[:, -1]
print(X)
print(y)

In [None]:
class CustomDataset(Dataset):
    def __init__(self,input,output):
        self.input=input
        self.output=output
    def __len__(self):
        return self.input.shape[0]
    def __getitem__(self,index):
        return self.input[index],self.output[index]

In [None]:
dataset=CustomDataset(X,y)


In [None]:
dataloader=DataLoader(dataset,batch_size=32,shuffle=True,pin_memory=True)


In [None]:
class my_nn(nn.Module):
    def __init__(self,vocab_length):
        super().__init__()
        self.embedding=nn.Embedding(vocab_length,300)
        self.gru=nn.GRU(300,512,batch_first=True)
        self.dropout=nn.Dropout(p=0.3)
        self.output=nn.Linear(512,vocab_length)

    def forward(self,input):
        embedding_out=self.embedding(input)
        all_states,final_hidden=self.gru(embedding_out)
        dropout_out=self.dropout(final_hidden[-1])
        output=self.output(dropout_out)
        return output

In [None]:
learning_rate=0.001
epochs=30


In [None]:
model=my_nn(len(vocab))
model=model.to(torch.device("cuda"))
loss_func=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)

In [None]:
for epoch in range(epochs):
    total_loss=0
    for batch_features , batch_labels in dataloader:
        batch_features,batch_labels=batch_features.to(torch.device("cuda")),batch_labels.to(torch.device("cuda"))
        ypred=model(batch_features)
        loss=loss_func(ypred,batch_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss=total_loss+loss.item()
    print(f"Epoch {epoch+1} loss is : {(total_loss/len(dataloader)):.4f}")

In [None]:
inverse_vocab={v:k for k,v in vocab.items()}


In [None]:
model.eval()


In [None]:
def prediction(model,vocab_length,text):
    sequence=text_to_sequence(text,vocab)
    padded=torch.tensor(padding(sequence,27),dtype=torch.long).to(torch.device("cuda"))
    ypred=model(padded.unsqueeze(0))
    prob=torch.nn.functional.softmax(ypred,dim=1)
    value,index=torch.max(prob,dim=1)
    return inverse_vocab[index.item()]

In [None]:
text="To Sherlock Holmes she"
# To Sherlock Holmes she is always _the_ woman. I have seldom heard him
for i in range(1,9):
    output=prediction(model,len(vocab),text)
    print(text+ " "+output)
    text=text+" "+output