In [None]:
from pathlib import Path
from random import shuffle
from typing import List, Tuple
from sklearn.metrics import f1_score
from torch import Tensor, tensor
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
from dao.ower.ower_dir import OwerDir
from dao.ower.samples_tsv import Sample

In [None]:
ower_dir_path = 'data/ower/ower-v4-fb-irt-100-5/'
class_count = 100
sent_count = 5

batch_size = 64
sent_len = 64

epoch_count = 20
lr = 0.01

In [None]:
ower_dir = OwerDir(Path(ower_dir_path))
ower_dir.check()

In [None]:
model_name = 'distilbert-base-uncased'
bert = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=class_count)
tokenizer = DistilBertTokenizer.from_pretrained(model_name)

In [None]:
train_set = ower_dir.train_samples_tsv.load(class_count, sent_count)
valid_set = ower_dir.valid_samples_tsv.load(class_count, sent_count)


def generate_batch(batch: List[Sample]) -> Tuple[Tensor, Tensor]:

    _, _, classes_batch, sents_batch = zip(*batch)

    for sents in sents_batch:
        shuffle(sents)

    contexts_batch = [' '.join(sents) for sents in sents_batch]

    encoded_batch = tokenizer(contexts_batch, padding=True, truncation=True, max_length=sent_len, return_tensors='pt')

    return encoded_batch, tensor(classes_batch)


train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=generate_batch, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, collate_fn=generate_batch)

In [None]:
criterion = BCEWithLogitsLoss()
optimizer = Adam(bert.parameters(), lr=lr)

bert = bert.cuda()

for epoch in range(epoch_count):

    bert.train()
    for step, (ctxt_batch, gt_batch) in enumerate(tqdm(train_loader)):
        input_ids = ctxt_batch.input_ids.cuda()
        attention_mask = ctxt_batch.attention_mask.cuda()
        gt_batch = gt_batch.cuda()

        pred_batch = bert(input_ids, attention_mask).logits

        loss = criterion(gt_batch.float(), pred_batch)
        print(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    bert.eval()
    for step, (ctxt_batch, gt_batch) in enumerate(tqdm(valid_loader)):
        input_ids = ctxt_batch.input_ids.cuda()
        attention_mask = ctxt_batch.attention_mask.cuda()
        gt_batch = gt_batch.cuda()

        pred_batch = bert(input_ids, attention_mask).logits > 0

        f1 = f1_score(gt_batch.detach().cpu().numpy(),
                      pred_batch.detach().cpu().numpy(),
                      average=None)

        print('### f1 = ', f1)