<a href="https://colab.research.google.com/github/JacopoMangiavacchi/SBERT-ZSC/blob/main/ZSC-Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [49]:
sentence = 'Who are you voting for in 2020?'
labels = ['business', 'art & culture', 'politics']

# Test HuggingFace Zero Shot Classification Pipeline

In [15]:
!pip install transformers

from transformers import pipeline

classifier = pipeline('zero-shot-classification')



No model was supplied, defaulted to facebook/bart-large-mnli (https://huggingface.co/facebook/bart-large-mnli)


Downloading:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [16]:
classes = classifier(sentence, labels)
classes

{'labels': ['politics', 'business', 'art & culture'],
 'scores': [0.9604312181472778, 0.020186087116599083, 0.019382672384381294],
 'sequence': 'Who are you voting for in 2020?'}

# Test with simple Sentence BERT Embedding mean

In [17]:
from transformers import AutoTokenizer, AutoModel
from torch.nn import functional as F
from scipy import spatial

# tokenizer = AutoTokenizer.from_pretrained('deepset/sentence_bert')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('deepset/sentence_bert')

Some weights of the model checkpoint at deepset/sentence_bert were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [18]:
# run inputs through model and mean-pool over the sequence
# dimension to get sequence-level representations
inputs = tokenizer.batch_encode_plus([sentence] + labels,
                                     return_tensors='pt',
                                     pad_to_max_length=True)



In [19]:
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
output = model(input_ids, attention_mask=attention_mask)[0]
sentence_rep = output[:1].mean(dim=1)
label_reps = output[1:].mean(dim=1)

In [22]:
# find the highest cosine similarities between sentences
print(F.cosine_similarity(sentence_rep[0], label_reps[0], dim=0))
print(F.cosine_similarity(sentence_rep[0], label_reps[1], dim=0))
print(F.cosine_similarity(sentence_rep[0], label_reps[2], dim=0))

tensor(0.0045, grad_fn=<DivBackward0>)
tensor(-0.0274, grad_fn=<DivBackward0>)
tensor(0.2156, grad_fn=<DivBackward0>)


In [23]:
sentence_embedding = sentence_rep.detach().numpy()
label_embedding = label_reps.detach().numpy()

print(1 - spatial.distance.cosine(sentence_embedding[0], label_embedding[0]))
print(1 - spatial.distance.cosine(sentence_embedding[0], label_embedding[1]))
print(1 - spatial.distance.cosine(sentence_embedding[0], label_embedding[2]))


0.0045241788029670715
-0.02739689312875271
0.21561525762081146


In [24]:
similarities = F.cosine_similarity(sentence_rep, label_reps)
closest = similarities.argsort(descending=True)
for ind in closest:
  print(f'label: {labels[ind]} \t similarity: {similarities[ind]}')

label: politics 	 similarity: 0.21561525762081146
label: business 	 similarity: 0.004524169024080038
label: art & culture 	 similarity: -0.02739688940346241


In [33]:
similarities, F.softmax(similarities)

  """Entry point for launching an IPython kernel.


(tensor([ 0.0045, -0.0274,  0.2156], grad_fn=<DivBackward0>),
 tensor([0.3121, 0.3023, 0.3855], grad_fn=<SoftmaxBackward>))

# Test augmenting labels with static embedding neighbors

In [35]:
import torchtext.vocab

glove = torchtext.vocab.GloVe(name='6B', dim=100)
print(f"{len(glove.itos)} words in dictionary")

.vector_cache/glove.6B.zip: 862MB [02:40, 5.36MB/s]                           
100%|█████████▉| 399999/400000 [00:19<00:00, 20529.47it/s]


400000 words in dictionary


In [45]:
import torch

def get_vector(embeddings, w):
  return embeddings.vectors[embeddings.stoi[w]]

def closest_words(embeddings, vector, n=10):
  distances = [(w, torch.dist(vector, get_vector(embeddings, w)).item()) for w in embeddings.itos]
  return sorted(distances, key = lambda w: w[1])[:n]

In [47]:
closest_words(glove, get_vector(glove, 'politics'))

[('politics', 0.0),
 ('political', 3.8383750915527344),
 ('debate', 4.631179332733154),
 ('matters', 4.661602973937988),
 ('influence', 4.729617118835449),
 ('culture', 4.731587886810303),
 ('rather', 4.750455856323242),
 ('history', 4.752238750457764),
 ('politicians', 4.768784999847412),
 ('matter', 4.817280292510986)]

In [53]:
labels_neighbours = [closest_words(glove, get_vector(glove, word)) for label in labels for word in label.split(' & ')]
labels_neighbours

[[('business', 0.0),
  ('industry', 3.5567009449005127),
  ('businesses', 3.84977126121521),
  ('marketing', 3.870338201522827),
  ('corporate', 3.901237726211548),
  ('enterprise', 4.052821636199951),
  ('companies', 4.098732948303223),
  ('company', 4.115787982940674),
  ('well', 4.250703811645508),
  ('commercial', 4.251638889312744)],
 [('art', 0.0),
  ('arts', 3.688779592514038),
  ('museum', 3.934798240661621),
  ('sculpture', 4.103562355041504),
  ('works', 4.126135349273682),
  ('photography', 4.151274681091309),
  ('contemporary', 4.155360221862793),
  ('painting', 4.276235103607178),
  ('gallery', 4.385191440582275),
  ('collection', 4.4654622077941895)],
 [('culture', 0.0),
  ('cultural', 3.783661127090454),
  ('tradition', 4.208914279937744),
  ('traditions', 4.227012634277344),
  ('cultures', 4.243590831756592),
  ('civilization', 4.2488861083984375),
  ('society', 4.413925647735596),
  ('history', 4.420716285705566),
  ('religion', 4.51834774017334),
  ('context', 4.55044

In [56]:
labels_neighbours = []
for l in range(len(labels)):
  neighbours = []
  for word in labels[l].split(' & '):
    neighbours.append(word)
  labels_neighbours.append(neighbours)

labels_neighbours

[['business'], ['art', 'culture'], ['politics']]