In [1]:
import os
import json

import torch
import tqdm

from transformers import T5Tokenizer, T5ForConditionalGeneration
from datasets import load_dataset


In [2]:
data = load_dataset("csv", data_files="../pleaseee.csv")
subset = data['train']  # TODO changer en test dans le ds original
print(subset)

Generating train split: 1578 examples [00:00, 12632.05 examples/s]


Dataset({
    features: ['id', 'topic_id', 'statement_medical', 'statement_pol', 'premise', 'NCT_title', 'NCT_id', 'label'],
    num_rows: 1578
})


In [3]:
model_path = "/lustre/fswork/projects/rech/hjp/ulj12fo/flan-t5-xxl" #"/lustre/fsn1/projects/rech/hjp/ulj12fo/flan-t5-base" 
tokenizer = T5Tokenizer.from_pretrained(model_path)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
model = T5ForConditionalGeneration.from_pretrained(model_path, device_map="auto")

Loading checkpoint shards: 100%|##########| 5/5 [01:08<00:00, 13.67s/it]


In [5]:
def get_input_text(premise, hypothesis):
    options_prefix = "OPTIONS:\n- "
    separator = "\n- "
    options_ = options_prefix + f"{separator}".join(["Entailment","Contradiction"])
    return f"{premise} \n Question: Imagine that you are a doctor reviewing patients profiles to enroll them for a clinical trial. Does the previous eligibility criteria imply that the following patient can participate to the trial?\n Patient profile:\n {hypothesis}? {options_}"
    # return f"Classification: {premise} \n Question: Does this imply that {hypothesis}? Entailment or Contradiction?Answer:"

In [6]:
samples = []
# TODO adapt in function of the med or POL (l.5)
for instance in subset:
    premise = instance['premise']
    sentence = f"Eligibility criteria of the trial are:\n {premise}"
    input_text = get_input_text(sentence, instance['statement_pol'])
    # temp = {"text":input_text, "label":sample['label']}
    temp = {"text":input_text, "label":0}
    print(input_text)
    samples.append(temp)

Eligibility criteria of the trial are:
 Inclusion Criteria:

          -  women with PUL

        Exclusion Criteria:
Female
Accepts Healthy Volunteers

 
 Question: Imagine that you are a doctor reviewing patients profiles to enroll them for a clinical trial. Does the previous eligibility criteria imply that the following patient can participate to the trial?
 Patient profile:
 I just turned 32 and last morning I woke up with strange blood stains on my underwear. My last periods were more than 2 months ago, which is unusual for me because I used to have regular periods lasting for 6 days every 29 days, more or less. I had several UTIs in the past. I also had appendicitis. I'm currently seeing several men and, to be honest, some of them do struggle to wear a condom. I went to the hospital to check myself up and they told me that my vitals were normal. I also had a blood test on Monday, and my β-hCG level was 1800 mIU/mL, and then on Wednesday, it went up to 2100 mIU/mL. The gynecologis

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [7]:
labels = []
pred = []
with torch.inference_mode():
    for sample in tqdm.tqdm(samples):
        labels.append(sample["label"])
        input_ids = tokenizer(sample["text"], return_tensors="pt",).input_ids.to("cuda")
        outputs = model.generate(input_ids)
        pred.append(tokenizer.decode(outputs[0]))

  0%|          | 3/1578 [00:03<25:49,  1.02it/s]  Token indices sequence length is longer than the specified maximum sequence length for this model (918 > 512). Running this sequence through the model will result in indexing errors
100%|##########| 1578/1578 [14:39<00:00,  1.80it/s]


In [8]:
pred

['<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<p

In [9]:
pred = [p[5:][:-4].strip() for p in pred]

In [10]:
set(pred)
from collections import Counter
Counter(pred)

Counter({'Entailment': 1209, 'Contradiction': 369})

In [11]:
prediction_dict = {}
for _id,pred_x in zip(data['train']['id'], pred):
    prediction_dict[str(_id)] = {"Prediction":pred_x}

In [12]:
# from sklearn.metrics import f1_score
# uuid_list = list(prediction_dict.keys())
# results_pred = []
# gold_labels = []
# for i in range(len(uuid_list)):
#     if prediction_dict[uuid_list[i]]["Prediction"] in ["Entailment", "Yes"]:
#         results_pred.append(1)
#     else:
#         results_pred.append(0)
#     if data[uuid_list[i]]["Label"] in ["Entailment", "No"]:
#         gold_labels.append(1)
#     else:
#         gold_labels.append(0)
# f1_score(gold_labels,results_pred)

In [13]:
json.dump(prediction_dict, open("results_flan_t5_xxl_zs_pol_persona.json",'w'),indent=4)

In [14]:
prediction_dict

{'621': {'Prediction': 'Entailment'},
 '5088': {'Prediction': 'Entailment'},
 '5395': {'Prediction': 'Entailment'},
 '517': {'Prediction': 'Contradiction'},
 '4068': {'Prediction': 'Entailment'},
 '3892': {'Prediction': 'Entailment'},
 '1019': {'Prediction': 'Entailment'},
 '1958': {'Prediction': 'Entailment'},
 '4179': {'Prediction': 'Entailment'},
 '6731': {'Prediction': 'Entailment'},
 '2244': {'Prediction': 'Entailment'},
 '6758': {'Prediction': 'Entailment'},
 '5297': {'Prediction': 'Contradiction'},
 '5077': {'Prediction': 'Contradiction'},
 '4665': {'Prediction': 'Contradiction'},
 '2398': {'Prediction': 'Entailment'},
 '652': {'Prediction': 'Entailment'},
 '5070': {'Prediction': 'Entailment'},
 '2127': {'Prediction': 'Entailment'},
 '3857': {'Prediction': 'Entailment'},
 '218': {'Prediction': 'Entailment'},
 '5976': {'Prediction': 'Contradiction'},
 '355': {'Prediction': 'Entailment'},
 '3775': {'Prediction': 'Entailment'},
 '3021': {'Prediction': 'Contradiction'},
 '538': {'Pr