In [23]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [24]:
import torch
from transformers import pipeline, AutoTokenizer, AutoModel
import pandas as pd
import json
from tqdm import tqdm

Helper functions

In [25]:
def batched(lst, batch_size):
    """
    Convert a list of items to a list of batched items with the specified batch size.
    """
    return [lst[i:i+batch_size] for i in range(0, len(lst), batch_size)]


Loading zshot model

In [26]:
model = 'facebook/bart-large-mnli'
task = 'zero-shot-classification'
clf = pipeline(task, model, device=0)
print(clf.device)

cuda:0


Checking zshot model

In [27]:
clf('This is a course about Python Programming', multi_label=True, candidate_labels=['Computer Education', 'Physics Education', 'Computer Networks', 'Python Snake'])

{'sequence': 'This is a course about Python Programming',
 'labels': ['Computer Education',
  'Python Snake',
  'Computer Networks',
  'Physics Education'],
 'scores': [0.8667939901351929,
  0.05108256638050079,
  0.00023025953851174563,
  7.492794247809798e-05]}

Loading train portion of data for 20-class and 54-class z-shot

In [28]:
data20 = pd.read_csv('data/prepared/train-zshot-20.tsv', delimiter='\t')
data54 = pd.read_csv('data/prepared/train-zshot-54.tsv', delimiter='\t')
labels20 = list(data20.columns[4:])
labels54 = list(data54.columns[4:])

with open('data/value-categories.json', 'r') as file:
    js = json.load(file)
    vals = {key:str.join(' or ', js[key].keys()) for key in js.keys()}

inps = list(data20.apply(lambda x: f'{x.Conclusion} {x.Stance} {x.Premise}', axis=1))

In [29]:
preds20 = []
preds54 = []
for batch in tqdm(batched(inps, 20)):
    p20 = clf(batch, candidate_labels=labels20, multi_label=True)
    p54 = clf(batch, candidate_labels=labels54, multi_label=True)

    preds20.append(p20)
    preds54.append(p54)


100%|██████████| 270/270 [3:10:52<00:00, 42.42s/it]


In [30]:
# Create a dictionary with the two lists
data = {"preds20": preds20, "preds54": preds54}

# Save the dictionary to a JSON file
with open("data.json", "w") as f:
    json.dump(data, f)
