In [49]:
import pandas as pd
import ast
#reading the data   
lines_path = "1/movie_lines.txt"

lines_df = pd.read_csv(
    lines_path,
    sep=" \+\+\+\$\+\+\+ ",
    engine='python',
    header=None,
    names=['lineID', 'characterID', 'movieID', 'characterName', 'text'],
    encoding='latin-1'  # <-- Important
)
conv_path = "1/movie_conversations.txt"

conv_df = pd.read_csv(
    conv_path,
    sep=" \+\+\+\$\+\+\+ ",
    engine='python',
    header=None,
    names=['character1', 'character2', 'movieID', 'lineIDs'],
    encoding='latin-1'
)

# Convert the string representation of list to an actual Python list
conv_df['lineIDs'] = conv_df['lineIDs'].apply(ast.literal_eval)
id2line = {}
for inbdx,row in lines_df.iterrows():
    id2line[row["lineID"]] = row['text']


  sep=" \+\+\+\$\+\+\+ ",
  sep=" \+\+\+\$\+\+\+ ",


In [50]:
#making the input and output pair
pairs = []
for inbdx,row in conv_df.iterrows():
    for i in range(len(row["lineIDs"]) - 1):
        input_lineID  = row["lineIDs"][i]
        output_lineID = row["lineIDs"][i + 1]
        pairs.append((input_lineID,output_lineID))


        

In [51]:
text_pairs = []
for (line_input,line_output) in pairs:
    line_input = id2line[line_input]
    line_output = id2line[line_output]
    text_pairs.append((line_input,line_output))

In [52]:
len(text_pairs)

221616

In [53]:
import numpy as np

arr = np.array(text_pairs,dtype=str)

In [54]:
text_pair = np.char.lower(arr)

In [55]:
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import string

In [56]:
nltk.download('stopwords')
nltk.download('punkt')

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\ahmed\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\ahmed\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [57]:
stop_words = set(stopwords.words('english'))
filterd_text_pair = []
for (i,o) in text_pair:
    i.encode("ascii", "ignore").decode()
    o.encode("ascii", "ignore").decode()
    i = word_tokenize(i)
    o = word_tokenize(o)
    filtered_i = [word for word in i if word not in stop_words]
    filtered_o = [word for word in o if word not in stop_words]
    filterd_text_pair.append((filtered_i,filtered_o))

In [58]:
filtered_text_pair_no_punct = []

for (i, o) in filterd_text_pair:
    # Remove punctuation from input
    filtered_i = [word for word in i if word not in string.punctuation]
    
    # Remove punctuation from output
    filtered_o = [word for word in o if word not in string.punctuation]
    
    filtered_text_pair_no_punct.append((filtered_i, filtered_o))


In [59]:
from collections import Counter
word_freq = Counter()
for (sentence_i,sentence_o) in filtered_text_pair_no_punct:
    word_freq.update(sentence_i)
    word_freq.update(sentence_o)

print(word_freq)



In [60]:
PAD = 0
SOS = 1
EOS = 2
UNK = 3


In [61]:
vocab = {"<PAD>":PAD,
         "<SOS>": SOS
         ,"<EOS>":EOS,
         "<UNK>" :UNK}

In [62]:
i = 4
for word, count in word_freq.most_common():
    if word not in vocab:
        vocab[word] = i 
    i = i + 1


In [63]:
encoded_pairs = []
for (i,o) in filtered_text_pair_no_punct:
    encoded_input = [vocab.get(w, vocab["<UNK>"]) for w in i]
    encoded_output = [vocab["<SOS>"]] + [vocab.get(w, vocab["<UNK>"]) for w in o] + [vocab["<EOS>"]]
    encoded_pairs.append((encoded_input, encoded_output))

In [64]:
from torch.nn.utils.rnn import pad_sequence
import torch
#turn the list into a tensor
filtered_pairs = []
for inp, out in encoded_pairs:
    if len(inp) > 0 and len(out) > 0:
        filtered_pairs.append((inp, out))

input_tensors = [torch.tensor(inp,dtype=torch.long)for inp, out in filtered_pairs]
output_tensors = [torch.tensor(out,dtype=torch.long)for inp, out in filtered_pairs]
#padding them
pad_input = pad_sequence(input_tensors,True,vocab["<PAD>"])
pad_output = pad_sequence(output_tensors,True,vocab["<PAD>"])

In [65]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [66]:
from torch import nn

In [77]:
max_len = max(pad_output,key=len)

In [78]:
print(len(max_len))

294


In [69]:
print(len(pad_input))
print(len(pad_output))

212479
212479


In [80]:
from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()

  scaler = GradScaler()


In [81]:
class transformer_model(nn.Module):
    def __init__(self, embed_size,vocab_size,num_heads=4,num_layers=3,forward_expansion=4,max_len=165,dropout=0.1):
        super().__init__()
        self.embed_size = embed_size
        self.word_embedding = nn.Embedding(vocab_size,embed_size)
        self.pos_embedding = nn.Embedding(max_len,embed_size)
        decode_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads,
                                                   dim_feedforward=embed_size *forward_expansion, dropout=dropout,batch_first=True)
        self.transformer = nn.TransformerEncoder(decode_layer,num_layers=num_layers)
        
        self.linear = nn.Linear(embed_size,vocab_size)
        self.max_len = max_len
    def forward(self,tgt,padding_mask):
        seq_len = tgt.shape[1]
        position = torch.arange(0,seq_len).unsqueeze(0).to(device)
        tgt = self.word_embedding(tgt) + self.pos_embedding(position)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(device)
        out = self.transformer(tgt,mask=tgt_mask,
            src_key_padding_mask=padding_mask)
        out = self.linear(out)
        return out 
   

In [82]:
embed_size = 512
dec = transformer_model(embed_size,len(vocab)).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=vocab["<PAD>"])
optimizer = torch.optim.Adam(list(dec.parameters()), lr=1e-4)

In [83]:
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset(pad_input, pad_output)
loader = DataLoader(dataset, batch_size=8, shuffle=True)


In [86]:
import tqdm

In [None]:
def train(dec,device,dataloader,optimizer,pad_indx,loss_fn,epochs = 50):
    dec.train()
    dec.to(device)
    for epoch in range(epochs):
        total_loss = 0
        loop = tqdm(loader, desc=f"Epoch [{epoch+1}/{epochs}]")
        for input , target in dataloader:
            input = input.to(device)
            target = target.to(device)
            input = input[:,:-1]
            target = target[:,1:]
            padding_mask = (input == pad_indx)
            logits = dec(input,padding_mask)
            target = target[:, :logits.size(1)]
            loss = loss_fn(logits.reshape(-1,logits.size(-1)),
                           target.reshape(-1))
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            loop.set_postfix(loss=loss.item())
        print(f"Epoch {epoch+1}: loss = {total_loss / len(dataloader):.4f}")


In [85]:
train(dec,device,loader,optimizer,vocab["<PAD>"],criterion)



Epoch 1
Epoch 1
Epoch 1
Epoch 1
Epoch 1
Epoch 1
Epoch 1
Epoch 1
Epoch 1
Epoch 1
Epoch 1
Epoch 1
Epoch 1
Epoch 1


KeyboardInterrupt: 