In [1]:
%pip install torch
%pip install transformers

Collecting torch
  Downloading torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl.metadata (25 kB)
Collecting filelock (from torch)
  Using cached filelock-3.13.1-py3-none-any.whl.metadata (2.8 kB)
Collecting sympy (from torch)
  Using cached sympy-1.12-py3-none-any.whl (5.7 MB)
Collecting networkx (from torch)
  Using cached networkx-3.2.1-py3-none-any.whl.metadata (5.2 kB)
Collecting jinja2 (from torch)
  Downloading Jinja2-3.1.2-py3-none-any.whl (133 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.1/133.1 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting fsspec (from torch)
  Downloading fsspec-2023.12.2-py3-none-any.whl.metadata (6.8 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Downloading MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl.metadata (3.0 kB)
Collecting mpmath>=0.19 (from sympy->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl (536 kB)
Downloading torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl (147.0 MB)
[

In [4]:
%pip install sentencepiece

Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-macosx_10_9_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m0m
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.99
Note: you may need to restart the kernel to use updated packages.


In [1]:
import os
import json
import torch
import tqdm
from transformers import T5Tokenizer, T5ForConditionalGeneration

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl")

tokenizer.json: 100%|██████████| 2.42M/2.42M [00:01<00:00, 1.67MB/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
model.safetensors.index.json: 100%|██████████| 53.0k/53.0k [00:00<00:00, 271kB/s]
model-00001-of-00002.safetensors: 100%|██████████| 9.45G/9.45G [43:37<00:00, 3.61MB/s]
model-00002-of-00002.safetensors: 100%|██████████| 1.95G/1.95G [09:20<00:00, 3.48MB/s]
Downloading shards: 100%|██████████| 2/2 [53:01<00:00, 1590.84s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:35<00:00, 17.87s/it]


In [4]:
# function to add to JSON
def write_json(id, new_data, filename='./flan-t5.json'):
    with open(filename,'r+') as file:
        # First we load existing data into a dict.
        file_data = json.load(file)
        # Join new_data with file_data inside emp_details
        file_data[id] = new_data
        # Sets file's current position at offset.
        file.seek(0)
        # convert back to json.
        json.dump(file_data, file, indent = 4)

In [5]:
def get_input_text(premise, hypothesis):
    separator = "\n- "
    # options_ = options_prefix + f"{separator}".join(["])
    return f"{premise} \n Question: Does this imply that the hypothesis '{hypothesis}' is correct or not?"

In [6]:
KEYS = json.load(open('./gemini-labels-train.json')).keys()

CONCLUSIONS = json.load(open('./gemini-conclusion.json'))

data = json.load(open('./training_data/train.json')) 
data_expanded = []
for _id, value in data.items():
    temp = {}
    temp["id"] = _id
    temp["statement"] = value["Statement"]
    temp["label"] = value["Label"]
    temp["conclusion"] = CONCLUSIONS[_id]
    data_expanded.append(temp)

In [7]:
labels = []
pred = []
with torch.inference_mode():
    for sample in tqdm.tqdm(data_expanded):
        labels.append(sample["label"])
        conclusion = get_input_text(sample["conclusion"], sample['statement'])
        input_ids = tokenizer(conclusion, return_tensors="pt",).input_ids
        outputs = model.generate(input_ids)
        pred.append(tokenizer.decode(outputs[0]))

 97%|█████████▋| 1650/1700 [1:18:15<02:03,  2.47s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (801 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 1700/1700 [1:21:03<00:00,  2.86s/it]


In [9]:
preds = [p[5:][:-4].strip() for p in pred]
set(preds)

{'it is not possible to tell', 'no', 'yes'}

In [10]:
from collections import Counter
Counter(preds)

Counter({'no': 1062, 'yes': 563, 'it is not possible to tell': 75})

In [11]:
prediction_dict = {}
for _id,pred_x in zip(data, pred):
    if pred_x == "it is not possible to tell" or pred_x == "no":
        predi = "Contradiction"
    else:
        predi = "Entailment"
    prediction_dict[str(_id)] = {"Prediction":predi}

In [12]:
data = json.load(open('./training_data/train.json'))

from sklearn.metrics import f1_score
uuid_list = list(prediction_dict.keys())
results_pred = []
gold_labels = []
for i in range(len(uuid_list)):
    if prediction_dict[uuid_list[i]]["Prediction"] in ["Entailment", "Yes"]:
        results_pred.append(1)
    else:
        results_pred.append(0)
    if data[uuid_list[i]]["Label"] in ["Entailment", "Yes"]:
        gold_labels.append(1)
    else:
        gold_labels.append(0)
f1_score(gold_labels,results_pred)

0.6666666666666666

In [13]:
from sklearn.metrics import classification_report
print(classification_report(gold_labels, results_pred))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00       850
           1       0.50      1.00      0.67       850

    accuracy                           0.50      1700
   macro avg       0.25      0.50      0.33      1700
weighted avg       0.25      0.50      0.33      1700



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
