In [1]:
from transformers import AutoModel, AutoConfig, AutoTokenizer
from transformers import BertTokenizerFast, BertForSequenceClassification, Trainer, TrainingArguments
import torch
import torch.nn as nn
import pytorch_lightning as pl
from config import CONFIG
import pandas as pd
from torch.utils.data import Dataset, DataLoader


In [2]:
# TODO: generalize to [String] -> [Class] system 
# TODO: generalize forward pass
# TODO: set parameter that enables cls token utilization or arbitrary hidden layer utilization
# TODO: (MAYBE) generalize models to extend BASE, otherwise add 
from config import CONFIG

from transformers import AutoModel, AutoTokenizer, AutoConfig
from torch import nn

class BaseClassificationModel(nn.Module):
    
    def __init__(self, dropout = 0.05, n_classes = 2, injection = False):
        super(BaseClassificationModel, self).__init__()
        
        
        # model body
        self.model = AutoModel.from_pretrained(CONFIG.pretrained_model_name)
        
        self.hidden_size = self.model.config.hidden_size #768
        
        # (standard) model classification head
        self.head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size, n_classes),
            nn.Softmax(dim=1)
        )
        
        # initialize weights in linear layers
        self.init_weights(self.head)
        
        
    def init_weights(self, module):
        for layer in module:
            if isinstance(layer, nn.Linear):
                layer.weight.data.normal_(mean = 0.0, std = 0.02)
                if layer.bias is not None:
                    layer.bias.data.zero_()
                    
    
    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        
        output = self.model(input_ids = input_ids, 
                       token_type_ids = token_type_ids,
                       attention_mask = attention_mask,
                       output_hidden_states = True)
        
        
        # last hidden state of all tokens
        last_hidden_state = output.last_hidden_state
        
        ################ Hidden States of each Transformer block
        ## Index=0 -> initial hidden state as token embedding + position embedding + segment embedding
        ## Index=13 -> last hidden state for each token in the sequence
        ## 
        ## "The ELMO authors suggest that lower levels encode syntax, while higher levels encode semantics."
        # hidden_states = output.hidden_states
        ################
        
        # hidden state of the first token e.g. classification token [CLS] or <s>
        # not a good representation of the whole sequence for decoder-based models such as GPT2
        cls_hidden_state = last_hidden_state[:, 0, :]
        
        # TODO: average hidden state of each token for each layer as better representation
        # TODO: consider earlier hidden states for syntax focused classification 
        return self.head(cls_hidden_state)

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
# Program
torch.manual_seed(0)

class RelationDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __getitem__(self, index):
        premise = self.data["premise"].iloc[index]
        claim = self.data["claim"].iloc[index]

        encoding = self.tokenizer.encode_plus(
            premise,
            claim,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_token_type_ids=True,
            return_tensors='pt'
        )


        if 'label' in self.data.columns:
            
            label = torch.tensor(0 if self.data["label"].iloc[index] == "Attack" else 1, dtype=torch.int64)
            
            return {
            'input_ids': encoding['input_ids'].flatten(),
            'token_type_ids': encoding['token_type_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': label
         }
            
        else:
            return {
            'input_ids': encoding['input_ids'].flatten(),
            'token_type_ids': encoding['token_type_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }

    def __len__(self):
        return len(self.data)

In [4]:
from config import CONFIG

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from typing import List
import pandas

import pytorch_lightning as pl
import torchmetrics 
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup


class ClassificationModule(pl.LightningModule):
    
    def __init__(self, model):
        super().__init__()

        self.model = model
        
        self.loss = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(task="binary")

    def forward(self, input_ids, attention_mask, token_type_ids):
        return self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    
    def step(self, batch, batch_idx, mode):
        input_ids = batch["input_ids"]
        token_type_ids = batch["token_type_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        #x, y = batch
        logits = self.forward(input_ids, attention_mask, token_type_ids )

        predictions = logits.argmax(dim = 1)
        
        loss = self.loss(logits, labels)
        accuracy = self.accuracy(predictions, labels)

        self.log(f'{mode}_loss', loss, on_epoch=True, prog_bar=True)
        self.log(f'{mode}_accuracy', accuracy, on_epoch=True, prog_bar=True)

        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, 'train')
    
    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, 'val')

    def test_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, 'test')

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        
        input_ids = batch["input_ids"]
        token_type_ids = batch["token_type_ids"]
        attention_mask = batch["attention_mask"]
        logits = self(input_ids, attention_mask, token_type_ids)
        predictions = logits.argmax(dim=-1)
        
        return predictions

        
    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=CONFIG.learning_rate)
        lr_scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=CONFIG.warmup_steps,
            num_training_steps=len(self.train_dataloader().dataset) // CONFIG.batch_size * CONFIG.epochs,
        )
        return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}]

    def create_data_loader(self, mode: str, shuffle=False):       
        df = pd.read_pickle("../data/microtext_references.pickle")
        split = df[df['mode'] == mode]
        
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
            
        return DataLoader(
            RelationDataset(split, tokenizer),
            batch_size = CONFIG.batch_size if mode == "train" else CONFIG.batch_size // 4,
            shuffle=shuffle, num_workers = CONFIG.num_workers
        )
    
    def train_dataloader(self):
        return self.create_data_loader(mode = "train", shuffle=True)

    def val_dataloader(self):
        return self.create_data_loader(mode = "validate")

    def test_dataloader(self):
        return self.create_data_loader(mode = "test")
        


In [5]:
class CombinedModel(nn.Module):
    def __init__(self, base_model_checkpoint, ernie_model_path, dropout=0.05, n_classes=2):
        super(CombinedModel, self).__init__()

        # Load the PyTorch trained base model
        # Initialize base model class
        self.base_model = BaseClassificationModel()
        
        # Initialize the ClassificationModule with the base model
        self.base_module = ClassificationModule(self.base_model)
        
        # Load the checkpoint
        checkpoint = torch.load(base_model_checkpoint)
        
        # Load the state dict into your base model
        self.base_model.load_state_dict(checkpoint['state_dict'], strict=False)
    

        # Load the pre-trained ERNIE model
        ernie_config = AutoConfig.from_pretrained(ernie_model_path)
        self.ernie_model = AutoModel.from_pretrained(ernie_model_path, config=ernie_config)

        # Set the hidden size based on one of the models
        self.hidden_size = ernie_config.hidden_size
        

        # (combined) model classification head
        self.head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(770, 768),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(768, n_classes),
        )
        #self.head = nn.Sequential(
        #    nn.Linear(770, 768),
        #   nn.ReLU(),
        #     nn.Linear(768, n_classes),
        #)

    def forward(self, base_model_inputs, ernie_model_inputs):
        # Pass input through base model
        base_output = self.base_model(**base_model_inputs)
        
        # Pass input through ERNIE model
        ernie_output = self.ernie_model(**ernie_model_inputs)

        # Get the cls hidden state of the ERNIE model
        ernie_hidden_state = ernie_output.last_hidden_state[:, 0, :]
        
        # Concatenate the base model's output logits and ERNIE model's cls hidden states
        combined_output = torch.cat((base_output, ernie_hidden_state), dim=1)

        # Pass through final classification head
        return self.head(combined_output)


In [22]:
base_model_checkpoint = "./trained_model/base_model.ckpt"
ernie_model_path = "../notebooks/models/"
combined_model = CombinedModel(base_model_checkpoint, ernie_model_path)

# Freeze the parameters
for param in combined_model.base_model.parameters():
    param.requires_grad = False

for param in combined_model.ernie_model.parameters():
    param.requires_grad = False


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at ../notebooks/models/ were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing B

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model_checkpoint = "./trained_model/base_model_kialo.ckpt"
ernie_model_path = "../notebooks/models/kialo/"
combined_model = CombinedModel(base_model_checkpoint, ernie_model_path)

# Freeze the parameters
for param in combined_model.base_model.parameters():
    param.requires_grad = False

for param in combined_model.ernie_model.parameters():
    param.requires_grad = False

combined_model = combined_model.to(device)

# Tokenizers
base_model_tokenizer = AutoTokenizer.from_pretrained(CONFIG.pretrained_model_name)
ernie_model_tokenizer = BertTokenizerFast.from_pretrained('nghuyong/ernie-2.0-en')

# Example test data
# Load the test dataset
mapping = {'Attack': 0, 'Support': 1}
df = pd.read_pickle("../data/kialo_references.pickle")
#split = df[df['mode'] == 'test']
#true_labels = split['label'].map(mapping)
split = df[(df['mode'] == 'test') & (df['label'] != 'Rephrase')] #Kialo data set
true_labels = split['label'].map(mapping) #kialo data set

# Prepare the datasets
ernie_dataset = RelationDataset(split[['premise', 'claim']], ernie_model_tokenizer)
base_dataset = RelationDataset(split[['premise', 'claim']], base_model_tokenizer)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at ../notebooks/models/kialo/ were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initiali

In [7]:
# Prepare the dataloaders
batch_size = 16  # Choose an appropriate batch size for your environment
base_dataloader = DataLoader(base_dataset, batch_size=batch_size)
ernie_dataloader = DataLoader(ernie_dataset, batch_size=batch_size)

combined_model.eval()  # Set the model to evaluation mode
predictions = []

# Iterate over batches from both dataloaders
for (base_batch, ernie_batch) in zip(base_dataloader, ernie_dataloader):
    with torch.no_grad():
        # Move batch to device
        base_batch = {k: v.to(device) for k, v in base_batch.items()}
        ernie_batch = {k: v.to(device) for k, v in ernie_batch.items()}

        # Get model outputs
        outputs = combined_model(base_batch, ernie_batch)
    
    # Get the predictions from the outputs
    predictions.extend(torch.argmax(outputs, dim=1).tolist())

#print(predictions)


In [8]:
from sklearn.metrics import classification_report

# Print classification report
report = classification_report(true_labels, predictions)
print(report)

              precision    recall  f1-score   support

           0       0.00      0.00      0.00     13237
           1       0.50      1.00      0.67     13480

    accuracy                           0.50     26717
   macro avg       0.25      0.50      0.34     26717
weighted avg       0.25      0.50      0.34     26717



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
