# Class Attetion with BERT model

In [1]:
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

import datasets
import transformers
from transformers import AutoModel, AutoTokenizer

import wandb
from tqdm.auto import tqdm

In [2]:
def get_output_dim(model):
    config = model.config
    if isinstance(config, transformers.DistilBertConfig):
        return config.hidden_dim
    
    return config.hidden_size


class ClassAttentionModel(nn.Module):
    def __init__(self, txt_encoder, cls_encoder, hidden_size):
        super().__init__()

        self.txt_encoder = txt_encoder
        self.cls_encoder = cls_encoder
        
        txt_encoder_h = get_output_dim(txt_encoder)
        self.txt_out = nn.Sequential(nn.Linear(txt_encoder_h, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size))

        cls_encoder_h = get_output_dim(cls_encoder)
        self.cls_out = nn.Linear(cls_encoder_h, hidden_size)
    
    def forward(self, input_dict, classes_dict):
        """
        Compute logits for input (input_dict,) corresponding to the classes (classes_dict)

        Optionally, you can provide additional keys in either input_dict or classes_dict
        Specifically, attention_mask, head_mask and inputs_embeds
        Howerver, you cannot provide output_attentions and output_hidden_states

        Args:
            input_ids: dict with key input_ids
                input_ids: LongTensor[batch_size, text_seq_len], input to the text network
            classes_ids: dict with key input_ids
                input_ids: LongTensor[n_classes, class_seq_len], a list of possible classes, each class described via text
        
        
        """
        h_x = self.txt_encoder(**input_dict)  # some tuple
        h_x = h_x[0]  # FloatTensor[bs, text_seq_len, hidden]
        h_x = h_x[:, 0]  # get CLS token representations, FloatTensor[bs, hidden]

        h_c = self.cls_encoder(**classes_dict)  # some tuple
        h_c = h_c[0]  # FloatTensor[n_classes, class_seq_len, hidden]

        h_c, _ = torch.max(h_c, dim=1)  # [n_classes, hidden]
        
        # attention map
        h_x = self.txt_out(h_x)
        h_c = self.cls_out(h_c)

        logits = h_x @ h_c.T  # [bs, n_classes]
        return logits


In [3]:
dataset = datasets.load_dataset("Fraser/news-category-dataset")
all_classes = dataset['train'].features['category_num'].names
print(all_classes)

Using custom data configuration default
Reusing dataset news_category (/Users/vladislavlialin/.cache/huggingface/datasets/news_category/default/0.0.0/e1ca79a7dd2ddfef8393f386829a339f4212bea93dface547d0de38cbecfa97b)


['POLITICS', 'WELLNESS', 'ENTERTAINMENT', 'TRAVEL', 'STYLE & BEAUTY', 'PARENTING', 'HEALTHY LIVING', 'QUEER VOICES', 'FOOD & DRINK', 'BUSINESS', 'COMEDY', 'SPORTS', 'BLACK VOICES', 'HOME & LIVING', 'PARENTS', 'THE WORLDPOST', 'WEDDINGS', 'WOMEN', 'IMPACT', 'DIVORCE', 'CRIME', 'MEDIA', 'WEIRD NEWS', 'GREEN', 'WORLDPOST', 'RELIGION', 'STYLE', 'SCIENCE', 'WORLD NEWS', 'TASTE', 'TECH', 'MONEY', 'ARTS', 'FIFTY', 'GOOD NEWS', 'ARTS & CULTURE', 'ENVIRONMENT', 'COLLEGE', 'LATINO VOICES', 'CULTURE & ARTS', 'EDUCATION']


In [4]:
# TODO: plot class frequencies

In [5]:
train_classes = random.sample(all_classes, int(len(all_classes) * 0.9))
train_dataset = [d for d in dataset['train'] if d['category'] in train_classes]
valid_classes = set(all_classes).difference(train_classes)
print(valid_classes)
print(f'original dataset len {len(dataset["train"])}')
print(f'Train dataset    len {len(train_dataset)}')

{'WORLDPOST', 'SPORTS', 'MEDIA', 'RELIGION', 'WEIRD NEWS'}
original dataset len 160682
Train dataset    len 148208


In [6]:
def add_extra_classes(batch_classes, all_classes, n_extra_classes=None):
    if n_extra_classes is None:
        n_extra_classes = random.randint(0, len(all_classes) - 1)  # TODO: experiment with distribution

    extra_classes = random.sample(all_classes, n_extra_classes)  # TODO: experiment with distribution

    new_classes = list(set(batch_classes) | set(extra_classes))
    return new_classes


def get_class_ids(batch_classes, possible_classes):
    return torch.LongTensor([possible_classes.index(c) for c in batch_classes])


_batch_classes = ['ARTS', 'BUSINESS', 'TRAVEL', 'POLITICS', 'ENTERTAINMENT', 'HEALTHY LIVING', 'ENTERTAINMENT']
_possible_classes = ['ENTERTAINMENT', 'BUSINESS', 'HEALTHY LIVING', 'TRAVEL', 'HOME & LIVING', 'POLITICS', 'TECH', 'WOMEN', 'ARTS', 'GREEN']
_expected_class_ids = torch.LongTensor([8, 1, 3, 5, 0, 2, 0])
assert torch.equal(get_class_ids(_batch_classes, _possible_classes), _expected_class_ids)

## Training loop

In [7]:
TXT_MODEL = 'bert-base-uncased'
CLS_MODEL = TXT_MODEL

txt_tokenizer = AutoTokenizer.from_pretrained(TXT_MODEL)
cls_tokenizer = AutoTokenizer.from_pretrained(CLS_MODEL)

txt_encoder = AutoModel.from_pretrained(TXT_MODEL)
cls_encoder = AutoModel.from_pretrained(CLS_MODEL)

In [8]:
def get_parameters_without_txt_encoder_embeddings(model):
    return (p for n, p in model.named_parameters() if 'cls_encoder.embeddings' not in n)


model = ClassAttentionModel(txt_encoder, cls_encoder, hidden_size=128)
assert len(list(model.named_parameters())) - 5 == len(list(get_parameters_without_txt_encoder_embeddings(model)))

In [9]:
def evaluate(model, dataset, batch_size=32, device=None):
    device = device or model.device

    eval_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
    model.eval()

    n_matched = 0

    for batch in tqdm(eval_dataloader):
        text = batch['headline']

        batch_classes = batch['category']
        possible_classes = add_extra_classes(batch_classes, all_classes)
        labels = get_class_ids(batch_classes, possible_classes)

        if device != 'cpu':
            txt_input = {k: v.to(device) for k, v in txt_input.items()}
            cls_input = {k: v.to(device) for k, v in cls_input.items()}
            labels = labels.to(device)

        txt_input = txt_tokenizer(text, return_tensors='pt', padding=True)
        cls_input = cls_tokenizer(possible_classes, return_tensors='pt', padding=True, add_special_tokens=False)

        assert False
        logits = model(txt_input, cls_input)

        _, preds = torch.max(logits, -1)
        n_matched += torch.sum(preds == labels)
        
    model.train()
        
    return {'accuracy': n_matched.detach().cpu().numpy() / len(dataset)}


# TODO: test

In [10]:
valid_dataset = [d for d in dataset['validation'] if d['category'] in train_classes]
valid_zero_shot_dataset = [d for d in dataset['validation'] if d['category'] in valid_classes]

print(f'Orig val len         : {len(dataset["validation"])}')
print(f'Train classes val len: {len(valid_dataset)}')
print(f'Valid classes val len: {len(valid_zero_shot_dataset)}')

assert len(dataset["validation"]) == len(valid_dataset) + len(valid_zero_shot_dataset)

Orig val len         : 10043
Train classes val len: 9275
Valid classes val len: 768


In [11]:
from itertools import chain

In [17]:
batch_size = 3
device = 'cuda' if torch.cuda.is_available() else 'cpu'

dataloader = torch.utils.data.DataLoader(train_dataset[:batch_size], batch_size=batch_size)
model = ClassAttentionModel(txt_encoder, cls_encoder, hidden_size=32)
model = model.to(device)

# do not update cls_encoder embeddings so that the embedding geometry is not corrupted
parameters = get_parameters_without_txt_encoder_embeddings(model)

# parameters = chain(model.txt_encoder.parameters(), model.txt_out.parameters(), model.cls_out.parameters())
parameters = model.txt_out.parameters()

optimizer = torch.optim.Adam(parameters, lr=1e-4)

In [18]:
wandb.init()
wandb.watch(model)

epochs = 1000
eval_freq = 1000

all_classes = dataset['train'].features['category_num'].names

step = 0

for e in tqdm(range(epochs), desc='epochs'):
    for batch in tqdm(dataloader):
        step += 1

        model.zero_grad()

        text = batch['headline']

        batch_classes = batch['category']

#         possible_classes = add_extra_classes(batch_classes, all_classes)

        possible_classes = batch_classes
        labels = get_class_ids(batch_classes, possible_classes)

        txt_input = txt_tokenizer(text, return_tensors='pt', padding=True)
        cls_input = cls_tokenizer(possible_classes, return_tensors='pt', padding=True, add_special_tokens=False)

        if device != 'cpu':
            txt_input = {k: v.to(device) for k, v in txt_input.items()}
            cls_input = {k: v.to(device) for k, v in cls_input.items()}
            labels = labels.to(device)

        assert False

        logits = model(txt_input, cls_input)

        loss = F.cross_entropy(logits, labels)        
        _, preds = logits.max(-1)

        wandb.log({'loss': loss, 'train/batch_accuracy': sum(preds == labels) / batch_size})

        loss.backward()
        optimizer.step()

        if step % eval_freq != 0:
            continue

        eval_train_classes = evaluate(model, valid_dataset[:100])
        eval_valid_classes = evaluate(model, valid_zero_shot_dataset[:100])

        wandb.log({'valid/train_classes/accuracy': eval_train_classes['accuracy'],
                   'valid/valid_classes/accuracy': eval_valid_classes['accuracy']})


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

HBox(children=(FloatProgress(value=0.0, description='epochs', max=1000.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))





AssertionError: 

In [19]:
txt_input['input_ids']

tensor([[  101,  1005,  1055, 20554,  1005, 12934,  2015,  5035, 18528,  4895,
         11256,  1999,  3147,  2330,   102],
        [  101,  1996,  5264, 27857,  3527,  1999,  6613,   102,     0,     0,
             0,     0,     0,     0,     0],
        [  101,  1020,  3971,  1037,  1002,  2260,  6263, 11897,  2052,  2393,
          1996,  4610,   102,     0,     0]])

In [20]:
cls_input['input_ids']

tensor([[4038],
        [2840],
        [2449]])

In [22]:
labels

tensor([0, 1, 2])

In [21]:
possible_classes

['COMEDY', 'ARTS', 'BUSINESS']

In [23]:
batch_classes

['COMEDY', 'ARTS', 'BUSINESS']

## Validation on train classes

In [None]:
logs = evaluate(model, valid_dataset, batch_size=32)

print(f'Accuracy: {logs['accuracy']}')

## Validation on test classes

In [None]:
logs = evaluate(model, valid_zero_shot_dataset, batch_size=32)

print(f'Accuracy: {logs['accuracy']}')