<a href="https://colab.research.google.com/github/JacopoMangiavacchi/SBERT-ZSC/blob/main/ZeroShotClassification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.10.2-py3-none-any.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 5.3 MB/s 
Collecting huggingface-hub>=0.0.12
  Downloading huggingface_hub-0.0.16-py3-none-any.whl (50 kB)
[K     |████████████████████████████████| 50 kB 3.5 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 47.9 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 46.6 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |████████████████████████████████| 636 kB 39.2 MB/s 
Installing collected packages: tokenizers, sacremoses, pyyaml, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3

In [2]:
sentence = 'Who are you voting for in 2020?'
labels = ['business', 'art & culture', 'politics']

# Test HuggingFace Zero Shot Classification Pipeline

In [3]:
from transformers import pipeline

classifier = pipeline('zero-shot-classification')

No model was supplied, defaulted to facebook/bart-large-mnli (https://huggingface.co/facebook/bart-large-mnli)


Downloading:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [4]:
classes = classifier(sentence, labels)
classes

{'labels': ['politics', 'business', 'art & culture'],
 'scores': [0.9604312181472778, 0.020186087116599083, 0.019382672384381294],
 'sequence': 'Who are you voting for in 2020?'}

#Test with BERT Sentence Embedding

In [5]:
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
from torch.nn import functional as F
import torch
from scipy import spatial

In [6]:
def ZeroShotClassification(sentence, labels, tokenizer, model, template = "{}"):
  print(f"Template: {template}")
  print(f"Sentence: {sentence}\n")

  for label in labels:
    template_label = template.format(label)
    # run through model pre-trained on MNLI matching sentense with first label neighbors sentence
    x = tokenizer.encode(sentence, template_label, return_tensors='pt',
                         add_special_tokens = True, padding = True,
                         truncation="do_not_truncate")
    logits = model(x)[0]

    # we throw away "neutral" (dim 1) and take the probability of
    # "entailment" (2) as the probability of the label being true 
    entail_contradiction_logits = logits[:,[0,2]]
    probs = entail_contradiction_logits.softmax(dim=1)
    prob_label_is_true = probs[:,1]
    print(f" {prob_label_is_true.item()} {label}")

## Test with simple BERT Sentence Embedding and cosine similarity

In [7]:
# tokenizer = AutoTokenizer.from_pretrained('deepset/sentence_bert')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('deepset/sentence_bert')

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at deepset/sentence_bert were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
# run inputs through model and mean-pool over the sequence
# dimension to get sequence-level representations
inputs = tokenizer.batch_encode_plus([sentence] + labels,
                                     return_tensors='pt',
                                     pad_to_max_length=True)



In [9]:
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
output = model(input_ids, attention_mask=attention_mask)[0]
sentence_rep = output[:1].mean(dim=1)
label_reps = output[1:].mean(dim=1)

In [10]:
# find the highest cosine similarities between sentences
print(F.cosine_similarity(sentence_rep[0], label_reps[0], dim=0))
print(F.cosine_similarity(sentence_rep[0], label_reps[1], dim=0))
print(F.cosine_similarity(sentence_rep[0], label_reps[2], dim=0))

tensor(0.0045, grad_fn=<DivBackward0>)
tensor(-0.0274, grad_fn=<DivBackward0>)
tensor(0.2156, grad_fn=<DivBackward0>)


In [11]:
sentence_embedding = sentence_rep.detach().numpy()
label_embedding = label_reps.detach().numpy()

print(1 - spatial.distance.cosine(sentence_embedding[0], label_embedding[0]))
print(1 - spatial.distance.cosine(sentence_embedding[0], label_embedding[1]))
print(1 - spatial.distance.cosine(sentence_embedding[0], label_embedding[2]))


0.0045241788029670715
-0.02739689312875271
0.21561525762081146


In [12]:
similarities = F.cosine_similarity(sentence_rep, label_reps)
closest = similarities.argsort(descending=True)
for ind in closest:
  print(f'label: {labels[ind]} \t similarity: {similarities[ind]}')

label: politics 	 similarity: 0.21561525762081146
label: business 	 similarity: 0.004524169024080038
label: art & culture 	 similarity: -0.02739688940346241


In [13]:
similarities, F.softmax(similarities)

  """Entry point for launching an IPython kernel.


(tensor([ 0.0045, -0.0274,  0.2156], grad_fn=<DivBackward0>),
 tensor([0.3121, 0.3023, 0.3855], grad_fn=<SoftmaxBackward>))

## Test augmenting labels with static embedding neighbors (GloVe) and BERT Sentence Embedding with cosine similarity

In [14]:
import torchtext.vocab

glove = torchtext.vocab.GloVe(name='6B', dim=100)
print(f"{len(glove.itos)} words in dictionary")

.vector_cache/glove.6B.zip: 862MB [02:40, 5.38MB/s]                           
100%|█████████▉| 399999/400000 [00:18<00:00, 21512.60it/s]


400000 words in dictionary


In [15]:
def get_vector(embeddings, w):
  return embeddings.vectors[embeddings.stoi[w]]

def closest_words(embeddings, vector, n=10):
  distances = [(w, torch.dist(vector, get_vector(embeddings, w)).item()) for w in embeddings.itos]
  return sorted(distances, key = lambda w: w[1])[:n]

In [16]:
closest_words(glove, get_vector(glove, 'politics'))

[('politics', 0.0),
 ('political', 3.8383750915527344),
 ('debate', 4.631179332733154),
 ('matters', 4.661602973937988),
 ('influence', 4.729617118835449),
 ('culture', 4.731587886810303),
 ('rather', 4.750455856323242),
 ('history', 4.752238750457764),
 ('politicians', 4.768784999847412),
 ('matter', 4.817280292510986)]

In [17]:
labels_neighbours = [closest_words(glove, get_vector(glove, word)) for label in labels for word in label.split(' & ')]
labels_neighbours

[[('business', 0.0),
  ('industry', 3.5567009449005127),
  ('businesses', 3.84977126121521),
  ('marketing', 3.870338201522827),
  ('corporate', 3.901237726211548),
  ('enterprise', 4.052821636199951),
  ('companies', 4.098732948303223),
  ('company', 4.115787982940674),
  ('well', 4.250703811645508),
  ('commercial', 4.251638889312744)],
 [('art', 0.0),
  ('arts', 3.688779592514038),
  ('museum', 3.934798240661621),
  ('sculpture', 4.103562355041504),
  ('works', 4.126135349273682),
  ('photography', 4.151274681091309),
  ('contemporary', 4.155360221862793),
  ('painting', 4.276235103607178),
  ('gallery', 4.385191440582275),
  ('collection', 4.4654622077941895)],
 [('culture', 0.0),
  ('cultural', 3.783661127090454),
  ('tradition', 4.208914279937744),
  ('traditions', 4.227012634277344),
  ('cultures', 4.243590831756592),
  ('civilization', 4.2488861083984375),
  ('society', 4.413925647735596),
  ('history', 4.420716285705566),
  ('religion', 4.51834774017334),
  ('context', 4.55044

In [18]:
labels_neighbours = []
for l in range(len(labels)):
  neighbours = []
  for word in labels[l].split(' & '):
    neighbours.extend([n[0] for n in closest_words(glove, get_vector(glove, word))])
  labels_neighbours.append(neighbours)

labels_neighbours

[['business',
  'industry',
  'businesses',
  'marketing',
  'corporate',
  'enterprise',
  'companies',
  'company',
  'well',
  'commercial'],
 ['art',
  'arts',
  'museum',
  'sculpture',
  'works',
  'photography',
  'contemporary',
  'painting',
  'gallery',
  'collection',
  'culture',
  'cultural',
  'tradition',
  'traditions',
  'cultures',
  'civilization',
  'society',
  'history',
  'religion',
  'context'],
 ['politics',
  'political',
  'debate',
  'matters',
  'influence',
  'culture',
  'rather',
  'history',
  'politicians',
  'matter']]

In [19]:
labels_sentences = [' & '.join(neighbor) for neighbor in labels_neighbours]
labels_sentences

['business & industry & businesses & marketing & corporate & enterprise & companies & company & well & commercial',
 'art & arts & museum & sculpture & works & photography & contemporary & painting & gallery & collection & culture & cultural & tradition & traditions & cultures & civilization & society & history & religion & context',
 'politics & political & debate & matters & influence & culture & rather & history & politicians & matter']

In [20]:
# run inputs through model and mean-pool over the sequence
# dimension to get sequence-level representations
inputs = tokenizer.batch_encode_plus([sentence] + labels_sentences,
                                     return_tensors='pt',
                                     pad_to_max_length=True)

input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
output = model(input_ids, attention_mask=attention_mask)[0]
sentence_rep = output[:1].mean(dim=1)
label_reps = output[1:].mean(dim=1)

similarities = F.cosine_similarity(sentence_rep, label_reps)
closest = similarities.argsort(descending=True)
for ind in closest:
  print(f'label: {labels[ind]} \t similarity: {similarities[ind]}')



label: politics 	 similarity: 0.19136427342891693
label: business 	 similarity: 0.011012410745024681
label: art & culture 	 similarity: -0.04034051671624184


## Test with BERT MNLI Sequence Classification large (BART-LARGE-MNLI) with simple labels

In [21]:
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')

ZeroShotClassification(sentence, labels, tokenizer, model)

Template: {}
Sentence: Who are you voting for in 2020?

 0.2593141198158264 business
 0.06867872178554535 art & culture
 0.7672092914581299 politics


## Test BERT MNLI Sequence Classification with augmented labels by static embedding neighbors (GloVe)

In [22]:
ZeroShotClassification(sentence, labels_sentences, tokenizer, model)

Template: {}
Sentence: Who are you voting for in 2020?

 0.7021744251251221 business & industry & businesses & marketing & corporate & enterprise & companies & company & well & commercial
 0.9785778522491455 art & arts & museum & sculpture & works & photography & contemporary & painting & gallery & collection & culture & cultural & tradition & traditions & cultures & civilization & society & history & religion & context
 0.9197666049003601 politics & political & debate & matters & influence & culture & rather & history & politicians & matter


In [23]:
labels_extended = ['foreign policy', 'Europe', 'elections & vote', 'business & industry & businesses & marketing', '2020', 'outdoor recreation', 'politics & political & debate & matters & influence & culture & rather & history & politicians & matter', 'art & culture & arts & museum & sculpture & photography & painting & gallery']

ZeroShotClassification(sentence, labels_extended, tokenizer, model)

Template: {}
Sentence: Who are you voting for in 2020?

 0.5547330975532532 foreign policy
 0.038544122129678726 Europe
 0.88347327709198 elections & vote
 0.04872370883822441 business & industry & businesses & marketing
 0.9925205707550049 2020
 0.16658049821853638 outdoor recreation
 0.9197666049003601 politics & political & debate & matters & influence & culture & rather & history & politicians & matter
 0.13452927768230438 art & culture & arts & museum & sculpture & photography & painting & gallery


## Test with BERT MNLI Sequence Classification small (DistilBert-Uncased-MNLI) with simple labels

In [24]:
tokenizer = AutoTokenizer.from_pretrained('textattack/distilbert-base-uncased-MNLI')
model = AutoModelForSequenceClassification.from_pretrained('textattack/distilbert-base-uncased-MNLI')

ZeroShotClassification(sentence, labels, tokenizer, model)

Downloading:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/639 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Template: {}
Sentence: Who are you voting for in 2020?

 0.14961938560009003 business
 0.6417831778526306 art & culture
 0.8972325921058655 politics


In [25]:
labels_extended = ['foreign policy', 'Europe', 'elections & vote', 'business & industry & businesses & marketing', '2020', 'outdoor recreation', 'politics & political & debate & matters & influence & culture & rather & history & politicians & matter', 'art & culture & arts & museum & sculpture & photography & painting & gallery']

ZeroShotClassification(sentence, labels_extended, tokenizer, model)

Template: {}
Sentence: Who are you voting for in 2020?

 0.846084475517273 foreign policy
 0.7868844270706177 Europe
 0.9221031069755554 elections & vote
 0.9136581420898438 business & industry & businesses & marketing
 0.6420759558677673 2020
 0.5007471442222595 outdoor recreation
 0.6367895603179932 politics & political & debate & matters & influence & culture & rather & history & politicians & matter
 0.7028874754905701 art & culture & arts & museum & sculpture & photography & painting & gallery


In [26]:
tokenizer = AutoTokenizer.from_pretrained('ishan/distilbert-base-uncased-mnli')
model = AutoModelForSequenceClassification.from_pretrained('ishan/distilbert-base-uncased-mnli')

ZeroShotClassification(sentence, labels, tokenizer, model)

Downloading:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/639 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Template: {}
Sentence: Who are you voting for in 2020?

 0.09194833785295486 business
 0.8147938251495361 art & culture
 0.8844190835952759 politics


## Test with BERT MNLI Sequence Classification (BART-LARGE-MNLI) with 'bert-base-uncased' standard Tokenizer (DO NOT WORK - Only work with distilbert-base-uncased-mnli)

In [27]:
# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# model = AutoModelForSequenceClassification.from_pretrained('ishan/distilbert-base-uncased-mnli')
# model = AutoModelForSequenceClassification.from_pretrained('textattack/distilbert-base-uncased-MNLI')

# 'bert-base-uncased' Tokanizer DO NOT work with 'facebook/bart-large-mnli'

tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')

ZeroShotClassification(sentence, labels, tokenizer, model)

Template: {}
Sentence: Who are you voting for in 2020?

 0.2593141198158264 business
 0.06867872178554535 art & culture
 0.7672092914581299 politics


# Test with proper hypothesis_template "This example is {}."

As from Huggingface ZeroShotClassificationPipeline documentation:

The template used to turn each label into an NLI-style hypothesis. This template must include a {} or similar syntax for the candidate label to be inserted into the template.

The default template works well in many cases, but it may be worthwhile to experiment with different templates depending on the task setting.

## Test with BART-LARGE-MNLI

In [28]:
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')

ZeroShotClassification(sentence, labels, tokenizer, model, template="This example is {}.")

Template: This example is {}.
Sentence: Who are you voting for in 2020?

 0.010291516780853271 business
 0.009040906094014645 art & culture
 0.972069501876831 politics


## Test with small distilbert-base-uncased-mnli

In [29]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained('textattack/distilbert-base-uncased-MNLI')

ZeroShotClassification(sentence, labels, tokenizer, model, template="This example is {}.")

Template: This example is {}.
Sentence: Who are you voting for in 2020?

 0.44128814339637756 business
 0.7829711437225342 art & culture
 0.9490302205085754 politics


In [30]:
model = AutoModelForSequenceClassification.from_pretrained('ishan/distilbert-base-uncased-mnli')

ZeroShotClassification(sentence, labels, tokenizer, model, template="This example is {}.")

Template: This example is {}.
Sentence: Who are you voting for in 2020?

 0.1918676644563675 business
 0.9734393954277039 art & culture
 0.9768148064613342 politics
