In [None]:
pip install datasets


In [None]:
from datasets import load_dataset
train_data = load_dataset("wmt16","de-en", split="train[:50000]")

In [None]:
pip install torch transformers nltk bert-score matplotlib tqdm


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from transformers import EncoderDecoderModel, EncoderDecoderConfig, BertTokenizer
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.meteor_score import meteor_score
from bert_score import score as bert_score
from torch.nn.utils.rnn import pad_sequence

# Define dataset class
class TranslationDataset(Dataset):
    def __init__(self, data, tokenizer_de, tokenizer_en):
        self.data = data
        self.tokenizer_de = tokenizer_de
        self.tokenizer_en = tokenizer_en

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        translation = self.data[idx]['translation']  # Access the 'translation' key
        src_text = translation['de']  # Access the 'de' key within 'translation'
        tgt_text = translation['en']  # Access the 'en' key within 'translation'
        src_tokens = self.tokenizer_de(src_text, padding=True, truncation=True, return_tensors='pt')['input_ids']
        tgt_tokens = self.tokenizer_en(tgt_text, padding=True, truncation=True, return_tensors='pt')['input_ids']
        return {'src_tokens': src_tokens.squeeze(), 'tgt_tokens': tgt_tokens.squeeze()}

def collate_fn(batch):
    # Sort batch by sequence length for efficient padding
        batch = sorted(batch, key=lambda x: x['src_tokens'].shape[0], reverse=True)

        # Pad sequences to the length of the longest sequence in the batch
        src_tokens = pad_sequence([x['src_tokens'] for x in batch], batch_first=True, padding_value=tokenizer_de.pad_token_id)
        tgt_tokens = pad_sequence([x['tgt_tokens'] for x in batch], batch_first=True, padding_value=tokenizer_en.pad_token_id)

        return {'src_tokens': src_tokens, 'tgt_tokens': tgt_tokens}

# Define model configuration
config = EncoderDecoderConfig.from_encoder_decoder_configs(
    EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-german-cased', 'bert-base-uncased').config.encoder,
    EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-german-cased', 'bert-base-uncased').config.decoder
)

# Initialize tokenizer
tokenizer_de = BertTokenizer.from_pretrained('bert-base-german-cased')
tokenizer_en = BertTokenizer.from_pretrained('bert-base-uncased')

# Define model
model = EncoderDecoderModel(config)

# Define optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Load data
train_data = load_dataset("wmt16","de-en", split="train[:50000]")  # Load training data
valid_data = load_dataset("wmt16","de-en", split="validation")  # Load validation data
test_data = load_dataset("wmt16","de-en", split="test")  # Load test data

# Initialize datasets and dataloaders
train_dataset = TranslationDataset(train_data, tokenizer_de, tokenizer_en)
valid_dataset = TranslationDataset(valid_data, tokenizer_de, tokenizer_en)
test_dataset = TranslationDataset(test_data, tokenizer_de, tokenizer_en)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

# Implement gradient accumulation
accumulation_steps = 4  # Accumulate gradients over 4 steps

# Training loop
'''def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc="Training"):
        src_tokens = batch['src_tokens'].to(device)
        tgt_tokens = batch['tgt_tokens'].to(device)
        optimizer.zero_grad()
        outputs = model(input_ids=src_tokens, decoder_input_ids=tgt_tokens[:, :-1], labels=tgt_tokens[:, 1:])
        loss = criterion(outputs.logits.view(-1, outputs.logits.size(-1)), tgt_tokens[:, 1:].view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)'''

def train_epoch(model, dataloader, criterion, optimizer, accumulation_steps, device):
    model.train()
    total_loss = 0
    for batch_idx, batch in enumerate(tqdm(dataloader, desc="Training")):  # Add batch_idx here
        src_tokens = batch['src_tokens'].to(device)
        tgt_tokens = batch['tgt_tokens'].to(device)
        optimizer.zero_grad()
        outputs = model(input_ids=src_tokens, decoder_input_ids=tgt_tokens[:, :-1], labels=tgt_tokens[:, 1:].contiguous())
        logits_shape = outputs.logits.shape
        labels_shape = tgt_tokens[:, 1:].shape
        print("Logits shape:", logits_shape)
        print("Labels shape:", labels_shape)
        loss = criterion(outputs.logits.reshape(-1, outputs.logits.size(-1)), tgt_tokens[:, 1:].reshape(-1))
        loss = loss / accumulation_steps  # Normalize the loss
        loss.backward()  # Backpropagate the gradients
        if (batch_idx + 1) % accumulation_steps == 0:  # Only update the weights every `accumulation_steps` batches
            optimizer.step()
            optimizer.zero_grad()
        total_loss += loss.item()
    return total_loss / len(dataloader)


# Evaluation function
# Move model to the same device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

def evaluate(model, dataloader, device):
    model.eval()
    predictions = []
    targets = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluation"):
            src_tokens = batch['src_tokens'].to(device)
            tgt_tokens = batch['tgt_tokens'].to(device)

            # Generate predictions
            generated = model.generate(src_tokens, decoder_start_token_id=tokenizer_en.cls_token_id)
            predictions.extend(generated)
            targets.extend(tgt_tokens[:, 1:].cpu().numpy())
    predictions = [tokenizer_en.decode(ids, skip_special_tokens=True) for ids in predictions]
    targets = [tokenizer_en.decode(ids, skip_special_tokens=True) for ids in targets]
    return predictions, targets




# Training loop
num_epochs = 10
train_losses = []
valid_losses = []

for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, accumulation_steps,device)
    valid_loss = evaluate(model, valid_loader,device)
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)

# Save model checkpoint
torch.save(model.state_dict(), 'model_checkpoint.pth')

# Evaluation
test_predictions, test_targets = evaluate(model, test_loader, device)

# Calculate evaluation metrics
bleu_score = corpus_bleu([[tgt.split()] for tgt in test_targets], [pred.split() for pred in test_predictions])
meteor_score_avg = meteor_score(test_targets, test_predictions)
bert_score_p, bert_score_r, bert_score_f1 = bert_score(test_predictions, test_targets, lang='en', model_type='bert-base-uncased', nthreads=4)

# Print evaluation metrics
print("BLEU Score:", bleu_score)
print("METEOR Score:", meteor_score_avg)
print("BERTScore Precision:", bert_score_p.mean())
print("BERTScore Recall:", bert_score_r.mean())
print("BERTScore F1:", bert_score_f1.mean())
