In [257]:
!pip install fasttext
!pip install fastparquet



In [258]:
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader, random_split
from sklearn.metrics import accuracy_score

In [259]:
test_data = pd.read_parquet('test-00000-of-00001-47685aa42db61e77.parquet', engine='fastparquet')
train_data = pd.read_parquet('train-00000-of-00001-210cfe9263b99806.parquet', engine='fastparquet')
valid_data = pd.read_parquet('valid-00000-of-00001-cc552de6d1a6fa4b.parquet', engine='fastparquet')
test_data

Unnamed: 0,id,query,answer,choices,gold
0,MedNLI0,\nTASK: Please classify the relationship betwe...,entailment,"[entailment, contradiction, neutral]",0
1,MedNLI1,\nTASK: Please classify the relationship betwe...,contradiction,"[entailment, contradiction, neutral]",1
2,MedNLI2,\nTASK: Please classify the relationship betwe...,neutral,"[entailment, contradiction, neutral]",2
3,MedNLI3,\nTASK: Please classify the relationship betwe...,entailment,"[entailment, contradiction, neutral]",0
4,MedNLI4,\nTASK: Please classify the relationship betwe...,contradiction,"[entailment, contradiction, neutral]",1
...,...,...,...,...,...
1417,MedNLI1417,\nTASK: Please classify the relationship betwe...,contradiction,"[entailment, contradiction, neutral]",1
1418,MedNLI1418,\nTASK: Please classify the relationship betwe...,neutral,"[entailment, contradiction, neutral]",2
1419,MedNLI1419,\nTASK: Please classify the relationship betwe...,entailment,"[entailment, contradiction, neutral]",0
1420,MedNLI1420,\nTASK: Please classify the relationship betwe...,contradiction,"[entailment, contradiction, neutral]",1


In [260]:
def find_pre_and_hyp(query):
    start_pre = query.find("[PRE]") + len("[PRE]")
    end_pre = query.find("[HYP]")
    start_hyp = query.find("[HYP]") + len("[HYP]")
    end_hyp = query.find("OUTPUT:")
    premise = query[start_pre:end_pre].strip()
    hypothesis = query[start_hyp:end_hyp].strip()

    return premise,hypothesis

In [261]:
class CustomDataset(Dataset):
    def __init__(self, x_list, y_list):
        self.samples = []
        for x,y in zip(x_list,y_list):
            #x_tensor = torch.tensor(x,dtype = torch.float32)
            y_tensor = torch.tensor(y,dtype = torch.float32)
            self.samples.append((x[0],x[1],y_tensor))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [262]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("gsarti/biobert-nli")
bert_model = AutoModel.from_pretrained("gsarti/biobert-nli").to(device)
src_pad_idx = 0
src_vocab_size = bert_model.config.vocab_size

In [263]:
x_list = []
y_list = []
i =0
for query,answer in zip(train_data['query'],train_data['answer']):
    i = i + 1
    if answer == 'entailment':
       y = [1,0,0]
    elif answer == 'neutral':
       y = [0,1,0]
    elif answer == 'contradiction':
       y = [0,0,1]
    else:
        print('should not get here')

    premise,hypothesis = find_pre_and_hyp(query)

    x_list.append((premise,hypothesis))
    y_list.append(y)

In [264]:
def get_lists(data):
    x_list = []
    y_list = []
    i =0
    for query,answer in zip(data['query'],data['answer']):
        i = i + 1
        if answer == 'entailment':
           y = [1,0,0]
        elif answer == 'neutral':
           y = [0,1,0]
        elif answer == 'contradiction':
           y = [0,0,1]
        else:
           print('should not get here')

        premise,hypothesis = find_pre_and_hyp(query)
        x_list.append((premise,hypothesis))
        y_list.append(y)
    return x_list,y_list

In [265]:
train_x_list,train_y_list = get_lists(train_data)
test_x_list,test_y_list = get_lists(test_data)
val_x_list,val_y_list = get_lists(valid_data)

In [266]:
print(len(train_x_list))
print(len(train_y_list))
print(len(test_x_list))
print(len(test_y_list))
print(len(val_x_list))
print(len(val_y_list))

11232
11232
1422
1422
1395
1395


In [267]:
train_dataset = CustomDataset(train_x_list,train_y_list)
test_dataset = CustomDataset(test_x_list,test_y_list)
val_dataset = CustomDataset(val_x_list,val_y_list)

In [268]:
class embedding_layer(nn.Module):
  def __init__(self,bert_model,tokenizer):
    super(embedding_layer, self).__init__()
    self.bert_model = bert_model
    self.tokenizer = tokenizer

  def forward(self, x):
    with torch.no_grad():
         s = tokenizer(x,return_tensors="pt",padding=True).to(device)
         vec = bert_model(**s)['last_hidden_state'].to(device)
    return vec

In [269]:
class PositionalEncoding(nn.Module):
    def __init__(self, seq_len,embedding_size):
        super().__init__()
        self.dropout = nn.Dropout(0.1)
        self.embedding_size = embedding_size
        self.seq_len = seq_len

    def forward(self, x):
        pe = torch.zeros(x.size(0), x.size(1), self.embedding_size).to(device)
        div_term = torch.zeros(x.size(0), 1, self.embedding_size).to(device)
        ks = torch.arange(self.embedding_size).float().to(device)
        values = torch.exp(-torch.log(torch.tensor(1000.0)) * 2 * ks / self.embedding_size).to(device)
        values = values.view(1, 1, -1).to(device)
        div_term = div_term + values
        x = x.reshape([x.shape[0],x.shape[1],1]).to(device)
        pe[:, :, ::2] = torch.sin(x * div_term)[:, :, ::2].to(device)
        pe[:, :, 1::2] = torch.cos(x * div_term)[:, :, 1::2].to(device)
        return self.dropout(pe)

In [270]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_size, heads):
        super().__init__()
        self.embedding_size = embedding_size
        self.heads = heads
        self.head_dim = embedding_size // heads
        assert(self.heads * self.head_dim == self.embedding_size), "Invalid number of heads"
        self.fc_values = nn.Linear(self.head_dim, self.head_dim, bias=False).to(device)
        self.fc_keys = nn.Linear(self.head_dim, self.head_dim, bias=False).to(device)
        self.fc_queries = nn.Linear(self.head_dim, self.head_dim, bias=False).to(device)
        self.fc_out = nn.Linear(heads * self.head_dim, embedding_size).to(device)

    def forward(self, values, keys, queries, mask):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)
        values = self.fc_values(values).to(device)
        keys = self.fc_keys(keys).to(device)
        queries = self.fc_queries(queries).to(device)
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]).to(device)
        if mask is not None:
            energy = energy.masked_fill(mask.to(device) == 0, float("-1e20")).to(device)
        energy = torch.softmax(energy / (self.embedding_size ** 0.5), dim=3).to(device)
        attention = torch.einsum("nhql,nlhd->nqhd", [energy, values]).to(device)
        attention = attention.reshape(N, query_len, self.heads * self.head_dim).to(device)
        out = self.fc_out(attention)
        return out

In [271]:
class TransformerBlock(nn.Module):
    def __init__(self, embedding_size, heads, forward_expansion, p):
        super().__init__()
        self.attention = MultiHeadAttention(embedding_size, heads)
        self.norm1 = nn.LayerNorm(embedding_size)
        self.feed_forward = nn.Sequential(nn.Linear(embedding_size, forward_expansion * embedding_size),
                                          nn.ReLU(),
                                          nn.Linear(forward_expansion * embedding_size, embedding_size))
        self.norm2 = nn.LayerNorm(embedding_size)
        self.dropout = nn.Dropout(p)
    def forward(self, values, keys, queries, mask):
        attention_out = self.attention(values, keys, queries, mask)
        x = self.norm1(attention_out + queries)
        x = self.dropout(x)
        ff_out = self.feed_forward(x)
        out = self.norm2(ff_out + x)
        out = self.dropout(out)
        return out

In [272]:
class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embedding_size, num_layers, heads,
                 forward_expansion, max_length, p, device):
        super().__init__()
        self.device = device
        self.word_embedding = embedding_layer(bert_model,tokenizer)
        self.positional_embedding = PositionalEncoding(max_length, embedding_size)
        self.layers = nn.ModuleList([TransformerBlock(embedding_size, heads, forward_expansion, p) for _ in range(num_layers)])
        self.dropout = nn.Dropout(p)

    def forward(self, x, mask):
        mask = None
        pe = self.word_embedding(x)
        N = pe.size(0)
        seq_len = pe.size(1)
        positions = torch.arange(0, seq_len).expand(N, seq_len).to(self.device)
        out = self.dropout((pe + self.positional_embedding(positions)))
        for layer in self.layers:
            out = layer(out, out, out ,mask)
        return out

In [309]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, src_pad_idx, embedding_size=768,
                 num_layers=1, forward_expansion=8, heads=8, max_length=100, p=0):
        super().__init__()
        self.src_pad_idx = src_pad_idx
        self.encoder = Encoder(src_vocab_size, embedding_size, num_layers, heads,
                               forward_expansion, max_length, p, device)

    def get_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2).to(device)
        return src_mask

    def forward(self, src):
        src_mask = None
        enc_out = self.encoder(src, src_mask).to(device)
        return enc_out

In [311]:
class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.dropout = nn.Dropout(0.1)
        self.transformer = Transformer(src_vocab_size, src_pad_idx).to(device)
        self.fc = nn.Linear(768*2, 3).to(device)

    def forward(self, x_pre,x_hyp):
        out1 = self.transformer(x_pre)
        out2 = self.transformer(x_hyp)
        out1 = torch.mean(out1,1)
        out2 = torch.mean(out2,1)
        x = torch.cat((out1,out2),1)
        x = self.fc(x)
        x = nn.functional.softmax(x,dim=1)
        return x

In [312]:
bsize = 32
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bsize, shuffle=True)
val_loader = torch.utils.data.DataLoader(test_data, batch_size=bsize, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=bsize, shuffle=False)

In [315]:
model = NN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)
#optimizer = optim.Adam(model.parameters(), lr=0.001)

In [316]:
def get_model_acc(model):
    model.eval()
    predictions = []
    true_labels = []
    i = 0
    for x_pre, x_hyp, y in train_loader:
        i = i + 1
        outputs = model(x_pre,x_hyp)
        predicted_labels = torch.argmax(outputs, dim=1).cpu().numpy()
        predictions.extend(predicted_labels)
        l1 = torch.argmax(y, dim=1).cpu().numpy()
        true_labels.extend(l1)
    model.train()
    return accuracy_score(true_labels, predictions)

In [323]:
model.train()
num_epochs = 60
for epoch in range(num_epochs):
    for x_pre,x_hyp,y in train_loader:
        optimizer.zero_grad()
        outputs = model(x_pre,x_hyp)
        #print(outputs)
        y = y.to(device)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
    acc = get_model_acc(model)
    print(acc)
    if acc > 0.84 :
       torch.save(model,'model1.pth')
    print(loss.item())
    print('==================================')

print('Training complete.')

0.8400997150997151
0.7050190567970276
0.8465099715099715
0.758567214012146


KeyboardInterrupt: 

In [162]:
get_model_acc(model)

0.5461182336182336