In [1]:
import pandas as pd
import numpy as np
import json

In [2]:
train_df=pd.read_csv('./output/train_statement_section.csv')
train_df

Unnamed: 0,statement,section,label
0,All the primary trial participants do not rece...,"[""INTERVENTION 1:"", ""Diagnostic (FLT PET)"", ""P...",Contradiction
1,"Patients with Platelet count over 100,000/mm¬¨...","[""DISEASE CHARACTERISTICS:"", ""Histologically o...",Contradiction
2,Heart-related adverse events were recorded in ...,"[""Adverse Events 1:"", ""Total: 5/32 (15.63%)"", ...",Entailment
3,Adult Patients with histologic confirmation of...,"[""Inclusion Criteria:"", ""Patients with histolo...",Contradiction
4,Laser Therapy is in each cohort of the primary...,"[""INTERVENTION 1:"", ""Laser Therapy Alone"", ""th...",Contradiction
...,...,...,...
1695,"Adequate blood, kidney, and hepatic function a...","[""Inclusion Criteria:"", ""Postmenopausal women,...",Entailment
1696,The Ridaforolimus + Dalotuzumab + Exemestane g...,"[""Outcome Measurement:"", ""1. Progression-free ...",Contradiction
1697,The only difference between the interventions ...,"[""INTERVENTION 1:"", ""Prone"", ""Prone position"",...",Entailment
1698,Patients must have a white blood cell count ab...,"[""DISEASE CHARACTERISTICS:"", ""Histologically c...",Entailment


In [3]:
hypothesis_lst=train_df['statement'].values.tolist()
len(hypothesis_lst)

1700

In [4]:
evidence_lst=train_df['section'].apply(lambda l:' '.join(json.loads(l))).values.tolist()
len(evidence_lst)

1700

In [5]:
label2id={"Contradiction":0,"Entailment":1}
label_lst=train_df['label'].apply(lambda x:label2id[x]).values.tolist()
len(label_lst)

1700

In [6]:
class InputSequence:
    
    def __init__(self,tok,l_text,l_text2,l_label,batch_size=64,gpu=True):
        
        self.data_len=len(l_text)
        self.data_idx=[i for i in range(self.data_len)]
        # self.texts=tok(l_text,l_text2,padding=True, truncation=True, max_length=512, return_tensors='pt')
        self.texts=tok(l_text,padding=True, truncation=True, max_length=512, return_tensors='pt')
        self.l_label=np.array(l_label)
        print('tokenize done')
        
        self.batch_size=batch_size
        self.gpu=gpu
        
    def on_epoch_end(self):
        random.shuffle(self.data_idx)
        
    def __getitem__(self,i):
        start=i*self.batch_size
        batch_idx=self.data_idx[start:min(start+self.batch_size,self.data_len)]
        
        return_texts=dict([(k,self.texts[k][batch_idx]) for k in self.texts])
        return_labels=torch.from_numpy(
            self.l_label[batch_idx].astype(np.int64)
        )
        
        if self.gpu:
            return_texts=dict([(k,return_texts[k].cuda()) for k in return_texts])
            return_labels=return_labels.cuda()
        
        return return_texts,return_labels
    
    def __len__(self):
        return math.ceil(1.0*self.data_len/self.batch_size)
    

In [7]:
import random
import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [8]:
text_tok=AutoTokenizer.from_pretrained('bert-base-uncased')
text_clf=AutoModelForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [9]:
training_data=InputSequence(text_tok,hypothesis_lst,evidence_lst,label_lst,gpu=True)
len(training_data)

tokenize done


27

In [10]:
class Model(nn.Module):
    def __init__(self,clf):
        super(Model, self).__init__()
        self.clf=clf
        self.loss=nn.CrossEntropyLoss()
    
    def forward(self, texts, labels, gpu=True):
        
        loss=self.loss(self.clf(**texts).logits, labels)
        
        return loss

In [11]:
model=Model(text_clf)

In [12]:
bat_s=32
l_rate=1e-5

training_data.batch_size=bat_s

model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=l_rate)
total_epoch_num=10
for epoch in range(total_epoch_num):
    training_data.on_epoch_end()
    loss_sum=0.0
    loss_count=0
    for batch in range(len(training_data)):
        optimizer.zero_grad()
        batch_texts,batch_labels=training_data[batch]
        loss_count+=len(batch_texts['input_ids'])
        loss = model(
            batch_texts,batch_labels
        )
        print('epoch:',epoch,'batch:',batch,'loss:',loss.item(),end='\n' if batch==0 or batch+1==len(training_data) or (batch+1)%1000==0 else '\r')
        loss_sum += 1.0*loss.item()*len(batch_texts['input_ids'])
        loss.backward()
        optimizer.step()
    model.clf.save_pretrained('./output/clf_hypothesis_models/bert-base-uncased_epoch_{}.pt'.format(format(epoch,'05d')))
_=model.cpu()

epoch: 0 batch: 0 loss: 0.749165415763855
epoch: 0 batch: 53 loss: 0.6988679766654968
epoch: 1 batch: 0 loss: 0.6916448473930359
epoch: 1 batch: 53 loss: 0.6514798998832703
epoch: 2 batch: 0 loss: 0.6576327085494995
epoch: 2 batch: 53 loss: 0.7498207092285156
epoch: 3 batch: 0 loss: 0.6945198178291321
epoch: 3 batch: 53 loss: 0.5770949721336365
epoch: 4 batch: 0 loss: 0.5928168892860413
epoch: 4 batch: 53 loss: 0.5595516562461853
epoch: 5 batch: 0 loss: 0.46898719668388367
epoch: 5 batch: 53 loss: 0.56804347038269046
epoch: 6 batch: 0 loss: 0.35951006412506104
epoch: 6 batch: 53 loss: 0.29742273688316345
epoch: 7 batch: 0 loss: 0.3593253195285797
epoch: 7 batch: 53 loss: 0.15454281866550446
epoch: 8 batch: 0 loss: 0.17078536748886108
epoch: 8 batch: 53 loss: 0.69046181440353397
epoch: 9 batch: 0 loss: 0.1680673211812973
epoch: 9 batch: 53 loss: 0.27725851535797127
