<a href="https://colab.research.google.com/github/JacopoMangiavacchi/SBERT-ZSC/blob/main/ZSC-Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
sentence = 'Who are you voting for in 2020?'
labels = ['business', 'art & culture', 'politics']

# Test HuggingFace Zero Shot Classification Pipeline

In [2]:
!pip install transformers

from transformers import pipeline

classifier = pipeline('zero-shot-classification')

Collecting transformers
  Downloading transformers-4.10.2-py3-none-any.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 19.5 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 27.9 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 22.9 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |████████████████████████████████| 636 kB 29.1 MB/s 
Collecting huggingface-hub>=0.0.12
  Downloading huggingface_hub-0.0.16-py3-none-any.whl (50 kB)
[K     |████████████████████████████████| 50 kB 4.9 MB/s 
Installing collected packages: tokenizers, sacremoses, pyyaml, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 

No model was supplied, defaulted to facebook/bart-large-mnli (https://huggingface.co/facebook/bart-large-mnli)


Downloading:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [68]:
?? classifier

In [3]:
classes = classifier(sentence, labels)
classes

{'labels': ['politics', 'business', 'art & culture'],
 'scores': [0.9604312181472778, 0.020186087116599083, 0.019382672384381294],
 'sequence': 'Who are you voting for in 2020?'}

# Test with simple Sentence BERT Embedding mean

In [4]:
from transformers import AutoTokenizer, AutoModel
from torch.nn import functional as F
from scipy import spatial

# tokenizer = AutoTokenizer.from_pretrained('deepset/sentence_bert')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('deepset/sentence_bert')

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at deepset/sentence_bert were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
# run inputs through model and mean-pool over the sequence
# dimension to get sequence-level representations
inputs = tokenizer.batch_encode_plus([sentence] + labels,
                                     return_tensors='pt',
                                     pad_to_max_length=True)



In [6]:
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
output = model(input_ids, attention_mask=attention_mask)[0]
sentence_rep = output[:1].mean(dim=1)
label_reps = output[1:].mean(dim=1)

In [7]:
# find the highest cosine similarities between sentences
print(F.cosine_similarity(sentence_rep[0], label_reps[0], dim=0))
print(F.cosine_similarity(sentence_rep[0], label_reps[1], dim=0))
print(F.cosine_similarity(sentence_rep[0], label_reps[2], dim=0))

tensor(0.0045, grad_fn=<DivBackward0>)
tensor(-0.0274, grad_fn=<DivBackward0>)
tensor(0.2156, grad_fn=<DivBackward0>)


In [8]:
sentence_embedding = sentence_rep.detach().numpy()
label_embedding = label_reps.detach().numpy()

print(1 - spatial.distance.cosine(sentence_embedding[0], label_embedding[0]))
print(1 - spatial.distance.cosine(sentence_embedding[0], label_embedding[1]))
print(1 - spatial.distance.cosine(sentence_embedding[0], label_embedding[2]))


0.0045241788029670715
-0.02739689312875271
0.21561525762081146


In [9]:
similarities = F.cosine_similarity(sentence_rep, label_reps)
closest = similarities.argsort(descending=True)
for ind in closest:
  print(f'label: {labels[ind]} \t similarity: {similarities[ind]}')

label: politics 	 similarity: 0.21561525762081146
label: business 	 similarity: 0.004524169024080038
label: art & culture 	 similarity: -0.02739688940346241


In [10]:
similarities, F.softmax(similarities)

  """Entry point for launching an IPython kernel.


(tensor([ 0.0045, -0.0274,  0.2156], grad_fn=<DivBackward0>),
 tensor([0.3121, 0.3023, 0.3855], grad_fn=<SoftmaxBackward>))

# Test augmenting labels with static embedding neighbors

In [11]:
import torchtext.vocab

glove = torchtext.vocab.GloVe(name='6B', dim=100)
print(f"{len(glove.itos)} words in dictionary")

.vector_cache/glove.6B.zip: 862MB [02:44, 5.25MB/s]                           
100%|█████████▉| 399999/400000 [00:18<00:00, 22041.90it/s]


400000 words in dictionary


In [12]:
import torch

def get_vector(embeddings, w):
  return embeddings.vectors[embeddings.stoi[w]]

def closest_words(embeddings, vector, n=10):
  distances = [(w, torch.dist(vector, get_vector(embeddings, w)).item()) for w in embeddings.itos]
  return sorted(distances, key = lambda w: w[1])[:n]

In [13]:
closest_words(glove, get_vector(glove, 'politics'))

[('politics', 0.0),
 ('political', 3.8383750915527344),
 ('debate', 4.631179332733154),
 ('matters', 4.661602973937988),
 ('influence', 4.729617118835449),
 ('culture', 4.731587886810303),
 ('rather', 4.750455856323242),
 ('history', 4.752238750457764),
 ('politicians', 4.768784999847412),
 ('matter', 4.817280292510986)]

In [14]:
labels_neighbours = [closest_words(glove, get_vector(glove, word)) for label in labels for word in label.split(' & ')]
labels_neighbours

[[('business', 0.0),
  ('industry', 3.5567009449005127),
  ('businesses', 3.84977126121521),
  ('marketing', 3.870338201522827),
  ('corporate', 3.901237726211548),
  ('enterprise', 4.052821636199951),
  ('companies', 4.098732948303223),
  ('company', 4.115787982940674),
  ('well', 4.250703811645508),
  ('commercial', 4.251638889312744)],
 [('art', 0.0),
  ('arts', 3.688779592514038),
  ('museum', 3.934798240661621),
  ('sculpture', 4.103562355041504),
  ('works', 4.126135349273682),
  ('photography', 4.151274681091309),
  ('contemporary', 4.155360221862793),
  ('painting', 4.276235103607178),
  ('gallery', 4.385191440582275),
  ('collection', 4.4654622077941895)],
 [('culture', 0.0),
  ('cultural', 3.783661127090454),
  ('tradition', 4.208914279937744),
  ('traditions', 4.227012634277344),
  ('cultures', 4.243590831756592),
  ('civilization', 4.2488861083984375),
  ('society', 4.413925647735596),
  ('history', 4.420716285705566),
  ('religion', 4.51834774017334),
  ('context', 4.55044

In [15]:
labels_neighbours = []
for l in range(len(labels)):
  neighbours = []
  for word in labels[l].split(' & '):
    neighbours.extend([n[0] for n in closest_words(glove, get_vector(glove, word))])
  labels_neighbours.append(neighbours)

labels_neighbours

[['business',
  'industry',
  'businesses',
  'marketing',
  'corporate',
  'enterprise',
  'companies',
  'company',
  'well',
  'commercial'],
 ['art',
  'arts',
  'museum',
  'sculpture',
  'works',
  'photography',
  'contemporary',
  'painting',
  'gallery',
  'collection',
  'culture',
  'cultural',
  'tradition',
  'traditions',
  'cultures',
  'civilization',
  'society',
  'history',
  'religion',
  'context'],
 ['politics',
  'political',
  'debate',
  'matters',
  'influence',
  'culture',
  'rather',
  'history',
  'politicians',
  'matter']]

In [43]:
labels_sentences = [' & '.join(neighbor) for neighbor in labels_neighbours]
labels_sentences

['business & industry & businesses & marketing & corporate & enterprise & companies & company & well & commercial',
 'art & arts & museum & sculpture & works & photography & contemporary & painting & gallery & collection & culture & cultural & tradition & traditions & cultures & civilization & society & history & religion & context',
 'politics & political & debate & matters & influence & culture & rather & history & politicians & matter']

In [44]:
# run inputs through model and mean-pool over the sequence
# dimension to get sequence-level representations
inputs = tokenizer.batch_encode_plus([sentence] + labels_sentences,
                                     return_tensors='pt',
                                     pad_to_max_length=True)

input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
output = model(input_ids, attention_mask=attention_mask)[0]
sentence_rep = output[:1].mean(dim=1)
label_reps = output[1:].mean(dim=1)

similarities = F.cosine_similarity(sentence_rep, label_reps)
closest = similarities.argsort(descending=True)
for ind in closest:
  print(f'label: {labels[ind]} \t similarity: {similarities[ind]}')



label: business 	 similarity: 0.9124128818511963
label: politics 	 similarity: 0.7269384264945984
label: art & culture 	 similarity: 0.5993659496307373


# Test with MNLI Sequence Classification (BART-LARGE-MNLI)

In [40]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')

In [49]:
print(f"{sentence}\n")
for label_sentence in labels_sentences:
  # run through model pre-trained on MNLI matching sentense with first label neighbors sentence
  x = tokenizer.encode(sentence, label_sentence, return_tensors='pt',
                       truncation_strategy='only_first')
  logits = nli_model(x)[0]

  # we throw away "neutral" (dim 1) and take the probability of
  # "entailment" (2) as the probability of the label being true 
  entail_contradiction_logits = logits[:,[0,2]]
  probs = entail_contradiction_logits.softmax(dim=1)
  prob_label_is_true = probs[:,1]
  print(prob_label_is_true.item(), label_sentence)

Who are you voting for in 2020?





0.7021744251251221 business & industry & businesses & marketing & corporate & enterprise & companies & company & well & commercial
0.9785778522491455 art & arts & museum & sculpture & works & photography & contemporary & painting & gallery & collection & culture & cultural & tradition & traditions & cultures & civilization & society & history & religion & context
0.9197666049003601 politics & political & debate & matters & influence & culture & rather & history & politicians & matter


In [69]:
sentence = 'Who are you voting for in 2020?'
labels = ['business', 'art & culture', 'politics']

print(f"{sentence}\n")
for label in labels:
  # run through model pre-trained on MNLI matching sentense with first label neighbors sentence
  x = tokenizer.encode(sentence, label, return_tensors='pt',
                       truncation_strategy='only_first')
  logits = nli_model(x)[0]

  # we throw away "neutral" (dim 1) and take the probability of
  # "entailment" (2) as the probability of the label being true 
  entail_contradiction_logits = logits[:,[0,2]]
  probs = entail_contradiction_logits.softmax(dim=1)
  prob_label_is_true = probs[:,1]
  print(prob_label_is_true.item(), label)

Who are you voting for in 2020?





0.2593141198158264 business
0.06867872178554535 art & culture
0.7672092914581299 politics


In [67]:
sentence = "Who are you voting for in 2020?"
labels = ['foreign policy', 'Europe', 'elections & vote', 'business & industry & businesses & marketing', '2020', 'outdoor recreation', 'politics & political & debate & matters & influence & culture & rather & history & politicians & matter', 'art & culture & arts & museum & sculpture & photography & painting & gallery']

print(f"{sentence}\n")
for label in labels:
  # run through model pre-trained on MNLI matching sentense with first label neighbors sentence
  x = tokenizer.encode(sentence, label, return_tensors='pt',
                       truncation_strategy='only_first')
  logits = nli_model(x)[0]

  # we throw away "neutral" (dim 1) and take the probability of
  # "entailment" (2) as the probability of the label being true 
  entail_contradiction_logits = logits[:,[0,2]]
  probs = entail_contradiction_logits.softmax(dim=1)
  prob_label_is_true = probs[:,1]
  print(prob_label_is_true.item(), label)

Who are you voting for in 2020?





0.5547330975532532 foreign policy
0.038544122129678726 Europe
0.88347327709198 elections & vote
0.04872370883822441 business & industry & businesses & marketing
0.9925205707550049 2020
0.16658049821853638 outdoor recreation
0.9197666049003601 politics & political & debate & matters & influence & culture & rather & history & politicians & matter
0.13452927768230438 art & culture & arts & museum & sculpture & photography & painting & gallery
