In [None]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AdamW
from torch.utils.data import Dataset, DataLoader


torch.cuda.empty_cache()


# Function to move batches to a device
def to_device(batch, device):
    return {k: v.to(device) for k, v in batch.items()}

# Load and preprocess data
def preprocess_data(df):
    df['input_text'] = df['context'].str.replace(df['abbreviation'], "<mask>", regex=False)
    df['target_text'] = df['expanded_abbreviation']
    return df

# Load the CSV files
training_df = pd.read_csv('C:/Users/kaczm/OneDrive/Pulpit/Abbr_env_v2/training_df.csv')
test_df = pd.read_csv('C:/Users/kaczm/OneDrive/Pulpit/Abbr_env_v2/test_df.csv')

# Preprocess the data
training_df = preprocess_data(training_df)
test_df = preprocess_data(test_df)


# Define a custom dataset
class AbbreviationDataset(Dataset):
    def __init__(self, tokenizer, data, max_length=512):
        self.tokenizer = tokenizer
        self.inputs = data['input_text']
        self.targets = data['target_text']
        self.max_length = max_length

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_text = self.inputs[idx]
        target_text = self.targets[idx]

        # Tokenize input and target texts
        input_tokens = self.tokenizer.encode_plus(input_text, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt')
        target_tokens = self.tokenizer.encode_plus(target_text, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt')

        # Combine input and target tokens into one dictionary
        return {**input_tokens, 'labels': target_tokens['input_ids']}

# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")  
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")

# Move model to GPU if available
device = torch.device("cpu")

if torch.cuda.is_available():
    print('CUDA being used') 
model.to(device)

# Create Dataset and DataLoader for training and test data
train_dataset = AbbreviationDataset(tokenizer, training_df)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

test_dataset = AbbreviationDataset(tokenizer, test_df)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# Training loop
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        # Move batch to the appropriate device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Unpack the input and labels from the batch
        input_ids = batch['input_ids'].squeeze(1)
        labels = batch['labels'].squeeze(1)

        # Forward pass
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

    # Evaluate on test data
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for batch in test_loader:
            # Move batch to the appropriate device
            batch = to_device(batch, device)

            # Unpack the input and labels from the batch
            input_ids = batch['input_ids'].squeeze(1)
            labels = batch['labels'].squeeze(1)

            # Forward pass
            outputs = model(input_ids=input_ids, labels=labels)
            total_loss += outputs.loss.item()

        avg_loss = total_loss / len(test_loader)
        print(f"Test Loss after Epoch {epoch+1}: {avg_loss}")