In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertModel
from model import BERTSUM
from dataset import CNNDailyMailDataset
from train import train_fn, eval_fn

# Set the random seed for PyTorch
torch.manual_seed(42)

# Initialize the BERTSUM model
model = BERTSUM()
model.train()

# Load pre-trained BERT weights
model.bert = BertModel.from_pretrained('bert-base-uncased')

# Set up tokenizer and data loaders
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_dataset = CNNDailyMailDataset('train', tokenizer)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
eval_dataset = CNNDailyMailDataset('val', tokenizer)
eval_loader = DataLoader(eval_dataset, batch_size=10, shuffle=False)

# Set up optimizer and learning rate scheduler
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

# Fine-tune BERTSUM model on CNN/Daily Mail dataset
num_epochs = 5
for epoch in range(num_epochs):
    train_fn(train_loader, model, optimizer, scheduler)
    eval_fn(eval_loader, model)


In [None]:
# Save fine-tuned BERTSUM model
# model.save_pretrained('fine_tuned_bertsum_model')


import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        text = self.data[index]['text']
        summary = self.data[index]['summary']
        
        # Tokenize text and summary
        text_encodings = self.tokenizer(text, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
        summary_encodings = self.tokenizer(summary, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        
        # Extract input_ids and attention_mask
        input_ids = text_encodings['input_ids'].squeeze()
        attention_mask = text_encodings['attention_mask'].squeeze()
        summary_ids = summary_encodings['input_ids'].squeeze()
        
        return input_ids, attention_mask, summary_ids
