In [None]:
import os
import sys

current_folder = str(globals()['_dh'][0])
root_path = current_folder.split('src/')[0]
os.chdir(root_path)
sys.path.append(root_path + 'src')

from utils.basics import init_env_variables, print_important_cfg, time_logger
from tqdm import tqdm
from math import ceil

init_env_variables()

from utils.pkg.distributed import initialize_deepspeed, initialize_distributed
from utils.project.exp import init_experiment
import logging
import hydra
import numpy as np
import pandas as pd

logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

from covid_llm.agent import DeepSpeedAgent, Agent
from covid_llm.instruction_dataset import InstructionDataset, load_sft_dataset
from covid_llm.model import CovidLLM
from utils.data.covid_data import CovidData
import torch as th

from covid_llm.metrics import calc_prediction_class_distribution

from hydra import compose, initialize
from omegaconf import OmegaConf

initialize(config_path=f'../../configs', version_base=None)

In [None]:
cfg = compose(config_name="main", overrides=['seed=2023', 'splits_type=sta_aug_splits', 'target=t1', 
                                               'total_steps=1301', 'use_cont_fields=True',
                                               'use_deepspeed=False', 'use_trends=False',
                                               'use_wandb=False','wandb.name=zero_shot_t1_simple_prompt_decrease',
                                               'data_file=processed_v5_3_BA1.pkl', 'use_variant_prompt=True',
                                               'eval_freq=1300', 'save_model=True'])
print(OmegaConf.to_yaml(cfg))
cfg, logger = init_experiment(cfg)
cfg.use_bf16 = th.cuda.is_bf16_supported() and cfg.use_bf16
initialize_deepspeed(cfg)
data = CovidData(cfg=cfg)
model = CovidLLM(cfg, data, logger)

model_path = '/home/hy235/zy/llm/CovidLLM4_5/output/None/' + \
'CovidLLM/nmwnlvqy-t1-sta_aug_splits-True-False-val_mse/checkpoints/final_model/pytorch_model.pt'

pretrained_dict=th.load(model_path)
model_dict=model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

In [None]:
model.tokenizer.from_pretrained('/home/hy235/zy/llm/CovidLLM4_5/output/None/CovidLLM/nmwnlvqy-t1-sta_aug_splits-True-False-val_mse/checkpoints/final_model')

In [None]:
initialize_distributed(cfg, logger)
batch_size = cfg.world_size * cfg.ds['train_micro_batch_size_per_gpu']
variant_dataset = InstructionDataset(data, cfg, cfg.mode, use_variant_prompt=True)
variant_ids = data.variant_splits
_, variant_iter, _ = load_sft_dataset(
            cfg,
            full_dataset=variant_dataset, split_ids=variant_ids,
            batch_size=cfg.inf_batch_size,
            split='test', world_size=cfg.world_size, rank=cfg.local_rank
        )

In [None]:
import numpy as np
th.cuda.set_device(4)
model.init_rank(cfg)
model.device = 'cpu'
model.to(model.device)
for batch in variant_iter:
    node_ids, prompt_tree_lol, conversation_list = batch
    batch = np.array(node_ids).astype(np.float32), prompt_tree_lol, conversation_list
    print(node_ids, prompt_tree_lol, conversation_list)
    results = model.forward(batch)
    print(results)
    break