In [None]:
!pip install --quiet pytorch-crf seqeval datasets evaluate peft optuna

In [None]:
from transformers import AutoTokenizer, AutoModel
import torch
import pandas as pd
import numpy as np
import pickle
import torch.nn as nn
from torchcrf import CRF
import torch.nn.functional as F
from torch import Tensor
from typing import Optional, Tuple
import optuna
import json

In [None]:
import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
key= user_secrets.get_secret("wandb")
wandb.login(key=key)

In [None]:
input_path='/kaggle/Input/'
output_path='/kaggle/working/'

In [None]:
label_list= ['O', 'B-TSK','I-TSK','B-MTD','I-MTD','B-DST','I-DST']

In [None]:
MODELS={"checkpoint":["allenai/scibert_scivocab_uncased","malteos/scincl","allenai/specter2_base"],
       "name":["SciBERT","SciNCL","SPECTER"]}
# change id to train on different baseline models
id=2

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODELS["checkpoint"][id], add_prefix_space=True)

# Dataset preparation

## Load the dataset

In [None]:
with open(input_path+'Data/df_scirex.pkl', 'rb') as file:
    data=pickle.load(file)
data

## Change the format

In [None]:
data=data.explode(['sentences','tags']).reset_index(drop=True)
data['ner_tags'] = pd.Series(dtype='object')
for i,item in enumerate(data.tags):
    ner_list=[]
    for tag in item:
        for j,label in enumerate(label_list):
            if tag==label:
                ner_list.append(j)
    data['ner_tags'][i]=ner_list
data['tokens'] = [i for i in data['sentences']]
data['id'] = [i for i in data.index]
data

In [None]:
from datasets import Dataset
dataset = Dataset.from_pandas(data[['id','tokens','ner_tags']])
dataset

## Split the dataset

In [None]:
# LOAD OR TRAIN MODEL
TRAIN = 1 # 1 to TRAIN WEIGHTS or 0 to LOAD WEIGHTS
# TRAIN/TEST SPLIT
TRAIN_TEST_SPLIT = 0.2
# TRAIN/VALIDATION SPLIT
TRAIN_VAL_SPLIT = 0.125
RANDOM_SEED = 42

In [None]:
from sklearn.model_selection import train_test_split
train_test = dataset.train_test_split(test_size=TRAIN_TEST_SPLIT)
train_valid = train_test['train'].train_test_split(test_size=TRAIN_VAL_SPLIT)

In [None]:
from datasets import DatasetDict
datasets = DatasetDict({
    'train': train_test['train'],
    'test': train_test['test'],
    'valid': train_valid['test']})
datasets

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True, max_length=512)
    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)
tokenized_datasets

# Build the model

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(torch.cuda.get_device_name(device))
#     print(torch.cuda.current_device())
else:
    device ='cpu'
output_dir = output_path+'Models/cache'
# BATCH SIZE
# TRY 4, 8, 16, 32, 64, 128, 256. REDUCE IF OOM ERROR, HIGHER FOR TPUS
BATCH_SIZES = 4
# EPOCHS - TRANSFORMERS ARE TYPICALLY FINE-TUNED BETWEEN 1 AND 3 EPOCHS 
EPOCHS = 10

# RANDOM SEED FOR REPRODUCIBILITY
RANDOM_SEED = 42
# torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

# SPECIFY THE WEIGHTS AND BIASES PROJECT NAME
%env WANDB_PROJECT = Thesis
# DETERMINE WHETHER TO SAVE THE MODEL IN THE 100GB OF FREE W&B STORAGE
%env WANDB_LOG_MODEL = False

In [None]:
# Logging date for w&b
from datetime import date
today = date.today()
log_date = today.strftime("%d-%m-%Y")

In [None]:
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
bert_model=AutoModel.from_pretrained(MODELS['checkpoint'][id])

In [None]:
bert_model.config

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=2)
        
    def forward(self, x):
        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
        context_vectors = torch.matmul(scores,values)
        attention = self.softmax(scores)
        weighted = torch.matmul(attention, values)
        return weighted

In [None]:
class BiLSTMAttentionCRF(nn.Module):
    def __init__(self, config, bert_model):
        super().__init__()
        self.config = config
        self.bert = bert_model
        self.bilstm = nn.LSTM(input_size=self.bert.config.hidden_size,
                              hidden_size=self.config.hidden_size,
                              num_layers=self.config.num_layers,
                              dropout=self.config.dropout_prob, 
                              bidirectional=True, batch_first=True)
        self.fc = nn.Linear(in_features=self.config.hidden_size * 2, out_features=self.config.num_labels)
        self.crf = CRF(num_tags=self.config.num_labels, batch_first=True)
        self.dropout = nn.Dropout(self.config.dropout_prob)
        self.linear = nn.Linear(self.bert.config.hidden_size, self.config.num_labels)
        self.attention = SelfAttention(self.config.hidden_size* 2)

    def forward(self, input_ids, attention_mask, labels):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # outputs.pooler_output.size: [batch_size, bert_hidden_size], outputs.last_hidden_state.size[batch_size, embedding size, bert_hidden_size])
        logits=outputs.last_hidden_state
        lstm_out, _ = self.bilstm(logits)
        lstm_out=self.attention(lstm_out)
        emissions = self.fc(lstm_out)
        log_probs = torch.log_softmax(emissions, dim=2)
        labels[labels == -100] = 0
        results = self.crf.decode(emissions)
        if labels is not None:
#             loss = -self.crf(emissions[mask].unsqueeze(0), labels[mask].unsqueeze(0), mask)
            loss = -self.crf(log_probs, labels, attention_mask.byte(), reduction='mean')
            return {"loss": loss, "logits": emissions}
        else:
            return {"labels": torch.tensor(results)}

In [None]:
from transformers import PretrainedConfig
# Define model hyperparameters
num_labels = len(label_list)
num_layers=2
hidden_size = 768
dropout_prob = 0.2
bert_config=bert_model.config

# from transformers.modeling_utils import PreTrainedModel
class CustomConfig(PretrainedConfig):
    def __init__(self, num_labels=num_labels, num_layers=num_layers, hidden_size=hidden_size, dropout_prob=dropout_prob, **kwargs):
        super().__init__(**kwargs)
        self.num_labels = num_labels
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob

config = CustomConfig(num_labels, num_layers, hidden_size, dropout_prob)
config

In [None]:
model = BiLSTMAttentionCRF(config=config, bert_model=bert_model)
model

### Find the best hyperparameters

In [None]:
from transformers import Trainer
from transformers import TrainingArguments
def objective(trial):
    # Define hyperparameters to tune
    hidden_size = trial.suggest_int("hidden_size", 256, 1024)
    dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.6)
    num_layers=trial.suggest_int("num_layers", 2, 2)
    num_labels = len(label_list)
    bert_config=bert_model.config
    config = CustomConfig(num_labels, num_layers, hidden_size, dropout_prob)
    
    # Create the model
    model = BiLSTMAttentionCRF(config=config, bert_model=bert_model)
    #Optimizer
    optimizer = torch.optim.AdamW(
                    model.parameters(),
                    lr=lr_max,
                    weight_decay=weight_decay
                )
    # Training schedule
    lr_sched = get_cosine_schedule_with_warmup(optimizer=optimizer,
                                       num_warmup_steps=num_warmup_steps,
                                       num_training_steps = num_training_steps,
                                       num_cycles=num_cycles)
    # Set up the Trainer
    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZES,
        per_device_eval_batch_size=BATCH_SIZES,
        weight_decay=weight_decay,
        lr_scheduler_type = 'cosine',
        warmup_ratio=warmup_ratio,
        logging_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=3,
        seed=RANDOM_SEED,
        report_to = 'wandb', # enable logging to W&B
        run_name = MODELS["name"][id] +"-based"+"-"+log_date,
        metric_for_best_model="f1",
        load_best_model_at_end = True,
    )
    # Create the Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],  # Replace with your training dataset
        eval_dataset=tokenized_datasets["test"],    # Replace with your evaluation dataset
        data_collator=data_collator,
        optimizers=(optimizer, lr_sched),
        compute_metrics=compute_metrics,
    )
    # Train the model
    trainer.train()
    # Evaluate the model
    result = trainer.evaluate()
    # return dictionary {'eval_loss':, 'eval_precision':, 'eval_recall':, 'eval_f1':, 'eval_accuracy':, 'eval_runtime':, 'eval_samples_per_second':, 'eval_steps_per_second':, 'epoch':} 
    return result["eval_f1"]

In [None]:
# Define the search space for hyperparameters
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=20)  # Run 10 trials for demonstration purposes; you can increase this number for a more thorough search.
print("Number of finished trials: ", len(study.trials))
print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)
print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

In [None]:
torch.cuda.empty_cache()

### final trian

In [None]:
# Training schedule
from transformers import AdamW, get_cosine_schedule_with_warmup
from transformers.optimization import Adafactor, AdafactorSchedule
learning_rate = 0.0000075
lr_max = learning_rate * BATCH_SIZES
weight_decay = 0.05
print("The maximum learning rate is: ",lr_max)
num_train_samples = len(datasets["train"])
warmup_ratio = 0.2 # Percentage of total steps to go from zero to max learning rate
num_cycles=0.8 # The cosine exponential rate
num_training_steps = num_train_samples*EPOCHS/BATCH_SIZES
num_warmup_steps = num_training_steps*warmup_ratio
#Optimizer
optimizer = torch.optim.AdamW(
                model.parameters(),
                lr=lr_max,
                weight_decay=weight_decay
            )
# Learning Rate Schedule
lr_sched = get_cosine_schedule_with_warmup(optimizer=optimizer,
                                   num_warmup_steps=num_warmup_steps,
                                   num_training_steps = num_training_steps,
                                   num_cycles=num_cycles)

In [None]:
from seqeval.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
# for the crf results
def compute_metrics(p):
    predictions, labels = p
    predicted_labels=model.crf.decode(torch.tensor(predictions).to(device))
    # the labels and true_predictions are padded with -100 already, so here is to remove the unwanted ones
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predicted_labels, labels)]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predicted_labels, labels)]
    # Define the metric parameters
    overall_precision = precision_score(true_labels, true_predictions)
    overall_recall = recall_score(true_labels, true_predictions)
    overall_f1 = f1_score(true_labels, true_predictions)
    overall_accuracy = accuracy_score(true_labels, true_predictions)

    # Return a dictionary with the calculated metrics
    return {
        "precision": overall_precision,
        "recall": overall_recall,
        "f1": overall_f1,
        "accuracy": overall_accuracy,
    }

In [None]:
from transformers import TrainingArguments
training_args = TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZES,
    per_device_eval_batch_size=BATCH_SIZES,
    weight_decay=weight_decay,
    lr_scheduler_type = 'cosine',
    warmup_ratio=warmup_ratio,
    logging_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    seed=RANDOM_SEED,
    report_to = 'wandb', # enable logging to W&B
    run_name = MODELS["name"][id] +"-based"+"-"+log_date,
    metric_for_best_model="f1",
    load_best_model_at_end = True,
)

In [None]:
from transformers import Trainer
# Create the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],  # Replace with your training dataset
    eval_dataset=tokenized_datasets["test"],    # Replace with your evaluation dataset
    data_collator=data_collator,
    optimizers=(optimizer, lr_sched),
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
model_path=output_path+"Models/"+MODELS["name"][id]+"-based"

In [None]:
trainer.save_model(model_path)

In [None]:
# Save the config as a JSON file
with open(model_path + "/config.json", "w") as file:
    json.dump(config.to_dict(), file)

In [None]:
wandb.finish()

## Predictions on validation dataset

In [None]:
# Read and parse the JSON data from the file
with open(model_path+"/config.json", "r") as file:
    config_data = json.load(file)
config = CustomConfig(**config_data)
config 

In [None]:
loaded_model = BiLSTMAttentionCRF(config=config, bert_model=bert_model)
# Load the saved state dictionary
loaded_model.load_state_dict(torch.load(output_path+"Models/"+MODELS["name"][id]+"-based/pytorch_model.bin"))
loaded_model

In [None]:
from transformers import TrainingArguments
# Load the TrainingArguments object
args = torch.load(model_path+"/training_args.bin")

In [None]:
pred_trainer = Trainer(
    loaded_model,
    args,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
predictions, labels, _ = pred_trainer.predict(tokenized_datasets["valid"])
predicted_labels=loaded_model.crf.decode(torch.tensor(predictions).to(device))

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predicted_labels, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predicted_labels, labels)
    ]

# Generate the metrics and display
results = classification_report(true_labels, true_predictions, zero_division=1)
print(results)

In [None]:
check=100
datasets["valid"][check]

In [None]:
# Have a look at the predicted extracted data
check_pred = zip(datasets["valid"][check]['tokens'], true_predictions[check])
for tup in check_pred:
    if tup[1] != 'O':
        print(tup)

In [None]:
# Compare to the actual labels
check_true = zip(datasets["valid"][check]['tokens'], true_labels[check])
for tup in check_true:
    if tup[1] != 'O':
        print(tup)