In [1]:
!pip install fasttext
!pip install fastparquet
!gdown --id '1cAthveg1d3MjrKJtMKGzfX3eH8HJ-dQp'
!unzip MedNLI_dataset.zip
!pip install medialpy
!pip install contractions

Collecting fasttext
  Downloading fasttext-0.9.3.tar.gz (73 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/73.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.4/73.4 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pybind11>=2.2 (from fasttext)
  Using cached pybind11-2.13.1-py3-none-any.whl.metadata (9.5 kB)
Using cached pybind11-2.13.1-py3-none-any.whl (238 kB)
Building wheels for collected packages: fasttext
  Building wheel for fasttext (pyproject.toml) ... [?25l[?25hdone
  Created wheel for fasttext: filename=fasttext-0.9.3-cp310-cp310-linux_x86_64.whl size=4246764 sha256=3e5a50d3ba8793191b2631e8adb0da5f0ba96b89329df57671f24f65289d40e5
  Stored in directory: /root/.cache/pip/wheels/0d/a2/00/81db54d3e6a8199b829d58

In [14]:
# @title imports
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader, random_split
from sklearn.metrics import accuracy_score
import time
import string
import medialpy
import contractions
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import time
nltk.download('punkt')
nltk.download('stopwords')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [20]:
# @title Dataset functions
class CustomDataset(Dataset):
    def __init__(self, x_list, y_list):
        self.samples = []
        for x,y in zip(x_list,y_list):
            #x_tensor = torch.tensor(x,dtype = torch.float32)
            y_tensor = torch.tensor(y,dtype = torch.float32)
            self.samples.append((x[0],x[1],y_tensor))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

def find_pre_and_hyp(query):
    start_pre = query.find("[PRE]") + len("[PRE]")
    end_pre = query.find("[HYP]")
    start_hyp = query.find("[HYP]") + len("[HYP]")
    end_hyp = query.find("OUTPUT:")
    premise = query[start_pre:end_pre].strip()
    hypothesis = query[start_hyp:end_hyp].strip()

    return premise,hypothesis

def get_lists(data):
    x_list = []
    y_list = []
    i =0
    for query,answer in zip(data['query'],data['answer']):
        i = i + 1
        if answer == 'entailment':
           y = [1,0,0]
        elif answer == 'neutral':
           y = [0,1,0]
        elif answer == 'contradiction':
           y = [0,0,1]
        else:
           print('should not get here')

        premise,hypothesis = find_pre_and_hyp(query)
        x_list.append((premise,hypothesis))
        y_list.append(y)
    return x_list,y_list

In [21]:
# @title Encoder functions
class embedding_layer(nn.Module):
  def __init__(self,bert_model,tokenizer):
    super(embedding_layer, self).__init__()
    self.bert_model = bert_model
    self.tokenizer = tokenizer

  def forward(self, x):
    with torch.no_grad():
         s = tokenizer(x,return_tensors="pt",padding=True).to(device)
         vec = bert_model(**s)['last_hidden_state'].to(device)
    return vec

class PositionalEncoding(nn.Module):
    def __init__(self, seq_len,embedding_size):
        super().__init__()
        self.dropout = nn.Dropout(0.1)
        self.embedding_size = embedding_size
        self.seq_len = seq_len

    def forward(self, x):
        pe = torch.zeros(x.size(0), x.size(1), self.embedding_size).to(device)
        div_term = torch.zeros(x.size(0), 1, self.embedding_size).to(device)
        ks = torch.arange(self.embedding_size).float().to(device)
        values = torch.exp(-torch.log(torch.tensor(1000.0)) * 2 * ks / self.embedding_size).to(device)
        values = values.view(1, 1, -1).to(device)
        div_term = div_term + values
        x = x.reshape([x.shape[0],x.shape[1],1]).to(device)
        pe[:, :, ::2] = torch.sin(x * div_term)[:, :, ::2].to(device)
        pe[:, :, 1::2] = torch.cos(x * div_term)[:, :, 1::2].to(device)
        return self.dropout(pe)

class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_size, heads):
        super().__init__()
        self.embedding_size = embedding_size
        self.heads = heads
        self.head_dim = embedding_size // heads
        assert(self.heads * self.head_dim == self.embedding_size), "Invalid number of heads"
        self.fc_values = nn.Linear(self.head_dim, self.head_dim, bias=False).to(device)
        self.fc_keys = nn.Linear(self.head_dim, self.head_dim, bias=False).to(device)
        self.fc_queries = nn.Linear(self.head_dim, self.head_dim, bias=False).to(device)
        self.fc_out = nn.Linear(heads * self.head_dim, embedding_size).to(device)

    def forward(self, values, keys, queries, mask):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)
        values = self.fc_values(values).to(device)
        keys = self.fc_keys(keys).to(device)
        queries = self.fc_queries(queries).to(device)
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]).to(device)
        if mask is not None:
            energy = energy.masked_fill(mask.to(device) == 0, float("-1e20")).to(device)
        energy = torch.softmax(energy / (self.embedding_size ** 0.5), dim=3).to(device)
        attention = torch.einsum("nhql,nlhd->nqhd", [energy, values]).to(device)
        attention = attention.reshape(N, query_len, self.heads * self.head_dim).to(device)
        out = self.fc_out(attention)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embedding_size, heads, forward_expansion, p):
        super().__init__()
        self.attention = MultiHeadAttention(embedding_size, heads)
        self.norm1 = nn.LayerNorm(embedding_size)
        self.feed_forward = nn.Sequential(nn.Linear(embedding_size, forward_expansion * embedding_size),
                                          nn.ReLU(),
                                          nn.Linear(forward_expansion * embedding_size, embedding_size))
        self.norm2 = nn.LayerNorm(embedding_size)
        self.dropout = nn.Dropout(p)
    def forward(self, values, keys, queries, mask):
        attention_out = self.attention(values, keys, queries, mask)
        x = self.norm1(attention_out + queries)
        x = self.dropout(x)
        ff_out = self.feed_forward(x)
        out = self.norm2(ff_out + x)
        out = self.dropout(out)
        return out

class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embedding_size, num_layers, heads,
                 forward_expansion, max_length, p, device):
        super().__init__()
        self.device = device
        self.word_embedding = embedding_layer(bert_model,tokenizer)
        self.positional_embedding = PositionalEncoding(max_length, embedding_size)
        self.layers = nn.ModuleList([TransformerBlock(embedding_size, heads, forward_expansion, p) for _ in range(num_layers)])
        self.dropout = nn.Dropout(p)

    def forward(self, x, mask):
        mask = None
        pe = self.word_embedding(x)
        N = pe.size(0)
        seq_len = pe.size(1)
        positions = torch.arange(0, seq_len).expand(N, seq_len).to(self.device)
        out = self.dropout((pe + self.positional_embedding(positions)))
        for layer in self.layers:
            out = layer(out, out, out ,mask)
        return out

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, src_pad_idx, embedding_size=768,
                 num_layers=1, forward_expansion=8, heads=8, max_length=100, p=0.1):
        super().__init__()
        self.src_pad_idx = src_pad_idx
        self.encoder = Encoder(src_vocab_size, embedding_size, num_layers, heads,
                               forward_expansion, max_length, p, device)

    def get_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2).to(device)
        return src_mask

    def forward(self, src):
        src_mask = None
        enc_out = self.encoder(src, src_mask).to(device)
        return enc_out

class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)
        self.transformer1 = Transformer(src_vocab_size, src_pad_idx).to(device)
        self.transformer2 = Transformer(src_vocab_size, src_pad_idx).to(device)
        self.fc1 = nn.Linear(768, 256).to(device)
        self.fc2 = nn.Linear(256, 3).to(device)

    def deabbreviation(self,text):
        try:
           return medialpy.find(text).meaning[0]
        except:
           return text

    def preprocess_text(self,text):
        text = contractions.fix(text)
        text = text.translate(str.maketrans('', '', string.punctuation))
        tokens = word_tokenize(text)
        ntokens = []
        for token in tokens:
            ntokens.append(self.deabbreviation(token).lower())
        stop_words = set(stopwords.words('english'))
        stop_words.remove('no')
        stop_words.remove('not')
        filtered_tokens = [word for word in ntokens if word not in stop_words]
        processed_text = ' '.join(filtered_tokens)
        return processed_text

    def forward(self, x_pre,x_hyp):
        nx_pre = ()
        nx_hyp = ()
        for t in x_pre:
            nx_pre += (self.preprocess_text(t),)
        for t in x_hyp:
            nx_hyp += (self.preprocess_text(t),)
        x_pre = nx_pre
        x_hyp = nx_hyp
        enc_x_pre = self.transformer1(x_pre)
        enc_x_hyp = self.transformer2(x_hyp)
        enc_x_pre = torch.mean(enc_x_pre,1)
        enc_x_hyp = torch.mean(enc_x_hyp,1)
        x = enc_x_pre*enc_x_hyp
        x = F.relu(self.fc1(self.dropout1(x)))
        x = nn.functional.softmax(self.fc2(self.dropout2(x)),dim=1)
        return x

class NNN(nn.Module):
    def __init__(self):
        super(NNN, self).__init__()
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)
        self.transformer1 = Transformer(src_vocab_size, src_pad_idx).to(device)
        self.transformer2 = Transformer(src_vocab_size, src_pad_idx).to(device)
        self.fc1 = nn.Linear(768*2, 256).to(device)
        self.fc2 = nn.Linear(256, 2).to(device)

    def deabbreviation(self,text):
        try:
           return medialpy.find(text).meaning[0]
        except:
           return text

    def preprocess_text(self,text):
        text = contractions.fix(text)
        text = text.translate(str.maketrans('', '', string.punctuation))
        tokens = word_tokenize(text)
        ntokens = []
        for token in tokens:
            ntokens.append(self.deabbreviation(token).lower())
        stop_words = set(stopwords.words('english'))
        stop_words.remove('no')
        stop_words.remove('not')
        filtered_tokens = [word for word in ntokens if word not in stop_words]
        processed_text = ' '.join(filtered_tokens)
        return processed_text

    def forward(self, x_pre,x_hyp):
        nx_pre = ()
        nx_hyp = ()
        for t in x_pre:
            nx_pre += (self.preprocess_text(t),)
        for t in x_hyp:
            nx_hyp += (self.preprocess_text(t),)
        x_pre = nx_pre
        x_hyp = nx_hyp
        enc_x_pre = self.transformer1(x_pre)
        enc_x_hyp = self.transformer2(x_hyp)
        enc_x_pre = torch.mean(enc_x_pre,1)
        enc_x_hyp = torch.mean(enc_x_hyp,1)
        x = torch.cat((enc_x_pre,enc_x_hyp),1)
        x = F.relu(self.fc1(self.dropout1(x)))
        x = nn.functional.softmax(self.fc2(self.dropout2(x)),dim=1)
        return x

def get_model_acc(model,data_loader):
    model.eval()
    predictions = []
    true_labels = []
    for x_pre, x_hyp, y in data_loader:
        outputs = model(x_pre,x_hyp)
        predicted_labels = torch.argmax(outputs, dim=1).cpu().numpy()
        predictions.extend(predicted_labels)
        l1 = torch.argmax(y, dim=1).cpu().numpy()
        true_labels.extend(l1)
    model.train()
    return accuracy_score(true_labels, predictions)

def inference(pre,hyp):
    return model((pre,),(hyp,))

In [22]:
# @title prepare for training main model
test_data = pd.read_parquet('MedNLI_dataset/test-00000-of-00001-47685aa42db61e77.parquet', engine='fastparquet')
train_data = pd.read_parquet('MedNLI_dataset/train-00000-of-00001-210cfe9263b99806.parquet', engine='fastparquet')
valid_data = pd.read_parquet('MedNLI_dataset/valid-00000-of-00001-cc552de6d1a6fa4b.parquet', engine='fastparquet')
###################################################################################################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("gsarti/biobert-nli")
bert_model = AutoModel.from_pretrained("gsarti/biobert-nli").to(device)
src_pad_idx = 0
src_vocab_size = bert_model.config.vocab_size
####################################################################################################################
train_x_list,train_y_list = get_lists(train_data)
test_x_list,test_y_list = get_lists(test_data)
val_x_list,val_y_list = get_lists(valid_data)
####################################################################################################################
train_dataset = CustomDataset(train_x_list,train_y_list)
test_dataset = CustomDataset(test_x_list,test_y_list)
val_dataset = CustomDataset(val_x_list,val_y_list)
####################################################################################################################
bsize = 32
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bsize, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bsize, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=bsize, shuffle=False)
#####################################################################################################################
model = NN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)
#####################################################################################################################

In [23]:
model.train()
num_epochs = 10
for epoch in range(num_epochs):
    for x_pre,x_hyp,y in train_loader:
        optimizer.zero_grad()
        outputs = model(x_pre,x_hyp)
        y = y.to(device)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
    #print(get_model_acc(model,train_loader))
    print(get_model_acc(model,test_loader))
    print(loss.item())
    time.sleep(20)
    print('===========================================================================')
print('Training completed.')

0.6026722925457103
0.8691142797470093
0.7011251758087201
0.8080055713653564
0.69901547116737
0.8341015577316284
0.7144866385372715
0.727597177028656
0.7165963431786216
0.8032186627388
0.7088607594936709
1.0074548721313477
0.7257383966244726
0.7155419588088989
0.7278481012658228
0.7220139503479004
0.7116736990154712
0.6873470544815063
0.720112517580872
0.7042057514190674
Training completed.


In [7]:
# @title reli model functions
def get_trust_model_inf(trustmodel,data_loader):
    trustmodel.eval()
    model.eval()
    trust_x_list = []
    untrust_x_list = []
    trust_y_list = []
    untrust_y_list = []
    for x_pre,x_hyp,y in data_loader:
        outputs = model(x_pre,x_hyp)
        outputs_tw = trustmodel(x_pre,x_hyp)
        y = y.to(device)
        for i in range(len(outputs)):
                 p = get_inference_from_tns(outputs[i])
                 t = get_inference_from_tns(y[i])
                 tw = get_trustworthy_from_tns(outputs_tw[i])
                 if tw == 'trustworthy':
                    trust_x_list.append((x_pre[i],x_hyp[i]))
                    trust_y_list.append(y[i].tolist())
                 elif tw == 'untrustworthy':
                    untrust_x_list.append((x_pre[i],x_hyp[i]))
                    untrust_y_list.append(y[i].tolist())
                 else:
                    print('shouldnt get here')
    trust_dataset = CustomDataset(trust_x_list,trust_y_list)
    untrust_dataset = CustomDataset(untrust_x_list,untrust_y_list)
    trust_data_loader = torch.utils.data.DataLoader(trust_dataset, batch_size=32, shuffle=True)
    untrust_data_loader = torch.utils.data.DataLoader(untrust_dataset, batch_size=32, shuffle=True)
    print('trust num: ' + str(len(trust_y_list)))
    print('untrust num: ' + str(len(untrust_y_list)))
    print('trust acc: ' + str(get_model_acc(model,trust_data_loader)))
    print('untrust acc: ' + str(get_model_acc(model,untrust_data_loader)))
    return trust_dataset,untrust_dataset

def get_inference_from_tns(tns):
    infr = np.argmax(tns.cpu().detach().numpy())
    if infr == 0:
       return 'entailment'
    elif infr == 1:
       return 'neutral'
    elif infr == 2:
       return 'contradiction'
    else:
       return 'shouldnt get here'

def get_trustworthy_from_tns(tns):
    infr = np.argmax(tns.cpu().detach().numpy())
    if infr == 0:
       return 'trustworthy'
    elif infr == 1:
       return 'untrustworthy'
    else:
       return 'shouldnt get here'

In [18]:
# @title prepare for training reli model
correct_label = int(get_model_acc(model,train_loader)*11232)
uncorrect_label = 11232 - correct_label
c_visit = 0
uc_visit = 0
train_relix_list = []
train_reliy_list = []
for x_pre, x_hyp, y in train_loader:
    outputs = model(x_pre,x_hyp)
    predicted_labels = torch.argmax(outputs, dim=1).cpu().numpy()
    true_labels = torch.argmax(y, dim=1).cpu().numpy()
    for i in range(len(predicted_labels)):
          predicted_label = predicted_labels[i]
          true_label = true_labels[i]
          if predicted_label != true_label:
             if uc_visit <= min(correct_label,uncorrect_label):
                uc_visit +=1
                premise = x_pre[i]
                hypothesis = x_hyp[i]
                train_relix_list.append((premise,hypothesis))
                train_reliy_list.append([0,1])
          else:
             if c_visit <= min(correct_label,uncorrect_label):
                c_visit +=1
                premise = x_pre[i]
                hypothesis = x_hyp[i]
                train_relix_list.append((premise,hypothesis))
                train_reliy_list.append([1,0])
relitrain_dataset = CustomDataset(train_relix_list,train_reliy_list)
###########################################################################################
val_relix_list = []
val_reliy_list = []
for x_pre, x_hyp, y in val_loader:
    outputs = model(x_pre,x_hyp)
    predicted_labels = torch.argmax(outputs, dim=1).cpu().numpy()
    true_labels = torch.argmax(y, dim=1).cpu().numpy()
    for i in range(len(predicted_labels)):
          predicted_label = predicted_labels[i]
          true_label = true_labels[i]
          if predicted_label != true_label:
             premise = x_pre[i]
             hypothesis = x_hyp[i]
             val_relix_list.append((premise,hypothesis))
             val_reliy_list.append([0,1])
          else:
             premise = x_pre[i]
             hypothesis = x_hyp[i]
             val_relix_list.append((premise,hypothesis))
             val_reliy_list.append([1,0])
relival_dataset = CustomDataset(val_relix_list,val_reliy_list)
###########################################################################################
test_relix_list = []
test_reliy_list = []
for x_pre, x_hyp, y in test_loader:
    outputs = model(x_pre,x_hyp)
    predicted_labels = torch.argmax(outputs, dim=1).cpu().numpy()
    true_labels = torch.argmax(y, dim=1).cpu().numpy()
    for i in range(len(predicted_labels)):
          predicted_label = predicted_labels[i]
          true_label = true_labels[i]
          if predicted_label != true_label:
             premise = x_pre[i]
             hypothesis = x_hyp[i]
             test_relix_list.append((premise,hypothesis))
             test_reliy_list.append([0,1])
          else:
             premise = x_pre[i]
             hypothesis = x_hyp[i]
             test_relix_list.append((premise,hypothesis))
             test_reliy_list.append([1,0])
relitest_dataset = CustomDataset(test_relix_list,test_reliy_list)
###################################################################################################
relitrain_loader = torch.utils.data.DataLoader(relitrain_dataset, batch_size=bsize, shuffle=True)
relival_loader = torch.utils.data.DataLoader(relival_dataset, batch_size=bsize, shuffle=True)
relitest_loader = torch.utils.data.DataLoader(relitest_dataset, batch_size=bsize, shuffle=True)
###################################################################################################
relimodel = NNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(relimodel.parameters(), lr=0.0001)

In [21]:
relimodel.train()
num_epochs = 3
for epoch in range(num_epochs):
    for x_pre,x_hyp,y in relitrain_loader:
        optimizer.zero_grad()
        outputs = relimodel(x_pre,x_hyp)
        y = y.to(device)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
    get_trust_model_inf(relimodel,test_loader)
    print(loss.item())
    print('==================================')
print('Training completed.')

trust num: 489
untrust num: 933
trust acc: 0.8179959100204499
untrust acc: 0.6495176848874598
0.6973949670791626
trust num: 420
untrust num: 1002
trust acc: 0.8738095238095238
untrust acc: 0.6447105788423154
0.6374236941337585
trust num: 489
untrust num: 933
trust acc: 0.8425357873210634
untrust acc: 0.639871382636656
0.5671142339706421
Training completed.
