In [1]:
import os
import json

import torch
import tqdm

from transformers import T5Tokenizer, T5ForConditionalGeneration
from datasets import load_dataset


In [2]:
data = load_dataset("csv", data_files="../test_split.csv")
subset = data['train']  # TODO changer en test dans le ds original
print(subset)

Generating train split: 1578 examples [00:00, 4211.91 examples/s]


Dataset({
    features: ['id', 'topic_id', 'statement_medical', 'statement_pol', 'premise', 'NCT_title', 'NCT_id', 'label'],
    num_rows: 1578
})


In [4]:
model_path = "/lustre/fswork/projects/rech/hjp/ulj12fo/flan-t5-xxl" #"/lustre/fsn1/projects/rech/hjp/ulj12fo/flan-t5-base" 
tokenizer = T5Tokenizer.from_pretrained(model_path)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
model = T5ForConditionalGeneration.from_pretrained(model_path, device_map="auto")

Loading checkpoint shards: 100%|##########| 5/5 [02:04<00:00, 24.86s/it]


In [6]:
def get_input_text(premise, hypothesis):
    options_prefix = "OPTIONS:\n- "
    separator = "\n- "
    options_ = options_prefix + f"{separator}".join(["Entailment","Contradiction"])
    return f"{premise} \n Question: Does the previous eligibility criteria imply that the following patient can participate to the trial?\n {hypothesis}? {options_}"
    # return f"Classification: {premise} \n Question: Does this imply that {hypothesis}? Entailment or Contradiction?Answer:"

In [7]:
samples = []
# TODO adapt in function of the med or POL (l.5)
for instance in subset:
    premise = instance['premise']
    sentence = f"Eligibility criteria of the trial are:\n {premise}"
    input_text = get_input_text(sentence, instance['statement_medical'])
    # temp = {"text":input_text, "label":sample['label']}
    temp = {"text":input_text, "label":0}
    print(input_text)
    samples.append(temp)

Eligibility criteria of the trial are:
 Inclusion Criteria:

          -  women with PUL

        Exclusion Criteria:
Female
Accepts Healthy Volunteers

 
 Question: Does the previous eligibility criteria imply that the following patient can participate to the trial?
 A 32-year-old woman comes to the hospital with vaginal spotting.  Her last menstrual period was 10 weeks ago. She has regular menses lasting for 6 days and repeating every 29 days. Medical history is significant for appendectomy and several complicated UTIs. She has multiple male partners, and she is inconsistent with using barrier contraceptives. Vital signs are normal.  Serum β-hCG level is 1800 mIU/mL, and a repeat level after 2 days shows an abnormal rise to 2100 mIU/mL.  Pelvic ultrasound reveals a thin endometrium with no gestational sac in the uterus.? OPTIONS:
- Entailment
- Contradiction
Eligibility criteria of the trial are:
 Inclusion Criteria:

          -  Postmenopausal women and men referred for bone densit

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [8]:
labels = []
pred = []
with torch.inference_mode():
    for sample in tqdm.tqdm(samples):
        labels.append(sample["label"])
        input_ids = tokenizer(sample["text"], return_tensors="pt",).input_ids.to("cuda")
        outputs = model.generate(input_ids)
        pred.append(tokenizer.decode(outputs[0]))

  0%|          | 3/1578 [00:24<2:58:38,  6.81s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (850 > 512). Running this sequence through the model will result in indexing errors
100%|##########| 1578/1578 [28:40<00:00,  1.09s/it] 


In [9]:
pred

['<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Ent

In [10]:
pred = [p[5:][:-4].strip() for p in pred]

In [17]:
set(pred)
from collections import Counter
Counter(pred)

Counter({'Entailment': 982, 'Contradiction': 596})

In [19]:
prediction_dict = {}
for _id,pred_x in zip(data['train']['id'], pred):
    prediction_dict[str(_id)] = {"Prediction":pred_x}

In [None]:
# from sklearn.metrics import f1_score
# uuid_list = list(prediction_dict.keys())
# results_pred = []
# gold_labels = []
# for i in range(len(uuid_list)):
#     if prediction_dict[uuid_list[i]]["Prediction"] in ["Entailment", "Yes"]:
#         results_pred.append(1)
#     else:
#         results_pred.append(0)
#     if data[uuid_list[i]]["Label"] in ["Entailment", "No"]:
#         gold_labels.append(1)
#     else:
#         gold_labels.append(0)
# f1_score(gold_labels,results_pred)

In [20]:
json.dump(prediction_dict, open("results_flan_t5_xxl_zs_med.json",'w'),indent=4)

In [16]:
prediction_dict

{'{\'id\': 621, \'topic_id\': 2, \'statement_medical\': \'A 32-year-old woman comes to the hospital with vaginal spotting.  Her last menstrual period was 10 weeks ago. She has regular menses lasting for 6 days and repeating every 29 days. Medical history is significant for appendectomy and several complicated UTIs. She has multiple male partners, and she is inconsistent with using barrier contraceptives. Vital signs are normal.  Serum β-hCG level is 1800 mIU/mL, and a repeat level after 2 days shows an abnormal rise to 2100 mIU/mL.  Pelvic ultrasound reveals a thin endometrium with no gestational sac in the uterus.\', \'statement_pol\': "I just turned 32 and last morning I woke up with strange blood stains on my underwear. My last periods were more than 2 months ago, which is unusual for me because I used to have regular periods lasting for 6 days every 29 days, more or less. I had several UTIs in the past. I also had appendicitis. I\'m currently seeing several men and, to be honest, s