In [2]:
!pip install torch torchvision torchaudio numpy nltk matplotlib tqdm

Defaulting to user installation because normal site-packages is not writeable



[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [6]:
!pip install nltk

Defaulting to user installation because normal site-packages is not writeable
Collecting nltk
  Using cached nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Collecting click (from nltk)
  Using cached click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting joblib (from nltk)
  Using cached joblib-1.4.2-py3-none-any.whl.metadata (5.4 kB)
Collecting regex>=2021.8.3 (from nltk)
  Using cached regex-2024.11.6-cp313-cp313-win_amd64.whl.metadata (41 kB)
Collecting tqdm (from nltk)
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Using cached nltk-3.9.1-py3-none-any.whl (1.5 MB)
Using cached regex-2024.11.6-cp313-cp313-win_amd64.whl (273 kB)
Using cached click-8.1.8-py3-none-any.whl (98 kB)
Using cached joblib-1.4.2-py3-none-any.whl (301 kB)
Using cached tqdm-4.67.1-py3-none-any.whl (78 kB)
Installing collected packages: tqdm, regex, joblib, click, nltk
Successfully installed click-8.1.8 joblib-1.4.2 nltk-3.9.1 regex-2024.11.6 tqdm-4.67.1



[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\deepi\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import os
import re
import unicodedata
import matplotlib.pyplot as plt


In [2]:
import re
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np

# Set seed
torch.manual_seed(42)

# Load Cornell Dataset
def load_cornell_data(path='cornell_movie_dialogs_corpus/movie_lines.txt'):
    lines = open(path, encoding='utf-8', errors='ignore').read().split('\n')
    conversations = []
    for line in lines:
        parts = line.split(" +++$+++ ")
        if len(parts) == 5:
            conversations.append(parts[-1])
    return conversations

# Use only pairs
def create_pairs(lines):
    pairs = []
    for i in range(len(lines)-1):
        pairs.append((lines[i], lines[i+1]))
    return pairs


In [3]:
def simple_tokenizer(text):
    text = text.lower()
    text = re.sub(r"[^a-zA-Z0-9?.!,]+", " ", text)
    return text.strip().split()

def build_vocab(pairs):
    word2idx = {'<PAD>':0, '<SOS>':1, '<EOS>':2, '<UNK>':3}
    idx2word = {0:'<PAD>', 1:'<SOS>', 2:'<EOS>', 3:'<UNK>'}
    idx = 4
    for pair in pairs:
        for sentence in pair:
            for word in simple_tokenizer(sentence):
                if word not in word2idx:
                    word2idx[word] = idx
                    idx2word[idx] = word
                    idx += 1
    return word2idx, idx2word


In [4]:
def sentence_to_tensor(sentence, word2idx):
    tokens = simple_tokenizer(sentence)
    indices = [word2idx.get(w, word2idx['<UNK>']) for w in tokens]
    indices = [word2idx['<SOS>']] + indices + [word2idx['<EOS>']]
    return torch.tensor(indices, dtype=torch.long).unsqueeze(0)


In [5]:
class Encoder(nn.Module):
    def __init__(self, input_size, emb_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(input_size, emb_size)
        self.rnn = nn.GRU(emb_size, hidden_size, batch_first=True)

    def forward(self, src):
        embedded = self.embedding(src)
        outputs, hidden = self.rnn(embedded)
        return hidden

class Decoder(nn.Module):
    def __init__(self, output_size, emb_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(output_size, emb_size)
        self.rnn = nn.GRU(emb_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input)
        output, hidden = self.rnn(embedded, hidden)
        prediction = self.fc(output.squeeze(1))
        return prediction, hidden


In [6]:
def train_step(pair, encoder, decoder, enc_opt, dec_opt, loss_fn, word2idx):
    encoder.train()
    decoder.train()
    
    src_tensor = sentence_to_tensor(pair[0], word2idx)
    tgt_tensor = sentence_to_tensor(pair[1], word2idx)

    enc_opt.zero_grad()
    dec_opt.zero_grad()

    hidden = encoder(src_tensor)
    input = torch.tensor([[word2idx['<SOS>']]])

    loss = 0
    for t in range(1, tgt_tensor.size(1)):
        output, hidden = decoder(input, hidden)
        loss += loss_fn(output, tgt_tensor[0][t].unsqueeze(0))
        input = tgt_tensor[:, t].unsqueeze(1)

    loss.backward()
    enc_opt.step()
    dec_opt.step()
    
    return loss.item()


In [7]:
def evaluate(encoder, decoder, sentence, word2idx, idx2word, max_len=15):
    encoder.eval()
    decoder.eval()

    with torch.no_grad():
        src_tensor = sentence_to_tensor(sentence, word2idx)
        hidden = encoder(src_tensor)
        
        input = torch.tensor([[word2idx['<SOS>']]])
        decoded_words = []

        for _ in range(max_len):
            output, hidden = decoder(input, hidden)
            top1 = output.argmax(1)
            word = idx2word.get(top1.item(), '<UNK>')
            if word == '<EOS>':
                break
            decoded_words.append(word)
            input = top1.unsqueeze(1)
        
        return ' '.join(decoded_words)


In [8]:
lines = load_cornell_data()
pairs = create_pairs(lines[:1000])  # Keep it small for demo
word2idx, idx2word = build_vocab(pairs)

input_size = output_size = len(word2idx)
emb_size = 256
hidden_size = 512

encoder = Encoder(input_size, emb_size, hidden_size)
decoder = Decoder(output_size, emb_size, hidden_size)

enc_opt = optim.Adam(encoder.parameters(), lr=0.001)
dec_opt = optim.Adam(decoder.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

# Train on small set
for epoch in range(5):
    total_loss = 0
    for pair in pairs[:100]:
        loss = train_step(pair, encoder, decoder, enc_opt, dec_opt, loss_fn, word2idx)
        total_loss += loss
    print(f"Epoch {epoch+1} Loss: {total_loss/100:.4f}")


Epoch 1 Loss: 65.9803
Epoch 2 Loss: 42.1352
Epoch 3 Loss: 25.6744
Epoch 4 Loss: 11.0631
Epoch 5 Loss: 3.9344


In [9]:
def chat_step(user_input):
    if user_input.lower() in ['exit', 'quit']:
        return "Chat ended."
    return evaluate(encoder, decoder, user_input, word2idx, idx2word)

# Example interaction:
print("You: Hello!")
print("Bot:", chat_step("Hello!"))


You: Hello!
Bot: did you change your hair?
