## gen 200 sample

In [1]:
import os
import torch
import pandas as pd
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "./results/lr5e-6/checkpoint-270"

model = AutoModelForCausalLM.from_pretrained(model_name, use_cache = False, device_map = "cuda:0", dtype = torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.42s/it]


In [3]:
gen_ds = load_dataset("json", data_files = "data/dpo_resource.json")["train"]
output_path = os.path.join("data", f"dpo_dataset_generated.csv")

In [4]:
results = []

In [5]:
for idx in tqdm(range(gen_ds.num_rows)):
    ith_inference = {"subject_id" : gen_ds[idx]["subject_id"]}
    ith_inference["text"] = gen_ds[idx]["messages"][1]["content"]

    for i in range(5):
        input_ids = tokenizer.apply_chat_template(
                        gen_ds[idx]["messages"],
                        add_generation_prompt=True,
                        return_tensors="pt"
        ).to(model.device)

        terminators = [tokenizer.eos_token_id]

        outputs = model.generate(
            input_ids,
            max_new_tokens=1024,
            eos_token_id=terminators,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            temperature=1.0,
            top_p = 0.95
        )

        response = outputs[0][input_ids.shape[-1]:]
        generation = tokenizer.decode(response, skip_special_tokens=True)
        ith_inference[f"Gen_{i}"] = generation

    results.append(ith_inference)

  0%|          | 0/200 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|██████████| 200/200 [1:45:02<00:00, 31.51s/it]


In [7]:
pd.DataFrame(results).to_csv("data/for_dpo_5_gen.csv")

In [8]:
pd.DataFrame(results)


Unnamed: 0,subject_id,text,Gen_0,Gen_1,Gen_2,Gen_3,Gen_4
0,14711614,\r\nName: ___ Unit No: ...,A female patient presented with vertigo; on ph...,A female patient presented with vertigo; on ph...,A female patient presented with vertigo; on ph...,A female patient presented with vertigo; on ph...,A female patient presented with vertigo; on ph...
1,17622322,\r\nName: ___ Unit No: __...,"A male patient presented with ""R knee pain""; o...",A male patient presented with Right knee pain;...,A male patient presented with right knee pain;...,A male patient presented with right knee pain;...,A male patient presented with a complaint of r...
2,10063762,\r\nName: ___ Unit No: ___\...,A female patient presented with dizziness; on ...,A 45-year-old woman presented with dizziness; ...,A female patient presented with dizziness; on ...,A female patient presented with dizziness; on ...,A 40-year-old woman presented with a 3-day his...
3,15732392,\r\nName: ___ Unit No: _...,A male patient presented with hematemesis and ...,A male patient presented with hematemesis and ...,A male patient presented with hematemesis and ...,A male patient presented with chief complaints...,A male patient presented with hematemesis and ...
4,10172264,\r\nName: ___ Unit No: ___...,A 36-year-old female presented with right calf...,A female patient presented with right calf pai...,A female patient presented with a chief compla...,A female patient presented with acute right ca...,A 35-year-old female presents with right calf ...
...,...,...,...,...,...,...,...
195,18399891,\r\nName: ___ Unit No:...,A male patient presented with left cerebellar ...,A male patient presented with a large left cer...,A male patient presented with cerebellar hemor...,A male patient presented with left cerebellar ...,A 83-year-old male presented with a cerebellar...
196,11902324,\r\nName: ___ Unit No: ___\...,"A male patient presented with abdominal pain, ...","A male patient presented with abdominal pain, ...","A male patient presented with epigastric pain,...","A male patient presented with abdominal pain, ...","A male patient presented with abdominal pain, ..."
197,13826876,\r\nName: ___ Unit No:...,A male patient presented with headache; on phy...,A male patient presented with headache; on phy...,A male patient presented with a headache; on p...,A male patient presented with headache; on phy...,A male patient presented with headache; on phy...
198,13901886,\r\nName: ___ Unit No: ...,A female patient presented with shortness of b...,A female patient presented with shortness of b...,A female patient presented with shortness of b...,A female patient presented with shortness of b...,A female patient presented with shortness of b...
