\# This notebook has steps for model initialisation and Training**



In [1]:
!pip install transformers
!pip install tokenizer
!pip install datasets
!pip install rouge_score
!pip install sentencepiece
!pip install rouge



In [1]:
import os
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from ModelSummarizer import SummarizationModel
from ModelSummarizer import load_data

# Define the path for the datasets

train_file_path = os.path.join('../dataset/', 'dataset_ground_truth.json')  # 100 pdfs
test_file_path =  os.path.join('../dataset/', 'dataset_test_ground_truth.json')   #20 pdfs
val_file_path =  os.path.join('../dataset/', 'dataset_eval_ground_truth.json')  #20 pdfs

model_name = "allenai/led-large-16384-arxiv"
summarizer = SummarizationModel(model_name)
model = summarizer.model

# Load training data
train_data = load_data(train_file_path)

# Load testing data
test_data = load_data(test_file_path)

#Load val data
val_data=load_data(val_file_path)

#Define Sequence length of model
seq_length=1024

#Initialize the list for storing the losses
train_losses = []
val_losses = []
rouge_scores = []

#Declare variable for storing the checkpoint
checkpoint_filename = "model_checkpoint.pt"

# Initialize variables for training
best_val_loss = float('inf')
epochs_no_improve = 0
num_epochs=1
patience = 3
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1, shuffle=True)

for epoch in range(num_epochs):
  train_loss = summarizer.train_model(train_loader)
  avg_train_loss = train_loss / len(train_loader)
  train_losses.append(avg_train_loss)

  # Validate the model
  val_loss,total_rouge1_f1,total_rouge2_f1,total_rougeL_f1,num_samples  = summarizer.validate_model(val_loader)
  avg_rouge1_f1 = total_rouge1_f1 / num_samples
  avg_rouge2_f1 = total_rouge2_f1 / num_samples
  avg_rougeL_f1 = total_rougeL_f1 / num_samples
  rouge_scores.append((avg_rouge1_f1, avg_rouge2_f1, avg_rougeL_f1))

  avg_val_loss = val_loss / len(val_loader)
  val_losses.append(avg_val_loss)

  # Log metrics to file
  summarizer.log_metrics(epoch, avg_train_loss, avg_val_loss, (avg_rouge1_f1, avg_rouge2_f1, avg_rougeL_f1))

  print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

  # Save checkpoint when the val loss improves
  if avg_val_loss < best_val_loss:
    best_val_loss = avg_val_loss
    epochs_no_improve = 0
    checkpoint_path = os.path.join('Checkpoints/', checkpoint_filename)
    torch.save(model.state_dict(), checkpoint_path)
  else:
    epochs_no_improve += 1
    if epochs_no_improve == patience:
      print("Early stopping triggered")
      break


# Plotting losses

plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()








  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


: 

: 