In [1]:
import os, sys
import wandb
from tqdm.notebook import tqdm
import transformers
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset, Subset
from importlib import reload
from torch_scatter import scatter
from transformers import pipeline
from IPython.display import clear_output

sys.path.append('../')

from utils import preprocess as pp
from utils.evaluate import eval_funcs, normalize
from utils.collate import collate_fn 
from utils.ckpt import _reload_best_model
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.bio_graphs import BiologicalDataset

# Load Datasets

In [2]:
# get dataset
# base_path = '../data/DREAM4_gold_standards/'

# c_node_id_data = BiologicalDataset(base_path + 'connections_node_id')
# c_node_label_data = BiologicalDataset(base_path + 'connections_node_label')
# sp_node_id_data = BiologicalDataset(base_path + 'shortest_path_node_id')
# sp_node_label_data = BiologicalDataset(base_path + 'shortest_path_node_label')
# get dataset
data_path = '../data/subgraphs/all'
dataset = BiologicalDataset(data_path)
idx_split = dataset.get_idx_split()

In [3]:
# split datasets on idx
test_dataset = Subset(dataset, idx_split['test'])

# options
batch_size = 8

# make dataloaders
test_loader = DataLoader(test_dataset, 
                          batch_size=batch_size,
                          shuffle=True,
                          collate_fn=collate_fn)

# Load Models

In [4]:
# base model
base_graph_llm = GraphLLM(max_text_len=512,
                     max_max_new_tokens=32,
                     max_memory=[80, 80],
                     llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                     llm_frozen=True,
                     revision="main") # args are defaulted in the class

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model is frozen


In [5]:
# path
path = '../checkpoints/graph_llm_fsdp/epoch_1_best.pth'

# load model
trained_graph_llm = _reload_best_model(base_graph_llm, path)

Loading checkpoint from ../checkpoints/graph_llm_fsdp/epoch_1_best.pth.


  checkpoint = torch.load(path, map_location="cpu")


[2025-02-15 11:32:55,122] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/pkr/miniconda3/envs/rag/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/pkr/miniconda3/envs/rag/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlvsym'
/home/pkr/miniconda3/envs/rag/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlopen'
/home/pkr/miniconda3/envs/rag/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlclose'
/home/pkr/miniconda3/envs/rag/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlerror'
/home/pkr/miniconda3/envs/rag/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlsym'
collect2: error: ld returned 1 exit status


# Evaluate Node Prediction Accuracy

In [7]:
# options
verbose = True
model = trained_graph_llm
loader = test_loader

# set to eval
model.model.generation_config.pad_token_id = model.tokenizer.pad_token_id
model.eval()

n_correct = 0
i = 0
# loop through dataloader
for batch in tqdm(loader):
    out = model.inference(batch)

    pred = out['pred']
    actual = out['label']

    # test accuracy
    for p, a in zip(pred, actual):
        p_ans, p_think = normalize(p)
        a = str(a)
        if verbose:
            print(p_think)
            print(p_ans)
            print(a)
            print()
        if a in p_ans:
            n_correct += 1
            if verbose:
                print("Correct!")
                print()
        else:
            if verbose:
                print("Incorrect :(")
                print()
        i += 1
    print(f"Accuracy: {n_correct/i:.2%} | {n_correct}/{i}", end='\r')
        
acc = n_correct / i
print(f"Accuracy: {acc:.2%} | {n_correct}/{i}")


  0%|          | 0/2258 [00:00<?, ?it/s]



yes

Incorrect :(



yes

Incorrect :(



no

Incorrect :(



no

Incorrect :(



no

Incorrect :(



no

Incorrect :(



no

Incorrect :(



no

Incorrect :(

Accuracy: 0.00% | 0/8

yes

Incorrect :(



yes

Incorrect :(



no

Incorrect :(



no

Incorrect :(



no

Incorrect :(



no

Incorrect :(



no

Incorrect :(



no

Incorrect :(

Accuracy: 0.00% | 0/16

yes

Incorrect :(



no

Incorrect :(



no

Incorrect :(



yes

Incorrect :(



yes

Incorrect :(



yes

Incorrect :(



yes

Incorrect :(



yes

Incorrect :(

Accuracy: 0.00% | 0/24

KeyboardInterrupt: 

In [12]:
model.inference(batch)

{'id': [14364, 26652, 53831, 35767, 52295, 42935, 67547, 87806],
 'pred': ['You will be given a biological graph and a question. Provide an answer of YES or NO based on the question and the given input graph. Explain your reasoning,',
  'You will be given a biological graph and a question. Provide an answer of YES of NO based on the question and the given input graph. Explain your reasoning,',
  'You will be given a biological graph and a question. Provide an answer of YES of NO based on the question and the given input graph. Explain your reasoning,',
  'You will be given a biological graph and a question. Provide an answer of YES or NO based on the question and the given input graph. Explain your reasoning,',
  'You will be given a biological graph and a question. Provide an answer of YES of NO based on the question and the given input graph. Explain your reasoning,',
  'You will be given a biological graph and a question. Provide an answer of YES or NO based on the question and the 