In [2]:
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
import datasets

import class_attention as cat


MODEL = "distilbert-base-uncased"

In [3]:
data = datasets.load_from_disk("../data/emotion_v0")
data

DatasetDict({
    train: Dataset({
        features: ['category', 'text'],
        num_rows: 20464
    })
    validation: Dataset({
        features: ['category', 'text'],
        num_rows: 7700
    })
    test: Dataset({
        features: ['category', 'text'],
        num_rows: 16000
    })
})

In [16]:
class BinaryBERTdataset(cat.CatDataset):
    """Generates triplets (example, class_name, binary_label)
    where example and class_name are stirngs and binary_label is a torch.FloatTensor(0.) or torch.FloatTensor(1.)

    Args:
        texts: List[str], a list of texts
        labels: optional, List[str], a list of labels
        text_tokenizer: transformers.Tokenizer
        label_tokenizer: optional, transformers.Tokenizer, if not provided `text_tokenizer` is used to tokenize class names
        negative_examples_ratio: a proportion of negative examples (example, class_name, False)
    """
    def __init__(self, texts, labels=None, negative_examples_ratio=0.5):
        self.texts = texts

        if labels is None:
            raise NotImplementedError()

        self.labels = labels
        self.negative_examples_ratio = negative_examples_ratio

        self._unique_labels_set = set(labels)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        is_correct = torch.tensor(1.)
        
        if random.random() < self.negative_examples_ratio:
            # randomly select a negative sample label
            label, = random.sample(self._unique_labels_set - {label}, 1)
            is_correct = torch.tensor(0.)

        return text, label, is_correct        


class TokenizerCollator:
    def __init__(self, tokenizer: transformers.BertTokenizer, max_len=512):
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __call__(self, items):
        pairs = []
        is_correct_pt = torch.zeros(len(items))
        
        for i, (text, label, is_correct) in enumerate(items):
            pairs.append((text, label))
            is_correct_pt[i] = is_correct

        tok_out = self.tokenizer.batch_encode_plus(pairs, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len)

        return tok_out, is_correct_pt


_texts = ["toy texts", "more toy texts"]
_labels = ["label1", "label2"]
_tok = transformers.AutoTokenizer.from_pretrained(MODEL)
_dataset = BinaryBERTdataset(texts=_texts, labels=_labels, negative_examples_ratio=0.5)
print(len(_dataset))
for d in _dataset:
    print(d)


_loader = torch.utils.data.DataLoader(_dataset, collate_fn=TokenizerCollator(_t), batch_size=2)
for x, y in _loader:
    print(x, y)

2
('toy texts', 'label2', tensor(0.))
('more toy texts', 'label1', tensor(0.))
{'input_ids': tensor([[ 101, 9121, 6981,  102, 3830, 2475,  102,    0],
        [ 101, 2062, 9121, 6981,  102, 3830, 2475,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1]])} tensor([0., 1.])


In [18]:
import pytorch_lightning as pl

In [21]:
class BinaryBERT(pl.LightningModule):

    def __init__(self, model_name):
        super().__init__()
        self.encoder = transformers.AutoModel.from_pretrained(model_name)

    def training_step(self, batch, batch_idx):
        # --------------------------
        # REPLACE WITH YOUR OWN
        x, y = batch
        logits = self.encoder(**x)
        
        loss = F.binary_cross_entropy_with_logits(logits, y)
        self.log('train_loss', loss)
        return loss
        # --------------------------

    def validation_step(self, batch, batch_idx):
        # --------------------------
        # REPLACE WITH YOUR OWN
        x, y = batch
        logits = self.encoder(**x)
        
        loss = F.binary_cross_entropy_with_logits(logits, y)
        acc = torch.mean(torch.sigmoid(logits) > 0.5 == y)

        self.log('val_loss', loss)
        self.log('binary_acc', acc)
        # --------------------------

    def test_step(self, batch, batch_idx):
        raise NotImplementedError()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [24]:
model = BinaryBERT(MODEL)

tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)

train_dataset = BinaryBERTdataset(texts=data['train']['text'], labels=data['train']['category'], negative_examples_ratio=0.5)
train_dataloader = torch.utils.data.DataLoader(train_dataset, collate_fn=TokenizerCollator(tokenizer), batch_size=32)

In [25]:
trainer = pl.Trainer(gpus=1, max_epochs=3)
trainer.fit(model, train_dataloader)

2021-04-16 16:01:08 | INFO | pytorch_lightning.utilities.distributed | GPU available: True, used: True
2021-04-16 16:01:08 | INFO | pytorch_lightning.utilities.distributed | TPU available: False, using: 0 TPU cores
2021-04-16 16:01:08 | INFO | pytorch_lightning.accelerators.gpu | LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]




2021-04-16 16:01:10 | INFO | pytorch_lightning.core.lightning | 
  | Name    | Type            | Params
--------------------------------------------
0 | encoder | DistilBertModel | 66.4 M
--------------------------------------------
66.4 M    Trainable params
0         Non-trainable params
66.4 M    Total params
265.452   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




AttributeError: 'BaseModelOutput' object has no attribute 'size'