In [None]:
%pip install -U transformers datasets peft evaluate tf-keras sacrebleu rouge_score pycocoevalcap nltk

In [None]:
# Import libraries
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from datasets import load_dataset
import evaluate
from peft import LoraConfig, get_peft_model

In [None]:
import pandas as pd
from itertools import groupby
from datasets import Dataset
# Load the E2E NLG Challenge dataset
dataset_e2e = load_dataset('e2e_nlg')
dataset_webnlg = load_dataset('web_nlg', 'webnlg_challenge_2017')
dataset_dart = load_dataset('dart')

# Dictionary to store datasets
datasets = {
    'E2E': dataset_e2e,
}

# Hyperparameters for each dataset
hyperparams = {
    'E2E': {
        'weight_decay': 0.01,
        'dropout_prob': 0.1,
        'label_smooth': 0.1,
        'length_penalty': 0.9
    }
}

In [None]:
# Grouping function for E2E NLG test dataset
def group_e2e_test_data(test_data):
    df = pd.DataFrame(test_data)
    df.sort_values(by='meaning_representation', inplace=True)
    grouped = df.groupby('meaning_representation')['human_reference'].apply(list).reset_index()
    grouped_dataset = Dataset.from_pandas(grouped)
    return grouped_dataset

def preprocess_e2e(examples):
    inputs = examples['meaning_representation']
    targets = examples['human_reference']
    texts = [inp + ' | ' + tgt + " " + tokenizer.eos_token for inp, tgt in zip(inputs, targets)]
    model_inputs = tokenizer(texts, truncation=True)
    model_inputs["labels"] = model_inputs["input_ids"].copy()
    return model_inputs

def preprocess_e2e_eval(examples):
    inputs = examples['meaning_representation']
    targets = examples['human_reference']
    texts = [inp + ' | ' for inp in inputs]
    model_inputs = tokenizer(texts, truncation=True)
    model_inputs["meaning_representation"] = texts
    model_inputs["human_reference"] = targets
    return model_inputs

In [None]:
%pip install nlg-metricverse

In [None]:
bleu = evaluate.load('bleu')
rouge = evaluate.load('rouge')
nist = evaluate.load('nist_mt')
from nlgmetricverse import NLGMetricverse, load_metric
cider = NLGMetricverse(metrics=load_metric("cider"))
meteor = NLGMetricverse(metrics=load_metric("meteor"))

In [None]:
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import re

def custom_collate_fn(batch):
    human_references = [item['human_reference'] for item in batch]
    meaning_representations = [item['meaning_representation'] for item in batch]
    # Remove 'human_reference' before using data_collator
    batch = [{'input_ids': item['input_ids'], 'attention_mask': item['attention_mask']} for item in batch]
    batch = data_collator(batch)
    batch['human_reference'] = human_references
    batch['meaning_representation'] = meaning_representations
    return batch

def generate_predictions(test_dataloader, model, tokenizer, length_penalty):
    model.eval()
    predictions = []
    references = []
    progress_bar = tqdm(test_dataloader, desc="Generating predictions")
    for batch in progress_bar:
        # Use only the meaning representation as input
        input_ids = batch['input_ids'].to(model.device)
        attention_mask = batch['attention_mask'].to(model.device)
        meaning_representations = batch['meaning_representation']
        human_references = batch['human_reference']
        with torch.no_grad():
            output_ids = model.generate(
                input_ids,
                attention_mask=attention_mask,
                num_beams=10,
                length_penalty=length_penalty,
                repetition_penalty=1.0,
                no_repeat_ngram_size=4,
                max_new_tokens=64,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
        for i, output in enumerate(output_ids):
            prediction = tokenizer.decode(output, skip_special_tokens=True)
            input_text = meaning_representations[i]
            if prediction.startswith(input_text):
                prediction = prediction[len(input_text) :]
            prediction = prediction.strip()
            #predictions.append(prediction)
            #references.append(human_references[i])
            predictions.append(format_outputs(prediction))
            references.append([format_outputs(human_reference) for human_reference in human_references[i]]) 
    return predictions, references

def format_outputs(text):
    text = text.lower()
    text = ' '.join(re.split('(\W)', text))
    text = text.split()
    text = ' '.join(text)
    return text

In [None]:
# Update the training and evaluation loop
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
import torch
import re

for dataset_name, dataset in datasets.items():
    print(f"Training on {dataset_name} dataset")
    params = hyperparams[dataset_name]
    
    # Load tokenizer and model
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
    #tokenizer.pad_token = tokenizer.eos_token
    #tokenizer.pad_token_id = 18610
    tokenizer.padding_side = 'left'
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
 
    model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
    with torch.no_grad():
      model.resize_token_embeddings(len(tokenizer))
    model.config.pad_token_id = tokenizer.pad_token_id
    
    # Apply LoRA
    lora_config = LoraConfig(
        r=4,
        lora_alpha=32,
        target_modules=["c_attn"],
        lora_dropout=params['dropout_prob'],
        init_lora_weights="gaussian",
        bias="none"
    )
    model = get_peft_model(model, lora_config)
    # Preprocess the dataset using the appropriate function
    if dataset_name == 'E2E':
        test_data = group_e2e_test_data(dataset['test'])
        preprocess_function = preprocess_e2e
        preprocess_function_eval = preprocess_e2e_eval
    
    train_data = dataset['train']
    
    train_tokenized = train_data.map(preprocess_function, batched=True, remove_columns=train_data.column_names)
    test_tokenized = test_data.map(preprocess_function_eval, batched=True, remove_columns=test_data.column_names)
    
    # Data collator and DataLoaders
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, return_tensors="pt", padding=True)

    train_dataloader = DataLoader(train_tokenized, shuffle=True, batch_size=8, collate_fn=data_collator)
    test_dataloader = DataLoader(test_tokenized, batch_size=8, collate_fn=custom_collate_fn)
    
    # Optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=params['weight_decay'])
    num_epochs = 5
    num_training_steps = num_epochs * len(train_dataloader)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=num_training_steps)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} [{dataset_name}]")
        for batch in progress_bar:
            inputs = {key: val.to(device) for key, val in batch.items()}
            outputs = model(**inputs)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            progress_bar.set_postfix(loss=loss.item())
    
        # Evaluation
        predictions, references = generate_predictions(test_dataloader, model, tokenizer, params['length_penalty'])
        bleu_score = bleu.compute(predictions=predictions, references=references)
        meteor_score = meteor(predictions=predictions, references=references)
        rouge_score = rouge.compute(predictions=predictions, references=references)
        nist_score = nist.compute(predictions=predictions, references=references)
        cider_score = cider(predictions=predictions, references=references)
        
        test_metrics = {
            'bleu': bleu_score['bleu'],
            'meteor': meteor_score['meteor']['score'],
            'rouge_l': rouge_score['rougeL'],
            'nist': nist_score['nist_mt'],
            'cider': cider_score['cider']['score']
        }
        
        print(test_metrics)

    