# Fine-Tuned T5 with Peft

In [None]:
import os
import pandas as pd
import torch
import gc
from datetime import datetime
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import DataCollatorForSeq2Seq
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from peft import get_peft_model, LoraConfig, TaskType
import ast

# Enable cuDNN benchmark for GPU optimization
torch.backends.cudnn.benchmark = True

# Clear memory
def clear_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

class CVEDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_length=128, max_output_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_output_length = max_output_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        query = self.data[idx]['input']
        answer_data = self.data[idx].get('output', {}).get('answer', '')

        if isinstance(answer_data, list):
            answer = "; ".join([str(item) for item in answer_data])
        elif isinstance(answer_data, dict):
            answer = str(answer_data)
        else:
            answer = str(answer_data)

        input_encodings = self.tokenizer(query, truncation=True, padding="max_length", 
                                         max_length=self.max_input_length, return_tensors="pt")
        output_encodings = self.tokenizer(answer, truncation=True, padding="max_length", 
                                          max_length=self.max_output_length, return_tensors="pt")

        input_ids = input_encodings['input_ids'].squeeze()
        attention_mask = input_encodings['attention_mask'].squeeze()
        labels = output_encodings['input_ids'].squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
        }

def train_with_peft(model, train_loader, val_loader, optimizer, scaler, device, accumulation_steps=4, epochs=3):
    model.train()
    epoch_results = []

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        epoch_loss = 0

        for step, batch in enumerate(tqdm(train_loader, desc="Training", ncols=100)):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()

            with autocast():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss / accumulation_steps

            scaler.scale(loss).backward()

            if (step + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            epoch_loss += loss.item()

        avg_epoch_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch + 1} Average Loss: {avg_epoch_loss:.4f}")
        epoch_results.append({"epoch": epoch + 1, "avg_loss": avg_epoch_loss})

        # Validate the model
        validate_loss = validate_model(model, val_loader, device)
        print(f"Epoch {epoch + 1} Validation Loss: {validate_loss:.4f}")
        epoch_results[-1]["val_loss"] = validate_loss

        clear_memory()

    return epoch_results

def validate_model(model, val_loader, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating", ncols=100):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            with autocast():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss

            total_loss += loss.item()

    return total_loss / len(val_loader)

def main():
    # Load existing prompts
    prompts_data = pd.read_csv('./data/generated_input_output_pairs.csv')
    prompts_data['output'] = prompts_data['output'].apply(ast.literal_eval)  # Convert stringified dicts back to dictionaries
    dynamic_prompts = prompts_data.to_dict(orient='records')

    # Prepare tokenizer and dataset
    tokenizer = T5Tokenizer.from_pretrained('t5-small')
    dataset = CVEDataset(dynamic_prompts, tokenizer)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # DataLoader
    data_collator = DataCollatorForSeq2Seq(tokenizer, padding=True)
    train_loader = DataLoader(train_dataset, batch_size=8, collate_fn=data_collator, shuffle=True, pin_memory=True, num_workers=8)
    val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=data_collator, pin_memory=True, num_workers=8)

    # Model and PEFT setup
    model = T5ForConditionalGeneration.from_pretrained('t5-small')
    peft_config = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1
    )
    model = get_peft_model(model, peft_config)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    scaler = GradScaler()

    # Training and validation with PEFT
    epoch_results = train_with_peft(model, train_loader, val_loader, optimizer, scaler, device, accumulation_steps=4, epochs=3)
    print("Training completed.")

    # Save validation results to CSV
    pd.DataFrame(epoch_results).to_csv('./validation_results.csv', index=False)
    print("Validation results saved to './validation_results.csv'")

    # Save PEFT model and tokenizer
    model.save_pretrained('./fine_tuned_t5_peft')
    tokenizer.save_pretrained('./fine_tuned_t5_peft')
    print("PEFT model and tokenizer saved to './fine_tuned_t5_peft'")

if __name__ == "__main__":
    main()
