<a href="https://colab.research.google.com/github/MorenoSara/Few-Shot_Text_Classification/blob/main/zero_shot_text_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U sentence-transformers

In [None]:
import pandas as pd
from sentence_transformers import SentenceTransformer
import torch
from sentence_transformers.util import cos_sim
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
test_dataset = pd.read_excel('test.xlsx', index_col=0) # 32889 samples
test_dataset.head()

Unnamed: 0,Y1,Y2,Y,Domain,area,keywords,Abstract
37587,0,16,16,CS,Bioinformatics,Secretome; Ascending thoracic aortic aneurysm...,Background: Ascending thoracic aortic aneurysm...
37588,2,4,37,Psychology,Prosocial behavior,Post-institutionalized children; Internationa...,The study examined the social skills of 92 Rus...
37589,5,15,87,Medical,Diabetes,IL-3; CELL SURVIVAL; PI3k/Akt; Erk; OXIDATIVE...,Interleukin-3 (IL-3) is a well-characterized g...
37590,2,15,48,Psychology,Gender roles,European comparison; gender inequality; socia...,Family policies in France and various European...
37591,6,0,125,biochemistry,Molecular biology,convergent evolution; effector; entomopathoge...,Entomopathogenic fungi play a pivotal role in ...


In [None]:
REMAP_LEV1 = {'CS': 'Computer Science', 
              'Civil': 'Civil Engineering', 
              'ECE': 'Electrical Engineering', 
              'Psychology': 'Psychology', 
              'MAE': 'Mechanical Engineering', 
              'Medical': 'Medical Science', 
              'biochemistry': 'Biochemistry'}

In [None]:
def get_mapped_labels(data, mapping_dict):
  labels = [l.strip() for l in data]
  return list(map(lambda l: mapping_dict[l], labels))

In [None]:
labels = get_mapped_labels(set(test_dataset['Domain']), REMAP_LEV1)
abstracts = list(test_dataset['Abstract'])

In [None]:
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device = device)

In [None]:
labels_embeddings = model.encode(labels)
doc_embeddings = model.encode(abstracts, batch_size = 256, show_progress_bar=True) # directly encode the entire documents 

Batches:   0%|          | 0/37 [00:00<?, ?it/s]

In [None]:
# with open('test_doc_embeddings.txt','wb') as f:
    # for line in np.matrix(doc_embeddings):
       # np.savetxt(f, line)

In [None]:
# df = pd.read_csv('/content/test_doc_embeddings.txt', sep = ' ', header=None) 
# retrieve corresponding document using test_dataset.iloc[i]

In [None]:
def floored_cosine_knn(x, y):
  norm_x = x/np.linalg.norm(x)
  norm_y = y/np.linalg.norm(y)
  return max(0.0, 1 - np.dot(norm_x, norm_y))

In [None]:
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=1, algorithm = 'brute', metric = floored_cosine_knn)
knn.fit(labels_embeddings, labels)
y_pred = knn.predict(doc_embeddings)

In [None]:
import pprint
from sklearn.metrics import classification_report
pp = pprint.PrettyPrinter(depth=4)
pp.pprint(classification_report(get_mapped_labels(test_dataset['Domain'], REMAP_LEV1), y_pred, target_names=labels, output_dict=True))

{'Biochemistry': {'f1-score': 0.6402164111812443,
                  'precision': 0.7709011943539631,
                  'recall': 0.5474171164225135,
                  'support': 1297},
 'Civil Engineering': {'f1-score': 0.48351648351648346,
                       'precision': 0.42259887005649716,
                       'recall': 0.5649546827794562,
                       'support': 662},
 'Computer Science': {'f1-score': 0.6030624263839811,
                      'precision': 0.6153846153846154,
                      'recall': 0.5912240184757506,
                      'support': 866},
 'Electrical Engineering': {'f1-score': 0.6959418534221683,
                            'precision': 0.6141101015499733,
                            'recall': 0.8029350104821803,
                            'support': 1431},
 'Mechanical Engineering': {'f1-score': 0.556475458308693,
                            'precision': 0.42254153569824876,
                            'recall': 0.8147186147186147,
     

### Document embeddings with entropy calculation

In [None]:
import nltk
nltk.download('punkt')
import scipy
import numpy as np

In [None]:
def floored_cosine(X, Y):
  norm_x = X / np.linalg.norm(X, axis=1, keepdims=True)
  norm_y = Y / np.linalg.norm(Y, axis=1, keepdims=True)
  return np.maximum(0, np.matmul(norm_x, norm_y.T))

def floored_cosine_tensors(X, Y):
  sim = cos_sim(X, Y)
  return np.maximum(0, sim)

In [None]:
def get_entropies(sentences, labels_embeddings):
  # probs = [[floored_cosine_knn(s, y) for y in labels_embeddings] for s in sentences]
  # probs = floored_cosine(sentences, labels_embeddings)
  probs = floored_cosine_tensors(sentences, labels_embeddings)
  normalized_probs = probs/np.linalg.norm(probs)
  normalized_entropy = scipy.stats.entropy(normalized_probs, axis = 1)/np.log(labels_embeddings.shape[0])
  return 1 - normalized_entropy

In [None]:
from tqdm import tqdm
docs = []
for abs in tqdm(abstracts):
  sentences = nltk.tokenize.sent_tokenize(abs)
  sent_embs = model.encode(sentences)
  entropies = get_entropies(sent_embs, labels_embeddings)
  docs.append(np.dot(entropies, sent_embs)/sum(entropies))

docs_embeddings = np.array(docs) # shape (num_docs, 768)

  pk = 1.0*pk / np.sum(pk, axis=axis, keepdims=True)
100%|██████████| 9398/9398 [13:12<00:00, 11.86it/s]


In [None]:
indices = list(set(np.argwhere(np.isnan(docs_embeddings))[:, 0]))

In [None]:
dc = docs_embeddings
dc = np.delete(dc, indices, axis = 0)

In [None]:
knn = KNeighborsClassifier(n_neighbors=1, algorithm = 'brute', metric = floored_cosine_knn)
knn.fit(labels_embeddings, labels)
y_pred = knn.predict(dc) # almost the same results obtained without entropy 

In [None]:
labs = np.asarray(get_mapped_labels(test_dataset['Domain'], REMAP_LEV1))
labs = np.delete(labs, indices, axis = 0)

In [None]:
pp = pprint.PrettyPrinter(depth=4)
pp.pprint(classification_report(labs, y_pred, target_names=labels, output_dict=True))

{'Biochemistry': {'f1-score': 0.6241610738255035,
                  'precision': 0.791970802919708,
                  'recall': 0.5150316455696202,
                  'support': 1264},
 'Civil Engineering': {'f1-score': 0.47536617842876167,
                       'precision': 0.4180327868852459,
                       'recall': 0.5509259259259259,
                       'support': 648},
 'Computer Science': {'f1-score': 0.5909090909090909,
                      'precision': 0.6509298998569385,
                      'recall': 0.5410225921521997,
                      'support': 841},
 'Electrical Engineering': {'f1-score': 0.6929577464788732,
                            'precision': 0.6250705815923208,
                            'recall': 0.7773876404494382,
                            'support': 1424},
 'Mechanical Engineering': {'f1-score': 0.5674008810572688,
                            'precision': 0.42781222320637735,
                            'recall': 0.8421970357454228,
      