In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import random
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# Set random seed for reproducibility
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load IWSLT16 English-French dataset
# dataset = load_dataset("iwslt2017", "iwslt2017-en-fr")
dataset = load_dataset("iwslt2017", "iwslt2017-en-fr", trust_remote_code=True)
print(f"Dataset loaded with {len(dataset['train'])} training examples")

# Limit to 10K examples
MAX_EXAMPLES = 10000
train_data = dataset['train'].select(range(min(MAX_EXAMPLES, len(dataset['train']))))
val_data = dataset['validation'].select(range(min(1000, len(dataset['validation']))))

# Simple tokenization function (in a real scenario, use proper tokenizers)
def tokenize(text, lang):
    return text.lower().split()

# Create vocabularies
class Vocab:
    def __init__(self, name):
        self.name = name
        self.word2idx = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
        self.idx2word = {0: "<pad>", 1: "<sos>", 2: "<eos>", 3: "<unk>"}
        self.count = 4
        self.max_size = 10000
    
    def add_sentence(self, sentence):
        for word in sentence:
            if word not in self.word2idx and self.count < self.max_size:
                self.word2idx[word] = self.count
                self.idx2word[self.count] = word
                self.count += 1
    
    def __len__(self):
        return self.count

# Build vocabularies
src_vocab = Vocab("English")
tgt_vocab = Vocab("French")

print("Building vocabularies...")
for item in tqdm(train_data):
    en_tokens = tokenize(item['translation']['en'], 'en')
    fr_tokens = tokenize(item['translation']['fr'], 'fr')
    src_vocab.add_sentence(en_tokens)
    tgt_vocab.add_sentence(fr_tokens)

print(f"Source vocabulary size: {len(src_vocab)}")
print(f"Target vocabulary size: {len(tgt_vocab)}")

# Dataset class
class TranslationDataset(Dataset):
    def __init__(self, data, src_vocab, tgt_vocab):
        self.data = data
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        src_text = item['translation']['en']
        tgt_text = item['translation']['fr']
        
        src_tokens = tokenize(src_text, 'en')
        tgt_tokens = tokenize(tgt_text, 'fr')
        
        # Limit sequence length to 50
        src_tokens = src_tokens[:50]
        tgt_tokens = tgt_tokens[:50]
        
        # Convert tokens to indices
        src_indices = [self.src_vocab.word2idx.get(token, 3) for token in src_tokens] 
        tgt_indices = [self.tgt_vocab.word2idx.get(token, 3) for token in tgt_tokens]
        
        # Add SOS and EOS tokens
        src_indices = [1] + src_indices + [2]
        tgt_indices = [1] + tgt_indices + [2]
        
        return torch.LongTensor(src_indices), torch.LongTensor(tgt_indices)

# Custom collate function to handle padding
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src, tgt in batch:
        src_batch.append(src)
        tgt_batch.append(tgt)
    
    # Pad sequences
    src_batch = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=0)
    tgt_batch = nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=0)
    
    return src_batch, tgt_batch

# Create data loaders
BATCH_SIZE = 64
train_dataset = TranslationDataset(train_data, src_vocab, tgt_vocab)
val_dataset = TranslationDataset(val_data, src_vocab, tgt_vocab)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                          shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, 
                        collate_fn=collate_fn)

# Basic Encoder (shared for both models)
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, 
                         dropout=dropout if n_layers > 1 else 0, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        # src: [batch_size, src_len]
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        return outputs, hidden

# Decoder without attention
class DecoderNoAttention(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, 
                         dropout=dropout if n_layers > 1 else 0, batch_first=True)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden):
        # input: [batch_size]
        input = input.unsqueeze(1)  # Add seq dimension
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.rnn(embedded, hidden)
        prediction = self.fc_out(output.squeeze(1))
        return prediction, hidden

# Attention mechanism
class Attention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.attn = nn.Linear(hid_dim * 2, hid_dim)
        self.v = nn.Linear(hid_dim, 1, bias=False)
        
    def forward(self, hidden, encoder_outputs, mask=None):
        # hidden: [batch_size, hid_dim]
        # encoder_outputs: [batch_size, src_len, hid_dim]
        
        src_len = encoder_outputs.shape[1]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        
        # Calculate attention scores
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)
        
        # Apply mask
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)
        
        # Softmax to get weights
        attention_weights = F.softmax(attention, dim=1)
        
        # Weighted sum of encoder outputs
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        
        return context, attention_weights

# Decoder with attention
class DecoderWithAttention(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, attention, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim + hid_dim, hid_dim, n_layers, 
                          dropout=dropout if n_layers > 1 else 0, batch_first=True)
        self.fc_out = nn.Linear(emb_dim + hid_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs, mask=None):
        input = input.unsqueeze(1)  # Add seq dimension
        embedded = self.dropout(self.embedding(input))
        
        # Calculate attention
        context, attention = self.attention(hidden[-1], encoder_outputs, mask)
        
        # Combine embedding and context for RNN input
        rnn_input = torch.cat((embedded, context.unsqueeze(1)), dim=2)
        output, hidden = self.rnn(rnn_input, hidden)
        
        # Prepare for output projection
        embedded = embedded.squeeze(1)
        output = output.squeeze(1)
        
        # Concatenate for prediction
        prediction = self.fc_out(torch.cat((output, context, embedded), dim=1))
        
        return prediction, hidden, attention

# Seq2Seq model without attention
class Seq2SeqNoAttention(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        tgt_len = tgt.shape[1]
        tgt_vocab_size = self.decoder.output_dim
        
        # Tensor to store outputs
        outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
        
        # Get encoder outputs
        encoder_outputs, hidden = self.encoder(src)
        
        # First decoder input is <sos> token
        input = tgt[:, 0]
        
        for t in range(1, tgt_len):
            output, hidden = self.decoder(input, hidden)
            outputs[:, t] = output
            
            # Teacher forcing
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = tgt[:, t] if teacher_force else top1
            
        return outputs

# Seq2Seq model with attention
class Seq2SeqWithAttention(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def create_mask(self, src):
        mask = (src != 0).to(self.device)
        return mask
    
    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        tgt_len = tgt.shape[1]
        tgt_vocab_size = self.decoder.output_dim
        src_len = src.shape[1]
        
        # Store outputs and attentions
        outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
        attentions = torch.zeros(batch_size, tgt_len, src_len).to(self.device)
        
        # Create mask for attention
        mask = self.create_mask(src)
        
        # Encode source sequence
        encoder_outputs, hidden = self.encoder(src)
        
        # Start with <sos> token
        input = tgt[:, 0]
        
        for t in range(1, tgt_len):
            output, hidden, attention = self.decoder(input, hidden, encoder_outputs, mask)
            
            outputs[:, t] = output
            attentions[:, t] = attention
            
            # Teacher forcing
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = tgt[:, t] if teacher_force else top1
            
        return outputs, attentions

# Initialize models
INPUT_DIM = len(src_vocab)
OUTPUT_DIM = len(tgt_vocab)
EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
DROPOUT = 0.5

# Create models
encoder = Encoder(INPUT_DIM, EMB_DIM, HID_DIM, N_LAYERS, DROPOUT)
decoder_no_attn = DecoderNoAttention(OUTPUT_DIM, EMB_DIM, HID_DIM, N_LAYERS, DROPOUT)
model_no_attn = Seq2SeqNoAttention(encoder, decoder_no_attn, device).to(device)

encoder_attn = Encoder(INPUT_DIM, EMB_DIM, HID_DIM, N_LAYERS, DROPOUT)
attention = Attention(HID_DIM)
decoder_attn = DecoderWithAttention(OUTPUT_DIM, EMB_DIM, HID_DIM, N_LAYERS, attention, DROPOUT)
model_attn = Seq2SeqWithAttention(encoder_attn, decoder_attn, device).to(device)

# Optimizers
optimizer_no_attn = optim.Adam(model_no_attn.parameters())
optimizer_attn = optim.Adam(model_attn.parameters())

# Loss function (ignores padding)
criterion = nn.CrossEntropyLoss(ignore_index=0)

# Training function
def train(model, data_loader, optimizer, criterion, clip=1.0):
    model.train()
    epoch_loss = 0
    
    for src, tgt in tqdm(data_loader):
        src, tgt = src.to(device), tgt.to(device)
        
        optimizer.zero_grad()
        
        if isinstance(model, Seq2SeqNoAttention):
            output = model(src, tgt)
        else:
            output, _ = model(src, tgt)
        
        # Calculate loss excluding the <sos> token
        output_dim = output.shape[-1]
        output = output[:, 1:].reshape(-1, output_dim)
        tgt = tgt[:, 1:].reshape(-1)
        
        loss = criterion(output, tgt)
        loss.backward()
        
        # Prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        epoch_loss += loss.item()
        
    return epoch_loss / len(data_loader)

# Evaluation function
def evaluate(model, data_loader, criterion):
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for src, tgt in data_loader:
            src, tgt = src.to(device), tgt.to(device)
            
            if isinstance(model, Seq2SeqNoAttention):
                output = model(src, tgt, 0)  # No teacher forcing
            else:
                output, _ = model(src, tgt, 0)
            
            # Calculate loss
            output_dim = output.shape[-1]
            output = output[:, 1:].reshape(-1, output_dim)
            tgt = tgt[:, 1:].reshape(-1)
            
            loss = criterion(output, tgt)
            epoch_loss += loss.item()
            
    return epoch_loss / len(data_loader)

# Translation function
# def translate_sentence(model, sentence, src_vocab, tgt_vocab, device, max_len=50):
#     model.eval()
    
#     # Tokenize if string
#     if isinstance(sentence, str):
#         tokens = tokenize(sentence, 'en')
#     else:
#         tokens = sentence
    
#     # Convert to indices
#     src_indices = [src_vocab.word2idx.get(token, 3) for token in tokens]
#     src_indices = [1] + src_indices + [2]  # Add <sos> and <eos>
#     src_tensor = torch.LongTensor(src_indices).unsqueeze(0).to(device)
    
#     with torch.no_grad():
#         if isinstance(model, Seq2SeqNoAttention):
#             encoder_outputs, hidden = model.encoder(src_tensor)
            
#             # Start with <sos>
#             tgt_idx = [1]
#             for i in range(max_len):
#                 tgt_tensor = torch.LongTensor([tgt_idx[-1]]).to(device)
#                 output, hidden = model.decoder(tgt_tensor, hidden)
#                 pred_token = output.argmax(1).item()
#                 tgt_idx.append(pred_token)
                
#                 if pred_token == 2:  # <eos>
#                     break
                    
#             translation = [tgt_vocab.idx2word[i] for i in tgt_idx[1:]]  # Skip <sos>
#             return translation, None
            
#         else:  # With attention
#             encoder_outputs, hidden = model.encoder(src_tensor)
#             mask = model.create_mask(src_tensor)
            
#             # Start with <sos>
#             tgt_idx = [1]
#             attentions = []
            
#             for i in range(max_len):
#                 tgt_tensor = torch.LongTensor([tgt_idx[-1]]).to(device)
#                 output, hidden, attention = model.decoder(tgt_tensor, hidden, encoder_outputs, mask)
                
#                 attentions.append(attention.cpu().numpy())
#                 pred_token = output.argmax(1).item()
#                 tgt_idx.append(pred_token)
                
#                 if pred_token == 2:  # <eos>
#                     break
                    
#             translation = [tgt_vocab.idx2word[i] for i in tgt_idx[1:]]  # Skip <sos>
#             return translation, np.array(attentions).squeeze(0)

def translate_sentence(model, sentence, src_vocab, tgt_vocab, device, max_len=50):
    model.eval()
    
    # Tokenize if string
    if isinstance(sentence, str):
        tokens = tokenize(sentence, 'en')
    else:
        tokens = sentence
    
    # Convert to indices
    src_indices = [src_vocab.word2idx.get(token, 3) for token in tokens]
    src_indices = [1] + src_indices + [2]  # Add <sos> and <eos>
    src_tensor = torch.LongTensor(src_indices).unsqueeze(0).to(device)
    
    with torch.no_grad():
        if isinstance(model, Seq2SeqNoAttention):
            encoder_outputs, hidden = model.encoder(src_tensor)
            
            # Start with <sos>
            tgt_idx = [1]
            for i in range(max_len):
                tgt_tensor = torch.LongTensor([tgt_idx[-1]]).to(device)
                output, hidden = model.decoder(tgt_tensor, hidden)
                pred_token = output.argmax(1).item()
                tgt_idx.append(pred_token)
                
                if pred_token == 2:  # <eos>
                    break
                    
            translation = [tgt_vocab.idx2word[i] for i in tgt_idx[1:]]  # Skip <sos>
            return translation, None
            
        else:  # With attention
            encoder_outputs, hidden = model.encoder(src_tensor)
            mask = model.create_mask(src_tensor)
            
            # Start with <sos>
            tgt_idx = [1]
            attentions = []
            
            for i in range(max_len):
                tgt_tensor = torch.LongTensor([tgt_idx[-1]]).to(device)
                output, hidden, attention = model.decoder(tgt_tensor, hidden, encoder_outputs, mask)
                
                attentions.append(attention.cpu().numpy())
                pred_token = output.argmax(1).item()
                tgt_idx.append(pred_token)
                
                if pred_token == 2:  # <eos>
                    break
                    
            translation = [tgt_vocab.idx2word[i] for i in tgt_idx[1:]]  # Skip <sos>
            # Fixed to handle arrays of any shape
            attention_array = np.array(attentions)
            return translation, attention_array


# Attention visualization
# def display_attention(sentence, translation, attention):
#     fig = plt.figure(figsize=(10, 8))
#     ax = fig.add_subplot(1, 1, 1)
    
#     # Remove <eos> token if present
#     if translation[-1] == '<eos>':
#         translation = translation[:-1]
    
#     # Get attention from first layer
#     attention = attention[:len(translation), :len(sentence)]
    
#     # Plot heatmap
#     cax = ax.matshow(attention, cmap='viridis')
#     fig.colorbar(cax)
    
#     # Set axes
#     ax.set_xticklabels([''] + sentence, rotation=90)
#     ax.set_yticklabels([''] + translation)
    
#     # Show label at every tick
#     ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
#     ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
#     plt.tight_layout()
#     return fig

# def display_attention(sentence, translation, attention):
#     fig = plt.figure(figsize=(10, 8))
#     ax = fig.add_subplot(1, 1, 1)
    
#     # Remove <eos> token if present
#     if translation[-1] == '<eos>':
#         translation = translation[:-1]
    
#     # Get attention and reshape if needed
#     if len(attention.shape) > 2:
#         attention = attention.reshape(len(translation), -1)
#     attention = attention[:len(translation), :len(sentence)]
    
#     # Plot heatmap
#     cax = ax.matshow(attention, cmap='viridis')
#     fig.colorbar(cax)
    
#     # Set axes
#     ax.set_xticklabels([''] + sentence, rotation=90)
#     ax.set_yticklabels([''] + translation)
    
#     # Show label at every tick
#     ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
#     ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
#     plt.tight_layout()
#     return fig

def display_attention(sentence, translation, attention):
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(1, 1, 1)
    
    # Remove <eos> token if present
    if translation[-1] == '<eos>':
        translation = translation[:-1]
    
    # Debug information
    print(f"Attention shape: {attention.shape}")
    print(f"Translation length: {len(translation)}")
    print(f"Sentence length: {len(sentence)}")
    
    # Handle attention arrays of different shapes
    if len(attention.shape) == 1:
        # For 1D attention, reshape to 2D with one row
        attention_2d = attention.reshape(1, -1)
    elif len(attention.shape) == 2:
        # Already 2D
        attention_2d = attention
    elif len(attention.shape) == 3:
        # For 3D attention, take the first batch
        # This assumes batch size is the middle dimension
        attention_2d = attention[:, 0, :] if attention.shape[1] > 0 else attention.reshape(attention.shape[0], -1)
    else:
        # Just flatten to 2D for higher dimensions
        attention_2d = attention.reshape(attention.shape[0], -1)
    
    # Determine how much of the attention to display
    n_rows = min(len(translation), attention_2d.shape[0])
    n_cols = min(len(sentence), attention_2d.shape[1])
    
    # Extract the part we can display
    attention_plot = attention_2d[:n_rows, :n_cols]
    
    # Plot heatmap
    cax = ax.matshow(attention_plot, cmap='viridis')
    fig.colorbar(cax)
    
    # Label axes
    ax.set_xticklabels([''] + sentence[:n_cols], rotation=90)
    ax.set_yticklabels([''] + translation[:n_rows])
    
    # Show label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
    plt.tight_layout()
    return fig


# Training and evaluation
N_EPOCHS = 25
train_losses_no_attn = []
val_losses_no_attn = []
train_losses_attn = []
val_losses_attn = []

print("Training model without attention...")
for epoch in range(N_EPOCHS):
    train_loss = train(model_no_attn, train_loader, optimizer_no_attn, criterion)
    val_loss = evaluate(model_no_attn, val_loader, criterion)
    train_losses_no_attn.append(train_loss)
    val_losses_no_attn.append(val_loss)
    print(f'Epoch {epoch+1}/{N_EPOCHS}, Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}')

print("Training model with attention...")
for epoch in range(N_EPOCHS):
    train_loss = train(model_attn, train_loader, optimizer_attn, criterion)
    val_loss = evaluate(model_attn, val_loader, criterion)
    train_losses_attn.append(train_loss)
    val_losses_attn.append(val_loss)
    print(f'Epoch {epoch+1}/{N_EPOCHS}, Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}')

# Plot training curves
plt.figure(figsize=(12, 6))
plt.plot(train_losses_no_attn, label='Train Loss (No Attention)')
plt.plot(val_losses_no_attn, label='Valid Loss (No Attention)')
plt.plot(train_losses_attn, label='Train Loss (With Attention)')
plt.plot(val_losses_attn, label='Valid Loss (With Attention)')
plt.title('Training and Validation Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('loss_comparison.png')
plt.show()

# Test examples
example_sentences = [
    "I love learning new languages.",
    "The cat sits on the mat.",
    "She plays the piano very well."
]

for sentence in example_sentences:
    print(f"\nOriginal: {sentence}")
    
    # Translate with model without attention
    translation_no_attn, _ = translate_sentence(model_no_attn, sentence, src_vocab, tgt_vocab, device)
    print(f"No Attention: {' '.join([t for t in translation_no_attn if t != '<eos>'])}")
    
    # Translate with model with attention
    translation_attn, attention_weights = translate_sentence(model_attn, sentence, src_vocab, tgt_vocab, device)
    print(f"With Attention: {' '.join([t for t in translation_attn if t != '<eos>'])}")
    
    # Visualize attention
    if attention_weights is not None:
        fig = display_attention(['<sos>'] + tokenize(sentence, 'en') + ['<eos>'], translation_attn, attention_weights)
        plt.savefig(f'attention_{sentence[:15].replace(" ", "_")}.png')
        plt.show()
