# Requirements

In [None]:
#!g1.1
%pip install transformers
%pip install pyyaml==5.4.1
%pip install gdown
%pip install wandb

# Config

In [33]:
#!g1.1
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [34]:
#!g1.1
BERT_TYPE = 'bert-base-uncased'

In [35]:
#!g1.1
from transformers import logging
logging.set_verbosity_error()

In [36]:
#!g1.1
PAD = 0
UNK = 1
BOS = 2
EOS = 3

tgt_vocab_size = 54 + 4

In [37]:
#!g1.1
BATCH_SIZE = 16

# Download AAPD

In [38]:
#!g1.1
import gdown

url = 'https://drive.google.com/drive/folders/1qw05BnA1O-XDgJ50OgNGFSlTa9Kls00j?usp=sharing'
gdown.download_folder(url, quiet=True)

['/home/jupyter/work/resources/label_test',
 '/home/jupyter/work/resources/label_train',
 '/home/jupyter/work/resources/label_val',
 '/home/jupyter/work/resources/test.tsv',
 '/home/jupyter/work/resources/text_test',
 '/home/jupyter/work/resources/text_train',
 '/home/jupyter/work/resources/text_val',
 '/home/jupyter/work/resources/train.tsv',
 '/home/jupyter/work/resources/validation.tsv']

In [8]:
# !cp -r drive/MyDrive/AAPD .

In [42]:
#!g1.1
!mkdir AAPD
!mv *.tsv AAPD
!mv text_* AAPD
!mv label_* AAPD


# Data

In [43]:
#!g1.1
def apply_to_dict_values(dict, f):
    for key, value in dict.items():
        dict[key] = f(value)

In [44]:
#!g1.1
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, BertConfig

class AAPDDataset(Dataset):
    """AAPD dataset."""

    def __init__(self, path):
        self.path = path
        self.data = pd.read_csv(self.path, sep='\t', header=None)
        self.tokenizer = BertTokenizer.from_pretrained(BERT_TYPE)

    def __len__(self):
        return self.data.shape[0]

    @staticmethod
    def target_to_tensor(target):
        return torch.tensor([float(label) for label in target])

    @staticmethod
    def target_to_tensor_with_specials(target):
        return torch.tensor([BOS] + [float(index) + 4 for index, label in enumerate(target) if label == '1'] + [EOS])

    def __getitem__(self, idx):
        data = self.tokenizer(self.data.iloc[idx, 1], return_tensors="pt", max_length=512, padding="max_length", truncation=True) # max_len=512 !DocBERT
        apply_to_dict_values(data, lambda x: x.flatten())
        return data, AAPDDataset.target_to_tensor_with_specials(self.data.iloc[idx, 0])

In [None]:
#!g1.1
train_dataset = AAPDDataset('./AAPD/train.tsv')
val_dataset = AAPDDataset('./AAPD/validation.tsv')
test_dataset = AAPDDataset('./AAPD/test.tsv')

In [46]:
#!g1.1
def padding(data):
    src, tgt = zip(*data)

    keys = src[0].keys()
    src_agg = {}
    for key in keys:
        agg = [s[key] for s in src]
        src_agg[key] = torch.stack(agg)    

    tgt_len = [len(t) for t in tgt]
    tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long()
    for i, s in enumerate(tgt):
        tgt_pad[i, :tgt_len[i]] = s.detach().clone()[:tgt_len[i]]

    return src_agg, tgt_pad

In [47]:
#!g1.1
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=padding)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=padding)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=padding)

# Models

## Encoder

In [48]:
#!g1.1
import torch
import torch.nn as nn
from transformers import BertModel, BertConfig


class BertEncdoer(nn.Module):
    def __init__(self, bert_type="bert-base-uncased", dropout_prob=0.):
        super(BertEncdoer, self).__init__()
        self.bert_model = BertModel.from_pretrained(bert_type)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, inputs):
        '''
        Bert input -> hidden states for SGM attention, (hidden state, cell state) for decoder init 
        [batch_size, seq_len] -> [batch_size, seq_len, bert_hidden_size], ([1, batch_size, bert_hidden_size] x 2)
        1 = n_decoder_layers
        '''
        bert_output = self.bert_model(**inputs)
        pooler_output = self.dropout(bert_output.pooler_output)
        return bert_output.last_hidden_state, (None, None)


## Decoder

In [49]:
#!g1.1
class SgmHead(nn.Module):

    def __init__(self, tgt_vocab_size, hidden_size):
        super(SgmHead, self).__init__()

        self.hidden_size = hidden_size

        self.inner_hidden_size = 768
        self.W_d = nn.Linear(self.hidden_size, self.inner_hidden_size)
        self.W_o = nn.Linear(self.inner_hidden_size, tgt_vocab_size)
        self.activation = nn.Tanh()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, hiddens, c_t, prev_predicted_labels=None, use_softmax=False):
        scores = self.W_o(self.activation(self.W_d(hiddens)))
        I = torch.zeros_like(scores)
        if prev_predicted_labels:
            for predicted_labels in prev_predicted_labels:
                I[(list(range(I.size(0))), predicted_labels)] = -1 * float('inf')
        scores = scores + I
        if use_softmax:
            scores = self.softmax(scores)
        return scores

## Transformer decoder

In [50]:
#!g1.1
def generate_square_subsequent_mask(size: int):
    """Generate a triangular (size, size) mask. From PyTorch docs."""
    mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# BERT + SGM

In [51]:
#!g1.1
class BertSGM(nn.Module):
    def __init__(self):
        super(BertSGM, self).__init__()
        tgt_vocab_size = 58
        tgt_embedding_size = 768
        decoder_hidden_size = 768
        decoder_num_layers = 1
        dropout_prob=0.2
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, tgt_embedding_size)
        self.encoder = BertEncdoer(dropout_prob=dropout_prob)
        self.mask = generate_square_subsequent_mask(20).to(device)
        self.decoder = nn.TransformerDecoder(nn.TransformerDecoderLayer(d_model=768, nhead=8, batch_first=True), num_layers=decoder_num_layers)
        self.sgm_head = SgmHead(tgt_vocab_size=tgt_vocab_size, hidden_size=decoder_hidden_size)
        self.criterion = self.create_criterion(tgt_vocab_size)
        

    def forward(self, src, tgt):
        context, _ = self.encoder(src)
        decoder_input = self.tgt_embedding(tgt[:, :-1])
        decoder_output = self.decoder(decoder_input, context, self.mask[:decoder_input.size(1), :decoder_input.size(1)])
        pseudo_predicted_labels = []
        saved_scores = []

        for decoder_output_step, t in zip(decoder_output.split(1, dim=1), tgt[:, 1:].transpose(0, 1)):
            scores = self.sgm_head(decoder_output_step.squeeze(1), pseudo_predicted_labels)
            saved_scores.append(scores)
            pseudo_predicted_labels.append(t)
        
        scores = torch.stack(saved_scores).transpose(0, 1)
        return self.compute_loss(scores, tgt)
    
    def compute_loss(self, scores, tgt):
        loss = 0.
        for score, t in zip(scores, tgt[:, 1:]):
            loss += self.criterion(score, t)
        return loss / tgt.size(0)
    
    def create_criterion(self, tgt_vocab_size):
        weight = torch.ones(tgt_vocab_size)
        weight[PAD] = 0
        crit = nn.CrossEntropyLoss(weight, ignore_index=PAD)
        return crit
    
    def predict(self, src, max_steps=10):
        context, _ = self.encoder(src)
        batch_size = src['input_ids'].size(0)
        decoder_input = self.tgt_embedding(torch.tensor([BOS for _ in range(batch_size)]).to(device)).unsqueeze(1) # (B, 1, emb_len)
        predicted_labels = []
        eos_predicted = torch.tensor([False for _ in range(batch_size)]).to(device)
        
        for _ in range(max_steps):
            output = self.decoder(decoder_input, context)
            scores = self.sgm_head(output[:, -1], predicted_labels)
            prediction = torch.argmax(scores, dim=-1)
            y_hat = self.tgt_embedding(prediction.to(device))
            
            predicted_labels.append(prediction.tolist())
            eos_predicted = eos_predicted | (prediction == EOS)
            if torch.all(eos_predicted):
                break
            decoder_input = torch.cat((decoder_input, output[:, -1:]), dim=1)
     
        return torch.tensor(predicted_labels)

# Metrics

In [69]:
#!g1.1
def one_hot_labels(batch, n_classes=54, n_specials=4):
    batch_labels = []
    for tensor in batch:
        labels = [0 for _ in range(n_classes)]
        for elem in tensor:
            if elem == EOS:
                break
            if elem >= n_specials:
                labels[elem - n_specials] = 1
        batch_labels.append(labels)
    return batch_labels

In [None]:
#!g1.1
!wget https://gist.githubusercontent.com/ArseniyBolotin/7623835da1631b00fb150bcd5b0d909f/raw/wandb_writer.py -O wandb_writer.py

In [54]:
#!g1.1
from sklearn import metrics
from wandb_writer import WandbWriter

def get_metrics(y, y_pre):
        hamming_loss = metrics.hamming_loss(y, y_pre)
        macro_f1 = metrics.f1_score(y, y_pre, average="macro")
        macro_precision = metrics.precision_score(y, y_pre, average="macro")
        macro_recall = metrics.recall_score(y, y_pre, average="macro")
        micro_f1 = metrics.f1_score(y, y_pre, average="micro")
        micro_precision = metrics.precision_score(y, y_pre, average="micro")
        micro_recall = metrics.recall_score(y, y_pre, average="micro")
        
        return {
            "hamming_loss": hamming_loss,
            "macro_f1": macro_f1,
            "macro_precision": macro_precision,
            "macro_recall": macro_recall,
            "micro_f1": micro_f1,
            "micro_precision": micro_precision,
            "micro_recall": micro_recall
        }

# Init

In [68]:
#!g1.1
model = BertSGM().to(device)

In [70]:
#!g1.1
import torch.optim as optim

optimizer = optim.Adam(params=model.parameters(), lr=2e-5, betas=(0.9, 0.99))

In [None]:
#!g1.1
wb_writer = WandbWriter("BERT+SGM experiment")

# Train

In [58]:
#!g1.1
from tqdm import tqdm

In [73]:
#!g1.1
def eval_model(model, dataloader, wb_writer, suffix):
    model.eval()

    targets = []
    predictions = []
    with torch.no_grad():
        for src, tgt in dataloader:
            apply_to_dict_values(src, lambda x: x.to(device))
            tgt = tgt.to(device)
            prediction = model.predict(src)
            targets.extend(tgt.tolist())
            predictions.extend(prediction.t().tolist())

    results = get_metrics(one_hot_labels(targets), one_hot_labels(predictions))

    if wb_writer:
        for k, v in results.items():
            name = k
            if suffix:
                name += suffix
            wb_writer.add_scalar(name, v)
        wb_writer.next_step()
        wb_writer.add_scalar("Step", wb_writer.step)
    
    return results

In [72]:
#!g1.1
def train_epoch(model, optimizer, dataloader, val_dataloader, val_freq, wb_writer=None):
    model.train()
    index = 0
    for src, tgt in tqdm(dataloader, leave=False):
        index += 1
        apply_to_dict_values(src, lambda x: x.to(device))
        tgt = tgt.to(device)
        optimizer.zero_grad()
        loss = model(src, tgt)
        loss.backward()
        optimizer.step()
        if wb_writer:
            wb_writer.add_scalar("Batch train loss", loss.item())
            wb_writer.next_step()
            wb_writer.add_scalar("Step", wb_writer.step)
        if index % val_freq == 0:
            eval_model(model, val_dataloader, wb_writer, '_validation')
            model.train()

# Train loop

In [None]:
#!g1.1
EPOCHS = 10
for epoch in range(1, EPOCHS + 1):
    train_epoch(model, optimizer, train_dataloader, val_dataloader, 100, wb_writer)
    log = eval_model(model, train_dataloader, wb_writer, '_train')
    print(log)
    log = eval_model(model, val_dataloader, wb_writer, '_validation')
    print(log)
    torch.save(model, 'decoder_transformer_' + str(epoch) + '.pt')
