In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


In [2]:
import torch
torch.cuda.empty_cache()

In [3]:
from transformers import BartTokenizer, BartForConditionalGeneration

model_path = "./final_bart_model"  # your saved model folder

tokenizer = BartTokenizer.from_pretrained(model_path)
model = BartForConditionalGeneration.from_pretrained(model_path)
model.eval()


  from .autonotebook import tqdm as notebook_tqdm


BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_lay

In [4]:
from datasets import load_dataset
import pandas as pd

df = pd.read_csv('output_cleaned.csv')[:800]

In [5]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

# Function to truncate text for max token length
def truncate_to_max_tokens(text, max_tokens=1024):
    encoded = tokenizer.encode(text, max_length=max_tokens, truncation=True)
    return tokenizer.decode(encoded, skip_special_tokens=True).strip()

# Apply truncation to 'introduction' and create input_text
df['input_text'] = df['introduction'].apply(truncate_to_max_tokens)

# Optional: calculate token count (for debugging or verification)
df['token_count'] = df['input_text'].apply(lambda x: len(tokenizer.encode(x)))

# Train-validation split
train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)

# Custom PyTorch Dataset for T5
class BartDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=1024):
        self.input_texts = df['input_text'].tolist()
        self.target_texts = df['abstract'].tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, idx):
        input_text = self.input_texts[idx]
        target_text = self.target_texts[idx]

        # Tokenize inputs and targets
        input_encoding = self.tokenizer(
            input_text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        target_encoding = self.tokenizer(
            target_text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        labels = target_encoding['input_ids'].squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100  # Ignore pad tokens in loss

        return {
            'input_ids': input_encoding['input_ids'].squeeze(),
            'attention_mask': input_encoding['attention_mask'].squeeze(),
            'labels': labels
        }

# Create PyTorch datasets
train_dataset = BartDataset(train_df, tokenizer)
val_dataset = BartDataset(val_df, tokenizer)

In [None]:
from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq
import evaluate

# Evaluation metric
rouge = evaluate.load("rouge")

def compute_metrics(p):
    # If predictions is a tuple, get the first element
    predictions = p.predictions
    if isinstance(predictions, tuple):
        predictions = predictions[0]

    labels = p.label_ids

    # Convert predicted logits to token IDs if needed
    if predictions.ndim == 3:
        predictions = predictions.argmax(-1)

    # Replace -100 in labels as tokenizer.decode can't handle them
    labels = [[(token if token != -100 else tokenizer.pad_token_id) for token in label] for label in labels]

    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels)

    # Return result directly, assuming it's already a dict of floats
    return result

# Data collator for dynamic padding
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results_bart',
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    save_total_limit=2,
    logging_dir='./logs',
    logging_steps=50,
    no_cuda=True,   
    fp16=True,  # Set to False if not using mixed-precision GPU
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

# Evaluate
results = trainer.evaluate()

  trainer = Trainer(
  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


{'eval_loss': 2.743373394012451, 'eval_model_preparation_time': 0.001, 'eval_rouge1': 0.4478078485062965, 'eval_rouge2': 0.14918591556017471, 'eval_rougeL': 0.32873686676753666, 'eval_rougeLsum': 0.3285933015361794, 'eval_runtime': 466.4884, 'eval_samples_per_second': 0.171, 'eval_steps_per_second': 0.043}


In [7]:
for key, value in results.items():
    print(f"{key}: {value}")


eval_loss: 2.743373394012451
eval_model_preparation_time: 0.001
eval_rouge1: 0.4478078485062965
eval_rouge2: 0.14918591556017471
eval_rougeL: 0.32873686676753666
eval_rougeLsum: 0.3285933015361794
eval_runtime: 466.4884
eval_samples_per_second: 0.171
eval_steps_per_second: 0.043
