In [1]:
! pip install transformers -q
! pip install tokenizers -q

You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [2]:
import re
import json
import ast
import os
import pandas as pd
from pathlib import Path
import matplotlib.cm as cm
import numpy as np
import pandas as pd
from typing import *
from tqdm.notebook import tqdm
from sklearn.utils.extmath import softmax
from sklearn import model_selection
from sklearn.metrics import classification_report, f1_score

In [3]:
import torch
import torch.optim as optim
import torch.nn as nn
import transformers
from transformers import AdamW
import tokenizers



In [4]:
def seed_all(seed = 42):
  """
  Fix seed for reproducibility
  """
  # python RNG
  import random
  random.seed(seed)

  # pytorch RNGs
  import torch
  torch.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

  # numpy RNG
  import numpy as np
  np.random.seed(seed)

In [5]:
class config:
  SEED = 42
  KFOLD = 5
  TRAIN_FILE = '../input/sdu-shared/train.csv'
  VAL_FILE = '../input/sdu-shared/dev.csv'
  TEST_FILE = '../input/sdu-shared/test.csv'
  SAVE_DIR = '../input/gpu-scibert-uncased-wiki-article-ws-sdu-2'
  MAX_LEN = 192
  MODEL = '../input/scibert-uncased'
  TOKENIZER = tokenizers.BertWordPieceTokenizer(f"{MODEL}/vocab.txt", lowercase=True)
  EPOCHS = 5
  TRAIN_BATCH_SIZE = 32
  VALID_BATCH_SIZE = 16
  DICTIONARY = json.load(open('../input/sdu-shared/diction.json'))
  
  A2ID = {}
  for k, v in DICTIONARY.items():
    for w in v:
      A2ID[w] = len(A2ID)


In [6]:
class AverageMeter:
    """
    Computes and stores the average and current value
    Source : https://www.kaggle.com/abhishek/bert-base-uncased-using-pytorch/
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [7]:
def sample_text(text, acronym, max_len):
  text = text.split()
  idx = text.index(acronym)
  left_idx = max(0, idx - max_len//2)
  right_idx = min(len(text), idx + max_len//2)
  sampled_text = text[left_idx:right_idx]
  return ' '.join(sampled_text)

In [8]:
def process_data(text, acronym, expansion, tokenizer, max_len):

  text = str(text)
  expansion = str(expansion)
  acronym = str(acronym)

  n_tokens = len(text.split())
  if n_tokens>120:
    text = sample_text(text, acronym, 120)

  answers = acronym + ' ' + ' '.join(config.DICTIONARY[acronym])
  start = answers.find(expansion)
  end = start + len(expansion)

  char_mask = [0]*len(answers)
  for i in range(start, end):
    char_mask[i] = 1
  
  tok_answer = tokenizer.encode(answers)
  answer_ids = tok_answer.ids
  answer_offsets = tok_answer.offsets

  answer_ids = answer_ids[1:-1]
  answer_offsets = answer_offsets[1:-1]

  target_idx = []
  for i, (off1, off2) in enumerate(answer_offsets):
      if sum(char_mask[off1:off2])>0:
        target_idx.append(i)

  start = target_idx[0]
  end = target_idx[-1]

  
  text_ids = tokenizer.encode(text).ids[1:-1]

  token_ids = [101] + answer_ids + [102] + text_ids + [102]
  offsets =   [(0,0)] + answer_offsets + [(0,0)]*(len(text_ids) + 2)
  mask = [1] * len(token_ids)
  token_type = [0]*(len(answer_ids) + 1) + [1]*(2+len(text_ids))

  text = answers + text
  start = start + 1
  end = end + 1

  padding = max_len - len(token_ids)
    

  if padding>=0:
    token_ids = token_ids + ([0] * padding)
    token_type = token_type + [1] * padding
    mask = mask + ([0] * padding)
    offsets = offsets + ([(0, 0)] * padding)
  else:
    token_ids = token_ids[0:max_len]
    token_type = token_type[0:max_len]
    mask = mask[0:max_len]
    offsets = offsets[0:max_len]
  

  assert len(token_ids)==max_len
  assert len(mask)==max_len
  assert len(offsets)==max_len
  assert len(token_type)==max_len

  return {
          'ids': token_ids,
          'mask': mask,
          'token_type': token_type,
          'offset': offsets,
          'start': start,
          'end': end,  
          'text': text,
          'expansion': expansion,
          'acronym': acronym,
        }


In [9]:
class Dataset:
    def __init__(self, text, acronym, expansion):
        self.text = text
        self.acronym = acronym
        self.expansion = expansion
        self.tokenizer = config.TOKENIZER
        self.max_len = config.MAX_LEN
    
    def __len__(self):
        return len(self.text)

    def __getitem__(self, item):
        data = process_data(
            self.text[item],
            self.acronym[item],
            self.expansion[item], 
            self.tokenizer,
            self.max_len,
            
        )

        return {
            'ids': torch.tensor(data['ids'], dtype=torch.long),
            'mask': torch.tensor(data['mask'], dtype=torch.long),
            'token_type': torch.tensor(data['token_type'], dtype=torch.long),
            'offset': torch.tensor(data['offset'], dtype=torch.long),
            'start': torch.tensor(data['start'], dtype=torch.long),
            'end': torch.tensor(data['end'], dtype=torch.long),
            'text': data['text'],
            'expansion': data['expansion'],
            'acronym': data['acronym'],
        }

In [10]:
def get_loss(start, start_logits, end, end_logits):
  loss_fn = nn.CrossEntropyLoss()
  start_loss = loss_fn(start_logits, start)
  end_loss = loss_fn(end_logits, end)
  loss = start_loss + end_loss
  return loss

In [11]:
class BertAD(nn.Module):
  def __init__(self):
    super(BertAD, self).__init__()
    self.bert = transformers.BertModel.from_pretrained(config.MODEL, output_hidden_states=True)
    self.layer = nn.Linear(768, 2)
    

  def forward(self, ids, mask, token_type, start=None, end=None):
    output = self.bert(input_ids = ids,
                       attention_mask = mask,
                       token_type_ids = token_type)
    
    logits = self.layer(output[0]) 
    start_logits, end_logits = logits.split(1, dim=-1)
    
    start_logits = start_logits.squeeze(-1)
    end_logits = end_logits.squeeze(-1)

    loss = get_loss(start, start_logits, end, end_logits)    

    return loss, start_logits, end_logits

In [12]:
def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

In [13]:
def evaluate_jaccard(text, selected_text, acronym, offsets, idx_start, idx_end):
  filtered_output = ""
  for ix in range(idx_start, idx_end + 1):
      filtered_output += text[offsets[ix][0]: offsets[ix][1]]
      if (ix+1) < len(offsets) and offsets[ix][1] < offsets[ix+1][0]:
          filtered_output += " "

  candidates = config.DICTIONARY[acronym]
  candidate_jaccards = [jaccard(w.strip(), filtered_output.strip()) for w in candidates]
  idx = np.argmax(candidate_jaccards)

  return candidate_jaccards[idx], candidates[idx]

In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model0 = BertAD()
vec = model0.state_dict()['bert.embeddings.position_ids']
chkp = torch.load(os.path.join(config.SAVE_DIR, 'model_0.bin'), map_location=device)
chkp['bert.embeddings.position_ids'] = vec
model0.load_state_dict(chkp)
model0.to(device)
model0.eval()

model1 = BertAD()
chkp = torch.load(os.path.join(config.SAVE_DIR, 'model_1.bin'), map_location=device)
chkp['bert.embeddings.position_ids'] = vec
model1.load_state_dict(chkp)
model1.to(device)
model1.eval()


model2 = BertAD()
chkp = torch.load(os.path.join(config.SAVE_DIR, 'model_2.bin'), map_location=device)
chkp['bert.embeddings.position_ids'] = vec
model2.load_state_dict(chkp)
model2.to(device)
model2.eval()

model3 = BertAD()
chkp = torch.load(os.path.join(config.SAVE_DIR, 'model_3.bin'), map_location=device)
chkp['bert.embeddings.position_ids'] = vec
model3.load_state_dict(chkp)
model3.to(device)
model3.eval()

model4 = BertAD()
chkp = torch.load(os.path.join(config.SAVE_DIR, 'model_4.bin'), map_location=device)
chkp['bert.embeddings.position_ids'] = vec
model4.load_state_dict(chkp)
model4.to(device)
model4.eval()

print('Models loaded')

Models loaded


In [15]:
test = pd.read_csv(config.TEST_FILE)
test['expansion'] = test['acronym_']

test_dataset = Dataset(
        text = test.text.values,
        acronym = test.acronym_.values,
        expansion = test.expansion.values,
    )
    
  
test_data_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=config.VALID_BATCH_SIZE,
      num_workers=2
  )

In [16]:
jac = AverageMeter()

tk0 = tqdm(test_data_loader, total=len(test_data_loader))

pred_expansion_ = []
true_expansion_ = []

for bi, d in enumerate(tk0):
  ids = d['ids']
  mask = d['mask']
  token_type = d['token_type']
  start = d['start']
  end = d['end']
  
  text = d['text']
  expansion = d['expansion']
  offset = d['offset']
  acronym = d['acronym']


  ids = ids.to(device, dtype=torch.long)
  mask = mask.to(device, dtype=torch.long)
  token_type = token_type.to(device, dtype=torch.long)
  start = start.to(device, dtype=torch.long)
  end = end.to(device, dtype=torch.long)
  
  with torch.no_grad():
    _, start_logits_0, end_logits_0 = model0(ids, mask, token_type, start, end)
    _, start_logits_1, end_logits_1 = model1(ids, mask, token_type, start, end)
    _, start_logits_2, end_logits_2 = model2(ids, mask, token_type, start, end)
    _, start_logits_3, end_logits_3 = model3(ids, mask, token_type, start, end)
    _, start_logits_4, end_logits_4 = model4(ids, mask, token_type, start, end)

    
  start_logits_0 = torch.softmax(start_logits_0, dim=1).detach().cpu().numpy()
  start_logits_1 = torch.softmax(start_logits_1, dim=1).detach().cpu().numpy()
  start_logits_2 = torch.softmax(start_logits_2, dim=1).detach().cpu().numpy()
  start_logits_3 = torch.softmax(start_logits_3, dim=1).detach().cpu().numpy()
  start_logits_4 = torch.softmax(start_logits_4, dim=1).detach().cpu().numpy()
  
    
  end_logits_0 = torch.softmax(end_logits_0, dim=1).detach().cpu().numpy()
  end_logits_1 = torch.softmax(end_logits_1, dim=1).detach().cpu().numpy()
  end_logits_2 = torch.softmax(end_logits_2, dim=1).detach().cpu().numpy()
  end_logits_3 = torch.softmax(end_logits_3, dim=1).detach().cpu().numpy()
  end_logits_4 = torch.softmax(end_logits_4, dim=1).detach().cpu().numpy()
  

  start_prob = (start_logits_0 + start_logits_1 + start_logits_2 + start_logits_3 + start_logits_4)/5.0
  end_prob = (end_logits_0 + end_logits_1 + end_logits_2 + end_logits_3 + end_logits_4)/5.0


  jac_= []
  
  for px, s in enumerate(text):
    start_idx = np.argmax(start_prob[px,:])
    end_idx = np.argmax(end_prob[px,:])

    js, exp = evaluate_jaccard(s, expansion[px], acronym[px], offset[px], start_idx, end_idx)
    jac_.append(js)
    pred_expansion_.append(exp)

  
  jac.update(np.mean(jac_), len(jac_))
  
  tk0.set_postfix(jaccard=jac.avg)
  

HBox(children=(FloatProgress(value=0.0, max=389.0), HTML(value='')))




In [17]:
test['pred_expansion'] = pred_expansion_

In [18]:
test.head()

Unnamed: 0,acronym_,id,text,expansion,pred_expansion
0,AD,TS-0,Experiment 2 : As an indirect way of measuring...,AD,alzheimer 's disease
1,SVM,TS-1,We consider a finite set of SVM regularization...,SVM,support vector machine
2,POS,TS-2,Skeleton n - grams with standard POS : Sorcery...,POS,part of speech
3,CT,TS-3,The dice loss for th CT image in equation 5eq ...,CT,computed tomography
4,DCM,TS-4,"Although we can represent every DCM as a BCF ,...",DCM,discrete choice models


In [19]:
predictions = []
for i, r in test.iterrows():
  d = {'id': r['id'], 'prediction': r['pred_expansion']}
  predictions.append(d)

with open(os.path.join('.', 'pred.json'), 'w') as f:
  json.dump(predictions, f)

In [20]:
test.to_csv('test_preds.csv', index=False)