In [17]:
from datasets import load_dataset
from gensim.utils import tokenize
from tqdm import tqdm
from textblob import TextBlob
from collections import Counter
from torch.optim import AdamW
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
data=load_dataset('imdb')
device='cuda'
from torch.nn import Module
from torch import nn 
epochs=5

In [18]:
all_sentences=[]
sent_threshold=32

for text_block in tqdm(data['train']['text']):
    for sentences in TextBlob(text_block).sentences:
        len_of_sent=sentences.words 
        for sent in sentences.split('.<br /><br />'):
            if len(len_of_sent)<sent_threshold:
                all_sentences.append(sent)
    
    


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25000/25000 [00:26<00:00, 960.39it/s]


In [19]:
words=[]
for sent in tqdm(all_sentences):
    for word in tokenize(sent):
        words.append(word.lower())

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 227434/227434 [00:01<00:00, 124312.33it/s]


In [20]:
cntr=Counter(words)

In [21]:
vocab_size=40000
our_words=set(['<unk>','<bos>','<eos>','<pad>'])
for word,cnt in cntr.most_common()[0:vocab_size]:
    our_words.add(word)

In [22]:
len(our_words)

40004

In [23]:
for i in tokenize(all_sentences[5]):
    print(i)

I
Am
Curious
Yellow
is
a
risible
and
pretentious
steaming
pile


In [24]:
word2int={word:cnt for cnt,word in enumerate(our_words)}
int2word={cnt:word for cnt,word in enumerate(our_words)}

In [25]:
class word_dataset(Dataset):
    def __init__(self,data):
        super(word_dataset,self).__init__()
        self.data=data

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):
        selected_sent=self.data[idx]
        indexed_text=[word2int['<bos>']]
        main_text=[word2int[word] if word in word2int else word2int['<unk>'] for word in tokenize(selected_sent.lower())]
        indexed_text=indexed_text+main_text+[word2int['<eos>']]
        return indexed_text
        
        
        

In [26]:
def make_batch(batch):
    lenghts=(len(x) for x in batch)
    max_len=max(lenghts)

    new_batch=[]
    for sent in batch:
        for pad in range(max_len-len(sent)):
            sent.append(word2int['<pad>'])
        new_batch.append(sent)

    new_batch=torch.LongTensor(new_batch).to(device)

    

    return {
        'inp':new_batch[:,:-1],
        'label':new_batch[:,1:]
    }
    

In [27]:
train_data=word_dataset(all_sentences)
train_dataloader=DataLoader(train_data,batch_size=32,shuffle=True,collate_fn=make_batch)

In [28]:
class LanguageModel(Module):
    def __init__(self,vocab_size,hidden_size):
        super().__init__()

        self.emb=nn.Embedding(vocab_size,hidden_size)
        self.gru=nn.GRU(hidden_size,hidden_size,num_layers=3,batch_first=True)

        self.lay_with_drop=nn.Sequential(
            nn.Tanh(),
            nn.Linear(hidden_size,hidden_size),
            nn.Dropout(),
            nn.Tanh()
        )

        self.final_lin=nn.Sequential(
            nn.Linear(hidden_size,vocab_size)
        )

        
    def forward(self,text):
        
        emb_x=self.emb(text)
        
    
        x,_=self.gru(emb_x)
        

        #agr_x=x.mean(dim=1)#???????
        

        x=self.lay_with_drop(x)
        x=self.final_lin(x)
        return x
        

        

In [29]:
model=LanguageModel(len(our_words),128).to(device)
loss_fn=nn.CrossEntropyLoss(ignore_index=word2int['<pad>'])
optimizer=AdamW(model.parameters())

In [30]:
len(our_words)

40004

In [35]:
def train(epochs,model,loss_fn,optimizer,dataloader):
    for epoch in range(epochs):
        for batch in (pbar:=tqdm(dataloader)):
            optimizer.zero_grad()
            pred=model(batch['inp']).flatten(start_dim=0,end_dim=1)
            
            loss=loss_fn(pred,batch['label'].flatten())
            loss_item=loss.item()
            loss.backward()
            optimizer.step()
            pbar.set_description(f'loss:{loss_item}')

In [63]:
train(epochs,model,loss_fn,optimizer,train_dataloader)

loss:5.442234516143799: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7108/7108 [02:07<00:00, 55.81it/s]
loss:5.241026401519775: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7108/7108 [02:12<00:00, 53.48it/s]
loss:5.638738632202148: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7108/7108 [02:15<00:00, 52.55it/s]
loss:5.652792930603027: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7108/7108 [02:14<00:00, 52.79it/s]
loss:5.26169490814209: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [68]:
def check_the_acc(model,starting_seq,max_seq_len=128):
    model=model.to(device)
    input_ids=[word2int['<bos>']] +[word2int[word] if word in word2int else word2int['<unk>'] for word in starting_seq.split()]

    input_ids=torch.LongTensor(input_ids).to(device)

    with torch.no_grad():
        for i in range(max_seq_len):
            new_word_dist=model(input_ids[-1].unsqueeze(dim=0))
            new_word=new_word_dist.squeeze().argmax()
            input_ids=torch.cat([input_ids,new_word.unsqueeze(0)])
        return input_ids

In [81]:
a=check_the_acc(model,'This is what I',5)

In [82]:
b=[int2word[i.item()] for i in a]

In [83]:
b

['<bos>', '<unk>', 'is', 'what', '<unk>', '<eos>', 's', 'way', 'to', 'the']