In [None]:
!pip install beir
!pip install tensorflow-text
!pip install --upgrade pip
!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]

In [2]:
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

import logging
import pathlib, os

from typing import List
import requests
import pandas as pd
from haystack import Document
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import RAGenerator, DensePassageRetriever
from haystack.utils import fetch_archive_from_http

In [5]:
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()],
                    force=True)

dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(os.getcwd(), "datasets")
data_path = util.download_and_unzip(url, out_dir)
print("Dataset downloaded here: {}".format(data_path))

/content/datasets/scifact.zip:   0%|          | 0.00/2.69M [00:00<?, ?iB/s]

Dataset downloaded here: /content/datasets/scifact


In [6]:
from beir.datasets.data_loader import GenericDataLoader

data_path = "datasets/scifact"
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test") # or split = "train" or "dev"

  0%|          | 0/5183 [00:00<?, ?it/s]

In [7]:
pd_corpus = pd.DataFrame(corpus)
pd_corpus = pd_corpus.transpose()
pd.DataFrame(pd_corpus)

Unnamed: 0,text,title
4983,Alterations of the architecture of cerebral white matter in the developing h...,Microstructural development of human newborn cerebral white matter assessed ...
5836,Myelodysplastic syndromes (MDS) are age-dependent stem cell malignancies tha...,Induction of myelodysplasia by myeloid-derived suppressor cells.
7912,ID elements are short interspersed elements (SINEs) found in high copy numbe...,"BC1 RNA, the transcript from a master gene for ID element amplification, is ..."
18670,DNA methylation plays an important role in biological processes in human hea...,The DNA Methylome of Human Peripheral Blood Mononuclear Cells
19238,Two human Golli (for gene expressed in the oligodendrocyte lineage)-MBP (for...,The human myelin basic protein gene is included within a 179-kilobase transc...
...,...,...
195689316,BACKGROUND The main associations of body-mass index (BMI) with overall and c...,Body-mass index and cause-specific mortality in 900 000 adults: collaborativ...
195689757,A key aberrant biological difference between tumor cells and normal differen...,Targeting metabolic remodeling in glioblastoma multiforme.
196664003,A signaling pathway transmits information from an upstream system to downstr...,Signaling architectures that transmit unidirectional information despite ret...
198133135,AIMS Trabecular bone score (TBS) is a surrogate indicator of bone microarchi...,"Association between pre-diabetes, type 2 diabetes and trabecular bone score:..."


In [8]:
pd_queries = pd.DataFrame(queries,queries.items())
pd.DataFrame(pd_queries)

Unnamed: 0,Unnamed: 1,1,3,5,13,36,42,48,49,50,51,...,1359,1362,1363,1368,1370,1379,1382,1385,1389,1395
1,0-dimensional biomaterials show inductive properties.,0-dimensional biomaterials show inductive properties.,"1,000 genomes project enables mapping of genetic sequence variation consisti...",1/2000 in UK have abnormal PrP positivity.,5% of perinatal mortality is due to low birth weight.,A deficiency of vitamin B12 increases blood levels of homocysteine.,A high microerythrocyte count raises vulnerability to severe anemia in homoz...,"A total of 1,000 people in the UK are asymptomatic carriers of vCJD infection.",ADAR1 binds to Dicer to cleave pre-miRNA.,AIRE is expressed in some skin tumors.,ALDH1 expression is associated with better breast cancer outcomes.,...,Varenicline monotherapy is more effective after 12 weeks of treatment compar...,Venules have a larger lumen diameter than arterioles.,Venules have a thinner or absent smooth layer compared to arterioles.,Vitamin D deficiency effects the term of delivery.,Vitamin D deficiency is unrelated to birth weight.,Women with a higher birth weight are more likely to develop breast cancer la...,aPKCz causes tumour enhancement by affecting glutamine metabolism.,cSMAC formation enhances weak ligand signalling.,mTORC2 regulates intracellular cysteine levels through xCT inhibition.,p16INK4A accumulation is linked to an abnormal wound response caused by the...
3,"1,000 genomes project enables mapping of genetic sequence variation consisting of rare variants with larger penetrance effects than common variants.",0-dimensional biomaterials show inductive properties.,"1,000 genomes project enables mapping of genetic sequence variation consisti...",1/2000 in UK have abnormal PrP positivity.,5% of perinatal mortality is due to low birth weight.,A deficiency of vitamin B12 increases blood levels of homocysteine.,A high microerythrocyte count raises vulnerability to severe anemia in homoz...,"A total of 1,000 people in the UK are asymptomatic carriers of vCJD infection.",ADAR1 binds to Dicer to cleave pre-miRNA.,AIRE is expressed in some skin tumors.,ALDH1 expression is associated with better breast cancer outcomes.,...,Varenicline monotherapy is more effective after 12 weeks of treatment compar...,Venules have a larger lumen diameter than arterioles.,Venules have a thinner or absent smooth layer compared to arterioles.,Vitamin D deficiency effects the term of delivery.,Vitamin D deficiency is unrelated to birth weight.,Women with a higher birth weight are more likely to develop breast cancer la...,aPKCz causes tumour enhancement by affecting glutamine metabolism.,cSMAC formation enhances weak ligand signalling.,mTORC2 regulates intracellular cysteine levels through xCT inhibition.,p16INK4A accumulation is linked to an abnormal wound response caused by the...
5,1/2000 in UK have abnormal PrP positivity.,0-dimensional biomaterials show inductive properties.,"1,000 genomes project enables mapping of genetic sequence variation consisti...",1/2000 in UK have abnormal PrP positivity.,5% of perinatal mortality is due to low birth weight.,A deficiency of vitamin B12 increases blood levels of homocysteine.,A high microerythrocyte count raises vulnerability to severe anemia in homoz...,"A total of 1,000 people in the UK are asymptomatic carriers of vCJD infection.",ADAR1 binds to Dicer to cleave pre-miRNA.,AIRE is expressed in some skin tumors.,ALDH1 expression is associated with better breast cancer outcomes.,...,Varenicline monotherapy is more effective after 12 weeks of treatment compar...,Venules have a larger lumen diameter than arterioles.,Venules have a thinner or absent smooth layer compared to arterioles.,Vitamin D deficiency effects the term of delivery.,Vitamin D deficiency is unrelated to birth weight.,Women with a higher birth weight are more likely to develop breast cancer la...,aPKCz causes tumour enhancement by affecting glutamine metabolism.,cSMAC formation enhances weak ligand signalling.,mTORC2 regulates intracellular cysteine levels through xCT inhibition.,p16INK4A accumulation is linked to an abnormal wound response caused by the...
13,5% of perinatal mortality is due to low birth weight.,0-dimensional biomaterials show inductive properties.,"1,000 genomes project enables mapping of genetic sequence variation consisti...",1/2000 in UK have abnormal PrP positivity.,5% of perinatal mortality is due to low birth weight.,A deficiency of vitamin B12 increases blood levels of homocysteine.,A high microerythrocyte count raises vulnerability to severe anemia in homoz...,"A total of 1,000 people in the UK are asymptomatic carriers of vCJD infection.",ADAR1 binds to Dicer to cleave pre-miRNA.,AIRE is expressed in some skin tumors.,ALDH1 expression is associated with better breast cancer outcomes.,...,Varenicline monotherapy is more effective after 12 weeks of treatment compar...,Venules have a larger lumen diameter than arterioles.,Venules have a thinner or absent smooth layer compared to arterioles.,Vitamin D deficiency effects the term of delivery.,Vitamin D deficiency is unrelated to birth weight.,Women with a higher birth weight are more likely to develop breast cancer la...,aPKCz causes tumour enhancement by affecting glutamine metabolism.,cSMAC formation enhances weak ligand signalling.,mTORC2 regulates intracellular cysteine levels through xCT inhibition.,p16INK4A accumulation is linked to an abnormal wound response caused by the...
36,A deficiency of vitamin B12 increases blood levels of homocysteine.,0-dimensional biomaterials show inductive properties.,"1,000 genomes project enables mapping of genetic sequence variation consisti...",1/2000 in UK have abnormal PrP positivity.,5% of perinatal mortality is due to low birth weight.,A deficiency of vitamin B12 increases blood levels of homocysteine.,A high microerythrocyte count raises vulnerability to severe anemia in homoz...,"A total of 1,000 people in the UK are asymptomatic carriers of vCJD infection.",ADAR1 binds to Dicer to cleave pre-miRNA.,AIRE is expressed in some skin tumors.,ALDH1 expression is associated with better breast cancer outcomes.,...,Varenicline monotherapy is more effective after 12 weeks of treatment compar...,Venules have a larger lumen diameter than arterioles.,Venules have a thinner or absent smooth layer compared to arterioles.,Vitamin D deficiency effects the term of delivery.,Vitamin D deficiency is unrelated to birth weight.,Women with a higher birth weight are more likely to develop breast cancer la...,aPKCz causes tumour enhancement by affecting glutamine metabolism.,cSMAC formation enhances weak ligand signalling.,mTORC2 regulates intracellular cysteine levels through xCT inhibition.,p16INK4A accumulation is linked to an abnormal wound response caused by the...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1379,Women with a higher birth weight are more likely to develop breast cancer later in life.,0-dimensional biomaterials show inductive properties.,"1,000 genomes project enables mapping of genetic sequence variation consisti...",1/2000 in UK have abnormal PrP positivity.,5% of perinatal mortality is due to low birth weight.,A deficiency of vitamin B12 increases blood levels of homocysteine.,A high microerythrocyte count raises vulnerability to severe anemia in homoz...,"A total of 1,000 people in the UK are asymptomatic carriers of vCJD infection.",ADAR1 binds to Dicer to cleave pre-miRNA.,AIRE is expressed in some skin tumors.,ALDH1 expression is associated with better breast cancer outcomes.,...,Varenicline monotherapy is more effective after 12 weeks of treatment compar...,Venules have a larger lumen diameter than arterioles.,Venules have a thinner or absent smooth layer compared to arterioles.,Vitamin D deficiency effects the term of delivery.,Vitamin D deficiency is unrelated to birth weight.,Women with a higher birth weight are more likely to develop breast cancer la...,aPKCz causes tumour enhancement by affecting glutamine metabolism.,cSMAC formation enhances weak ligand signalling.,mTORC2 regulates intracellular cysteine levels through xCT inhibition.,p16INK4A accumulation is linked to an abnormal wound response caused by the...
1382,aPKCz causes tumour enhancement by affecting glutamine metabolism.,0-dimensional biomaterials show inductive properties.,"1,000 genomes project enables mapping of genetic sequence variation consisti...",1/2000 in UK have abnormal PrP positivity.,5% of perinatal mortality is due to low birth weight.,A deficiency of vitamin B12 increases blood levels of homocysteine.,A high microerythrocyte count raises vulnerability to severe anemia in homoz...,"A total of 1,000 people in the UK are asymptomatic carriers of vCJD infection.",ADAR1 binds to Dicer to cleave pre-miRNA.,AIRE is expressed in some skin tumors.,ALDH1 expression is associated with better breast cancer outcomes.,...,Varenicline monotherapy is more effective after 12 weeks of treatment compar...,Venules have a larger lumen diameter than arterioles.,Venules have a thinner or absent smooth layer compared to arterioles.,Vitamin D deficiency effects the term of delivery.,Vitamin D deficiency is unrelated to birth weight.,Women with a higher birth weight are more likely to develop breast cancer la...,aPKCz causes tumour enhancement by affecting glutamine metabolism.,cSMAC formation enhances weak ligand signalling.,mTORC2 regulates intracellular cysteine levels through xCT inhibition.,p16INK4A accumulation is linked to an abnormal wound response caused by the...
1385,cSMAC formation enhances weak ligand signalling.,0-dimensional biomaterials show inductive properties.,"1,000 genomes project enables mapping of genetic sequence variation consisti...",1/2000 in UK have abnormal PrP positivity.,5% of perinatal mortality is due to low birth weight.,A deficiency of vitamin B12 increases blood levels of homocysteine.,A high microerythrocyte count raises vulnerability to severe anemia in homoz...,"A total of 1,000 people in the UK are asymptomatic carriers of vCJD infection.",ADAR1 binds to Dicer to cleave pre-miRNA.,AIRE is expressed in some skin tumors.,ALDH1 expression is associated with better breast cancer outcomes.,...,Varenicline monotherapy is more effective after 12 weeks of treatment compar...,Venules have a larger lumen diameter than arterioles.,Venules have a thinner or absent smooth layer compared to arterioles.,Vitamin D deficiency effects the term of delivery.,Vitamin D deficiency is unrelated to birth weight.,Women with a higher birth weight are more likely to develop breast cancer la...,aPKCz causes tumour enhancement by affecting glutamine metabolism.,cSMAC formation enhances weak ligand signalling.,mTORC2 regulates intracellular cysteine levels through xCT inhibition.,p16INK4A accumulation is linked to an abnormal wound response caused by the...
1389,mTORC2 regulates intracellular cysteine levels through xCT inhibition.,0-dimensional biomaterials show inductive properties.,"1,000 genomes project enables mapping of genetic sequence variation consisti...",1/2000 in UK have abnormal PrP positivity.,5% of perinatal mortality is due to low birth weight.,A deficiency of vitamin B12 increases blood levels of homocysteine.,A high microerythrocyte count raises vulnerability to severe anemia in homoz...,"A total of 1,000 people in the UK are asymptomatic carriers of vCJD infection.",ADAR1 binds to Dicer to cleave pre-miRNA.,AIRE is expressed in some skin tumors.,ALDH1 expression is associated with better breast cancer outcomes.,...,Varenicline monotherapy is more effective after 12 weeks of treatment compar...,Venules have a larger lumen diameter than arterioles.,Venules have a thinner or absent smooth layer compared to arterioles.,Vitamin D deficiency effects the term of delivery.,Vitamin D deficiency is unrelated to birth weight.,Women with a higher birth weight are more likely to develop breast cancer la...,aPKCz causes tumour enhancement by affecting glutamine metabolism.,cSMAC formation enhances weak ligand signalling.,mTORC2 regulates intracellular cysteine levels through xCT inhibition.,p16INK4A accumulation is linked to an abnormal wound response caused by the...


In [9]:
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval import models
from beir.retrieval.search.dense import DenseRetrievalExactSearch 

#### Dense Retrieval using SBERT (Sentence-BERT) ####
#### Provide any pretrained sentence-transformers model
#### The model was fine-tuned using cosine-similarity.
#### Complete list - https://www.sbert.net/docs/pretrained_models.html

model = DenseRetrievalExactSearch(models.SentenceBERT("msmarco-distilbert-base-v3"), batch_size=128)
retriever = EvaluateRetrieval(model, score_function="cos_sim")

#### Retrieve dense results (format of results is identical to qrels)
results = retriever.retrieve(corpus, queries)

Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.71k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/545 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/499 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/41 [00:00<?, ?it/s]

In [10]:
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
for k in retriever.k_values:
    ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
    print("Retriever evaluation for k in: {}".format(k))
    print(ndcg, _map, recall, precision)

Retriever evaluation for k in: 1
{'NDCG@1': 0.42333, 'NDCG@3': 0.48416, 'NDCG@5': 0.51037, 'NDCG@10': 0.53789, 'NDCG@100': 0.57592, 'NDCG@1000': 0.59134} {'MAP@1': 0.39944, 'MAP@3': 0.45935, 'MAP@5': 0.47679, 'MAP@10': 0.48894, 'MAP@100': 0.49742, 'MAP@1000': 0.49797} {'Recall@1': 0.39944, 'Recall@3': 0.52561, 'Recall@5': 0.58872, 'Recall@10': 0.67233, 'Recall@100': 0.846, 'Recall@1000': 0.96833} {'P@1': 0.42333, 'P@3': 0.19333, 'P@5': 0.13333, 'P@10': 0.07567, 'P@100': 0.0096, 'P@1000': 0.0011}
Retriever evaluation for k in: 3
{'NDCG@1': 0.42333, 'NDCG@3': 0.48416, 'NDCG@5': 0.51037, 'NDCG@10': 0.53789, 'NDCG@100': 0.57592, 'NDCG@1000': 0.59134} {'MAP@1': 0.39944, 'MAP@3': 0.45935, 'MAP@5': 0.47679, 'MAP@10': 0.48894, 'MAP@100': 0.49742, 'MAP@1000': 0.49797} {'Recall@1': 0.39944, 'Recall@3': 0.52561, 'Recall@5': 0.58872, 'Recall@10': 0.67233, 'Recall@100': 0.846, 'Recall@1000': 0.96833} {'P@1': 0.42333, 'P@3': 0.19333, 'P@5': 0.13333, 'P@10': 0.07567, 'P@100': 0.0096, 'P@1000': 0.0011

In [11]:
import random

#### Print top-k documents retrieved ####
top_k = 3

query_id, ranking_scores = random.choice(list(results.items()))
scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
print("Query : %s\n" % queries[query_id])

for rank in range(top_k):
    doc_id = scores_sorted[rank][0]
    # Format: Rank x: ID [Title] Body
    print("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))

Query : Localization of PIN1 in the Arabidopsis embryo does not require VPS9a

Rank 1: 435529 [Uridylation of miRNAs by HEN1 SUPPRESSOR1 in Arabidopsis] - HEN1-mediated 2'-O-methylation has been shown to be a key mechanism to protect plant microRNAs (miRNAs) and small interfering RNAs (siRNAs) as well as animal piwi-interacting RNAs (piRNAs) from degradation and 3' terminal uridylation [1-8]. However, enzymes uridylating unmethylated miRNAs, siRNAs, or piRNAs in hen1 are unknown. In this study, a genetic screen identified a second-site mutation hen1 suppressor1-2 (heso1-2) that partially suppresses the morphological phenotypes of the hypomorphic hen1-2 allele and the null hen1-1 allele in Arabidopsis. HESO1 encodes a terminal nucleotidyl transferase that prefers to add untemplated uridine to the 3' end of RNA, which is completely abolished by 2'-O-methylation. heso1-2 affects the profile of u-tailed miRNAs and siRNAs and increases the abundance of truncated and/or normal sized ones in 

In [None]:
!pip install transformers-interpret

In [13]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer

In [14]:
model_name = "sentence-transformers/msmarco-distilbert-base-v3"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading:   0%|          | 0.00/545 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/253M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v3 and are newly initialized: ['classifier.bias', 'pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading:   0%|          | 0.00/499 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [15]:
print (tokenizer, model)

PreTrainedTokenizerFast(name_or_path='sentence-transformers/msmarco-distilbert-base-v3', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}) DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
         

In [28]:
queries[query_id]

'Localization of PIN1 in the Arabidopsis embryo does not require VPS9a'

In [20]:
sample_text = queries[query_id]

In [21]:
multiclass_explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)
word_attributions = multiclass_explainer(text=sample_text)

In [22]:
word_attributions

[('[CLS]', 0.0),
 ('local', -0.0748019458360897),
 ('##ization', 0.16485497855988732),
 ('of', -0.004193454024853225),
 ('pin', 0.214898223388914),
 ('##1', 0.02513307525741778),
 ('in', 0.1972248721071557),
 ('the', 0.5857358648085718),
 ('arab', -0.5130440398402036),
 ('##ido', -0.11250133767137623),
 ('##psis', -0.10866664984771567),
 ('embryo', 0.33954034114490056),
 ('does', 0.10812975278236774),
 ('not', -0.008703144869958664),
 ('require', 0.10485826049915359),
 ('vp', 0.2667948868085848),
 ('##s', 0.058195930583887535),
 ('##9', 0.11199179593180941),
 ('##a', 0.15992042824091618),
 ('[SEP]', 0.0)]

In [23]:
html = multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,LABEL_0 (0.55),LABEL_0,1.52,[CLS] local ##ization of pin ##1 in the arab ##ido ##psis embryo does not require vp ##s ##9 ##a [SEP]
,,,,


#Zero shot classification 

In [24]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import ZeroShotClassificationExplainer

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")

zero_shot_explainer = ZeroShotClassificationExplainer(model, tokenizer)

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

In [29]:
word_attributions = zero_shot_explainer(sample_text, labels = ["finance", "technology", "sports", "medicine"])

In [30]:
word_attributions

{'finance': [('<s>', 0.0),
  ('Local', 0.0),
  ('ization', -0.05434263928631245),
  ('of', -0.06337591182237089),
  ('PIN', -0.01799789605018157),
  ('1', -0.006152274409481252),
  ('in', 0.03351544192816125),
  ('the', 0.00120762570092583),
  ('Arab', -0.06179183000413382),
  ('id', 0.14604364016660773),
  ('opsis', -0.1154736391064592),
  ('embryo', 0.33369005466193347),
  ('does', -0.03109996719414142),
  ('not', -0.39139438698395024),
  ('require', 0.6396668212763092),
  ('V', -0.3668298757008845),
  ('PS', -0.0902121618902987),
  ('9', -0.36834882405868197)],
 'medicine': [('<s>', 0.0),
  ('Local', 0.0),
  ('ization', -0.1060215686058711),
  ('of', -0.17919097387084323),
  ('PIN', -0.1891141801139749),
  ('1', -0.18394102025438624),
  ('in', -0.018289478420385964),
  ('the', -0.24287508169994632),
  ('Arab', -0.0014543862446121793),
  ('id', 0.04784564053864182),
  ('opsis', -0.07512869895880672),
  ('embryo', -0.08214453128030345),
  ('does', -0.1927287216107089),
  ('not', 0.244

In [31]:
zero_shot_explainer.visualize("zero_shot.html")

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
finance,finance (0.08),finance,-0.41,#s Local ization of PIN 1 in the Arab id opsis embryo does not require V PS 9
,,,,
technology,technology (0.52),technology,2.11,#s Local ization of PIN 1 in the Arab id opsis embryo does not require V PS 9
,,,,
sports,sports (0.09),sports,-1.01,#s Local ization of PIN 1 in the Arab id opsis embryo does not require V PS 9
,,,,
medicine,medicine (0.31),medicine,-2.25,#s Local ization of PIN 1 in the Arab id opsis embryo does not require V PS 9
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
finance,finance (0.08),finance,-0.41,#s Local ization of PIN 1 in the Arab id opsis embryo does not require V PS 9
,,,,
technology,technology (0.52),technology,2.11,#s Local ization of PIN 1 in the Arab id opsis embryo does not require V PS 9
,,,,
sports,sports (0.09),sports,-1.01,#s Local ization of PIN 1 in the Arab id opsis embryo does not require V PS 9
,,,,
medicine,medicine (0.31),medicine,-2.25,#s Local ization of PIN 1 in the Arab id opsis embryo does not require V PS 9
,,,,


#TODO

## integrated gradients
