## gen 200 sample

In [1]:
import os
import torch
import pandas as pd
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "./results/lr5e-6/checkpoint-450"

model = AutoModelForCausalLM.from_pretrained(model_name, use_cache = False, device_map = "cuda:0", dtype = torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

Loading checkpoint shards: 100%|██████████| 4/4 [00:18<00:00,  4.60s/it]


In [3]:
gen_ds = load_dataset("json", data_files = "data/dpo_resource.json")["train"]
output_path = os.path.join("data", f"dpo_dataset_generated.csv")

In [4]:
results = []

In [5]:
idx = 0

ith_inference = {"subject_id" : gen_ds[idx]["subject_id"]}
ith_inference["text"] = gen_ds[idx]["messages"][1]["content"]

for i in range(5):
    input_ids = tokenizer.apply_chat_template(
                    gen_ds[idx]["messages"],
                    add_generation_prompt=True,
                    return_tensors="pt"
    ).to(model.device)

    terminators = [tokenizer.eos_token_id]

    outputs = model.generate(
        input_ids,
        max_new_tokens=1024,
        eos_token_id=terminators,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=1.0,
        top_p = 0.95
    )

    response = outputs[0][input_ids.shape[-1]:]
    generation = tokenizer.decode(response, skip_special_tokens=True)
    ith_inference[f"Gen_{i}"] = generation

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [6]:
ith_inference


{'subject_id': 14711614,
 'text': " \r\nName:  ___                    Unit No:   ___\r\n \r\nAdmission Date:  ___              Discharge Date:   ___\r\n \r\nDate of Birth:  ___             Sex:   F\r\n \r\nService: OBSTETRICS/GYNECOLOGY\r\n \r\nAllergies: \r\nReglan / Compazine / Viramune\r\n \r\nAttending: ___\r\n \r\nChief Complaint:\r\nVeritgo\r\n \r\nMajor Surgical or Invasive Procedure:\r\nlumbar puncture\r\n\r\n \r\nHistory of Present Illness:\r\n___ yo G3P1 at 19w3d WGA presents to gyn triage with persistent \r\nvertigo x 2 days. Worse with changes in position and upon \r\nlooking down, but also present when closes eyes and head is at \r\nrest. Had nausea, but no emesis. Denies tinnitus. Hearing is \r\nWNL. Denies headache. Denies visual changes. Denies weakness / \r\nsensory changes / falls. Of note, HIV meds changed ___. Has had \r\nsimilar sx prior to med change, resulting in a fall. However, \r\nthose sx resolved\r\nspontaneously and are now recurring.  Denies fevers / chill