In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import string
import lightning as Li
from collections import Counter
import plotly.express as px
from dash import Dash, dcc, html, Input, Output


In [3]:
# upload dataset
plays = pd.read_csv("datasets/shakespeare_plays.csv")

comedy_plays = plays[plays["genre"] == "Comedy"]
tragedy_plays = plays[plays["genre"] == "Tragedy"]
history_plays = plays[plays["genre"] == "History"]

comedy_text = comedy_plays["text"].unique()
tragedy_text = tragedy_plays["text"].unique()
history_text = history_plays["text"].unique()

#hyperparameter controlling context window
L = 2
#hyperparamter controlling neg/pos words ratio
k = 2



In [4]:

# helper function to remove punctionation
# makes words lowercase
def clean_word(word):
    word = word.lower()
        
    translator = str.maketrans('', '', string.punctuation)
    cleaned_word = word.translate(translator)
        
    return cleaned_word

# first, create dictionary to store index of each word
# will do the same for all texts to have same "subspace"

#makes array of words
comedy_words = np.array([clean_word(word) for string in comedy_text for word in string.split()])
tragedy_words = np.array([clean_word(word) for string in tragedy_text for word in string.split()])
history_words = np.array([clean_word(word) for string in history_text for word in string.split()])

all_words = np.concatenate((comedy_words, tragedy_words, history_words))
all_words = np.unique(all_words)
np.save("all_words.npy", all_words)

#create described dict
current_index = 0
common_dict = {all_words[i]:i for i in range(len(all_words))}
print(len(all_words))

27604


In [11]:
# create (word, label) pairs
def create_pairs(text, dict, L):
    word_pair = {i:[] for i in range(len(dict))}
    for sentence in text:
        words = sentence.split()
        for i in range(len(words)):
            word = clean_word(words[i])
            for j in range(1,L+1):
                if i+j < len(words):
                    context = clean_word(words[i+j])
                    word_pair[dict[word]].append(dict[context])
                if i-j >= 0:
                    context = clean_word(words[i-j])
                    word_pair[dict[word]].append(dict[context])
    return word_pair


comedy_pairs  = create_pairs(comedy_text, common_dict, L)
tragedy_pairs  = create_pairs(tragedy_text, common_dict, L)
history_pairs  = create_pairs(history_text, common_dict, L)


In [12]:
class PlayDataset(Dataset):
    def __init__(self, word_pair, text, dict) -> None:
        # creating final dataset
        # adding negative words
        self.words, self.pos, self.neg = [], [], []
        vocab_size = len(dict)
        for word, contexts in word_pair.items():
            for pos in contexts:
                neg = []
                while len(neg) < k:
                    temp = np.random.randint(vocab_size)
                    if temp != word and temp not in contexts:
                        neg.append(temp)
                self.words.append(word)
                self.pos.append(pos)
                self.neg.append(neg)
        super().__init__()


    def __len__(self):
        return len(self.words)

    def __getitem__(self, idx):
        return self.words[idx], self.pos[idx], self.neg[idx]
        super().__init__()

    def __len__(self):
        return len(self.words)

    def __getitem__(self, idx):
        return self.words[idx], self.pos[idx], self.neg[idx]
    



comedy_dataset = PlayDataset(comedy_pairs, comedy_text, common_dict)
comedy_loader = DataLoader(comedy_dataset, batch_size=32, shuffle=True)

tragedy_dataset = PlayDataset(tragedy_pairs, tragedy_text, common_dict)
tragedy_loader = DataLoader(tragedy_dataset, batch_size=32, shuffle=True)

history_dataset = PlayDataset(history_pairs, history_text, common_dict)
history_loader = DataLoader(history_dataset, batch_size=32, shuffle=True)

In [13]:

class neural(Li.LightningModule):
    def __init__(self, vocab_size, word_dim, k) -> None:
        super().__init__()
        self.w_embedding = nn.Embedding(vocab_size, word_dim, device='cpu')
        self.c_embedding = nn.Embedding(vocab_size, word_dim, device='cpu')
        # for every positive context word, we will have k negative ones.
        self.k = k
        self.learning_rate = 0.01


    def forward(self, word, pos, negatives):
        first_term = torch.mm(self.c_embedding(pos), self.w_embedding(word).T).sum()
        first_term = (nn.functional.logsigmoid(first_term))


        second_term = 0
        for neg in negatives:
            neg = torch.tensor(neg, device=self.device)
            temp = torch.mm(-self.c_embedding(neg), self.w_embedding(word).T).sum()
            temp = nn.functional.logsigmoid(temp)
            second_term += temp

        loss = -1 * (first_term + second_term)

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    def training_step(self, batch, batch_idx):
        word, pos, neg = batch
        word, pos = torch.tensor(word, device=self.device), torch.tensor(pos, device=self.device)
        loss = self.forward(word, pos, neg) 
        self.log("loss", loss, prog_bar=True)
        return loss

    







In [14]:
c_n = neural(len(common_dict),L,k)
trainer = Li.Trainer(max_epochs=2, enable_model_summary=0)
trainer.fit(c_n, train_dataloaders=comedy_loader)

t_n = neural(len(common_dict),L,k)
trainer = Li.Trainer(max_epochs=2, enable_model_summary=0)
trainer.fit(t_n, train_dataloaders=tragedy_loader)

h_n = neural(len(common_dict),L,k)
trainer = Li.Trainer(max_epochs=2, enable_model_summary=0)
trainer.fit(h_n, train_dataloaders=history_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


/Users/mohamadaltrabulsi/Desktop/coding_2024/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 0:   0%|          | 0/34657 [00:00<?, ?it/s] 

  word, pos = torch.tensor(word, device=self.device), torch.tensor(pos, device=self.device)
  neg = torch.tensor(neg, device=self.device)


Epoch 1: 100%|██████████| 34657/34657 [04:14<00:00, 136.10it/s, v_num=66, loss=-0.00]   

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 34657/34657 [04:14<00:00, 136.09it/s, v_num=66, loss=-0.00]


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Epoch 1: 100%|██████████| 22510/22510 [02:39<00:00, 141.55it/s, v_num=67, loss=-0.00]   

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 22510/22510 [02:39<00:00, 141.54it/s, v_num=67, loss=-0.00]


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Epoch 1: 100%|██████████| 23910/23910 [02:44<00:00, 145.53it/s, v_num=68, loss=0.0369]  

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 23910/23910 [02:44<00:00, 145.52it/s, v_num=68, loss=0.0369]


In [15]:
#finding the common words between the texts

common_words = np.intersect1d(np.intersect1d(comedy_words, tragedy_words),history_words)
common_words = np.array([clean_word(word) for word in common_words])
common_words = common_words[common_words != ""]

# sorting the words by how much they occur in all of the texts
word_counts = Counter(common_words)
common_words = sorted(word_counts.keys(), key=lambda word: word_counts[word], reverse=True)
common_words = np.array(common_words)

np.save("common_words.npy", common_words)
#words that are banned because they are not very interesting
banned_words = ['well', 'so', 'what', 'hence', 'there', 'on', 'indeed', 'you', 'do', 'this',
                'thus', 'here', 'him', 'say', 'thee', 'be', 'it','thou', 'to', 'come', 'go', 'away',
                'he', 'more', 'one', 'see', 'speak', 'stay', 'that', 'then', 'up', 'yet', 'by', 'all',
                'in', 'we', 'again', 'down', 'enough', 'for', 'have', 'her', 'himself', 'hold', 'how',
                'no', 'i', 'now', 'me', 'mine', 'none', 'not', 'of', 'still', 'out', 'farewell', 'fellow'
                'words', 'right', 'time', 'done', 'gentlemen', 'grace', 'sir', 'fellow', 'help', 'lady', 
                'myself', 'long', 'masters', 'name', 'queen', 'rest', 'sons', 'stand', 'tomorrow', 'tonight', 'words'
                'together', 'too', 'will', 'a', 'o', 'am', 'another', 'answer', 'are', 'arm', 'bear', 'best', 'both'
                'cause', 'comes', 'arms', 'art', 'back', 'both', 'cause', 'cousin', 'daughter', 'ever', 'fall',
                'first', 'fly', 'further', 'his', 'pardon', 'much', 'madam', 'look', 'leave', 'last', 'leave',
                'look', 'madam', 'much', 'pardon', 'sleep', 'she', 'shall', 'ill', 'is', 'know', 'upon', 'thyself',
                'think', 'thine', 'thanks', 'tell', 'sister', 'return', 'prince', 'aside', 'wife', 'welcome',
                'eye']
mask_ban = ~np.isin(common_words, banned_words)

preferred_word = ['lord', 'good', 'dead', 'love', 'true', 'faith', 'death', 'honour', 
                  'god', 'die', 'hell', 'heaven', 'joy', 'patience', 'nothing', 'shame',
                  'fear', 'fortune', 'enemies']
mask_inc = np.isin(common_words, preferred_word)


common_words = common_words[mask_ban]
max_words = 500
common_words = common_words[:max_words]

In [16]:
comedy_2d = c_n.w_embedding.weight.detach().cpu().numpy() +  c_n.c_embedding.weight.detach().cpu().numpy() 
comedy_points = np.array([comedy_2d[common_dict[i]] for i in common_words])

#used to help with graphing
comedy_data = pd.DataFrame({
    'x':comedy_points[:, 0],
    'y':comedy_points[:, 1],
    'genre':'comedy',
    'word':common_words
})

tragedy_2d = t_n.w_embedding.weight.detach().cpu().numpy() +  t_n.c_embedding.weight.detach().cpu().numpy() 
tragedy_points = np.array([tragedy_2d[common_dict[i]] for i in common_words])

tragedy_data = pd.DataFrame({
    'x':tragedy_points[:, 0],
    'y':tragedy_points[:, 1],
    'genre':'tragedy',
    'word':common_words
})

history_2d = h_n.w_embedding.weight.detach().cpu().numpy() +  h_n.c_embedding.weight.detach().cpu().numpy() 
history_points = np.array([history_2d[common_dict[i]] for i in common_words])

history_data = pd.DataFrame({
    'x':history_points[:, 0],
    'y':history_points[:, 1],
    'genre':'history',
    'word':common_words
})


In [17]:
torch.save(comedy_2d, "comedy_embeddings.pt")
torch.save(tragedy_2d, "tragedy_embeddings.pt")
torch.save(history_2d, "history_embeddings.pt")

In [19]:

app = Dash(__name__)
app.layout = html.Div([
    dcc.Graph(id="scatter-plot"),
    html.P("percentage of words shown"),
    dcc.Slider(
        id='slider',
        min=10, max=100, step=10,
        value=10),
    html.P("Select genres:"),
    dcc.Checklist(
        id='checklist',
        options=[
            {'label': 'Comedy', 'value': 'comedy'},
            {'label': 'Tragedy', 'value': 'tragedy'},
            {'label': 'History', 'value': 'history'}
        ],
        value=['comedy', 'tragedy', 'history'],  
    ),
])

@app.callback(
    Output("scatter-plot", "figure"), 
    Input("slider", "value"),
    Input("checklist", "value"))

def update_plot(percentage, genres):
    global comedy_data, tragedy_data, history_data
    if genres == []:
        return px.scatter()
    max = int(comedy_data.shape[0] * percentage/100)
    lim_comedy_data = comedy_data[:max]
    lim_tragedy_data = tragedy_data[:max]
    lim_history_data = history_data[:max]
    history_data = pd.DataFrame({
    'x':history_points[:, 0],
    'y':history_points[:, 1],
    'genre':'history',
    'word':common_words
})
    data = pd.DataFrame({
        'x':[0,0,0],
        'y':[-100,-100,-100],
        'genre':['comedy', 'tragedy', 'history'],
    })
    if 'comedy' in genres:
        data = pd.concat([data,lim_comedy_data])
    if 'tragedy' in genres:
        data = pd.concat([data,lim_tragedy_data])
    if 'history' in genres:
        data = pd.concat([data,lim_history_data])
    fig = px.scatter(data, x="x", y="y", color="word", symbol='genre', hover_name='word')
    fig.update_yaxes(range = [-5,5])
    fig.update_xaxes(range = [-5,5])
    return fig

app.run_server(debug=True)