In [5]:
import os, sys, re
import wandb
from tqdm.notebook import tqdm
import transformers
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset, Subset
from importlib import reload
from torch_scatter import scatter
from transformers import pipeline
from IPython.display import clear_output
from accelerate import Accelerator

sys.path.append('../')

from utils import preprocess as pp
from utils.evaluate import eval_funcs, normalize
from utils.collate import collate_fn 
from utils.ckpt import _reload_best_model
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.bio_graphs import BiologicalDataset
from utils.seed import seed_everything

In [8]:
# -------
# OPTIONS
# -------
verbose = False
seed = 42
seed_everything(seed)
accelerator = Accelerator

batch_size = 8
data_path = '../data/subgraphs/all'
model_path = '../checkpoints/graph_llm_fsdp/' # REPLACE WITH BEST MODEL PATH
eval_path = '../logs/eval/graph_llm_fsdp/'

In [9]:
# --------------------
# DATASET / DATALOADER
# --------------------
dataset = BiologicalDataset(data_path)
idx_split = dataset.get_idx_split()

# split datasets on idx
test_dataset = Subset(dataset, idx_split['test'])

loader = DataLoader(test_dataset, 
                    batch_size=batch_size,
                    drop_last=True,
                    pin_memory=True,
                    shuffle=True,
                    collate_fn=collate_fn)

In [12]:
# ----------
# LOAD MODEL
# ----------
base = GraphLLM(max_text_len=256,
                max_max_new_tokens=512,
                max_memory=[80, 80],
                llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                llm_frozen=True,
                fsdp=False,
                revision="main") # args are defaulted in the class
base.load_state_dict(torch.load(model_path))

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model is frozen


  base.load_state_dict(torch.load(model_path))


IsADirectoryError: [Errno 21] Is a directory: '../checkpoints/graph_llm_fsdp/'

In [8]:
batch = next(iter(loader))
model.model.generation_config.pad_token_id = model.tokenizer.pad_token_id
model.eval()
with torch.no_grad():
    out = model.inference(batch)
print(out)

{'id': [77047, 84994, 71125, 49244, 52217, 52341, 7132, 25087], 'pred': [' <think>Looking at the graph, I see that there is no edge directly between nodes 304 and 743. However, I notice that there is an', ' <think>Looking at the graph, I can see that there is a node labeled "454" and another node labeled "400". I can also see that', " <think>Let's examine the graph. We can see that there are multiple edges between nodes, but we need to focus on the specific question. We're", ' <think>Looking at the graph, I notice that there is no direct edge between nodes 141 and 40. However, there is a path that connects', 'Think: We can look at the graph and see if there is an edge between nodes 209 and 411. The graph shows us that there is an edge', 'thinkThere is an edge between nodes 152 and 433 in the graph, as it is labeled as a directed edge. The edge is pointing from node ', ' <think>Looking at the graph, I can see that there is no edge directly between nodes 116 and 374. However, I can see t

In [10]:
len(test_dataset)

18062

In [None]:
# --------
# EVALUATE
# --------
# set to eval


n_correct = 0
i = 0
pbar = tqdm(total=len(test_dataset))

# loop through dataloader
with torch.no_grad():
    with pbar:
        for batch in loader:
            out = model.inference(batch)
            pred = out['pred']
            actual = out['label']

            # test accuracy
            for p, a in zip(pred, actual):
                p_ans, p_think = normalize(p)
                a = str(a)
                if verbose:
                    print(p_think)
                    print(p_ans)
                    print(a)
                    print()
                if a in p_ans:
                    n_correct += 1
                    if verbose:
                        print("Correct!")
                        print()
                else:
                    if verbose:
                        print("Incorrect :(")
                        print()
                i += 1
                pbar.update(1)
            print(f"Accuracy: {n_correct/i:.2%} | {n_correct}/{i}", end='\r')
acc = n_correct / i
print(f"Accuracy: {acc:.2%} | {n_correct}/{i}")

  0%|          | 0/18062 [00:00<?, ?it/s]

Accuracy: 0.00% | 0/56

KeyboardInterrupt: 

: 

In [None]:
batch = next(iter(loader))

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)