In [None]:
%pip install -U transformers datasets peft evaluate tf-keras sacrebleu rouge_score pycocoevalcap

In [None]:
# Import libraries
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, DataCollatorForLanguageModeling
from datasets import load_dataset
import evaluate
from peft import LoraConfig, get_peft_model

In [None]:
import pandas as pd
from itertools import groupby
from datasets import Dataset
# Load the E2E NLG Challenge dataset
dataset = load_dataset('e2e_nlg')

train_data = dataset['train']
validation_data = dataset['validation']

test_data = dataset['test']
test_data = sorted(test_data, key=lambda x: x['meaning_representation'])

test_data = [list(group) for key, group in
                groupby(test_data, key=lambda x: x['meaning_representation'])]

test_data = [
    {
        "meaning_representation": group[0]['meaning_representation'],
        "human_reference": [item['human_reference'] for item in group]
    }
    for group in test_data
]

test_data = Dataset.from_pandas(pd.DataFrame(test_data))

# Preprocess the dataset
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
tokenizer.pad_token = tokenizer.eos_token  # Add padding token
tokenizer.padding_side = 'left'

def preprocess_function(examples):
    inputs = [mr for mr in examples['meaning_representation']]
    targets = [text for text in examples['human_reference']]
    texts = ["Input: " + inp + ' Output: ' + tgt + " " + tokenizer.eos_token for inp, tgt in zip(inputs, targets)]
    model_inputs = tokenizer(texts, truncation=True, max_length=512)
    return model_inputs

def preprocess_function_eval(examples):
    inputs = [mr for mr in examples['meaning_representation']]
    inputs = ["Input: " + inp + " Output: " for inp in inputs]
    model_inputs = tokenizer(inputs, truncation=True, max_length=512)
    model_inputs["meaning_representation"] = inputs
    model_inputs["human_reference"] = examples["human_reference"]
    return model_inputs

train_tokenized = train_data.map(preprocess_function, batched=True)
validation_tokenized = validation_data.map(preprocess_function, batched=True)
test_tokenized = test_data.map(preprocess_function_eval, batched=True)

train_tokenized = train_tokenized.remove_columns(train_data.column_names)
validation_tokenized = validation_tokenized.remove_columns(validation_data.column_names)

In [None]:
# Load GPT-2 Medium model and apply LoRA to Wq and Wv
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')

# Prepare LoRA configuration
lora_config = LoraConfig(
    r=4,
    lora_alpha=32,
    target_modules=["c_attn"],  # Applies to Wq and Wv
    lora_dropout=0.1,
    init_lora_weights="gaussian"
)

model = get_peft_model(model, lora_config)

In [None]:
# Data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [None]:
%pip install nlg-metricverse

In [None]:
bleu = evaluate.load('bleu')
meteor = evaluate.load('meteor')
rouge = evaluate.load('rouge')
nist = evaluate.load('nist_mt')
#cider = evaluate.load("Kamichanw/CIDEr")
from nlgmetricverse import NLGMetricverse, load_metric
cider = NLGMetricverse(metrics=load_metric("cider"))

In [None]:
from torch.utils.data import DataLoader

def custom_collate_fn(batch):
    human_references = [item['human_reference'] for item in batch]
    meaning_representations = [item['meaning_representation'] for item in batch]
    # Remove 'human_reference' before using data_collator
    batch = [{'input_ids': item['input_ids'], 'attention_mask': item['attention_mask']} for item in batch]
    batch = data_collator(batch)
    batch['human_reference'] = human_references
    batch['meaning_representation'] = meaning_representations
    return batch

def generate_predictions(test_dataloader):
    model.eval()
    predictions = []
    references = []
    progress_bar = tqdm(test_dataloader, desc="Generating predictions")
    for batch in progress_bar:
        # Use only the meaning representation as input
        input_ids = batch['input_ids'].to(model.device)
        attention_mask = batch['attention_mask'].to(model.device)
        meaning_representations = batch['meaning_representation']
        human_references = batch['human_reference']
        with torch.no_grad():
            output_ids = model.generate(
                input_ids,
                attention_mask=attention_mask,
                num_beams=10,
                length_penalty=0.9,
                no_repeat_ngram_size=4,
                early_stopping=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
        for i, output in enumerate(output_ids):
            prediction = tokenizer.decode(output, skip_special_tokens=True)
            input_text = meaning_representations[i]
            if prediction.startswith(input_text):
                prediction = prediction[len(input_text):].strip()
            predictions.append(prediction)
            references.append(human_references[i])
    return predictions, references

test_dataloader = DataLoader(
    test_tokenized,
    batch_size=8,
    collate_fn=custom_collate_fn,
)

In [None]:
from transformers import get_linear_schedule_with_warmup
from tqdm.auto import tqdm

# Prepare DataLoaders
train_dataloader = DataLoader(train_tokenized, shuffle=True, batch_size=8, collate_fn=data_collator)
validation_dataloader = DataLoader(validation_tokenized, batch_size=8, collate_fn=data_collator)

# Initialize optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)
num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=num_training_steps)

# Track best validation loss
best_eval_loss = float('inf')
model.to("cuda" if torch.cuda.is_available() else "cpu")

# Training loop
for epoch in range(num_epochs):
    model.train()
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch in progress_bar:
        inputs = batch.to(model.device)
        outputs = model(**inputs)
        loss = outputs.loss
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        progress_bar.set_postfix(loss=loss.item())
    
    # Evaluation
    predictions, references = generate_predictions(test_dataloader)

    # Compute metrics
    bleu_score = bleu.compute(predictions=predictions, references=references)
    meteor_score = meteor.compute(predictions=predictions, references=references)
    rouge_score = rouge.compute(predictions=predictions, references=references)
    nist_score = nist.compute(predictions=predictions, references=references)
    #cider_score = cider.compute(predictions, references)
    
    cider_score = cider(predictions=predictions, references=references)

    test_metrics = {
        'bleu': bleu_score['bleu'],
        'meteor': meteor_score['meteor'],
        'rouge_l': rouge_score['rougeL'],
        'nist': nist_score['nist_mt'],
        'cider': cider_score['cider']['score'],
    }
    print(test_metrics)
    """for batch in validation_dataloader:
        with torch.no_grad():
            inputs = batch.to(model.device)
            outputs = model(**inputs)
            eval_loss += outputs.loss.item()
    avg_eval_loss = eval_loss / len(validation_dataloader)
    print(f"Validation Loss: {avg_eval_loss}")

    # Save the best model
    if avg_eval_loss < best_eval_loss:
        best_eval_loss = avg_eval_loss
        torch.save(model.state_dict(), "./results/best_model.bin")"""

#model.load_state_dict(torch.load("./results/best_model.bin"))