In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 5.3 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.49-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 44.4 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 37.9 MB/s 
Collecting tokenizers!=0.11.3,>=0.11.1
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)
[K     |████████████████████████████████| 6.5 MB 30.2 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 4.9 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found ex

In [1]:
import pandas as pd
import numpy as np
import torch.nn as nn
import torch

import transformers
from transformers import AdamW, get_linear_schedule_with_warmup
import json

In [2]:
class ChemProtDataset:
    def __init__(self, tokenizer, sentence, label, max_len, ss, os):
        self.sentence = sentence
        self.ss = ss
        self.os = os
        self.label = label
        
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.sentence)
        
    def __getitem__(self, item):
        sentence = str(self.sentence[item])
        inputs = self.tokenizer.encode_plus(
            sentence,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True
        )
        
        ids = inputs['input_ids']
        token_type_ids = inputs['token_type_ids']
        mask = inputs['attention_mask']

        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'label': torch.tensor(self.label[item], dtype=torch.long),
            'ss': self.ss[item],
            'os': self.os[item]
        } 
    
class REModel(nn.Module):
    def __init__(self):
        super(REModel, self).__init__()
        self.bert = transformers.BertModel.from_pretrained('bert-base-uncased')
        self.hidden_size = 768
        self.classifier = nn.Sequential(
            nn.Linear(2 * self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, 5),
            nn.Softmax(dim=1)
        )
            
    def forward(self, ids, mask, token_type_ids, ss, os):
        outputs = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
        pooled_output = outputs[0]
        idx = torch.arange(ids.size(0)).to(ids.device)
        
        ss_emb = pooled_output[idx, ss]
        os_emb = pooled_output[idx, os]

        h = torch.cat((ss_emb, os_emb), dim=-1)
        return self.classifier(h)
    
    
def loss_fn(outputs, targets):
    return nn.CrossEntropyLoss()(outputs, targets)


def train_loop_fn(data_loader, model, optimizer, device, scheduler=None):
    model.train()
    for bi, d in enumerate(data_loader):
        ids = d['ids']
        mask = d['mask']
        token_type_ids = d['token_type_ids']
        label = d['label']
        ss = d['ss']
        os = d['os']

        
        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        label = label.to(device, dtype=torch.long)
        ss = ss.to(device, dtype=torch.long)
        os = os.to(device, dtype=torch.long)
        
        optimizer.zero_grad()
        outputs = model(ids, mask, token_type_ids, ss, os)

        loss = loss_fn(outputs, label)
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        if bi % 50 == 0:
            print(f'bi={bi}, loss={loss}')


def eval_loop_fn(data_loader, model, device):
    model.eval()
    fin_labels = []
    fin_outputs = []
    for bi, d in enumerate(data_loader):
        with torch.no_grad():
            ids = d['ids'].to(device, dtype=torch.long)
            mask = d['mask'].to(device, dtype=torch.long)
            token_type_ids = d['token_type_ids'].to(device, dtype=torch.long)
            label = d['label'].to(device, dtype=torch.long)
            ss = d['ss'].to(device, dtype=torch.long)
            os = d['os'].to(device, dtype=torch.long)
          
            outputs = model(ids, mask, token_type_ids, ss, os)
          
            fin_labels.append(label.cpu().detach().numpy())
            fin_outputs.append(outputs.cpu().detach().numpy())

    return np.vstack(fin_outputs), np.hstack(fin_labels)


def read_data(path):
    with open(path) as f:
        result = []
        for line in f:
            res = json.loads(line)
            for i, word in enumerate(res['text'].split(' ')):
                if '<<' in word:
                    ss = i
                elif '>>' in word:
                    se = i
                elif '[[' in word:
                    os = i
                elif ']]' in word:
                    oe = i
            res['ss'], res['se'], res['os'], res['oe'] = ss+1, se-1, os+1, oe-1
            res['label'] = LABEL_DICT[res['label']]
            result.append(res)
    return pd.DataFrame(result)

In [3]:
LABEL_DICT = {'UPREGULATOR': 0, 'ACTIVATOR': 0, 'INDIRECT-UPREGULATOR': 0,
              'DOWNREGULATOR': 1, 'INHIBITOR': 1, 'INDIRECT-DOWNREGULATOR': 1,
              'AGONIST': 2,'AGONIST-ACTIVATOR': 2,'AGONIST-INHIBITOR': 2,
              'ANTAGONIST': 3, 'SUBSTRATE': 4, 'PRODUCT-OF': 4, 'SUBSTRATE_PRODUCT-OF': 4}
MAX_LEN = 512
TRAIN_BATCH_SIZE = 8
EPOCHS = 4
SEED = 42
LEARNING_RATE = 2e-5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
model = REModel().to(device)

df_train = read_data('./RE_data/chemprot/train.txt')
df_test = read_data('./RE_data/chemprot/test.txt')

train_dataset = ChemProtDataset(
    sentence=df_train.text.values,
    label=df_train.label.values,
    tokenizer=tokenizer,
    max_len=MAX_LEN,
    ss = df_train.ss.values, 
    os = df_train.os.values
)
train_data_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True
)

test_dataset = ChemProtDataset(
    sentence=df_test.text.values,
    label=df_test.label.values,
    tokenizer=tokenizer,
    max_len=MAX_LEN,
    ss = df_test.ss.values, 
    os = df_test.os.values
)
test_data_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=True,
    drop_last=True
)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
num_training_steps = int(len(df_train) / TRAIN_BATCH_SIZE * EPOCHS)
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)


for epoch in range(EPOCHS):
    train_loop_fn(train_data_loader, model, optimizer, device, scheduler)
outputs, labels = eval_loop_fn(test_data_loader, model, device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


bi=0, loss=1.601925253868103
bi=50, loss=1.528367280960083
bi=100, loss=1.0347719192504883
bi=150, loss=1.6531282663345337
bi=200, loss=1.1566364765167236
bi=250, loss=1.1560050249099731
bi=300, loss=1.0312153100967407
bi=350, loss=1.4049549102783203
bi=400, loss=1.2802988290786743
bi=450, loss=1.5296716690063477
bi=500, loss=1.2802188396453857
bi=0, loss=1.1553738117218018
bi=50, loss=1.2800735235214233
bi=100, loss=1.5297130346298218
bi=150, loss=1.0303175449371338
bi=200, loss=1.1551741361618042
bi=250, loss=1.5297671556472778
bi=300, loss=1.4048430919647217
bi=350, loss=1.155063509941101
bi=400, loss=1.1550943851470947
bi=450, loss=1.5297911167144775
bi=500, loss=1.2799491882324219
bi=0, loss=1.5297601222991943
bi=50, loss=1.4048662185668945
bi=100, loss=1.4048595428466797
bi=150, loss=1.2799512147903442
bi=200, loss=1.279968023300171
bi=250, loss=1.2799317836761475
bi=300, loss=1.6547120809555054
bi=350, loss=1.2799172401428223
bi=400, loss=1.1549978256225586
bi=450, loss=1.529785

In [8]:
from sklearn.metrics import f1_score, classification_report
f1_score(np.argmax(outputs, axis=1), labels, average='micro')

0.4803921568627451

In [21]:
print(classification_report(labels, np.argmax(outputs, axis=1), zero_division=0))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00       667
           1       0.48      1.00      0.65      1666
           2       0.00      0.00      0.00       198
           3       0.00      0.00      0.00       293
           4       0.00      0.00      0.00       644

    accuracy                           0.48      3468
   macro avg       0.10      0.20      0.13      3468
weighted avg       0.23      0.48      0.31      3468



In [11]:
outputs.shape

(3468, 5)

In [19]:
pd.Series(np.argmax(outputs, axis=1)).value_counts()

1    3468
dtype: int64