In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import re
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import heapq
from collections import deque

# Set random seed for reproducibility
torch.manual_seed(42)

# Vocabulary class to handle word to index mapping
class Vocabulary:
    def __init__(self):
        self.word2idx = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
        self.idx2word = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'}
        self.word_count = 4

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.word_count
            self.idx2word[self.word_count] = word
            self.word_count += 1

    def sentence_to_indices(self, sentence):
        return [self.word2idx.get(word, self.word2idx['<UNK>']) for word in sentence.split()]

    def indices_to_sentence(self, indices):
        return ' '.join(self.idx2word.get(idx, '<UNK>') for idx in indices)

# Text preprocessing
def preprocess_text(text):
    text = text.lower().strip()
    text = re.sub(r'[^\w\s]', '', text)  # Remove punctuation
    return text

# Encoder
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, dropout):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, 
                           dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_seq):
        embedded = self.dropout(self.embedding(input_seq))
        outputs, (hidden, cell) = self.lstm(embedded)
        return outputs, hidden, cell

# Attention
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.v = nn.Parameter(torch.rand(hidden_dim))
        stdv = 1. / np.sqrt(self.v.size(0))
        self.v.data.uniform_(-stdv, stdv)

    def forward(self, hidden, encoder_outputs):
        batch_size = encoder_outputs.size(0)
        seq_len = encoder_outputs.size(1)
        
        hidden = hidden[-1].unsqueeze(1).repeat(1, seq_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        energy = energy.transpose(1, 2)
        v = self.v.repeat(batch_size, 1).unsqueeze(1)
        attention = torch.bmm(v, energy).squeeze(1)
        return torch.softmax(attention, dim=1)

# Decoder
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, dropout):
        super(Decoder, self).__init__()
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim + hidden_dim, hidden_dim, n_layers, 
                           dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.attention = Attention(hidden_dim)

    def forward(self, input, hidden, cell, encoder_outputs):
        input = input.unsqueeze(1)
        embedded = self.dropout(self.embedding(input))
        
        attention_weights = self.attention(hidden, encoder_outputs)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        
        lstm_input = torch.cat((embedded, context), dim=2)
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
        prediction = self.fc(output.squeeze(1))
        return prediction, hidden, cell, attention_weights

# Seq2Seq Model
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        batch_size = source.size(0)
        target_len = target.size(1)
        vocab_size = self.decoder.vocab_size

        outputs = torch.zeros(batch_size, target_len, vocab_size).to(source.device)
        encoder_outputs, hidden, cell = self.encoder(source)
        
        input = target[:, 0]
        for t in range(1, target_len):
            output, hidden, cell, _ = self.decoder(input, hidden, cell, encoder_outputs)
            outputs[:, t, :] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = target[:, t] if teacher_force else top1
        return outputs

# Dataset
class ChatDataset(Dataset):
    def __init__(self, pairs, vocab):
        self.pairs = [(preprocess_text(input_sent), preprocess_text(target_sent)) for input_sent, target_sent in pairs]
        self.vocab = vocab

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        input_sentence, target_sentence = self.pairs[idx]
        input_indices = [self.vocab.word2idx['<SOS>']] + self.vocab.sentence_to_indices(input_sentence) + [self.vocab.word2idx['<EOS>']]
        target_indices = [self.vocab.word2idx['<SOS>']] + self.vocab.sentence_to_indices(target_sentence) + [self.vocab.word2idx['<EOS>']]
        return torch.tensor(input_indices), torch.tensor(target_indices)

# Training function
def train(model, iterator, optimizer, criterion, clip, teacher_forcing_ratio=0.5):
    model.train()
    epoch_loss = 0

    for batch in iterator:
        source, target = batch
        source, target = source.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(source, target, teacher_forcing_ratio)
        
        output_dim = output.shape[-1]
        output = output[:, 1:].reshape(-1, output_dim)
        target = target[:, 1:].reshape(-1)
        
        loss = criterion(output, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    
    return epoch_loss / len(iterator)

# Beam search inference with length normalization and response validation
def beam_search(model, sentence, vocab, device, beam_width=5, max_length=50, temperature=1.0):
    model.eval()
    sentence = preprocess_text(sentence)
    indices = [vocab.word2idx['<SOS>']] + vocab.sentence_to_indices(sentence) + [vocab.word2idx['<EOS>']]
    source = torch.tensor([indices]).to(device)
    
    with torch.no_grad():
        encoder_outputs, hidden, cell = model.encoder(source)
        
        # Initialize beam
        beams = [(0, [vocab.word2idx['<SOS>']], hidden, cell, [])]  # (score, sequence, hidden, cell, attention)
        completed = []
        
        for _ in range(max_length):
            new_beams = []
            for score, seq, h, c, attn in beams:
                input = torch.tensor([seq[-1]]).to(device)
                output, new_hidden, new_cell, attention = model.decoder(input, h, c, encoder_outputs)
                
                probs = torch.softmax(output / temperature, dim=1).squeeze(0)
                top_probs, top_idx = probs.topk(beam_width)
                
                for prob, idx in zip(top_probs, top_idx):
                    # Length normalization to avoid favoring short sequences
                    length_penalty = ((len(seq) + 5) / 6) ** 0.65
                    new_score = (score + torch.log(prob).item()) / length_penalty
                    new_seq = seq + [idx.item()]
                    new_attn = attn + [attention]
                    if idx.item() == vocab.word2idx['<EOS>'] or len(new_seq) >= max_length - 1:
                        completed.append((new_score * length_penalty, new_seq, new_attn))
                    else:
                        new_beams.append((new_score, new_seq, new_hidden, new_cell, new_attn))
            
            beams = heapq.nlargest(beam_width, new_beams, key=lambda x: x[0])
            if len(completed) >= beam_width:
                break
        
        # Select best sequence
        if completed:
            best = max(completed, key=lambda x: x[0])
            response = vocab.indices_to_sentence(best[1][1:])  # Skip <SOS>
            hidden, cell = None, None  # Not needed for extension
        else:
            best = max(beams, key=lambda x: x[0])
            response = vocab.indices_to_sentence(best[1][1:])
            hidden, cell = best[2], best[3]  # Keep for potential extension
        
        # Post-process response
        response = response.replace('<UNK>', '').replace('<EOS>', '')
        # Remove repetitive words
        words = response.split()
        cleaned_words = []
        prev_word = None
        for word in words:
            if word != prev_word:
                cleaned_words.append(word)
                prev_word = word
        response = ' '.join(cleaned_words).strip()
        
        # Validate response completeness
        if response and not response.endswith(('.', '!', '?')) and hidden is not None and cell is not None:
            # Extend response if it seems incomplete
            input = torch.tensor([best[1][-1]]).to(device)
            try:
                output, _, _, _ = model.decoder(input, hidden, cell, encoder_outputs)
                probs = torch.softmax(output / temperature, dim=1).squeeze(0)
                next_word_idx = probs.argmax().item()
                if next_word_idx != vocab.word2idx['<EOS>']:
                    response += ' ' + vocab.idx2word.get(next_word_idx, '')
            except Exception as e:
                # Skip extension if error occurs
                pass
        
        return response.capitalize() if response else "I don't know how to respond."

# Main execution
if __name__ == '__main__':
    # Hyperparameters
    INPUT_DIM = 1000
    OUTPUT_DIM = 1000
    EMB_DIM = 256
    HID_DIM = 512
    N_LAYERS = 2
    DROPOUT = 0.5
    N_EPOCHS = 30
    CLIP = 1
    BEAM_WIDTH = 5  # Increased for better sequence exploration
    TEMPERATURE = 0.8

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create vocabulary and expanded sample data
    vocab = Vocabulary()
    sample_pairs = [
        ("hello how are you", "I'm doing great, thanks for asking!"),
        ("what's your name", "I'm Roopal, nice to meet you!"),
        ("what can you do", "I can chat, answer questions, and help with various tasks!"),
        ("how's the weather", "I don't have weather data, but it's always sunny in my world!"),
        ("tell me a joke", "Why did the scarecrow become a motivational speaker? He was outstanding in his field!"),
        ("what is ai", "AI is like me: a clever system trying to understand and respond to the world!"),
        ("how old are you", "I'm timeless, but I was born in 2025, so pretty young!"),
        ("what's the time", "Time's a mystery, but I'm here for you right now!"),
        ("who made you", "The brilliant folks at xAI brought me to life!"),
        ("can you help me", "Sure thing, what's on your mind?"),
        ("what do you like", "I enjoy chatting with humans and learning new things!"),
        ("where are you from", "I'm from the digital realm, created by xAI!"),
        ("are you human", "Nope, I'm a friendly AI designed to assist you!"),
        ("what's your favorite color", "I like all colors, but if I had to pick, I'd say binary blue!"),
        ("how's it going", "Going great, how about you?"),
        ("tell me about yourself", "I'm Grok, an AI with a sense of humor and a passion for helping humans!"),
        ("what's new", "Just hanging out in the digital world, ready to answer your questions!"),
    ]

    for input_sent, target_sent in sample_pairs:
        for word in preprocess_text(input_sent).split() + preprocess_text(target_sent).split():
            vocab.add_word(word)

    # Create dataset and dataloader
    dataset = ChatDataset(sample_pairs, vocab)
    def collate_fn(batch):
        sources, targets = zip(*batch)
        sources = pad_sequence(sources, batch_first=True, padding_value=vocab.word2idx['<PAD>'])
        targets = pad_sequence(targets, batch_first=True, padding_value=vocab.word2idx['<PAD>'])
        return sources, targets

    dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

    # Initialize model
    encoder = Encoder(vocab.word_count, EMB_DIM, HID_DIM, N_LAYERS, DROPOUT)
    decoder = Decoder(vocab.word_count, EMB_DIM, HID_DIM, N_LAYERS, DROPOUT)
    model = Seq2Seq(encoder, decoder).to(device)

    # Optimizer and loss function
    optimizer = optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx['<PAD>'])

    # Training loop with decaying teacher forcing
    print("Training the chatbot...")
    for epoch in range(N_EPOCHS):
        teacher_forcing_ratio = max(0.3, 0.7 - (0.4 * epoch / N_EPOCHS))
        loss = train(model, dataloader, optimizer, criterion, CLIP, teacher_forcing_ratio)
        print(f'Epoch: {epoch+1:02} | Loss: {loss:.3f} | Teacher Forcing: {teacher_forcing_ratio:.2f}')

    # Interactive chatbot loop with conversation history
    print("\nChatbot is ready! Type your question or 'STOP' to end the conversation.")
    history = deque(maxlen=3)  # Store last 3 exchanges for context
    while True:
        user_input = input("You: ").strip()
        if user_input.upper() == "STOP":
            print("Chatbot: Goodbye! Thanks for chatting!")
            break
        if not user_input:
            print("Chatbot: Please say something, or type 'STOP' to end.")
            continue
        
        # Add context from history
        context = ' '.join([f"{q} {a}" for q, a in history]) + ' ' + user_input if history else user_input
        response = beam_search(model, context, vocab, device, beam_width=BEAM_WIDTH, temperature=TEMPERATURE)
        print(f"You: {user_input}")
        print(f"Chatbot: {response}")
        print("Ask another question or type 'STOP' to end.")
        
        # Update history
        history.append((user_input, response))

Training the chatbot...
Epoch: 01 | Loss: 4.873 | Teacher Forcing: 0.70
Epoch: 02 | Loss: 4.446 | Teacher Forcing: 0.69
Epoch: 03 | Loss: 4.102 | Teacher Forcing: 0.67
Epoch: 04 | Loss: 3.896 | Teacher Forcing: 0.66
Epoch: 05 | Loss: 3.667 | Teacher Forcing: 0.65
Epoch: 06 | Loss: 3.598 | Teacher Forcing: 0.63
Epoch: 07 | Loss: 3.424 | Teacher Forcing: 0.62
Epoch: 08 | Loss: 3.203 | Teacher Forcing: 0.61
Epoch: 09 | Loss: 3.023 | Teacher Forcing: 0.59
Epoch: 10 | Loss: 2.668 | Teacher Forcing: 0.58
Epoch: 11 | Loss: 2.553 | Teacher Forcing: 0.57
Epoch: 12 | Loss: 2.379 | Teacher Forcing: 0.55
Epoch: 13 | Loss: 2.330 | Teacher Forcing: 0.54
Epoch: 14 | Loss: 2.055 | Teacher Forcing: 0.53
Epoch: 15 | Loss: 2.036 | Teacher Forcing: 0.51
Epoch: 16 | Loss: 1.805 | Teacher Forcing: 0.50
Epoch: 17 | Loss: 1.592 | Teacher Forcing: 0.49
Epoch: 18 | Loss: 1.219 | Teacher Forcing: 0.47
Epoch: 19 | Loss: 1.366 | Teacher Forcing: 0.46
Epoch: 20 | Loss: 1.302 | Teacher Forcing: 0.45
Epoch: 21 | Loss