In [None]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer, AdamW
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
# dataset = load_dataset('cos_e', 'v1.11')

In [None]:
from datasets import load_from_disk


# Example to load from the directory named "formatted_dataset"
loaded_dataset = load_from_disk('./formatted_dataset')

# Verify the loaded dataset
print(loaded_dataset['train'][1]['generated_output'])
print(loaded_dataset['validation'][1]['generated_output'])

remove \n$answer$

In [None]:
def remove_answer_prefix(example):
    if 'generated_output' in example:
        example['generated_output'] = [output.replace("$answer$ =", "").strip() for output in example['generated_output']]
    return example

In [None]:
from datasets import load_from_disk

# Apply the function to the dataset
transformed_dataset = loaded_dataset.map(remove_answer_prefix, batched=True)

# Verify the transformation
print(transformed_dataset['train'][0]['generated_output'])
print(transformed_dataset['validation'][0]['generated_output'])

In [None]:
print([transformed_dataset['validation'][0]['generated_output']])

In [None]:
# Verify the transformation
print(transformed_dataset['train'][1]['generated_output'])
print(transformed_dataset['validation'][1]['generated_output'])

In [None]:
transformed_dataset

In [None]:
# Initialize the model and tokenizer
model_name = "t5-base"  # Use "t5-base" or "t5-large" if resources allow
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

must be change

In [None]:
# # Preprocess the dataset
# def preprocess_function(examples):
#     inputs = ["question: " + q + " answer: " + " ".join(choices) for q, choices in zip(examples['question'], examples['choices'])]
#     targets = ["answer: " + answer + " explanation: " + explanation for answer, explanation in zip(examples['answer'], examples['abstractive_explanation'])]
#     model_inputs = tokenizer(inputs, max_length=256, truncation=True, padding='max_length')
#     labels = tokenizer(targets, max_length=256, truncation=True, padding='max_length')
#     model_inputs['labels'] = labels['input_ids']
#     return model_inputs

In [None]:
# Preprocess the dataset
def preprocess_function(examples):
    # Check if 'generated_output' exists in the examples
    if 'generated_output' not in examples:
        raise ValueError("The 'generated_output' field is missing in the dataset examples.")

    # Concatenate each generated_output to the end of the input
    inputs = ["question: " + q + " answer: " + " ".join(choices) + " " + generated_output
              for q, choices, generated_output in zip(examples['question'], examples['choices'], examples['generated_output'])]
    
    targets = ["answer: " + answer + " explanation: " + explanation 
               for answer, explanation in zip(examples['answer'], examples['abstractive_explanation'])]
    
    model_inputs = tokenizer(inputs, max_length=256, truncation=True, padding='max_length')
    labels = tokenizer(targets, max_length=256, truncation=True, padding='max_length')
    model_inputs['labels'] = labels['input_ids']
    
    return model_inputs

In [None]:
encoded_dataset = transformed_dataset.map(preprocess_function, batched=True)

In [None]:
# Print a few examples to verify the preprocessing
print(encoded_dataset['train'][0]['input_ids'])
print(encoded_dataset['train'][0]['labels'])
print(encoded_dataset['validation'][0]['input_ids'])
print(encoded_dataset['validation'][0]['labels'])

In [None]:
# Function to decode the encoded inputs and labels for validation
def decode_example(example):
    input_ids = example['input_ids']
    labels = example['labels']
    
    decoded_input = tokenizer.decode(input_ids, skip_special_tokens=True)
    decoded_label = tokenizer.decode(labels, skip_special_tokens=True)
    
    return decoded_input, decoded_label

# Decode and print a few examples from the training set
for i in range(3):
    decoded_input, decoded_label = decode_example(encoded_dataset['train'][i])
    print(f"Example {i + 1} - Decoded Input: {decoded_input}")
    print(f"Example {i + 1} - Decoded Label: {decoded_label}\n")

# Decode and print a few examples from the validation set
for i in range(3):
    decoded_input, decoded_label = decode_example(encoded_dataset['validation'][i])
    print(f"Example {i + 1} - Decoded Input: {decoded_input}")
    print(f"Example {i + 1} - Decoded Label: {decoded_label}\n")

In [None]:
encoded_dataset['validation']

In [None]:
from torch.utils.data import DataLoader

# Convert dataset to PyTorch tensors
encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# Create DataLoaders
train_loader = DataLoader(encoded_dataset['train'], batch_size=16, shuffle=True)
val_loader = DataLoader(encoded_dataset['validation'], batch_size=16)

In [None]:
# Initialize optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

In [None]:
# Training loop
def train_loop(model, loader, optimizer, accumulation_steps=2):
    model.train()
    batch_losses = []
    optimizer.zero_grad()

    for i, batch in enumerate(tqdm(loader, desc='Training:')):
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        outputs = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=labels)
        loss = outputs.loss / accumulation_steps  # normalize loss

        batch_loss_value = loss.item() * accumulation_steps  # convert to original loss value for logging
        loss.backward()

        if (i + 1) % accumulation_steps == 0:  # update weights every accumulation_steps mini-batches
            optimizer.step()
            optimizer.zero_grad()  # reset gradients

        batch_losses.append(batch_loss_value)

    # Update remaining gradients if the number of batches is not a multiple of accumulation_steps
    if len(loader) % accumulation_steps != 0:
        optimizer.step()
        optimizer.zero_grad()

    loss_value = sum(batch_losses) / len(batch_losses)
    return {'train_loss': loss_value}

In [None]:
def convert_to_sentences(list_of_lists):
    sentences = [' '.join(inner_list) for inner_list in list_of_lists]
    return sentences

In [None]:
import bert_score
from datasets import load_metric
def validate_loop(model, loader):
    model.eval()
    batch_losses = []
    accuracy_preds = []
    accuracy_labels = []

    with torch.no_grad():
        for batch in tqdm(loader, desc='Validation:'):
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device)
            outputs = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=labels)
            loss = outputs.loss

            # Generate predictions
            predictions = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=512)

            # Decode predictions and labels
            decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            # Print decoded predictions and labels for debugging
            # print("Decoded predictions:")
            # for pred in decoded_preds:
            #     print(f"'{pred}'")
            # print("Decoded labels:")
            # for label in decoded_labels:
            #     print(f"'{label}'")

            # # Extract the answers from decoded predictions and labels
            # extracted_preds = [pred.strip().split('.')[0] for pred in decoded_preds if pred.strip()]
            # extracted_labels = [label.strip().split('.')[0] for label in decoded_labels if label.strip()]
            # Extract the answers from decoded predictions and labels
            extracted_preds = [pred.split('answer: ')[1].split(' ')[0] for pred in decoded_preds if 'answer: ' in pred]
            extracted_labels = [label.split('answer: ')[1].split(' ')[0] for label in decoded_labels if 'answer: ' in label]
            
            explanations_preds = [pred.split('explanation: ')[1].split(' ') for pred in decoded_preds if 'explanation: ' in pred]
            explanations_labels = [label.split('explanation: ')[1].split(' ') for label in decoded_labels if 'explanation: ' in label]
            # print(convert_to_sentences(explanations_preds))
            # print(convert_to_sentences(explanations_labels))
            
            # Ensure lengths match for accuracy calculation
            if len(extracted_preds) == len(extracted_labels):
                accuracy_preds.extend(extracted_preds)
                accuracy_labels.extend(extracted_labels)

            batch_losses.append(loss.item())

    # Calculate accuracy
    correct = sum(p == l for p, l in zip(accuracy_preds, accuracy_labels))
    accuracy = correct / len(accuracy_preds) if accuracy_preds else 0.0
    
    
    # Calculate BERTScore for explanations (explanations_preds and explanations_labels)
    P_exp, R_exp, F1_exp = bert_score.score(convert_to_sentences(explanations_preds), convert_to_sentences(explanations_labels), lang="en", rescale_with_baseline=True)
    bertscore_exp_avg = F1_exp.mean().item()

    loss_value = sum(batch_losses) / len(batch_losses)
    # return {'val_loss': loss_value, 'accuracy': accuracy}
    # return {'val_loss': loss_value, 'bertscore_exp': bertscore_exp_avg}
    return {'val_loss': loss_value, 'accuracy': accuracy, 'bertscore_exp': bertscore_exp_avg}


# # Training and validation
# num_epochs = 1
# for epoch in range(num_epochs):
#     train_metrics = train_loop(model, train_loader, optimizer)
#     val_metrics = validate_loop(model, val_loader)

#     print(f"Epoch {epoch + 1}/{num_epochs}")
#     print(f"Train Loss: {train_metrics['train_loss']:.4f}")
#     print(f"Validation Loss: {val_metrics['val_loss']:.4f}")
#     print(f"Validation Accuracy: {val_metrics['accuracy']:.4f}")
    
# Training and validation
num_epochs = 3
for epoch in range(num_epochs):
    train_metrics = train_loop(model, train_loader, optimizer)
    val_metrics = validate_loop(model, val_loader)

    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"Train Loss: {train_metrics['train_loss']:.4f}")
    print(f"Validation Loss: {val_metrics['val_loss']:.4f}")
    print(f"Validation Accuracy: {val_metrics['accuracy']:.4f}")
    print(f"Validation BERTScore (Explanations): {val_metrics['bertscore_exp']:.4f}")