In [1]:
import os

if os.path.basename(os.getcwd()) != 'HUST-NLP-Medical-MultiDocument-Summarization-':
    %cd ../../

e:\pyenv\GTCC\KPG-RL\HUST-NLP-Medical-MultiDocument-Summarization-


In [2]:
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import LEDForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, AutoTokenizer
import torch
from torch.nn import CrossEntropyLoss
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score
from tqdm.notebook import tqdm
from transformers import DataCollatorForSeq2Seq

In [3]:
RANDOM_SEED = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
PATH = 'allenai/led-base-16384'
tokenizer = AutoTokenizer.from_pretrained(PATH)
special_tokens_dict = {'additional_special_tokens': ['<doc-sep>']}
tokenizer.add_special_tokens(special_tokens_dict)

1

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(PATH).to(device)
model.resize_token_embeddings(len(tokenizer))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(50266, 768, padding_idx=1)

In [6]:
DOC_SEP_ = "<doc-sep>"
docsep_token_id = tokenizer.convert_tokens_to_ids(DOC_SEP_)

In [11]:
import evaluate
rouge = evaluate.load('rouge')
bertscore = evaluate.load('bertscore')

In [9]:
class PT_Medical_Dataset(Dataset):
    def __init__(self,tokenizer:AutoTokenizer,train_data):
        self.data = train_data.copy()
        self.tokenizer = tokenizer
        
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self,id):
        sentence = self.data.at[id,'Abstracts']
        target = self.data.at[id,'Target']
        encoding = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True, max_length=4096)
        target_encoding = self.tokenizer(target, return_tensors='pt', padding=False, truncation=True, max_length=1024)
        global_attention_mask = [[1 if y in [tokenizer.cls_token_id, docsep_token_id] else 0 for y in x]
                                                 for x in encoding['input_ids']]
        return {
            'input_ids': encoding['input_ids'].squeeze(0), # Squeeze to remove the extra dimension
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': target_encoding['input_ids'].squeeze(0),
            'global_attention_mask': torch.tensor(global_attention_mask).squeeze(0),
        }
    
    
    
class Medical_Dataset(Dataset):
    def __init__(self,tokenizer:AutoTokenizer,train_data,train_label):
        self.data = train_data
        self.label = train_label
        self.tokenizer = tokenizer
        
    def __len__(self):
        return self.label.shape[0]
    
    def __getitem__(self,id):
        sentence = self.data.at[id,'Abstracts']
        target = self.label.at[id,'Target']
        encoding = self.tokenizer(sentence, return_tensors='pt', truncation=True, max_length=4096)
        target_encoding = self.tokenizer(target, return_tensors='pt', truncation=True, max_length=1024)
        global_attention_mask = [[1 if y in [tokenizer.cls_token_id, docsep_token_id] else 0 for y in x]
                                                 for x in encoding['input_ids']]
        return {
            'input_ids': encoding['input_ids'].squeeze(0), # Squeeze to remove the extra dimension
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': target_encoding['input_ids'].squeeze(0),
            'global_attention_mask': torch.tensor(global_attention_mask).squeeze(0),
        }

In [10]:
cochrane_train_input = pd.read_csv("./datasets/mslr_data/ms2/train-inputs-pretrain.csv")
cochrane_train_input = cochrane_train_input.iloc[0:1,:]

train_dataset = PT_Medical_Dataset(tokenizer,cochrane_train_input)


cochrane_dev_input = pd.read_csv(".\datasets\mslr_data\ms2\dev-inputs.csv")
cochrane_dev_input["Abstract"].fillna("",inplace = True)
cochrane_dev_input = cochrane_dev_input.groupby('ReviewID').apply(lambda group:
    "".join([f"{row['Title']}{DOC_SEP_}{row['Abstract']}{DOC_SEP_}" for index, row in group.iterrows()])
).reset_index(name="Abstracts")
cochrane_dev_label = pd.read_csv(".\datasets\mslr_data\ms2\dev-targets.csv")

cochrane_dev_input.sort_values(by='ReviewID', inplace=True)
cochrane_dev_input.reset_index(drop=True, inplace=True)

cochrane_dev_label.drop_duplicates(subset=['ReviewID'], keep='first', inplace=True)
cochrane_dev_label.sort_values(by='ReviewID', inplace=True)
cochrane_dev_label.reset_index(drop=True, inplace=True)

cochrane_dev_input = cochrane_dev_input.iloc[0:100,:]
cochrane_dev_label = cochrane_dev_label.iloc[0:100,:]

valid_dataset = Medical_Dataset(tokenizer,cochrane_dev_input,cochrane_dev_label)

In [16]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer,model=model)

In [17]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=data_collator)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, shuffle=True, collate_fn=data_collator)

In [18]:
learning_rate = 0.001
epochs = 3

In [None]:
import wandb
api_key = "b837839166bd4f97a07e90a26fa965ee17f8b64f"
wandb.login(key=api_key)
wandb.init(project = "summarization")

In [None]:
optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    model.train()
    total_train_loss = 0
    for batch in tqdm(train_dataloader):
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        print(input_ids.shape)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        global_attention_mask = batch['global_attention_mask'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels, global_attention_mask=global_attention_mask)
        loss = outputs.loss
        loss.backward()
        optim.step()
        total_train_loss += loss.item() 

    avg_train_loss = total_train_loss / len(train_dataloader) 
    wandb.log({"epoch": epoch+1, "avg_train_loss": avg_train_loss})

        
    model.eval()
    total_valid_loss = 0 
    predictions = [] 
    references = [] 
    for batch in tqdm(valid_dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        global_attention_mask = batch['global_attention_mask'].to(device)
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels, global_attention_mask=global_attention_mask)
        loss = outputs.loss
        total_valid_loss += loss.item()

        filtered_labels = []
        for label_sequence in labels:
            filtered_sequence = label_sequence[label_sequence != -100] # Keep only IDs != -100
            filtered_labels.append(filtered_sequence)
        
        predicted_tokens = outputs.logits.argmax(dim=-1) 
        decoded_preds = tokenizer.batch_decode(predicted_tokens, skip_special_tokens=True) 
        decoded_labels = tokenizer.batch_decode(filtered_labels, skip_special_tokens=True) 
        

        predictions.extend(decoded_preds) 
        references.extend(decoded_labels)

    
    avg_valid_loss = total_valid_loss / len(valid_dataloader)


    rouge_results = rouge.compute(predictions=predictions, references=references)
    wandb.log(rouge_results) 

    bertscore_results = bertscore.compute(predictions=predictions, references=references, lang="en")
    wandb.log(bertscore_results) 
    
    wandb.log({"avg_valid_loss": avg_valid_loss}) 
    print(f"Epoch {epoch+1} completed, Avg. Train Loss: {avg_train_loss:.4f}, Avg. Valid Loss: {avg_valid_loss:.4f}")
wandb.finish() 

In [15]:
tokenizer_save_path = "my_centrum_tokenizer"
model_save_path = "my_centrum_led_model"

tokenizer.save_pretrained(tokenizer_save_path)
model.save_pretrained(model_save_path)