In [16]:
# !pip install protobuf==3.20.*

In [79]:
import os
import json
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification

import pytorch_lightning as pl
from torchmetrics.classification import Accuracy

import argparse

from tqdm import tqdm
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

In [48]:
class Load_Preprocess():
    def __init__(self):
        self.Label2num = {'Entailment': 1, "Contradiction": 0}
        
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
        self.train, self.train_labels = self.prepare_dataset(self.load_Processed_Data(self.load_Processed_cTAKES('ctakes/Train_Statements_cTAKES_Processed.json'), 'train'))
        self.dev, self.dev_labels = self.prepare_dataset(self.load_Processed_Data(self.load_Processed_cTAKES('ctakes/Dev_Statements_cTAKES_Processed.json'), 'dev'))
        self.test, self.test_labels = self.prepare_dataset(self.load_Processed_Data(self.load_Processed_cTAKES('ctakes/Test_Statements_cTAKES_Processed.json'), 'test'))

        self.max_len = self.get_max_length(self.train)
        print(self.max_len)

    def load_Processed_cTAKES(self, split):
        ctakes_tokens_path = split
        with open(ctakes_tokens_path) as json_file:
            ctakes_tokens = json.load(json_file)

        preferred_text_dict = dict()
        for i in range(len(ctakes_tokens)):
            cm_dict = ctakes_tokens[i]['clinical_mention']
            temp_list = []
            for key in cm_dict.keys():
                try:
                    temp_list.append(cm_dict[key][0]['preferredText'])
                except:
                    continue
            preferred_text_dict[ctakes_tokens[i]['UUID'][0]] = temp_list
        return preferred_text_dict

    def load_Processed_Data(self, ctakes, split):
        preferred_text = []
        statement = []
        trail1 = []
        trail2 = []
        section = []
        label = []
        
        with open(f"training_data/{split}" + ".json") as file:
            data = json.load(file)
            uuid_list = list(data.keys())

        for id in uuid_list:
            statement.append(data[id]['Statement'])
            if split != 'test':
                label.append(self.Label2num[data[id]['Label']])
            section.append(data[id]['Section_id'])
        
            with open(f"training_data/CT json/{data[id]['Primary_id']}" + ".json") as file: 
                ct = json.load(file)
                trail1.append(self.join_list(ct[data[id]['Section_id']]))
                
            if data[id]['Type'] == "Comparison":  
                with open(f"training_data/CT json/{data[id]['Secondary_id']}" + ".json") as file:
                    ct = json.load(file)
                    trail2.append(self.join_list(ct[data[id]['Section_id']]))
            else:
                trail2.append("_")
                
            preferred_text.append(','.join(ctakes[id]))
        
        return {'preferred_text':preferred_text, 'statement':statement, 'trail1':trail1, 'trail2':trail2, 'section':section, 'label':label}
        
    def join_list(self, sentences):
        return ", ".join([sent.strip() for sent in sentences])
    
    def merge_inputs(self, idx, Data): 
        stat, sec = Data['statement'][idx], Data['section'][idx]
        trail1, trail2, pref_text = Data['trail1'][idx], Data['trail2'][idx], Data['preferred_text'][idx]
        
        if trail2 == '_':
            sent = f"{stat} [SEP] {sec} [SEP] {trail1} [SEP] {pref_text}"
        else:
            sent = f"{stat} [SEP] {sec} [SEP] {trail1} [SEP] {trail2} [SEP] {pref_text}"
        return sent.strip()

    def prepare_dataset(self, Data):
        Texts = []
        for idx in range(0, len(Data['statement'])):
            Texts.append(self.merge_inputs(idx, Data))
        return Texts, Data['label']
    
    def get_max_length(self, Texts):
        lens = [len(text.split(' ')) for text in Texts]
        print('index', lens.index(max(lens)))
        return max(lens)

CT_Data = Load_Preprocess()

index 273
1747


In [44]:
CT_Data.dev[0]

'there is a 13.2% difference between the results from the two the primary trial cohorts [SEP] Results [SEP] Outcome Measurement:, Event-free Survival, Event free survival, the primary endpoint of this study, is defined as the time from randomization to the time of documented locoregional or distant recurrence, new primary breast cancer, or death from any cause., Time frame: 5 years, Results 1:, Arm/Group Title: Exemestane, Arm/Group Description: Patients receive oral exemestane (25 mg) once daily for 5 years., exemestane: Given orally, Overall Number of Participants Analyzed: 3789, Measure Type: Number, Unit of Measure: percentage of participants  88        (87 to 89), Results 2:, Arm/Group Title: Anastrozole, Arm/Group Description: Patients receive oral anastrozole (1 mg) once daily for 5 years., anastrozole: Given orally, Overall Number of Participants Analyzed: 3787, Measure Type: Number, Unit of Measure: percentage of participants  89        (88 to 90) [SEP] Primary operation'

In [76]:
# CT_Data.train[273], CT_Data.train_labels[273]

In [96]:
class NLI_Dataset(Dataset):
    def __init__(self, Texts, labels, max_len, tokenizer):
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        self.Texts = Texts
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        sent = self.tokenizer(self.Texts[idx], return_tensors="pt", padding="max_length", max_length=self.max_len, truncation=True)
        label = torch.tensor(self.labels[idx])
        return sent, label

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_dataset = NLI_Dataset(CT_Data.train, CT_Data.train_labels, CT_Data.max_len, tokenizer)
Val_dataset = NLI_Dataset(CT_Data.dev, CT_Data.dev_labels, CT_Data.max_len, tokenizer)
test_dataset = NLI_Dataset(CT_Data.test, CT_Data.test_labels, CT_Data.max_len, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(Val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [97]:
train_dataset.__getitem__(0)

({'input_ids': tensor([[ 101, 2035, 1996,  ...,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0]])},
 tensor(0))

In [100]:
train_dataset.__getitem__(1)[0]['input_ids'].shape

torch.Size([1, 1747])

In [101]:
class SemEval_NLI(pl.LightningModule):
    def __init__(self) :
        super(SemEval_NLI, self).__init__()

        self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
                
        self.CE = nn.CrossEntropyLoss()
        self.Faithfulness = nn.L1Loss()

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
#         print(y_hat.shape)
        loss = self.CE(y_hat, y)
        self.log('Train_loss', loss, prog_bar=True, logger=True)
        acc = self.Faithfulness(y_hat, y)
        self.log('Train_acc', acc, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.CE(y_hat, y)
        self.log('Val_loss', loss, prog_bar=True, logger=True)
        acc = self.Faithfulness(y_hat, y)
        self.log('Val_acc', acc, prog_bar=True)
        
        return loss

    def test_step(self, test_batch, batch_idx):
        x, y = test_batch
        y_hat = self.forward(x)
        loss = self.CE(y_hat, y)
        acc = self.Faithfulness(y_hat, y)
        
        self.log('test_loss', loss)
        self.log('test_acc', acc, prog_bar=True)

        print('Test_acc', acc)
    
    def predict_step(self, test_batch, batch_idx, dataloader_idx=None):
        with torch.no_grad():
            x, y = test_batch  # Assuming your input data is the first element of the batch
            logits = self.forward(x)  # Get logits from your model

            # Apply softmax to get probability distribution over classes
            probabilities = F.softmax(logits, dim=1)

            # Get class predictions by finding the class with the highest probability
            _, predicted_classes = torch.max(probabilities, dim=1)

        return predicted_classes.tolist()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=5e-6)

In [102]:
torch.manual_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Model = SemEval_NLI().to(device)
print(Model)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

SemEval_NLI(
  (model): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bi

In [103]:
trainer = pl.Trainer(max_epochs=5)
trainer.fit(Model, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3080 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type                          | Params
---------------------------------------------------------------
0 | model        | BertForSequenceClassification | 109 M 
1 | CE           | CrossEntropyLoss              | 0     
2 | Faithfulness | L1Loss                        | 0     
---------------------------------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     To

Sanity Checking: 0it [00:00, ?it/s]

AttributeError: 

In [None]:
predictions = trainer.predict(Model, test_loader)