# Extended Subproject - Fine Tuning

In [None]:
import pandas as pd
import torch
from torch.optim import AdamW  # Import PyTorch's AdamW

from torch.utils.data import DataLoader, Dataset
#from transformers import BertTokenizer, BertForSequenceClassification
from transformers import LongformerTokenizer, LongformerForSequenceClassification, AdamW

from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt

In [None]:
# Load the datasets
train_data = pd.read_csv("cnndm/train_data_ext.csv", sep=';')
valid_data = pd.read_csv("cnndm/valid_data_ext.csv", sep=';')
test_data = pd.read_csv("cnndm/test_data_ext.csv", sep=';')

In [None]:
print(train_data.isna().sum())
missing_rows = train_data[train_data.isna().any(axis=1)]
train_data.head()

## Model and Tokenizer Loading

In [None]:
# Load tokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

# Load Longformer model for classification
model = LongformerForSequenceClassification.from_pretrained(
    "allenai/longformer-base-4096",
    num_labels=2  # Binary classification
)

# # Check model details
# print(model.config)

## Prepare Data for Fine-Tuning

In [None]:
class TokenLevelDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=2048):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = []
        self.skipped_count = 0  # Counter for rows skipped due to document length
        self.skipped_bc_chunk = 0  # Counter for rows skipped due to chunking issues

        self._create_examples()

    def _create_examples(self):
        for _, row in self.data.iterrows():
            doc, summ, hallucination_labels = row['article'], row['highlights'], row['labels']

            # Tokenize document and summary
            doc_tokens = self.tokenizer.tokenize(doc)
            summ_tokens = self.tokenizer.tokenize(summ)

            # Ensure document fits within max_length alone
            if len(doc_tokens) + 3 > self.max_length:  # [CLS] doc_tokens [SEP]
                self.skipped_count += 1
                continue

            # Function to check if a chunk fits within max_length
            def chunk_fits(tokens_chunk):
                return len(doc_tokens) + len(tokens_chunk) + 3 <= self.max_length

            # Case 1: Check if the full summary fits
            if chunk_fits(summ_tokens):
                input_ids, attention_mask, labels = self._create_input(doc_tokens, summ_tokens, hallucination_labels)
                self.examples.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
            else:
                # Case 2: Try splitting into halves
                mid = len(summ_tokens) // 2
                if chunk_fits(summ_tokens[:mid]) and chunk_fits(summ_tokens[mid:]):
                    for chunk, chunk_labels in zip(
                        [summ_tokens[:mid], summ_tokens[mid:]],
                        [hallucination_labels[:mid], hallucination_labels[mid:]]
                    ):
                        input_ids, attention_mask, labels = self._create_input(doc_tokens, chunk, chunk_labels)
                        self.examples.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
                else:
                    # Case 3: Try splitting into thirds
                    third = len(summ_tokens) // 3
                    chunks = [summ_tokens[:third], summ_tokens[third:2 * third], summ_tokens[2 * third:]]
                    chunk_labels = [
                        hallucination_labels[:third],
                        hallucination_labels[third:2 * third],
                        hallucination_labels[2 * third:]
                    ]
                    if all(chunk_fits(chunk) for chunk in chunks):
                        for chunk, chunk_label in zip(chunks, chunk_labels):
                            input_ids, attention_mask, labels = self._create_input(doc_tokens, chunk, chunk_label)
                            self.examples.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
                    else:
                        # Case 4: Skip if none of the strategies work
                        self.skipped_bc_chunk += 1

    def _create_input(self, doc_tokens, summ_tokens, hallucination_labels):
        input_ids = [self.tokenizer.cls_token_id] + \
                    self.tokenizer.convert_tokens_to_ids(doc_tokens) + \
                    [self.tokenizer.sep_token_id] + \
                    self.tokenizer.convert_tokens_to_ids(summ_tokens) + \
                    [self.tokenizer.sep_token_id]

        attention_mask = [1] * len(input_ids)

        # Align labels
        labels = [-100] * len(input_ids)  # Initialize with ignored index
        doc_len = len(doc_tokens)
        labels[doc_len + 2:doc_len + 2 + len(hallucination_labels)] = hallucination_labels

        # Pad if necessary
        if len(input_ids) < self.max_length:
            pad_length = self.max_length - len(input_ids)
            input_ids += [self.tokenizer.pad_token_id] * pad_length
            attention_mask += [0] * pad_length
            labels += [-100] * pad_length  # Ignore padding tokens

        return input_ids, attention_mask, labels

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        return {
            "input_ids": torch.tensor(example["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(example["attention_mask"], dtype=torch.long),
            "labels": torch.tensor(example["labels"], dtype=torch.long)
        }


In [None]:
# Parameters
MAX_LEN = 2048 # das setzen wir als balance zwischen wie viele padding brauchen und wie viele rausfallen

# Prepare datasets
train_dataset = TokenLevelDataset(train_data, tokenizer, max_length=MAX_LEN)
valid_dataset = TokenLevelDataset(valid_data, tokenizer, max_length=MAX_LEN)
test_dataset = TokenLevelDataset(test_data, tokenizer, max_length=MAX_LEN)


In [None]:
print(f'Train: {train_dataset.skipped_count}/{train_data.shape[0]} ~ {train_dataset.skipped_count/train_data.shape[0] * 100} %')
print(f'Test: {test_dataset.skipped_count}/{test_data.shape[0]} ~ {test_dataset.skipped_count/test_data.shape[0] * 100} %')
print(f'Valid: {valid_dataset.skipped_count}/{valid_data.shape[0]} ~ {valid_dataset.skipped_count/valid_data.shape[0] * 100} %')

In [None]:
print(f'Train: {train_dataset.skipped_bc_chunk}/{train_data.shape[0]} ~ {train_dataset.skipped_bc_chunk/train_data.shape[0] * 100} %')
print(f'Test: {test_dataset.skipped_bc_chunk}/{test_data.shape[0]} ~ {test_dataset.skipped_bc_chunk/test_data.shape[0] * 100} %')
print(f'Valid: {valid_dataset.skipped_bc_chunk}/{valid_data.shape[0]} ~ {valid_dataset.skipped_bc_chunk/valid_data.shape[0] * 100} %')

In [None]:
BATCH_SIZE = 16

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Optimizer
optimizer = AdamW(model.parameters(), lr=2e-5)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)

# Loss function
criterion = torch.nn.CrossEntropyLoss()