# Requirements

In [173]:
#!g1.1
%pip install transformers
%pip install pyyaml==5.4.1
%pip install gdown
%pip install wandb

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.[0m
Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.[0m
Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.[0m
Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.[0m


# Config

In [113]:
#!g1.1
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [114]:
#!g1.1
BERT_TYPE = 'bert-base-uncased'

In [115]:
#!g1.1
from transformers import logging
logging.set_verbosity_error()

In [116]:
#!g1.1
PAD = 0
UNK = 1
BOS = 2
EOS = 3

tgt_vocab_size = 54 + 4

In [117]:
#!g1.1
BATCH_SIZE = 16

# Data

In [118]:
#!g1.1
import gdown

url = 'https://drive.google.com/drive/folders/1qw05BnA1O-XDgJ50OgNGFSlTa9Kls00j?usp=sharing'
gdown.download_folder(url, quiet=True)

['/home/jupyter/work/resources/label_test',
 '/home/jupyter/work/resources/label_train',
 '/home/jupyter/work/resources/label_val',
 '/home/jupyter/work/resources/test.tsv',
 '/home/jupyter/work/resources/text_test',
 '/home/jupyter/work/resources/text_train',
 '/home/jupyter/work/resources/text_val',
 '/home/jupyter/work/resources/train.tsv',
 '/home/jupyter/work/resources/validation.tsv']

In [119]:
# !cp -r drive/MyDrive/AAPD .

In [121]:
!mkdir AAPD
!mv *.tsv AAPD
!mv text_* AAPD
!mv label_* AAPD


In [122]:
#!g1.1
def apply_to_dict_values(dict, f):
    for key, value in dict.items():
        dict[key] = f(value)

In [123]:
#!g1.1
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, BertConfig

class AAPDDataset(Dataset):
    """AAPD dataset."""

    def __init__(self, path):
        self.path = path
        self.data = pd.read_csv(self.path, sep='\t', header=None)
        self.tokenizer = BertTokenizer.from_pretrained(BERT_TYPE)

    def __len__(self):
        return self.data.shape[0]

    @staticmethod
    def target_to_tensor(target):
        return torch.tensor([float(label) for label in target])

    @staticmethod
    def target_to_tensor_with_specials(target):
        return torch.tensor([BOS] + [float(index) + 4 for index, label in enumerate(target) if label == '1'] + [EOS])

    def __getitem__(self, idx):
        data = self.tokenizer(self.data.iloc[idx, 1], return_tensors="pt", max_length=512, padding="max_length", truncation=True) # max_len=512 !DocBERT
        apply_to_dict_values(data, lambda x: x.flatten())
        return data, AAPDDataset.target_to_tensor_with_specials(self.data.iloc[idx, 0])

In [124]:
#!g1.1
train_dataset = AAPDDataset('./AAPD/train.tsv')
val_dataset = AAPDDataset('./AAPD/validation.tsv')
test_dataset = AAPDDataset('./AAPD/test.tsv')

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=231508.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=28.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=466062.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=570.0), HTML(value='')))







In [146]:
#!g1.1
def padding(data):
    src, tgt = zip(*data)

    keys = src[0].keys()
    src_agg = {}
    for key in keys:
        agg = [s[key] for s in src]
        src_agg[key] = torch.stack(agg)    

    tgt_len = [len(t) for t in tgt]
    tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long()
    for i, s in enumerate(tgt):
        tgt_pad[i, :tgt_len[i]] = s.detach().clone()[:tgt_len[i]]

    return src_agg, tgt_pad

In [147]:
#!g1.1
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=padding)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=padding)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=padding)

# Models

## Encoder

In [181]:
#!g1.1
import torch
import torch.nn as nn
from transformers import BertModel, BertConfig


class BertEncdoer(nn.Module):
    def __init__(self, bert_type="bert-base-uncased"):
        super(BertEncdoer, self).__init__()
        self.bert_model = BertModel.from_pretrained(bert_type)
                
        dropout_prob = 0.2
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, inputs):
        '''
        Bert input -> hidden states for SGM attention, (hidden state, cell state) for decoder init 
        [batch_size, seq_len] -> [batch_size, seq_len, bert_hidden_size], ([1, batch_size, bert_hidden_size] x 2)
        1 = n_decoder_layers
        '''
        bert_output = self.bert_model(**inputs)
        pooler_output = self.dropout(bert_output.pooler_output)
        return bert_output.last_hidden_state, (torch.unsqueeze(pooler_output, 0), torch.unsqueeze(pooler_output, 0))


## Attention

In [128]:
#!g1.1
class SgmAttention(nn.Module):

    def __init__(self, encoder_hidden_size, decoder_hidden_size, att_hidden_size):
        super(SgmAttention, self).__init__()
        self.U = nn.Linear(encoder_hidden_size, att_hidden_size)
        self.W = nn.Linear(decoder_hidden_size, att_hidden_size)
        self.tanh = nn.Tanh()
        self.V = nn.Linear(att_hidden_size, 1) # seq_len x att_sz -> seq_len
        self.softmax = nn.Softmax(dim=-1)

    def init_context(self, context):
        '''
        Context from encoder. Size: [batch_size, seq_len, encoder_hidden_size]
        '''
        self.context = context


    def forward(self, s):
        state_term = self.W(s).unsqueeze(1) # batch_size x decoder_hidden_size -> batch_size x 1 x att_hidden_size
        context_term = self.U(self.context) # batch_size x seq_len x encoder_hidden_size -> batch_size x seq_len x att_hidden_size
        sum_activation = self.tanh(context_term + state_term.expand_as(context_term)) # batch_size x seq_len x att_hidden_size
        weights = self.V(sum_activation).squeeze(-1) # batch_size x seq_len
        softmax_weights = self.softmax(weights)
        c_t = torch.bmm(softmax_weights.unsqueeze(1), self.context).squeeze(1) # batch_size x seq_len
        # output = self.linear_out(torch.cat([h, c_t], 1))

        return c_t

## Decoder

In [129]:
#!g1.1
class StackedLSTM(nn.Module):
    def __init__(self, num_layers, input_size, hidden_size, dropout):
        super(StackedLSTM, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.num_layers = num_layers
        self.layers = nn.ModuleList()

        for _ in range(num_layers):
            self.layers.append(nn.LSTMCell(input_size, hidden_size))
            input_size = hidden_size

    def forward(self, input, hidden):
        h_0, c_0 = hidden
        h_1, c_1 = [], []
        for i, layer in enumerate(self.layers):
            h_1_i, c_1_i = layer(input, (h_0[i], c_0[i]))
            input = h_1_i
            if i + 1 != self.num_layers:
                input = self.dropout(input)
            h_1 += [h_1_i]
            c_1 += [c_1_i]

        h_1 = torch.stack(h_1)
        c_1 = torch.stack(c_1)

        return input, (h_1, c_1)

In [193]:
#!g1.1
class RnnDecoder(nn.Module):

    def __init__(self, tgt_vocab_size, hidden_size):
        super(RnnDecoder, self).__init__()

        self.hidden_size = hidden_size
        dropout_prob = 0.2
        num_layers=1

        self.rnn = StackedLSTM(input_size=2 * self.hidden_size, 
                               hidden_size=self.hidden_size,
                               num_layers=num_layers, 
                               dropout=dropout_prob)

        self.inner_hidden_size = 768
        self.W_d = nn.Linear(self.hidden_size, self.inner_hidden_size)
        self.V_d = nn.Linear(self.hidden_size, self.inner_hidden_size)
        self.W_o = nn.Linear(self.inner_hidden_size, tgt_vocab_size)
        self.activation = nn.Tanh()
        self.softmax = nn.Softmax(dim=-1)
        
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, input, state, c_t, prev_predicted_labels=None): 
        output, state = self.rnn(input, state)
        return output, state

    def compute_score(self, hiddens, c_t, prev_predicted_labels=None, use_softmax=False):
        scores = self.W_o(self.activation(self.W_d(hiddens) + self.V_d(c_t)))
        I = torch.zeros_like(scores)
        if prev_predicted_labels:
            for predicted_labels in prev_predicted_labels:
                I[(list(range(I.size(0))), predicted_labels)] = -1 * float('inf')
        scores = scores + I
        if use_softmax:
            scores = self.softmax(scores)
        return scores

# BERT + SGM

In [194]:
#!g1.1
class BertSGM(nn.Module):
    def __init__(self):
        super(BertSGM, self).__init__()
        tgt_vocab_size = 58
        tgt_embedding_size = 768
        decoder_hidden_size = 768
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, tgt_embedding_size)
        self.encoder = BertEncdoer()
        self.attention = SgmAttention(encoder_hidden_size=768, decoder_hidden_size=decoder_hidden_size, att_hidden_size=768) 
        self.decoder = RnnDecoder(tgt_vocab_size=tgt_vocab_size, hidden_size=decoder_hidden_size)
        self.criterion = self.create_criterion(tgt_vocab_size)
        

    def forward(self, src, tgt):
        context, decoder_init_state = self.encoder(src)
        
        self.attention.init_context(context)
        y_hats = self.tgt_embedding(tgt[:, :-1])

        batch_size = y_hats.size(0)
        prev_predicted_labels = []
        saved_scores = []
        decoder_state = decoder_init_state

        for y_hat, t in zip(y_hats.split(1, dim=1), tgt[:, 1:].transpose(0, 1)):
            c_t = self.attention(decoder_state[0].squeeze(0))
            input = torch.cat([y_hat.squeeze(1), c_t], dim=-1)
            output, decoder_state = self.decoder(input, decoder_state, c_t)
            scores = self.decoder.compute_score(output, c_t, prev_predicted_labels)
            saved_scores.append(scores)
            prev_predicted_labels.append(t)
        
        scores = torch.stack(saved_scores).transpose(0, 1)
        return self.compute_loss(scores, tgt)
    
    def compute_loss(self, scores, tgt):
        loss = 0.
        for score, t in zip(scores, tgt[:, 1:]):
            loss += self.criterion(score, t)
        return loss / tgt.size(0)
    
    def create_criterion(self, tgt_vocab_size):
        weight = torch.ones(tgt_vocab_size)
        weight[PAD] = 0
        crit = nn.CrossEntropyLoss(weight, ignore_index=PAD)
        return crit
    
    def predict(self, src, max_steps=10):
        context, decoder_init_state = self.encoder(src)
        self.attention.init_context(context)
        batch_size = src['input_ids'].size(0)
        y_hat = self.tgt_embedding(torch.tensor([BOS for _ in range(batch_size)]).to(device))
        decoder_state = decoder_init_state
        
        predicted_labels = []
        eos_predicted = torch.tensor([False for _ in range(batch_size)]).to(device)

        for _ in range(max_steps):
            c_t = self.attention(decoder_state[0].squeeze(0))
            input = torch.cat([y_hat.squeeze(1), c_t], dim=-1)
            output, decoder_state = self.decoder(input, decoder_state, c_t)
            scores = self.decoder.compute_score(output, c_t, predicted_labels)
            prediction = torch.argmax(scores, dim=-1)
            y_hat = self.tgt_embedding(prediction.to(device))
            predicted_labels.append(prediction.tolist())
            eos_predicted = eos_predicted | (prediction == EOS)
            if torch.all(eos_predicted):
                break

        return torch.tensor(predicted_labels)
    


# Metrics

In [180]:
#!g1.1
def one_hot_labels(batch, n_classes=54, n_specials=4):
    batch_labels = []
    for tensor in batch:
        labels = [0 for _ in range(n_classes)]
        for elem in tensor:
            if elem == EOS:
                break
            if elem >= n_specials:
                labels[elem - n_specials] = 1
        batch_labels.append(labels)
    return batch_labels

In [133]:
#!g1.1
!wget https://gist.githubusercontent.com/ArseniyBolotin/7623835da1631b00fb150bcd5b0d909f/raw/wandb_writer.py -O wandb_writer.py

--2022-05-05 21:49:48--  https://gist.githubusercontent.com/ArseniyBolotin/7623835da1631b00fb150bcd5b0d909f/raw/wandb_writer.py
Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2609 (2.5K) [text/plain]
Saving to: ‘wandb_writer.py’


2022-05-05 21:49:48 (52.3 MB/s) - ‘wandb_writer.py’ saved [2609/2609]



In [162]:
#!g1.1
from sklearn import metrics
from wandb_writer import WandbWriter

def get_metrics(y, y_pre):
        hamming_loss = metrics.hamming_loss(y, y_pre)
        macro_f1 = metrics.f1_score(y, y_pre, average="macro")
        macro_precision = metrics.precision_score(y, y_pre, average="macro")
        macro_recall = metrics.recall_score(y, y_pre, average="macro")
        micro_f1 = metrics.f1_score(y, y_pre, average="micro")
        micro_precision = metrics.precision_score(y, y_pre, average="micro")
        micro_recall = metrics.recall_score(y, y_pre, average="micro")
        
        return {
            "hamming_loss": hamming_loss,
            "macro_f1": macro_f1,
            "macro_precision": macro_precision,
            "macro_recall": macro_recall,
            "micro_f1": micro_f1,
            "micro_precision": micro_precision,
            "micro_recall": micro_recall
        }

# Init

In [195]:
#!g1.1
model = BertSGM().to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [196]:
#!g1.1
import torch.optim as optim

optimizer = optim.Adam(params=model.parameters(), lr=2e-5, betas=(0.9, 0.99))

In [197]:
#!g1.1
wb_writer = WandbWriter("BERT+SGM experiment")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mc3n34ka[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Train

In [142]:
#!g1.1
from tqdm import tqdm

In [198]:
#!g1.1
def train_epoch(model, optimizer,  dataloader, val_dataloader, val_freq, wb_writer=None):
    model.train()
    index = 0
    for src, tgt in tqdm(dataloader, leave=False):
        index += 1
        apply_to_dict_values(src, lambda x: x.to(device))
        tgt = tgt.to(device)
        optimizer.zero_grad()
        loss = model(src, tgt)
        loss.backward()
        optimizer.step()
        wb_writer.add_scalar("Batch train loss", loss.item())
        wb_writer.next_step()
        wb_writer.add_scalar("Step", wb_writer.step)
        if index % val_freq == 0:
            log = eval_model(model, val_dataloader, wb_writer, '_validation')
            model.train()

In [None]:
train_epoch(model, optimizer, train_dataloader, wb_writer)

In [None]:
torch.save(model, 'bert_sgm_colab.pt')

# Validation

In [None]:
gdown.download(id='1R2uPw2xKgfWtlw7ms7AqVSCz0XmOplok', output='bert_sgm_colab.pt', quiet=True)

In [None]:
# !cp drive/MyDrive/bert_sgm_colab.pt .

In [None]:
model = torch.load('bert_sgm_colab.pt')

In [188]:
#!g1.1
def eval_model(model, dataloader, wb_writer, suffix):
    model.eval()

    targets = []
    predictions = []
    with torch.no_grad():
        for src, tgt in tqdm(dataloader, leave=False):
            apply_to_dict_values(src, lambda x: x.to(device))
            tgt = tgt.to(device)
            prediction = model.predict(src)
            targets.extend(tgt.tolist())
            predictions.extend(prediction.t().tolist())
    
    results = get_metrics(one_hot_labels(targets), one_hot_labels(predictions))

    if wb_writer:
        for k, v in results.items():
            name = k
            if suffix:
                name += suffix
            wb_writer.add_scalar(name, v)
        wb_writer.next_step()
        wb_writer.add_scalar("Step", wb_writer.step)
    
    return results

In [None]:
eval_model(model, val_dataloader, wb_writer)

# Train loop

In [199]:
#!g1.1
EPOCHS = 10
for epoch in range(1, EPOCHS + 1):
    train_epoch(model, optimizer, train_dataloader, val_dataloader, 100, wb_writer)
    log = eval_model(model, train_dataloader, wb_writer, '_train')
    print(log)
    log = eval_model(model, val_dataloader, wb_writer, '_validation')
    print(log)
    torch.save(model, 'bert_sgm_final_' + str(epoch) + '.pt')


{'hamming_loss': 0.023990493093390568, 'macro_f1': 0.4772448791314393, 'macro_precision': 0.6406247808556488, 'macro_recall': 0.43216103766620007, 'micro_f1': 0.7095957598291274, 'micro_precision': 0.770472238044864, 'micro_recall': 0.6576347836824152}
{'hamming_loss': 0.024777777777777777, 'macro_f1': 0.45318665981527634, 'macro_precision': 0.5572174457419228, 'macro_recall': 0.41975773017990775, 'micro_f1': 0.6990553306342779, 'micro_precision': 0.7595307917888563, 'micro_recall': 0.6475}
{'hamming_loss': 0.01890099609267514, 'macro_f1': 0.5914650273717825, 'macro_precision': 0.7330007028116076, 'macro_recall': 0.5488321508950043, 'micro_f1': 0.7718660223517494, 'micro_precision': 0.8352531042786034, 'micro_recall': 0.71742116717344}
{'hamming_loss': 0.02287037037037037, 'macro_f1': 0.5188781409589596, 'macro_precision': 0.6009210233873957, 'macro_recall': 0.49344563288703436, 'micro_f1': 0.7232803047277615, 'micro_precision': 0.782355792535143, 'micro_recall': 0.6725}
{'hamming_loss

  3%|▎         | 99/3365 [00:57<31:44,  1.71it/s]
  0%|          | 0/63 [00:00<?, ?it/s][A
  2%|▏         | 1/63 [00:00<00:14,  4.17it/s][A
  3%|▎         | 2/63 [00:00<00:14,  4.20it/s][A
  5%|▍         | 3/63 [00:00<00:14,  4.21it/s][A
  6%|▋         | 4/63 [00:00<00:13,  4.24it/s][A
  8%|▊         | 5/63 [00:01<00:13,  4.23it/s][A
 10%|▉         | 6/63 [00:01<00:13,  4.14it/s][A
 11%|█         | 7/63 [00:01<00:13,  4.14it/s][A
 13%|█▎        | 8/63 [00:01<00:13,  4.10it/s][A
 14%|█▍        | 9/63 [00:02<00:13,  4.09it/s][A
 16%|█▌        | 10/63 [00:02<00:13,  4.04it/s][A
 17%|█▋        | 11/63 [00:02<00:13,  3.97it/s][A
 19%|█▉        | 12/63 [00:02<00:12,  3.96it/s][A
 21%|██        | 13/63 [00:03<00:12,  4.01it/s][A
 22%|██▏       | 14/63 [00:03<00:12,  3.99it/s][A
 24%|██▍       | 15/63 [00:03<00:12,  3.99it/s][A
 25%|██▌       | 16/63 [00:03<00:11,  3.98it/s][A
 27%|██▋       | 17/63 [00:04<00:11,  3.93it/s][A
 29%|██▊       | 18/63 [00:04<00:11,  3.91it/s][A

KeyboardInterrupt: 

In [191]:
#!g1.1
for epoch in range(1, EPOCHS + 1):
    model = torch.load('bert_sgm_final_' + str(epoch) + '.pt')
    print(epoch)
    print(eval_model(model, val_dataloader, None, '_validation'))
    print(eval_model(model, test_dataloader, None, '_test'))    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


1
{'hamming_loss': 0.023833333333333335, 'macro_f1': 0.4706473863931913, 'macro_precision': 0.5729747333190717, 'macro_recall': 0.4307202777824952, 'micro_f1': 0.7100698355485469, 'micro_precision': 0.7729279058361942, 'micro_recall': 0.6566666666666666}
{'hamming_loss': 0.025777777777777778, 'macro_f1': 0.46449113406274195, 'macro_precision': 0.5687185931505901, 'macro_recall': 0.42742205860352017, 'micro_f1': 0.6883116883116883, 'micro_precision': 0.7515892420537897, 'micro_recall': 0.6348616274266832}
2
{'hamming_loss': 0.02248148148148148, 'macro_f1': 0.5021458265250056, 'macro_precision': 0.6050154253549676, 'macro_recall': 0.46934153592296396, 'micro_f1': 0.7301022676745219, 'micro_precision': 0.782650142993327, 'micro_recall': 0.6841666666666667}
{'hamming_loss': 0.024555555555555556, 'macro_f1': 0.49817337929273403, 'macro_precision': 0.6001679268103159, 'macro_recall': 0.4610761160760501, 'micro_f1': 0.7075430083811204, 'micro_precision': 0.759110269758637, 'micro_recall': 0.6

In [None]:
#!g1.1
