# Frozen, do not modify

In [46]:
import argparse
import logging
import pprint
import os
import sys
from itertools import chain
from collections import Counter

import torch
import torch.utils.data
import torch.nn.functional as F
import transformers
import datasets
import wandb

from tqdm import tqdm

import class_attention as cat

# Prepare and train a model

In [3]:
class args:
    model = "distilbert-base-uncased"
    test_class_frac = 0.2
    dataset_frac = 0.1
    batch_size = 32
    device = "cuda:2"
    lr = 1e-4
    max_epochs = 4
    normailze_cls = True
    normalize_txt = True
    scale_attention = False
    freeze_cls_network = False
    learn_temperature = True
    use_n_projection_layers = 1
    hidden = 128


In [4]:
(
    train_dataloader,
    test_dataloader,
    all_classes_str,
    test_classes_str,
) = cat.training_utils.prepare_dataloaders(args.test_class_frac, args.batch_size, args.model, args.dataset_frac)

Using custom data configuration default
Reusing dataset news_category (/home/vlialin/.cache/huggingface/datasets/news_category/default/0.0.0/737b7b6dff469cbba49a6202c9e94f9d39da1fed94e13170cf7ac4b61a75fb9c)


Moving the following classes to a class-test set: ['MEDIA', 'DIVORCE', 'QUEER VOICES', 'FOOD & DRINK', 'WELLNESS', 'HOME & LIVING', 'WORLDPOST', 'SPORTS']


HBox(children=(FloatProgress(value=0.0, description='Preprocessing Dataset', max=12214.0, style=ProgressStyle(…




HBox(children=(FloatProgress(value=0.0, description='Preprocessing Dataset', max=12214.0, style=ProgressStyle(…




In [5]:
text_encoder = transformers.AutoModel.from_pretrained(args.model)
label_encoder = transformers.AutoModel.from_pretrained(args.model)

In [6]:
model = cat.ClassAttentionModel(
    text_encoder,
    label_encoder,
    *vars(args),
)
model = model.to(args.device)

In [7]:
parameters = model.get_trainable_parameters()
optimizer = torch.optim.Adam(parameters, lr=args.lr)

In [9]:
config = {k: v for k, v in vars(args).items() if not k.startswith("_")}

wandb.init(project="class_attention", config=config, tags=["notebook"])
wandb.watch(model, log="all")

[34m[1mwandb[0m: Currently logged in as: [33mguitaricet[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.20 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[<wandb.wandb_torch.TorchGraph at 0x7fc0e46f98e0>]

In [10]:
model = cat.training_utils.train_cat_model(
    model,
    optimizer,
    train_dataloader,
    test_dataloader,
    all_classes_str,
    test_classes_str,
    args.max_epochs,
    args.device,
)

HBox(children=(FloatProgress(value=0.0, description='Epochs', max=4.0, style=ProgressStyle(description_width='…




# Error analysis

In [17]:
text_tokenizer = test_dataloader.dataset.text_tokenizer
label_tokenizer = test_dataloader.dataset.label_tokenizer

_t, _c = test_dataloader.dataset[4]
text_tokenizer.decode(_t), label_tokenizer.decode(_c)

('[CLS] former cia chief smacks down donald trump in clinton endorsement [SEP]',
 '[CLS] politics [SEP]')

In [21]:
set(all_classes_str).difference(test_classes_str)

{'ARTS',
 'ARTS & CULTURE',
 'BLACK VOICES',
 'BUSINESS',
 'COLLEGE',
 'COMEDY',
 'CRIME',
 'CULTURE & ARTS',
 'EDUCATION',
 'ENTERTAINMENT',
 'ENVIRONMENT',
 'FIFTY',
 'GOOD NEWS',
 'GREEN',
 'HEALTHY LIVING',
 'IMPACT',
 'LATINO VOICES',
 'MONEY',
 'PARENTING',
 'PARENTS',
 'POLITICS',
 'RELIGION',
 'SCIENCE',
 'STYLE',
 'STYLE & BEAUTY',
 'TASTE',
 'TECH',
 'THE WORLDPOST',
 'TRAVEL',
 'WEDDINGS',
 'WEIRD NEWS',
 'WOMEN',
 'WORLD NEWS'}

In [18]:
test_classes_str

['MEDIA',
 'FOOD & DRINK',
 'WELLNESS',
 'HOME & LIVING',
 'QUEER VOICES',
 'DIVORCE',
 'SPORTS',
 'WORLDPOST']

In [31]:
news_dataset = datasets.load_dataset("Fraser/news-category-dataset")
_, only_test_classes_data = cat.utils.split_classes(news_dataset["validation"], test_classes=test_classes_str)

Using custom data configuration default
Reusing dataset news_category (/home/vlialin/.cache/huggingface/datasets/news_category/default/0.0.0/737b7b6dff469cbba49a6202c9e94f9d39da1fed94e13170cf7ac4b61a75fb9c)


In [32]:
otc_dataset = cat.CatDataset(
    only_test_classes_data["headline"],
    text_tokenizer,
    only_test_classes_data["category"],
    label_tokenizer,
)

test_classes_ids = label_tokenizer.batch_encode_plus(
    test_classes_str,
    return_tensors="pt",
    add_special_tokens=True,
    padding=True,
)["input_ids"]

otc_collator = cat.CatTestCollator(
    possible_labels_ids=test_classes_ids, pad_token_id=label_tokenizer.pad_token_id
)

otc_dataloader = torch.utils.data.DataLoader(otc_dataset, collate_fn=otc_collator, shuffle=False, pin_memory=True)

HBox(children=(FloatProgress(value=0.0, description='Preprocessing Dataset', max=2463.0, style=ProgressStyle(d…




In [36]:
metrics = cat.utils.evaluate_model_per_class(
    model,
    otc_dataloader,
    device=args.device,
    labels_str=test_classes_str,
    zeroshot_labels=test_classes_str,
)

metrics

{'acc': tensor(0.5989, device='cuda:2'),
 'P/MEDIA': 0.34239130416174385,
 'R/MEDIA': 0.4632352937770329,
 'F1/MEDIA': 0.3937499508789124,
 'P/FOOD & DRINK': 0.5472027971071324,
 'R/FOOD & DRINK': 0.9399399396576756,
 'F1/FOOD & DRINK': 0.691712660516593,
 'P/WELLNESS': 0.8706047818747392,
 'R/WELLNESS': 0.6931690928675063,
 'F1/WELLNESS': 0.7718203994253021,
 'P/HOME & LIVING': 0.23280423268105596,
 'R/HOME & LIVING': 0.2046511626955111,
 'F1/HOME & LIVING': 0.21782173227748416,
 'P/QUEER VOICES': 0.6776315785015582,
 'R/QUEER VOICES': 0.3121212120266299,
 'F1/QUEER VOICES': 0.42738584875777363,
 'P/DIVORCE': 0.4508928569415657,
 'R/DIVORCE': 0.5489130431799386,
 'F1/DIVORCE': 0.4950979894535806,
 'P/SPORTS': 0.5357142855017006,
 'R/SPORTS': 0.5421686744810568,
 'F1/SPORTS': 0.5389221054752815,
 'P/WORLDPOST': 0.541899441038045,
 'R/WORLDPOST': 0.7886178855377091,
 'F1/WORLDPOST': 0.6423840572540714,
 'R_zero_shot': 0.5616020380278826,
 'P_zero_shot': 0.5248926597259426,
 'F1_zero_sho

In [43]:
random_model = cat.ClassAttentionModel(
    transformers.AutoModel.from_config(transformers.AutoConfig.from_pretrained(args.model)),
    transformers.AutoModel.from_config(transformers.AutoConfig.from_pretrained(args.model)),
)

metrics_random = cat.utils.evaluate_model_per_class(
    random_model,
    otc_dataloader,
    device=args.device,
    labels_str=test_classes_str,
    zeroshot_labels=test_classes_str,
)

metrics_random

{'acc': tensor(0.0966, device='cuda:2'),
 'P/MEDIA': 0.0,
 'R/MEDIA': 0.0,
 'F1/MEDIA': 0.0,
 'P/FOOD & DRINK': 0.0,
 'R/FOOD & DRINK': 0.0,
 'F1/FOOD & DRINK': 0.0,
 'P/WELLNESS': 0.43181818149104684,
 'R/WELLNESS': 0.06382978722689477,
 'F1/WELLNESS': 0.11121948973421036,
 'P/HOME & LIVING': 0.06586826343361182,
 'R/HOME & LIVING': 0.05116279067387777,
 'F1/HOME & LIVING': 0.05759157379599062,
 'P/QUEER VOICES': 0.24999999375000015,
 'R/QUEER VOICES': 0.0030303030293847566,
 'F1/QUEER VOICES': 0.005988021581986607,
 'P/DIVORCE': 0.0803921568588043,
 'R/DIVORCE': 0.8913043473416824,
 'F1/DIVORCE': 0.14748199919744218,
 'P/SPORTS': 0.0,
 'R/SPORTS': 0.0,
 'F1/SPORTS': 0.0,
 'P/WORLDPOST': 0.05102040811120367,
 'R/WORLDPOST': 0.040650406471015935,
 'F1/WORLDPOST': 0.04524881937721648,
 'R_zero_shot': 0.13124720434285697,
 'P_zero_shot': 0.10988737545558334,
 'F1_zero_shot': 0.04594123796085578}

In [56]:
bert_model = cat.ClassAttentionModel(
    transformers.AutoModel.from_pretrained(args.model),
    transformers.AutoModel.from_pretrained(args.model),
)

metrics_bert = cat.utils.evaluate_model_per_class(
    bert_model,
    otc_dataloader,
    device=args.device,
    labels_str=test_classes_str,
    zeroshot_labels=test_classes_str,
)

metrics_bert



{'acc': tensor(0.1518, device='cuda:2'),
 'P/MEDIA': 0.0,
 'R/MEDIA': 0.0,
 'F1/MEDIA': 0.0,
 'P/FOOD & DRINK': 0.16327543423507318,
 'R/FOOD & DRINK': 0.9879879876912948,
 'F1/FOOD & DRINK': 0.2802384764860978,
 'P/WELLNESS': 0.0,
 'R/WELLNESS': 0.0,
 'F1/WELLNESS': 0.0,
 'P/HOME & LIVING': 0.10044642854900748,
 'R/HOME & LIVING': 0.20930232548404543,
 'F1/HOME & LIVING': 0.13574656246916003,
 'P/QUEER VOICES': 0.0,
 'R/QUEER VOICES': 0.0,
 'F1/QUEER VOICES': 0.0,
 'P/DIVORCE': 0.0,
 'R/DIVORCE': 0.0,
 'F1/DIVORCE': 0.0,
 'P/SPORTS': 0.0,
 'R/SPORTS': 0.0,
 'F1/SPORTS': 0.0,
 'P/WORLDPOST': 0.0,
 'R/WORLDPOST': 0.0,
 'F1/WORLDPOST': 0.0,
 'R_zero_shot': 0.14966128914691754,
 'P_zero_shot': 0.03296523284801008,
 'F1_zero_shot': 0.05199812986940723}

In [48]:
class_counts = Counter(otc_dataloader.dataset.labels)
class_counts

Counter({'SPORTS': 249,
         'MEDIA': 136,
         'WELLNESS': 893,
         'QUEER VOICES': 330,
         'DIVORCE': 184,
         'FOOD & DRINK': 333,
         'WORLDPOST': 123,
         'HOME & LIVING': 215})

In [55]:
{k: v / len(otc_dataloader.dataset) for k, v in class_counts.most_common()}

{'WELLNESS': 0.3625659764514819,
 'FOOD & DRINK': 0.13520097442143728,
 'QUEER VOICES': 0.13398294762484775,
 'SPORTS': 0.10109622411693057,
 'HOME & LIVING': 0.08729192042224929,
 'DIVORCE': 0.07470564352415753,
 'MEDIA': 0.055217214778725134,
 'WORLDPOST': 0.049939098660170524}

## Result

Trained class attention model, **when evaluated on zero-shot classes only**, is significantly better than a random baseline or a constant baseline

* Random accuracy: 0.10
* BERT without fine-tuning: 0.15
* Best constant accuracy: 0.36
* Model: 0.60

In [57]:
all_classes_ids = label_tokenizer.batch_encode_plus(
    all_classes_str,
    return_tensors="pt",
    add_special_tokens=True,
    padding=True,
)["input_ids"]

all_classes_collator = cat.CatTestCollator(
    possible_labels_ids=all_classes_ids, pad_token_id=label_tokenizer.pad_token_id
)

otc_dataloader_all_classes = torch.utils.data.DataLoader(otc_dataset, collate_fn=all_classes_collator, shuffle=False, pin_memory=True)

In [60]:
cat.utils.evaluate_model_per_class(
    model,
    otc_dataloader_all_classes,
    device=args.device,
    labels_str=all_classes_str,
    zeroshot_labels=test_classes_str,
)

{'acc': tensor(0.0171, device='cuda:2'),
 'P/TRAVEL': 0.0,
 'R/TRAVEL': 0.0,
 'F1/TRAVEL': 0.0,
 'P/FOOD & DRINK': 0.0,
 'R/FOOD & DRINK': 0.0,
 'F1/FOOD & DRINK': 0.0,
 'P/IMPACT': 0.0,
 'R/IMPACT': 0.0,
 'F1/IMPACT': 0.0,
 'P/BLACK VOICES': 0.0,
 'R/BLACK VOICES': 0.0,
 'F1/BLACK VOICES': 0.0,
 'P/BUSINESS': 0.0,
 'R/BUSINESS': 0.0,
 'F1/BUSINESS': 0.0,
 'P/POLITICS': 0.0,
 'R/POLITICS': 0.0,
 'F1/POLITICS': 0.0,
 'P/SPORTS': 0.9285714219387756,
 'R/SPORTS': 0.05220883532039806,
 'F1/SPORTS': 0.09885930543451643,
 'P/FIFTY': 0.0,
 'R/FIFTY': 0.0,
 'F1/FIFTY': 0.0,
 'P/PARENTS': 0.0,
 'R/PARENTS': 0.0,
 'F1/PARENTS': 0.0,
 'P/MEDIA': 0.4999999750000013,
 'R/MEDIA': 0.0073529411710640145,
 'F1/MEDIA': 0.014492750745642099,
 'P/GREEN': 0.0,
 'R/GREEN': 0.0,
 'F1/GREEN': 0.0,
 'P/ARTS & CULTURE': 0.0,
 'R/ARTS & CULTURE': 0.0,
 'F1/ARTS & CULTURE': 0.0,
 'P/WELLNESS': 0.9999999500000026,
 'R/WELLNESS': 0.0022396416570840266,
 'F1/WELLNESS': 0.0044692732960894075,
 'P/QUEER VOICES': 0.999

In [62]:
metrics["P/SPORTS"], metrics["R/SPORTS"], metrics["F1/SPORTS"]

(0.5357142855017006, 0.5421686744810568, 0.5389221054752815)