In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import re
import torch
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from torch.nn.utils.rnn import pad_sequence
    


In [None]:
class EncoderLSTM(nn.Module):
    def __init__(self, embedding, hidden_size):
        super().__init__()
        self.embedding = embedding
        self.lstm = nn.LSTM(embedding.embedding_dim, hidden_size, batch_first=True)

    def forward(self, src):
        # src: [batch_size, src_len]
        embedded = self.embedding(src)  # [batch_size, src_len, emb_dim]
        outputs, (hidden, cell) = self.lstm(embedded)
        return hidden, cell
    
class DecoderLSTM(nn.Module):
    def __init__(self, embedding, hidden_size, output_dim):
        super().__init__()
        self.embedding = embedding
        self.lstm = nn.LSTM(embedding.embedding_dim, hidden_size, batch_first=True)
        self.fc_out = nn.Linear(hidden_size, output_dim)

    def forward(self, trg, hidden, cell):
        # trg: [batch_size, trg_len]
        embedded = self.embedding(trg)
        outputs, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        logits = self.fc_out(outputs)  # [batch_size, trg_len, output_dim]
        return logits, hidden, cell
    
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, trg):
        hidden, cell = self.encoder(src)
        outputs, _, _ = self.decoder(trg, hidden, cell)
        return outputs  # [batch_size, trg_len, vocab_size]


In [None]:
SPECIAL_TOKENS = ['<pad>', '<sos>', '<eos>', '<unk>']

def load_text_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    return text

def split_into_pairs(text):
    utterances = [utt.strip() for utt in text.split('__eou__') if utt.strip()]
    pairs = [(utterances[i], utterances[i + 1]) for i in range(len(utterances) - 1)]
    return pairs

def preprocess_text(text):
    text = text.lower()
    text = re.sub(r"[^a-z0-9\s']", "", text)
    return text.strip()

def build_vocab(pairs, min_freq=2):
    counter = Counter()
    for inp, res in pairs:
        counter.update(preprocess_text(inp).split())
        counter.update(preprocess_text(res).split())
    vocab = SPECIAL_TOKENS.copy()
    for word, freq in counter.items():
        if freq >= min_freq and word not in vocab:
            vocab.append(word)
    word2idx = {word: idx for idx, word in enumerate(vocab)}
    idx2word = {idx: word for word, idx in word2idx.items()}
    return word2idx, idx2word

def tokenize(text, word2idx):
    tokens = preprocess_text(text).split()
    return [word2idx.get(token, word2idx['<unk>']) for token in tokens]

In [None]:
class ChatbotDataset(Dataset):
    def __init__(self, pairs, word2idx):
        self.pairs = pairs
        self.word2idx = word2idx

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        inp, res = self.pairs[idx]
        inp_ids = tokenize(inp, self.word2idx)
        res_input = [self.word2idx['<sos>']] + tokenize(res, self.word2idx)
        res_target = tokenize(res, self.word2idx) + [self.word2idx['<eos>']]
        return torch.tensor(inp_ids), torch.tensor(res_input), torch.tensor(res_target)

def collate_fn(batch):
    inputs, res_inputs, res_targets = zip(*batch)
    inputs_padded = pad_sequence(inputs, batch_first=True, padding_value=word2idx['<pad>'])
    res_inputs_padded = pad_sequence(res_inputs, batch_first=True, padding_value=word2idx['<pad>'])
    res_targets_padded = pad_sequence(res_targets, batch_first=True, padding_value=word2idx['<pad>'])
    return inputs_padded, res_inputs_padded, res_targets_padded

# ----------- USAGE ------------

file_path = 'dataset/dialogues_train.txt'  # put your filename here
raw_text = load_text_from_file(file_path)
pairs = split_into_pairs(raw_text)
word2idx, idx2word = build_vocab(pairs, min_freq=2)

dataset = ChatbotDataset(pairs, word2idx)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

# Example: fetch one batch
for batch_inp, batch_res_inp, batch_res_tgt in dataloader:
    print(batch_inp.shape)       # (batch_size, seq_len)
    print(batch_res_inp.shape)   # (batch_size, seq_len)
    print(batch_res_tgt.shape)   # (batch_size, seq_len)
    break

In [None]:
embed_size = 128
hidden_size = 512
model = Seq2Seq(encoder=EncoderLSTM(embed_size, hidden_size),
                decoder=DecoderLSTM(embed_size, hidden_size))

In [None]:
criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(30):
    for input, target in train_data:
       output = model(input)  
       loss = criterion(output, target)
       loss.backward()  
       optimizer.step()
       optimizer.zero_grad()

In [None]:
from transformers import GPT2Tokenizer

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
scores = evaluate(model, val_data, tokenizer)  
print(f"Perplexity score: {scores['perplexity']}")
print(f"BLEU score: {scores['bleu']}")