# Using class attention for zero-shot learning

Terrible zero-shot performance for now, but we have a testing setup now.

In [1]:
import copy
import random
from pprint import pprint
from itertools import chain

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
from transformers import AutoModel, AutoTokenizer
import datasets

import wandb
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import class_attention as cat

%load_ext autoreload
%autoreload 2


def detorch(x):
    return x.detach().cpu().numpy()

To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


# Data

In [2]:
def split_classes(dataset, p_valid_classes=None, valid_classes=None, class_field_name='category', verbose=False):
    """
    Move random classes to a class-validation set (i.e. meta-test).
    
    All dataset examples with these classes are removed from the original dataset

    Args:
        dataset: datasets.arrow_dataset.Dataset object
        p_valid_classes: 0 < float < 1
        valid_classes: alternative to p_valid_classes, a list of classes to move to the class-validation set
        class_field_name: name of the class field in the dataset
        verbose: log splitted classes info
    
    Returns:
        (train_set, class_validation_set)
        where both objects are ArrowDataset and all validation classes are moved to class_validation_set
    """
    
    if not (p_valid_classes is None) ^ (valid_classes is None):
        raise ValueError("Only one of p_valid_classes or valid_classes should be specified. "
                         f"Got p_valid_classes = {p_valid_classes}\n"
                         f"valid_classes = {valid_classes}")

    if p_valid_classes is not None:
        all_classes = list(set(dataset[class_field_name]))
        n_valid_classes = int(len(all_classes) * p_valid_classes)
        if n_valid_classes == 0:
            raise ValueError(f"p_valid_classes={p_valid_classes} is too small for the dataset with {len(all_classes)} classes.")

        valid_classes = random.sample(all_classes, k=n_valid_classes)
    
    if verbose:
        print(f"Moving the following classes to a class-validation set: {valid_classes}")
    
    valid_mask = [c in valid_classes for c in dataset[class_field_name]]
    train_mask = [not m for m in valid_mask]

    valid_subset = dataset[valid_mask]
    train_subset = dataset[train_mask]
    
    valid_dataset = datasets.arrow_dataset.Dataset.from_dict(valid_subset)
    train_dataset = datasets.arrow_dataset.Dataset.from_dict(train_subset)
    
    return train_dataset, valid_dataset

### A note on labels

The dataset we use (`Fraser/news-category-dataset`) has some interesting particularities in the class names.

For example, it has classes `STYLE` and `STYLE & BEAUTY` or `WORLD NEWS` and `NEWS`. I.e., some classes contain other classes names in their name.
The classes that have `&` in their name have a similar particularity. Some of the categories does not seem to be distinguishable. E.g., `THE WORLDPOST` and `WORLDPOST` or `ARTS & CULTURE` and `CULTURE & ARTS`.



* &	: STYLE & BEAUTY, ARTS & CULTURE, HOME & LIVING, FOOD & DRINK, CULTURE & ARTS
* VOICES	: LATINO VOICES, BLACK VOICES, QUEER VOICES
* NEWS	: WEIRD NEWS, GOOD NEWS, WORLD NEWS
* ARTS	: ARTS, ARTS & CULTURE, CULTURE & ARTS
* CULTURE	: ARTS & CULTURE, CULTURE & ARTS
* LIVING	: HEALTHY LIVING, HOME & LIVING
* WORLDPOST	: THE WORLDPOST, WORLDPOST
* WORLD	: THE WORLDPOST, WORLDPOST

In [3]:
news_dataset = datasets.load_dataset("Fraser/news-category-dataset")

# some magic is happening here to make a toy dataset that is consistent, read carefuly
train_set = cat.utils.sample_dataset(news_dataset['train'], p=0.1)

all_classes = list(set(news_dataset['train']['category']))
classes_left = list(set(train_set['category']))

valid_set = news_dataset['validation']
if len(all_classes) > len(classes_left):
    _, valid_set = split_classes(valid_set, valid_classes=classes_left)

valid_set = cat.utils.sample_dataset(valid_set, p=0.1)

del news_dataset

# TODO: you need to work on naming

# valid_classes = ['TASTE', 'WELLNESS', 'EDUCATION']
train_reduced_set, _train_remainder = split_classes(train_set, p_valid_classes=0.21, verbose=True)

valid_classes = list(set(_train_remainder['category']))
valid_reduced_set, valid_remainder = split_classes(valid_set, valid_classes=valid_classes)

del train_set

print(f"Total classes: {len(all_classes)}")
print(f"Size of the resulting training set: {len(train_reduced_set)}")
print(f"Size of the resulting validation set for the train classes: {len(valid_reduced_set)}")
print(f"Size of the resulting validation set for the valid classes: {len(valid_remainder)}")

Using custom data configuration default
Reusing dataset news_category (/home/vlialin/.cache/huggingface/datasets/news_category/default/0.0.0/737b7b6dff469cbba49a6202c9e94f9d39da1fed94e13170cf7ac4b61a75fb9c)


Moving the following classes to a class-validation set: ['EDUCATION', 'STYLE & BEAUTY', 'RELIGION', 'SCIENCE', 'PARENTING', 'ARTS & CULTURE', 'ARTS', 'DIVORCE']
Total classes: 41
Size of the resulting training set: 13641
Size of the resulting validation set for the train classes: 843
Size of the resulting validation set for the valid classes: 161


# Test Collator

In [4]:
class CatTestCollator:
    """
    Collates text into batches with a fixed set of labels.

    During inference, we do not know in advance what class we have to predict.
    Thus, using a regular CatCollator that only inputs into the model
    a subset of classes is not possible as it would be cheating.

    The set of possible classes in this collator is defined during initialization
    and all of them are used for every batch.

    Args:
        possible_labels: torch.LongTensor[n_labels, label_len], a matrix of padded label ids
        pad_token_id: paddng token id used for BOTH texts and labels
    """
    def __init__(self, possible_labels: torch.LongTensor, pad_token_id):
        self.possible_labels = possible_labels
        self.pad_token_id = pad_token_id

    def __call__(self, examples):
        """
        Collates examples into batches and creates targets for the ``contrastive loss''.

        The main difference with CatCollator is the unique_labels.
        In the case of test collator, it is always equal to self.possible_labels
        and do not depend on the batch labels.

        Args:
            examples: list of tuples (text_seq, label_seq)
                where
                text_seq: LongTensor[text_len,]
                label_seq: LongTensor[label_len,]

        Returns:
            a tuple (batch_x, unique_labels, targets)
                where
                batch_x: LongTensor[batch_size, max_text_len]
                unique_labels: LongTensor[batch_size,] = self.possible_labels
                targets: LongTensor[batch_size,]
        """
        _validate_input(examples)

        batch_size = len(examples)
        max_text_len = max(len(text) for text, label in examples)
        max_label_len = max(len(label) for label in self.possible_labels)  # 1st major difference from CatCollator
        device = examples[0][0].device

        # we construct this tensor only to use it in get_index, we do not return it
        _pre_batch_y = torch.full(
            size=[batch_size, max_label_len],
            fill_value=self.pad_token_id,
            dtype=torch.int64,
            device=device,
        )

        batch_x = torch.full(
            size=[batch_size, max_text_len],
            fill_value=self.pad_token_id,
            dtype=torch.int64,
            device=device,
        )

        for i, (text, label) in enumerate(examples):
            batch_x[i, : len(text)] = text
            
            try:
                _pre_batch_y[i, : len(label)] = label
            except:
                import pdb; pdb.set_trace()
                print(_pre_batch_y.shape)
                print(i)
                print(len(label))
                print(label.shape)

        targets = get_index(self.possible_labels, _pre_batch_y)  # 2nd major difference from CatCollator

        if batch_size != targets.shape[0]:
            import pdb; pdb.set_trace()
            raise RuntimeError(f"Wrong number of targets. Expected {batch_size}, got {targets.shape[0]} instead.")

        return batch_x, self.possible_labels, targets


def get_index(host, target):
    diff = target.unsqueeze(1) - host.unsqueeze(0)
    dsum = torch.abs(diff).sum(-1)
    loc = torch.nonzero(dsum == 0)
    return loc[:, -1]

def _validate_input(examples):
    if not isinstance(examples[0], tuple):
        raise ValueError(examples)

    text_0, label_0 = examples[0]
    if not len(text_0.shape) == 1:
        raise ValueError(
            f"Wrong number of dimensions in the text tensor. "
            f"Expected a rank-one tensor, got rank-{len(text_0.shape)} instead"
        )

    if not len(label_0.shape) == 1:
        raise ValueError(
            f"Wrong number of dimensions in the label tensor. "
            f"Expected a rank-one tensor, got rank-{len(label_0.shape)} instead"
        )

# Model

In [5]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

__all__ = ["MaskPowerNorm"]


def _sum_ft(tensor):
    """sum over the first and last dimention"""
    return tensor.sum(dim=0).sum(dim=-1)


class GroupScaling1D(nn.Module):
    r"""Scales inputs by the second moment for the entire layer."""

    def __init__(self, eps=1e-5, group_num=4):
        super(GroupScaling1D, self).__init__()
        self.eps = eps
        self.group_num = group_num

    def extra_repr(self):
        return f"eps={self.eps}, group={self.group_num}"

    def forward(self, input):
        # calculate second moment
        # different group use different mean
        T, B, C = input.shape[0], input.shape[1], input.shape[2]
        Cg = C // self.group_num
        gn_input = input.contiguous().reshape(T, B, self.group_num, Cg)
        moment2 = (
            torch.repeat_interleave(
                torch.mean(gn_input * gn_input, dim=3, keepdim=True), repeats=Cg, dim=-1
            )
            .contiguous()
            .reshape(T, B, C)
        )
        # divide out second moment
        return input / torch.sqrt(moment2 + self.eps)


def _unsqueeze_ft(tensor):
    """add new dimensions at the front and the tail"""
    return tensor.unsqueeze(0).unsqueeze(-1)


class PowerFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        x,
        weight,
        bias,
        running_phi,
        eps,
        afwd,
        abkw,
        ema_gz,
        debug,
        warmup_iters,
        current_iter,
        mask_x,
    ):
        ctx.eps = eps
        ctx.debug = debug
        current_iter = current_iter.item()
        ctx.current_iter = current_iter
        ctx.warmup_iters = warmup_iters
        ctx.abkw = abkw
        rmax = 1
        N, C, H, W = x.size()
        x2 = (mask_x * mask_x).mean(dim=0)

        var = x2.reshape(1, C, 1, 1)
        if current_iter <= warmup_iters:
            z = x / (var + eps).sqrt()
        else:
            z = x / (running_phi + eps).sqrt()

        y = z
        ctx.save_for_backward(z, var, weight, ema_gz)

        if current_iter < warmup_iters:
            running_phi.copy_(
                running_phi * (current_iter - 1) / current_iter
                + var.mean(dim=0, keepdim=True) / current_iter
            )
        running_phi.copy_(afwd * running_phi + (1 - afwd) * var.mean(dim=0, keepdim=True))
        y = weight.reshape(1, C, 1, 1) * y + bias.reshape(1, C, 1, 1)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        debug = ctx.debug
        current_iter = ctx.current_iter
        warmup_iters = ctx.warmup_iters
        abkw = ctx.abkw

        N, C, H, W = grad_output.size()
        z, var, weight, ema_gz = ctx.saved_variables

        y = z
        g = grad_output * weight.reshape(1, C, 1, 1)
        g = g * 1

        gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)

        approx_grad_g = g - (1 - abkw) * ema_gz * z
        ema_gz.add_(
            (approx_grad_g * z)
            .mean(dim=3, keepdim=True)
            .mean(dim=2, keepdim=True)
            .mean(dim=0, keepdim=True)
        )

        gx = 1.0 / torch.sqrt(var + eps) * approx_grad_g
        return (
            gx,
            (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
            grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


class MaskPowerNorm(nn.Module):
    """
    An implementation of masked batch normalization, used for testing the numerical
    stability.
    """

    def __init__(
        self,
        num_features,
        eps=1e-5,
        alpha_fwd=0.9,
        alpha_bkw=0.9,
        affine=True,
        warmup_iters=10000,
        group_num=1,
    ):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.affine = affine

        self.register_parameter("weight", nn.Parameter(torch.ones(num_features)))
        self.register_parameter("bias", nn.Parameter(torch.zeros(num_features)))
        self.register_buffer("running_phi", torch.ones(1, num_features, 1, 1))
        self.register_buffer("ema_gz", torch.zeros(1, num_features, 1, 1))
        self.register_buffer("iters", torch.zeros(1).type(torch.LongTensor))

        self.afwd = alpha_fwd
        self.abkw = alpha_bkw

        self.eps = eps
        self.debug = False
        self.warmup_iters = warmup_iters
        self.gp = GroupScaling1D(group_num=group_num)
        self.group_num = group_num

    def extra_repr(self):
        return (
            "{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, "
            "affine={affine}, warmup={warmup_iters}, group_num={group_num}".format(**self.__dict__)
        )

    def forward(self, input, pad_mask=None, is_encoder=False):
        """
        input:  T x B x C -> B x C x T
             :  B x C x T -> T x B x C
        pad_mask: B x T (padding is True)
        """
        shaped_input = len(input.shape) == 2
        if shaped_input:
            input = input.unsqueeze(0)
        T, B, C = input.shape
        input = self.gp(input)

        # construct the mask_input, size to be (BxL) x C: L is the real length here
        if pad_mask is None:
            mask_input = input.clone()
        else:
            # Transpose the bn_mask (B x T -> T x B)
            bn_mask = ~pad_mask
            bn_mask = bn_mask.transpose(0, 1)

        if pad_mask is not None:
            pad_size = (~bn_mask).sum()
            mask_input = input[bn_mask, :]
        else:
            mask_input = input.clone()

        mask_input = mask_input.reshape(-1, self.num_features)

        input = input.permute(1, 2, 0).contiguous()
        input_shape = input.size()
        input = input.reshape(input.size(0), self.num_features, -1)
        input = input.unsqueeze(-1)

        if self.training:
            self.iters.copy_(self.iters + 1)
            output = PowerFunction.apply(
                input,
                self.weight,
                self.bias,
                self.running_phi,
                self.eps,
                self.afwd,
                self.abkw,
                self.ema_gz,
                self.debug,
                self.warmup_iters,
                self.iters,
                mask_input,
            )

        else:
            N, C, H, W = input.size()
            var = self.running_phi
            output = input / (var + self.eps).sqrt()
            output = self.weight.reshape(1, C, 1, 1) * output + self.bias.reshape(1, C, 1, 1)

        output = output.reshape(input_shape)
        output = output.permute(2, 0, 1).contiguous()
        # Reshape it.
        if shaped_input:
            output = output.squeeze(0)

        return output


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from class_attention import modelling_utils


def normalize_embeds(embeds):
    return embeds / torch.sqrt(torch.sum(embeds * embeds, dim=1, keepdim=True))


class ClassAttentionModel(nn.Module):
    def __init__(self, txt_encoder, cls_encoder, hidden_size):
        super().__init__()

#         print('Hidden size is not used')

        self.txt_encoder = txt_encoder
        self.cls_encoder = cls_encoder

        txt_encoder_h = modelling_utils.get_output_dim(txt_encoder)
        self.txt_out = nn.Linear(txt_encoder_h, hidden_size)
#         self.txt_out_norm = MaskPowerNorm(hidden_size)

        cls_encoder_h = modelling_utils.get_output_dim(cls_encoder)
        self.cls_out = nn.Linear(cls_encoder_h, hidden_size)
#         self.cls_out_norm = MaskPowerNorm(hidden_size)

    def forward(self, text_input, labels_input):
        """
        Compute logits for input (input_dict,) corresponding to the classes (classes_dict)

        Optionally, you can provide additional keys in either input_dict or classes_dict
        Specifically, attention_mask, head_mask and inputs_embeds
        Howerver, one should not provide output_attentions and output_hidden_states

        Args:
            text_input: dict with key input_ids
                input_ids: LongTensor[batch_size, text_seq_len], input to the text network
            labels_input: dict with key input_ids
                input_ids: LongTensor[n_classes, class_seq_len], a list of possible classes, each class described via text
        """
        text_input, labels_input = modelling_utils.maybe_format_inputs(text_input, labels_input)
        modelling_utils.validate_inputs(text_input, labels_input)

        h_x = self.txt_encoder(**text_input)  # some tuple
        h_x = h_x[0]  # FloatTensor[bs, text_seq_len, hidden]
        h_x = h_x[:, 0]  # get CLS token representations, FloatTensor[bs, hidden]

        h_c = self.cls_encoder(**labels_input)  # some tuple
        h_c = h_c[0]  # FloatTensor[n_classes, class_seq_len, hidden]

        h_c, _ = torch.max(h_c, dim=1)  # [n_classes, hidden]

        # attention map
        h_x = self.txt_out(h_x)
#         h_x = F.dropout(h_x, p=0.5)
#         h_x = self.txt_out_norm(h_x)

        h_c = self.cls_out(h_c)
        h_c = normalize_embeds(h_c)

#         h_c = F.dropout(h_c, p=0.5)
#         h_c = self.txt_out_norm(h_c)

        # the scaling is extremely important
        scaling = h_c.size(-1) ** 0.5
        logits = (h_x @ h_c.T) / scaling  # [bs, n_classes]

        return logits


In [17]:
# Tokenizers
MODEL = 'distilbert-base-uncased'
text_tokenizer = AutoTokenizer.from_pretrained(MODEL, fast=True)
label_tokenizer = AutoTokenizer.from_pretrained(MODEL, fast=True)

# Dataset
dataset = cat.CatDataset(train_reduced_set['headline'], text_tokenizer, train_reduced_set['category'], label_tokenizer)
valid_dataset = cat.CatDataset(valid_set['headline'], text_tokenizer, valid_set['category'], label_tokenizer)
valid_train_dataset = cat.CatDataset(valid_reduced_set['headline'], text_tokenizer, valid_reduced_set['category'], label_tokenizer)
valid_valid_dataset = cat.CatDataset(valid_remainder['headline'], text_tokenizer, valid_remainder['category'], label_tokenizer)

HBox(children=(FloatProgress(value=0.0, description='Preprocessing Dataset', max=13641.0, style=ProgressStyle(…




HBox(children=(FloatProgress(value=0.0, description='Preprocessing Dataset', max=1004.0, style=ProgressStyle(d…




HBox(children=(FloatProgress(value=0.0, description='Preprocessing Dataset', max=843.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Preprocessing Dataset', max=161.0, style=ProgressStyle(de…




In [19]:
label_tokenizer.batch_encode_plus(
    all_classes_str, return_tensors="pt", add_special_tokens=True
)

{'input_ids': tensor([[  101, 20429,   102,     0,     0],
         [  101, 25860,   102,     0,     0],
         [  101,  2308,   102,     0,     0],
         [  101,  2267,   102,     0,     0],
         [  101,  2840,  1004,  3226,   102],
         [  101,  3604,   102,     0,     0],
         [  101,  7402,  5755,   102,     0],
         [  101,  8179,   102,     0,     0],
         [  101,  5595,   102,     0,     0],
         [  101,  1996,  2088, 19894,   102],
         [  101,  2665,   102,     0,     0],
         [  101,  6627,   102,     0,     0],
         [  101,  2449,   102,     0,     0],
         [  101,  4254,   102,     0,     0],
         [  101,  2806,  1004,  5053,   102],
         [  101,  2204,  2739,   102,     0],
         [  101,  2088, 19894,   102,     0],
         [  101,  5510,   102,     0,     0],
         [  101, 19483,  5755,   102,     0],
         [  101,  4024,   102,     0,     0],
         [  101,  6881,  2739,   102,     0],
         [  101,  796

In [8]:
# Train Dataloader
batch_size = 24
collator = cat.CatCollator(pad_token_id=0)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collator, shuffle=True)

# Classes
training_classes_str = set(train_reduced_set['category'])
validation_train_classes_str = set(valid_reduced_set['category'])
validation_valid_classes_str = set(valid_remainder['category'])

assert training_classes_str == validation_train_classes_str

all_classes_str = set(all_classes)  # defined above
assert training_classes_str.issubset(all_classes_str)
assert validation_train_classes_str.issubset(all_classes_str)

In [16]:
all_classes_str

['WEDDINGS',
 'WELLNESS',
 'WOMEN',
 'COLLEGE',
 'ARTS & CULTURE',
 'TRAVEL',
 'LATINO VOICES',
 'DIVORCE',
 'FIFTY',
 'THE WORLDPOST',
 'GREEN',
 'TECH',
 'BUSINESS',
 'IMPACT',
 'STYLE & BEAUTY',
 'GOOD NEWS',
 'WORLDPOST',
 'TASTE',
 'QUEER VOICES',
 'ENTERTAINMENT',
 'WEIRD NEWS',
 'HEALTHY LIVING',
 'RELIGION',
 'MEDIA',
 'SPORTS',
 'WORLD NEWS',
 'EDUCATION',
 'ARTS',
 'STYLE',
 'PARENTING',
 'FOOD & DRINK',
 'ENVIRONMENT',
 'CULTURE & ARTS',
 'HOME & LIVING',
 'COMEDY',
 'BLACK VOICES',
 'SCIENCE',
 'PARENTS',
 'MONEY',
 'CRIME',
 'POLITICS']

In [15]:
label_tokenizer.batch_encode_plus()

<bound method PreTrainedTokenizer.batch_encode_plus of <transformers.tokenization_distilbert.DistilBertTokenizer object at 0x7f458c2b01f0>>

In [9]:
training_classes_str = list(training_classes_str)
training_classes_ids = label_tokenizer.batch_encode_plus(
    training_classes_str,
    return_tensors='pt',
    add_special_tokens=True  # should be consistent with the tokenization used in the CatDataset
)['input_ids']

all_classes_str = list(all_classes_str)
all_classes_ids = label_tokenizer.batch_encode_plus(
    all_classes_str,
    return_tensors='pt',
    add_special_tokens=True
)['input_ids']


test_collator_all_labels = CatTestCollator(
    possible_labels=all_classes_ids,
    pad_token_id=label_tokenizer.pad_token_id
)


test_collator_train_labels = CatTestCollator(
    possible_labels=training_classes_ids,
    pad_token_id=label_tokenizer.pad_token_id
)


# Validation dataloader for the training classes
valid_train_dataloader = torch.utils.data.DataLoader(
    valid_train_dataset,
    batch_size=batch_size,
    collate_fn=test_collator_train_labels,
    shuffle=False,
)

# Validation dataloader for the test classes (zero-shot)
valid_valid_dataloader = torch.utils.data.DataLoader(
    valid_valid_dataset,
    batch_size=batch_size,
    collate_fn=test_collator_all_labels,
    shuffle=False
)

# Validation dataloader for the train classes given all possible classes
valid_train_dataloader_but_all = torch.utils.data.DataLoader(
    valid_train_dataset,
    batch_size=batch_size,
    collate_fn=test_collator_all_labels,
    shuffle=False
)

# Validation dataloader for all classes
valid_dataloader_full = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=batch_size,
    collate_fn=test_collator_all_labels,
    shuffle=False
)

## Train class attention

In [10]:
def validate_model_on_dataloader(model, dataloader, device):
    if "CatTestCollator" not in str(dataloader.collate_fn):
        raise RuntimeError("Validation or test dataloader should have a CatTestCollator instead of CatCollator")

    model = model.to(device)
    model.eval()
    n_correct = 0
    n_total = 0

    with torch.no_grad():
        for x, c, y in dataloader:
            # Note: `c` does not change in CatTestCollator
            x, c, y = x.to(device), c.to(device), y.to(device)

            logits = model(x, c)

            _, preds = logits.max(-1)

            n_correct += torch.sum(preds == y).float()
            n_total += x.shape[0]

    acc = n_correct / n_total

    model.train()

    return acc


_random_text_encoder = transformers.BertModel(transformers.BertConfig(num_hidden_layers=2, intermediate_size=256))
_random_label_encoder = transformers.BertModel(transformers.BertConfig(num_hidden_layers=2, intermediate_size=256))
_random_model = cat.ClassAttentionModel(_random_text_encoder, _random_label_encoder, hidden_size=768)

_acc = validate_model_on_dataloader(_random_model, valid_train_dataloader, device='cuda')
print(_acc)

_acc = validate_model_on_dataloader(_random_model, valid_valid_dataloader, device='cuda')
print(_acc)

tensor(0.0403, device='cuda:0')
tensor(0.0311, device='cuda:0')


In [11]:
from collections import defaultdict


def validate_model_per_class_on_dataloader(model, dataloader, device, labels):
    """
    Args:
        labels: List[str], names of classes, in the same order as in the CatTestCollator.possible_labels
    """
    model = model.to(device)
    model.eval()

    n_correct = 0
    n_total = 0
    label2n_correct = defaultdict(int)
    label2n_predicted = defaultdict(int)
    label2n_expected = defaultdict(int)

    with torch.no_grad():
        for x, c, y in dataloader:
            # Note: `c` does not change in CatTestCollator
            x, c, y = x.to(device), c.to(device), y.to(device)

            logits = model(x, c)

            _, preds = logits.max(-1)

            predicted_labels = [labels[i] for i in preds]
            expected_labels = [labels[i] for i in y]

            for label_pred, label_exp in zip(predicted_labels, expected_labels):
                label2n_predicted[label_pred] += 1
                label2n_expected[label_exp] += 1
                label2n_correct[label_pred] += int(label_pred == label_exp)

            n_correct += torch.sum(preds == y).float()
            n_total += x.shape[0]

    res = {
        "acc": n_correct / n_total,
    }

    for label in label2n_expected.keys():
        label_str = "_".join(label.split(' '))
        p = label2n_correct[label] / (label2n_predicted[label] + 1e-7)
        r = label2n_correct[label] / (label2n_expected[label] + 1e-7)

        res["P/" + label_str] = p
        res["R/" + label_str] = r
        res["F1/" + label_str] = 2 * (p * r) / (p + r + 1e-7)

    model.train()
    return res


In [12]:
_random_text_encoder = transformers.BertModel(transformers.BertConfig(num_hidden_layers=2, intermediate_size=256))
_random_label_encoder = transformers.BertModel(transformers.BertConfig(num_hidden_layers=2, intermediate_size=256))
_random_model = cat.ClassAttentionModel(_random_text_encoder, _random_label_encoder, hidden_size=768)

_acc = validate_model_per_class_on_dataloader(_random_model, valid_dataloader_full, device='cuda', labels=all_classes_str)
print(_acc)

{'acc': tensor(0.0159, device='cuda:0'), 'P/COMEDY': 0.0, 'R/COMEDY': 0.0, 'F1/COMEDY': 0.0, 'P/PARENTING': 0.0, 'R/PARENTING': 0.0, 'F1/PARENTING': 0.0, 'P/ENTERTAINMENT': 0.0, 'R/ENTERTAINMENT': 0.0, 'F1/ENTERTAINMENT': 0.0, 'P/FOOD_&_DRINK': 0.0, 'R/FOOD_&_DRINK': 0.0, 'F1/FOOD_&_DRINK': 0.0, 'P/WORLDPOST': 0.0, 'R/WORLDPOST': 0.0, 'F1/WORLDPOST': 0.0, 'P/HEALTHY_LIVING': 0.0, 'R/HEALTHY_LIVING': 0.0, 'F1/HEALTHY_LIVING': 0.0, 'P/TRAVEL': 0.0, 'R/TRAVEL': 0.0, 'F1/TRAVEL': 0.0, 'P/WOMEN': 0.0, 'R/WOMEN': 0.0, 'F1/WOMEN': 0.0, 'P/BLACK_VOICES': 0.0, 'R/BLACK_VOICES': 0.0, 'F1/BLACK_VOICES': 0.0, 'P/BUSINESS': 0.0, 'R/BUSINESS': 0.0, 'F1/BUSINESS': 0.0, 'P/WEIRD_NEWS': 0.0, 'R/WEIRD_NEWS': 0.0, 'F1/WEIRD_NEWS': 0.0, 'P/SPORTS': 0.0, 'R/SPORTS': 0.0, 'F1/SPORTS': 0.0, 'P/CRIME': 0.0, 'R/CRIME': 0.0, 'F1/CRIME': 0.0, 'P/IMPACT': 0.0, 'R/IMPACT': 0.0, 'F1/IMPACT': 0.0, 'P/POLITICS': 0.0, 'R/POLITICS': 0.0, 'F1/POLITICS': 0.0, 'P/QUEER_VOICES': 0.0, 'R/QUEER_VOICES': 0.0, 'F1/QUEER_VOICES

In [13]:
# Model
text_encoder = AutoModel.from_pretrained(MODEL)
label_encoder = AutoModel.from_pretrained(MODEL)

model = cat.ClassAttentionModel(text_encoder, label_encoder, hidden_size=4096)

x = torch.randint(0, 100, size=[3, 5])
c = torch.unique(torch.randint(0, 100, size=[7, 1])).unsqueeze(1)

out = model(text_input=x, labels_input=c)
assert out.shape == (3, 7)

device = 'cuda'
model = model.to(device)

parameters = chain(model.txt_encoder.parameters(), model.txt_out.parameters(), model.cls_out.parameters())
optimizer = torch.optim.Adam(parameters, lr=1e-4)

In [14]:
test_classes_str = ','.join(valid_classes)
wandb.init(project='class_attention', tags=['notebooks'], notes=' ', config={"test_classes": test_classes_str})
wandb.watch(model)


for _ in tqdm(range(50)):
    for x, c, y in dataloader:
        optimizer.zero_grad()

        x = x.to(device)
        c = c.to(device)
        y = y.to(device)

        x_dict = {'input_ids': x}
        c_dict = {'input_ids': c}
        logits = model(x_dict, c_dict)

        loss = F.cross_entropy(logits, y)

        _, preds = logits.max(-1)
        acc = torch.sum(preds == y).float() / x.shape[0]

        wandb.log({
            'train_acc': acc,
            'loss': loss,
        })

        loss.backward()
        optimizer.step()

    valid_train_acc = validate_model_on_dataloader(model, valid_train_dataloader, device=device)
    valid_valid_acc = validate_model_on_dataloader(model, valid_valid_dataloader, device=device)
    valid_train_acc_given_all = validate_model_on_dataloader(model, valid_train_dataloader_but_all, device=device)
    valid_acc_per_class = validate_model_per_class_on_dataloader(model, valid_dataloader_full, device=device, labels=all_classes_str)

    wandb.log({
        'eval/train_classes_acc': valid_train_acc,
        'eval/train_classes_given_all_acc': valid_train_acc_given_all,
        'eval/valid_classes_acc': valid_valid_acc,
        **{f'eval_c/{k}': v for k, v in valid_acc_per_class.items()},
    })

[34m[1mwandb[0m: Currently logged in as: [33mguitaricet[0m (use `wandb login --relogin` to force relogin)


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))




KeyboardInterrupt: 

In [24]:
wandb.join()

VBox(children=(Label(value=' 0.09MB of 0.09MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_acc,0.75
loss,0.59033
_step,2689.0
_runtime,166.0
_timestamp,1611783098.0
validation/train_classes_acc,0.52755
validation/train_classes_given_all_acc,0.52755
validation/valid_classes_acc,0.0036


0,1
train_acc,▁▂▁▃▃▅▄▄▄▂▄▅▅▅▆▅▇▆▅▆▆▅▇▇▆▆▅▇▆▆▇▆█▇▆▇█▆▇▆
loss,█▇▇▆▅▄▄▅▄▆▄▄▃▄▄▄▂▂▃▂▂▄▁▁▃▂▄▂▂▃▂▂▂▁▂▁▁▃▁▂
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
validation/train_classes_acc,▁█▅█▂
validation/train_classes_given_all_acc,▁▆▅█▃
validation/valid_classes_acc,█▂▂▁▂
