In [1]:
import tqdm
import torch
import requests
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tokenizers import ByteLevelBPETokenizer


## Let's load the data

In [2]:
url = "https://www.gutenberg.org/files/100/100-0.txt"
response = requests.get(url)
if response.status_code == 200:
    with open("shakespeare_complete_works.txt", "wb") as f:
        f.write(response.content)
    print("Shakespeare's complete works have been downloaded successfully.")
else:
    print("Failed to download Shakespeare's complete works. Status code:", response.status_code)

with open('shakespeare_complete_works.txt','rb') as f:
    text_data=f.readlines()
    text_data=[line.decode() for line in text_data if len(line)>10]

Shakespeare's complete works have been downloaded successfully.


## First Let's provide our tokenizer


In [3]:
tokenizer = ByteLevelBPETokenizer()
tokenizer.train_from_iterator( text_data,vocab_size=52_000, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])
tokenizer.enable_padding(pad_id=1, pad_token='<pad>',length=32)
# tokenizer.enable_truncation(max_length=32)

## Now the dataset

In [4]:
class TextDataset(Dataset):
    def __init__(self,text_data):
        self.text_data=text_data
    def __len__(self):
        return len(self.text_data)
    def __getitem__(self, idx):
        ids=torch.tensor(tokenizer.encode('<s> '+self.text_data[idx]+ ' </s>').ids)[:32]
        return {'x':ids[:-1],'y':ids[1:]}

In [5]:
text_dataset=TextDataset(text_data)
train_loader=DataLoader(text_dataset,batch_size=32,shuffle=True)

## Now the model

In [6]:
device='cuda'
class RNN(nn.Module):
    def __init__(self, input_size=16, hidden_size=64, output_size=128):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.i2h = nn.Linear(input_size, hidden_size, bias=False)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)

    
    def forward(self, x, hidden_state):
        x = self.i2h(x)
        hidden_state = self.h2h(hidden_state)
        hidden_state = torch.tanh(x + hidden_state)
        return self.h2o(hidden_state), hidden_state

    def init_zero_hidden(self, batch_size=1):
        return torch.zeros(batch_size, self.hidden_size, requires_grad=False).to(device)
    
class LanguageModel(nn.Module):
    def __init__(self,embedding_size=16, hidden_size=64, output_size=128,batch_size=32):
        super(LanguageModel, self).__init__()
        self.num_vocabs=tokenizer.get_vocab_size()
        self.embedding=nn.Embedding(self.num_vocabs,embedding_size)
        self.rnn_cell=RNN(embedding_size, hidden_size, output_size)
        self.output=nn.Linear(output_size,self.num_vocabs)
    def forward(self, x):
        b=x.size(0)
        x=self.embedding(x) #b,l,embedding_size
        h=self.rnn_cell.init_zero_hidden(b)
        outputs=[]
        for c in range(x.shape[1]):
            out, h = self.rnn_cell(x[:, c].reshape(b,-1),h)
            outputs.append(out)
        outputs=torch.stack(outputs,dim=1)
        outputs=self.output(outputs)
        return outputs





In [7]:
lm=LanguageModel().to(device)

In [8]:
def predict_next_word(sentence):
    ids=torch.stack([torch.tensor(tokenizer.encode(sentence).ids)],dim=0).to(device)
    if ids.shape[-1]>31:
        ids=ids[:,-31:]
    y=lm(ids).detach().cpu().numpy().argmax(axis=-1)
    last=len(ids[ids!=1])-1
    return tokenizer.decode([y[0][last]])
def generate_text(sentence,n_iter=100):
    for _ in range(n_iter):
        next_word=predict_next_word(sentence)
        sentence=sentence+next_word
        if next_word=='</s>':
            break
    return sentence

In [9]:
generate_text('<s> Hello')

'<s> Hellochester redbre jointedPropheobservance’!ImprisonedStoremember cinders kin By’!Intends stayed crutches stewardship VAU Squire Rutlandantages corners jealousy 13� dusty satisfy RICHARD whispering lyingest rescuedSIR RICHARD GAOLchester spritespierced imaginations speariwork hog June dwind flats lyingest rescuedSIR RICHARD GAOLchester spritespierced imaginations speariwork hog June dwind flats lyingest rescuedSIR RICHARD GAOLchester spritespierced imaginations speariwork hog June dwind flats lyingest rescuedSIR RICHARD GAOLchester spritespierced imaginations speariwork hog June dwind flats lyingest rescuedSIR RICHARD GAOLchester spritespierced imaginations speariwork'

In [10]:
def train(model,opt,loss_fn,train_loader,device='cpu',n_epochs=100,save=False):
    model.train()
    for epoch in range(n_epochs):
        total_loss=0
        model.train()
        max_loss=1e9
        for batch in tqdm.tqdm(train_loader):
            x = batch['x'].to(device) # GPU
            y = batch['y'].to(device) # GPU
            opt.zero_grad()
            x_hat = model(x).permute((0,2,1))
            loss=loss_fn(x_hat,y)
            
            total_loss+=loss.detach().cpu().numpy()
            loss.backward()
            opt.step()
        print(generate_text('<s> Hello'))
        if save:
            if total_loss/len(train_loader)<max_loss:
                max_loss=total_loss/len(train_loader)
                torch.save(model.state_dict(), "best_model.pt")
        print(f"TRAIN: EPOCH {epoch}: loss: {total_loss/len(train_loader)}")

In [11]:
opt = torch.optim.Adam(lm.parameters(),lr=0.001)
loss_fn=torch.nn.CrossEntropyLoss()

In [12]:
train(lm,opt,loss_fn,train_loader,'cuda')

 43%|████▎     | 1807/4180 [01:11<01:33, 25.32it/s]


KeyboardInterrupt: 