# Trying to make zero-shot work

Multiple model experiments including different normalizations, loss funciton modifications and so on

In [1]:
import copy
import random
from pprint import pprint
from itertools import chain

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
from transformers import AutoModel, AutoTokenizer
import datasets

import wandb
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import class_attention as cat

%load_ext autoreload
%autoreload 2


def detorch(x):
    return x.detach().cpu().numpy()

# A note on labels

Note that the dataset we use (`Fraser/news-category-dataset`) has some interesting particularities in the class names.

For example, it has classes `STYLE` and `STYLE & BEAUTY` or `WORLD NEWS` and `NEWS`. I.e., some classes contain other classes names in their name.
The classes that have `&` in their name have a similar particularity. Some of the categories does not seem to be distinguishable. E.g., `THE WORLDPOST` and `WORLDPOST` or `ARTS & CULTURE` and `CULTURE & ARTS`.



* &	: STYLE & BEAUTY, ARTS & CULTURE, HOME & LIVING, FOOD & DRINK, CULTURE & ARTS
* VOICES	: LATINO VOICES, BLACK VOICES, QUEER VOICES
* NEWS	: WEIRD NEWS, GOOD NEWS, WORLD NEWS
* ARTS	: ARTS, ARTS & CULTURE, CULTURE & ARTS
* CULTURE	: ARTS & CULTURE, CULTURE & ARTS
* LIVING	: HEALTHY LIVING, HOME & LIVING
* WORLDPOST	: THE WORLDPOST, WORLDPOST
* WORLD	: THE WORLDPOST, WORLDPOST

In [2]:
(
    train_dataloader,
    test_dataloader,
    all_classes_str,
    test_classes_str,
) = cat.training_utils.prepare_dataloaders(
    test_class_frac=0.2,
    batch_size=32,
    model_name="distilbert-base-uncased",
    dataset_frac=0.1
)

Using custom data configuration default
Reusing dataset news_category (/Users/vladislavlialin/.cache/huggingface/datasets/news_category/default/0.0.0/737b7b6dff469cbba49a6202c9e94f9d39da1fed94e13170cf7ac4b61a75fb9c)


Moving the following classes to a class-test set: ['STYLE & BEAUTY', 'TECH', 'ARTS & CULTURE', 'PARENTS', 'THE WORLDPOST', 'WOMEN', 'FOOD & DRINK', 'WELLNESS']


Preprocessing Dataset:   0%|          | 0/12244 [00:00<?, ?it/s]

Preprocessing Dataset:   0%|          | 0/12244 [00:00<?, ?it/s]

# Model

In [3]:
class ClassAttentionModel(nn.Module):
    def __init__(self, txt_encoder, cls_encoder, hidden_size):
        super().__init__()

        self.txt_encoder = txt_encoder
        self.cls_encoder = cls_encoder

        txt_encoder_h = cat.modelling_utils.get_output_dim(txt_encoder)
        self.txt_out = nn.Linear(txt_encoder_h, hidden_size)

        cls_encoder_h = cat.modelling_utils.get_output_dim(cls_encoder)
        self.cls_out = nn.Linear(cls_encoder_h, hidden_size)

    def forward(self, text_input, labels_input):
        """
        Compute logits for input (input_dict,) corresponding to the classes (classes_dict)

        Optionally, you can provide additional keys in either input_dict or classes_dict
        Specifically, attention_mask, head_mask and inputs_embeds
        Howerver, one should not provide output_attentions and output_hidden_states

        Args:
            text_input: dict with key input_ids
                input_ids: LongTensor[batch_size, text_seq_len], input to the text network
            labels_input: dict with key input_ids
                input_ids: LongTensor[n_classes, class_seq_len], a list of possible classes, each class described via text
        """
        text_input, labels_input = cat.modelling_utils.maybe_format_inputs(text_input, labels_input)
        cat.modelling_utils.validate_inputs(text_input, labels_input)

        h_x = self.txt_encoder(**text_input)  # some tuple
        h_x = h_x[0]  # FloatTensor[bs, text_seq_len, hidden]
        h_x = h_x[:, 0]  # get CLS token representations, FloatTensor[bs, hidden]

        h_c = self.cls_encoder(**labels_input)  # some tuple
        h_c = h_c[0]  # FloatTensor[n_classes, class_seq_len, hidden]

        h_c, _ = torch.max(h_c, dim=1)  # [n_classes, hidden]

        # attention map
        h_x = self.txt_out(h_x)
        h_c = self.cls_out(h_c)

        # make all class embeddings to have the same Euclidean norm
        h_c = cat.modelling_utils.normalize_embeds(h_c)

        # the scaling is extremely important
        scaling = h_c.size(-1) ** 0.5
        logits = (h_x @ h_c.T) / scaling  # [bs, n_classes]

        return logits


### Look at the initial model distribution

In [4]:
random_text_encoder = transformers.BertModel(
    transformers.BertConfig(num_hidden_layers=2, intermediate_size=256)
)
random_label_encoder = transformers.BertModel(
    transformers.BertConfig(num_hidden_layers=2, intermediate_size=256)
)
random_model = cat.ClassAttentionModel(
    random_text_encoder, random_label_encoder, hidden_size=768
)

x = torch.randint(0, 100, size=[3, 5])
c = torch.unique(torch.randint(0, 50, size=[7, 1])).unsqueeze(1)

logits = random_model(x, c)
p = F.softmax(logits, -1)
p

tensor([[0.1672, 0.1691, 0.1653, 0.1639, 0.1682, 0.1663],
        [0.1647, 0.1660, 0.1667, 0.1660, 0.1698, 0.1668],
        [0.1647, 0.1675, 0.1683, 0.1638, 0.1684, 0.1673]],
       grad_fn=<SoftmaxBackward>)

### Now, same thing but with the DistilBERT initialization

In [5]:
MODEL = 'distilbert-base-uncased'
_tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)

_text_encoder = transformers.AutoModel.from_pretrained(MODEL)
_label_encoder = transformers.AutoModel.from_pretrained(MODEL)

_model = cat.ClassAttentionModel(
    _text_encoder, _label_encoder, hidden_size=768
)

x = _tokenizer.encode_plus([
    "Loads dataset with zero-shot classes, creates collators and dataloaders",
    ""
])
c = torch.unique(torch.randint(0, 50, size=[7, 1])).unsqueeze(1)

logits = random_model(x, c)
p = F.softmax(logits, -1)
p

AttributeError: 'DistilBertTokenizerFast' object has no attribute 'encode_batch'