In [None]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()

In [None]:
from datasets import load_dataset
import pandas as pd

df = pd.read_csv('output_cleaned.csv')[:2000]

In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration

# Load tokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base").to(device)

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

# Function to truncate text for max token length
def truncate_to_max_tokens(text, max_tokens=1024):
    encoded = tokenizer.encode(text, max_length=max_tokens, truncation=True)
    return tokenizer.decode(encoded, skip_special_tokens=True).strip()

# Apply truncation to 'introduction' and create input_text
df['input_text'] = df['introduction'].apply(truncate_to_max_tokens)

# Optional: calculate token count (for debugging or verification)
df['token_count'] = df['input_text'].apply(lambda x: len(tokenizer.encode(x)))

# Train-validation split
train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)

# Custom PyTorch Dataset for T5
class BartDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=1024):
        self.input_texts = df['input_text'].tolist()
        self.target_texts = df['abstract'].tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, idx):
        input_text = self.input_texts[idx]
        target_text = self.target_texts[idx]

        # Tokenize inputs and targets
        input_encoding = self.tokenizer(
            input_text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        target_encoding = self.tokenizer(
            target_text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        labels = target_encoding['input_ids'].squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100  # Ignore pad tokens in loss

        return {
            'input_ids': input_encoding['input_ids'].squeeze(),
            'attention_mask': input_encoding['attention_mask'].squeeze(),
            'labels': labels
        }

# Create PyTorch datasets
train_dataset = BartDataset(train_df, tokenizer)
val_dataset = BartDataset(val_df, tokenizer)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8)

In [None]:
import evaluate

rouge = evaluate.load("rouge")

def compute_metrics(p):
    # If predictions is a tuple, get the first element
    predictions = p.predictions
    if isinstance(predictions, tuple):
        predictions = predictions[0]

    labels = p.label_ids

    # Convert predicted logits to token IDs if needed
    if predictions.ndim == 3:
        predictions = predictions.argmax(-1)

    # Replace -100 in labels as tokenizer.decode can't handle them
    labels = [[(token if token != -100 else tokenizer.pad_token_id) for token in label] for label in labels]

    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels)

    # Return result directly, assuming it's already a dict of floats
    return result

In [None]:
from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq

# Data collator for dynamic padding
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results_bart',
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    save_total_limit=2,
    logging_dir='./logs',
    logging_steps=50,
    no_cuda=False,   
    fp16=True,  # Set to False if not using mixed-precision GPU
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

# Train the model
trainer.train()

In [None]:
# Save final model and tokenizer
trainer.save_model('./final_bart_model')
tokenizer.save_pretrained('./final_bart_model')

In [None]:
trainer.evaluate()