In [1]:
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
mt_path = "output/full_vicuna_7b_bsz1_epoch3_lr8e-06_13858199"
model = AutoModelForCausalLM.from_pretrained(mt_path).half().cuda()
tok = AutoTokenizer.from_pretrained(mt_path)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
train_data_path = "data/kg_instruction_1000.json"
kg_path = "data/umls_kg_filter_count_5.csv"
train_dst = json.load(open(train_data_path, "r"))
et = {e:t for es,ts in zip([d['input_entities']+d['output_entities'] for d in train_dst], [d['input_triplets']+d['output_triplets'] for d in train_dst]) for e,t in zip(es,ts)}
kg = pd.read_csv(kg_path).to_dict(orient='records')
dash_token = "[DASH]"

In [3]:
all_acc = []
from tqdm.auto import tqdm
for e in tqdm(et.keys(), total=len(et.keys())):
    t = et[e]
    bsz = 16
    e_acc = 0
    for tid in t:
        tri = kg[tid]
        prompts = [f"{tri['source']} {tri['edge']}", f"{tri['target']} {tri['edge']}"]
        targets = [tri['target'], tri['source']]
        tok.padding_side = 'left'
        tok.pad_token = tok.eos_token
        tok.pad_token_id = tok.eos_token_id
        max_new_tokens = 10
        flag = False
        for prompt,target in zip(prompts, targets):
            print(f"{tid} PROMPT:{prompt}")
            print(f"{tid} GT:{target}")
            inp = tok(prompt, return_tensors="pt")
            inp_len = inp['input_ids'].shape[1]
            output = model.generate(inp['input_ids'].to(model.device), attention_mask=inp['attention_mask'].to(model.device), max_new_tokens=max_new_tokens)
            pred = tok.decode(output[0,inp_len:], skip_special_tokens=True).strip()
            print(f"{tid} PRED:{pred}")
            flag = ((pred.lower() in target.lower()) and pred != "") or flag
        print(f"{tid} SUCCESS? {flag}\n")
        e_acc += int(flag)
    e_acc /= len(t)
    print(f"{e} acc:{e_acc}")
    all_acc.append(e_acc)
    
print(f"avg acc:{sum(all_acc)/len(all_acc)}")

  0%|          | 0/1337 [00:00<?, ?it/s]

30912 PROMPT:Malignant melanoma of lower limb or hip NOS (disorder)  [DASH] possibly equivalent to
30912 GT:Malignant melanoma of skin of hip (disorder)




30912 PRED:Malignant melanoma of lower limb
30912 PROMPT:Malignant melanoma of skin of hip (disorder)  [DASH] possibly equivalent to
30912 GT:Malignant melanoma of lower limb or hip NOS (disorder)
30912 PRED:Malignant melanoma of skin of th
30912 SUCCESS? False

30913 PROMPT:Malignant melanoma of lower limb or hip NOS (disorder)  [DASH] possibly equivalent to
30913 GT:Malignant melanoma of lower limb (disorder)
30913 PRED:Malignant melanoma of lower limb
30913 PROMPT:Malignant melanoma of lower limb (disorder)  [DASH] possibly equivalent to
30913 GT:Malignant melanoma of lower limb or hip NOS (disorder)
30913 PRED:Malignant melanoma of skin of lower
30913 SUCCESS? True

30914 PROMPT:Malignant melanoma of skin of hip (disorder)  [DASH] possibly equivalent to
30914 GT:Malignant melanoma of lower limb or hip NOS (disorder)
30914 PRED:Malignant melanoma of skin of th
30914 PROMPT:Malignant melanoma of lower limb or hip NOS (disorder)  [DASH] possibly equivalent to
30914 GT:Malignant melano

In [None]:
import sys
sys.path.append("/home/cs/yangyuchen/guoyiqiu/gpt_re/")
import os
import time
import pytorch_lightning as pl
import torch
from model import *
import torch.utils.data as tud
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import WandbLogger
from tqdm.notebook import tqdm
from utils.my_utils import *
import torch.nn.functional as F
import random
import regex as re
from dataset import *
import ipywidgets as widgets
from IPython.display import display
from typing import Union, List
torch.set_float32_matmul_precision('medium')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
# os.environ['CUDA_VISIBLE_DEVICES'] = '3'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

model_list = [
    ("gpt2", "/mnt/workspace/guoyiqiu/coding/huggingface/hub/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8"),
    ("gpt2-xl", "/mnt/workspace/guoyiqiu/coding/huggingface/hub/models--gpt2-xl/snapshots/33cdb5c0db5423c1879b1b9f16c352988e8754a8"),
    ("gpt2-medium", "/mnt/workspace/guoyiqiu/coding/huggingface/hub/models--gpt2-medium/snapshots/425b0cc90498ac177aa51ba07be26fc2fea6af9d"),
    ("llama_7b", "/nvme/share/guoyiqiu/llama-7b"),
    ("llama_13b", "/nvme/share/guoyiqiu/llama-13b"),
    ("vicuna_7b", "/home/cs/yangyuchen/yushengliao/Medical_LLM/vicuna-7b"),
    ("vicuna_13b", "/mnt/workspace/guoyiqiu/coding/vicuna-13b-v1.1"),
    ("book_7b", "/mnt/workspace/guoyiqiu/coding/Book_7B/checkpoint-4968"),
    ("book_13b", "/home/cs/yangyuchen/yushengliao/Medical_LLM/FastChat/checkpoints/medical_llama_13b_chatv1.3/checkpoint-4974/"),
    ("book_13b_kg", "/home/cs/yangyuchen/guoyiqiu/kg_llm/output/full_book_13b_bsz1_epoch3_lr1e-05"),
    ("vicuna_7b_kg", "/home/cs/yangyuchen/guoyiqiu/kg_llm/output/full_vicuna_7b_bsz2_epoch3_lr1e-05"),
]


def setup_widgets(model_list):
    global mt_dropdown
    global setup_btn
    global device_tbtn
    global precision_tbtn
    global mnt_slider
    global input_textarea
    global output_textarea
    global submit_btn
    global chat_checkbox
    global sample_checkbox
    global model
    global tok
    global mt
    
    def setup_llm(btn):
        global mt
        global vis
        global model
        global tok
        time_st = time.time()
        btn.description = "Loading model..."
        mt = LLM.from_pretrained(model_name=mt_dropdown.value, fp16=(precision_tbtn.value == "half"),)
        btn.description = "Everything is ready."
        device_tbtn.value = 'cpu'
        model = mt.model
        tok = mt.tokenizer
        print(f"Time cost: {time.time() - time_st:.2f}s")
    
    def switch_device(change):
        device_tbtn.disabled = True
        mt.to(change.new)
        torch.cuda.empty_cache() if change.new == 'cpu' else None
        device_tbtn.disabled = False

    def switch_precision(change):
        precision_tbtn.disabled = True
        if mt is not None:
            mt.model = mt.model.half() if change.new == 'half' else mt.model.float()
        precision_tbtn.disabled = False

    def generate(btn):
        CHAT_TEMPLATE = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n##USER:\n{}\n\n##ASSISTANT:\n"
        btn.disabled = True
        submit_btn.description = "Generating..."
        input_text = CHAT_TEMPLATE.format(input_textarea.value) if chat_checkbox.value else input_textarea.value
        gen_kwargs = {
            "input_texts":input_text,
            "max_new_tokens":mnt_slider.value,
            "do_sample": sample_checkbox.value,
        }
        result = mt.generate(**gen_kwargs)
        btn.disabled = False
        submit_btn.description = "generate"
        output_text = result[0].replace(input_text, "") if chat_checkbox.value else result[0]
        output_textarea.value = output_text

    # model dropdown
    mt_dropdown = widgets.Dropdown(options=model_list, description='Model:', disabled=False,)

    # setup button
    setup_btn = widgets.Button(description="Setup everything", disabled=False,)
    setup_btn.on_click(setup_llm)

    # switch deivce
    device_tbtn = widgets.ToggleButtons(options=['cpu', f'cuda',], disabled=False,)
    device_tbtn.observe(switch_device, names='value')

    # switch precision
    precision_tbtn = widgets.ToggleButtons(options=['float', 'half'], disabled=False,)
    precision_tbtn.observe(switch_precision, names='value')

    # max new token slider
    mnt_slider = widgets.IntSlider(value=64,min=1,max=512,step=1,description='new token:',disabled=False,)
    
    # sample checkbox
    sample_checkbox = widgets.Checkbox(value=False,description='do sample',disabled=False,)
    
    # input and output textarea
    input_textarea = widgets.Textarea(value='',description='Input:',layout=widgets.Layout(width='30%', height='250px'),disabled=False)
    output_textarea = widgets.Textarea(value='',description='Output:',layout=widgets.Layout(width='30%', height='250px'),disabled=False)

    # submit button
    submit_btn = widgets.Button(description="generate",disabled=False,)
    submit_btn.on_click(generate)

    # chat mode checkbox
    chat_checkbox = widgets.Checkbox(value=False,description='chat mode',disabled=False,)
    
    # pannel layout
    control_panel = widgets.HBox([mt_dropdown, setup_btn, precision_tbtn, device_tbtn])
    generate_panel = widgets.HBox([input_textarea, widgets.VBox([mnt_slider, sample_checkbox, chat_checkbox, submit_btn]), output_textarea])
    all_panel = widgets.VBox([control_panel, generate_panel])
    display(all_panel)

setup_widgets(model_list)
mt= LLM.from_mt(model,tok)