In [47]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import streamlit as st
import re
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt


In [39]:
text = open(r"C:\Users\abhin\OneDrive\Documents\Cpp\es335-24-fall-assignment-3\shakespeare_input.txt", "r").read()
text = re.sub(r'[^a-zA-Z\s.,!?:\']+', '', text)
text = re.sub(r'([.,!?:])', r' \1 ', text)
words = text.lower().split()


In [40]:
vocab = list(set(words)) + ['<UNK>']
vocab_size = len(vocab)
print("Length of vocabulary: ", vocab_size)
word2idx = {w: idx for (idx, w) in enumerate(vocab)}
idx2word = {idx: w for (idx, w) in enumerate(vocab)}

Length of vocabulary:  28241


In [41]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [42]:
padidx = word2idx['.']

In [43]:
class NextWordModel(nn.Module):
    def __init__(self, emb_dim, context_len, actvn='ReLU'):
        super(NextWordModel, self).__init__()
        self.context_len = context_len
        actdict = {'ReLU': nn.ReLU(), 'Tanh': nn.Tanh()}
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.lin1 = nn.Linear(context_len * emb_dim, 256)
        self.act1 = actdict[actvn]
        self.lin2 = nn.Linear(256, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = x.view(x.size(0), -1)
        x = self.lin1(x)
        x = self.act1(x)
        x = self.lin2(x)
        return x
    
def get_ctxt(context_len):
    X, y = [], []
    context = [padidx]*context_len
    for w in words:
        idx = word2idx.get(w, word2idx['<UNK>'])
        X.append(context)
        y.append(idx)
        context = context[1:] + [idx]
    return torch.tensor(X), torch.tensor(y)

def train(model, epochs, context_len):  
    X, y = get_ctxt(context_len)
    dataset = torch.utils.data.TensorDataset(X, y)
    dataloder = torch.utils.data.DataLoader(dataset, pin_memory=True, num_workers=2, batch_size=128)
    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=0.001)    
    model.to(device)
    for epoch in range(epochs):
        losses = []
        for x_batch, y_batch in dataloder:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            y_pred = model(x_batch)
            loss = loss_fn(y_pred, y_batch)
            losses.append(loss.item())
            loss.backward()
            opt.step()
            opt.zero_grad()
        if epoch % 10 == 0:
            print(f"Epoch {epoch+1} Loss:", sum(losses)/len(losses))

In [44]:
e32_c5_r = NextWordModel(32, 5, 'ReLU')
e64_c5_r = NextWordModel(64, 5, 'ReLU')
e32_c5_t = NextWordModel(32, 5, 'Tanh')
e64_c5_t = NextWordModel(64, 5, 'Tanh')
e32_c10_r = NextWordModel(32, 10, 'ReLU')
e64_c10_r = NextWordModel(64, 10, 'ReLU')
e32_c10_t = NextWordModel(32, 10, 'Tanh')
e64_c10_t = NextWordModel(64, 10, 'Tanh')


In [45]:
train(e32_c5_t, 200,5)
torch.save(e32_c5_t, "e32_c5_t.pth")

KeyboardInterrupt: 

In [33]:
def generate(model, text, num_words, temperature):
    text = re.sub(r'[^a-zA-Z\s.,!?:\']+', '', text)
    text = re.sub(r'([.,!?:])', r' \1 ', text)
    text = text.lower().split()
    context_len = model.context_len
    generated_words = []
    if len(text) < context_len:
        context = [padidx]*(context_len - len(text)) + [word2idx.get(w, word2idx['<UNK>']) for w in text]
    else:
        context = [word2idx.get(w, word2idx['<UNK>']) for w in text[-context_len:]] 
    for _ in num_words:
        x = torch.tensor(context).unsqueeze(0)
        logits = model(x)
        idx = torch.distributions.Categorical(logits=model(x)).sample().item()
        next_word = idx2word[idx]
        generated_words.append(next_word)
        context = context[1:] + [idx]
    return ' '.join(generated_words)


    


In [48]:
def visualize_embeddings(model):
   embeddings = model.embedding.weight.data.cpu().numpy()
   tsne = TSNE(n_components=2, random_state=42)
   embeddings_2d = tsne.fit_transform(embeddings)
   selected_words = [
       'love', 'hate',
       'life', 'death',
       'good', 'evil',
       'light', 'dark',
       'sweet', 'bitter',
       'joy', 'sorrow',
       'peace', 'war',
       'heaven', 'hell',
       'truth', 'lie',
       'fair', 'foul',
       'friend', 'foe',
       'honor', 'shame',
       'king', 'queen',
       'lord', 'lady',
       'day', 'night',
       'young', 'old',
       'laugh', 'weep'
   ]
   plt.figure(figsize=(15, 10))
   plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
              c='lightgray', alpha=0.1, s=5)
   for word in selected_words:
       if word in word2idx:
           idx = word2idx[word]
           plt.scatter(embeddings_2d[idx, 0], embeddings_2d[idx, 1], 
                     c='blue', s=100)
           plt.annotate(word, 
                      (embeddings_2d[idx, 0], embeddings_2d[idx, 1]),
                      fontsize=10,
                      alpha=0.8)
   
   plt.title(f'Word Embeddings Visualization (dim={model.embedding.embedding_dim})')
   plt.axis('off')
   plt.tight_layout()
   plt.show()