In [1]:
import os, sys
import wandb
from tqdm.notebook import tqdm
import transformers
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset
from importlib import reload
from torch_scatter import scatter
from transformers import pipeline
from IPython.display import clear_output

sys.path.append('../')

from utils import preprocess as pp
from utils.evaluate import eval_funcs, normalize
from utils.collate import collate_fn 
from utils.ckpt import _reload_best_model
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.bio_graphs import BiologicalDataset

# Load Datasets

In [2]:
# get dataset
# base_path = '../data/DREAM4_gold_standards/'

# c_node_id_data = BiologicalDataset(base_path + 'connections_node_id')
# c_node_label_data = BiologicalDataset(base_path + 'connections_node_label')
# sp_node_id_data = BiologicalDataset(base_path + 'shortest_path_node_id')
# sp_node_label_data = BiologicalDataset(base_path + 'shortest_path_node_label')
# get dataset
data_path = '../data/DREAM4_gold_standards/connections_node_label'
dataset = BiologicalDataset(data_path)
idx_split = dataset.get_idx_split()

In [3]:
# split datasets on idx
train_dataset = [dataset[i] for i in idx_split["train"]]
val_dataset = [dataset[i] for i in idx_split["val"]]
test_dataset = [dataset[i] for i in idx_split["test"]]

# options
batch_size = 1

# make dataloaders
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

val_loader = DataLoader(val_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

test_loader = DataLoader(test_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

  graph = torch.load(f'{self.path}/graphs/{index}.pt')


# Load Models

In [4]:
vanilla_llm = LLM(max_text_len=512,
                  max_max_new_tokens=32,
                  max_memory=[80, 80],
                  llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                  llm_frozen='True',
                  revision="main") # need toadd args

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA...
Finished loading LLaMA...


In [5]:
# base model
base_graph_llm = GraphLLM(max_text_len=512,
                     max_max_new_tokens=32,
                     max_memory=[80, 80],
                     llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                     llm_frozen='True',
                     revision="main") # args are defaulted in the class

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA!
Finished loading LLaMA!


In [6]:
# path
path = '../checkpoints/test_run/epoch_5_best.pth'

# load model
trained_graph_llm = _reload_best_model(base_graph_llm, path)

Loading checkpoint from ../checkpoints/test_run/epoch_5_best.pth.


  checkpoint = torch.load(path, map_location="cpu")


# Evaluate Node Prediction Accuracy

In [7]:
# options
verbose = True
model = trained_graph_llm
loader = train_loader

# set to eval
model.model.generation_config.pad_token_id = model.tokenizer.pad_token_id
model.eval()

n_correct = 0
# loop through dataloader
for batch in tqdm(loader):
    out = model.inference(batch)

    pred = out['pred']
    actual = out['label']

    # test accuracy
    for p, a in zip(pred, actual):
        p = normalize(p)
        a = normalize(a) + ' '
        if verbose:
            print(p)
            print(a)
            print()
        if a in p:
            n_correct += 1
            if verbose:
                print("Correct!")
                print()
        else:
            if verbose:
                print("Incorrect :(")
                print()

acc = n_correct / len(loader)
print(f"Accuracy: {acc:.2%} | {n_correct}/{len(loader)}")


  0%|          | 0/31 [00:00<?, ?it/s]

yes there is edge between nodes g7 and g4 edge is directed from g7 to g4 indicating that g7 is associated with
yes 

Correct!

yes there is edge between g10 and g3 as indicated by association g3 is associated with g10endoftext
yes 

Correct!

yes there is edge between nodes g1 and g5 edge is associated with string g1 is associated with g5 this edge
yes 

Correct!

yes there is edge between nodes g1 and g10 edge is result of associations between nodes which can be seen in graph
no 

Incorrect :(

yes there is edge between nodes g2 and g8 edge is formed by association g2 is associated with g8
yes 

Correct!

yes there is edge between g8 and g10 edge is associated with edge g3 is associated with g10 and edge
no 

Incorrect :(

yes there is edge between nodes g4 and g9 this edge is associated with string g4 is associated with g10 which implies
no 

Incorrect :(

yes there is edge between g9 and g3 as mentioned in last line of text g3 is associated with g10
no 

Incorrect :(

according to g