In [1]:
import os, sys
import wandb
from tqdm.notebook import tqdm
import transformers
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset
from importlib import reload
from torch_scatter import scatter
from transformers import pipeline
from IPython.display import clear_output

sys.path.append('../')

from utils import preprocess as pp
from utils.evaluate import eval_funcs, normalize
from utils.collate import collate_fn 
from utils.ckpt import _reload_best_model
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.bio_graphs import BiologicalDataset

# Load Datasets

In [None]:
# get dataset
# base_path = '../data/DREAM4_gold_standards/'

# c_node_id_data = BiologicalDataset(base_path + 'connections_node_id')
# c_node_label_data = BiologicalDataset(base_path + 'connections_node_label')
# sp_node_id_data = BiologicalDataset(base_path + 'shortest_path_node_id')
# sp_node_label_data = BiologicalDataset(base_path + 'shortest_path_node_label')
# get dataset
data_path = '../data/DREAM4_gold_standards/connections_node_label'
dataset = BiologicalDataset(data_path)
idx_split = dataset.get_idx_split()

In [None]:
# split datasets on idx
train_dataset = [dataset[i] for i in idx_split["train"]]
val_dataset = [dataset[i] for i in idx_split["val"]]
test_dataset = [dataset[i] for i in idx_split["test"]]

# options
batch_size = 1

# make dataloaders
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

val_loader = DataLoader(val_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

test_loader = DataLoader(test_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

  graph = torch.load(f'{self.path}/graphs/{index}.pt')


# Load Models

In [None]:
vanilla_llm = LLM(max_text_len=512,
                  max_max_new_tokens=32,
                  max_memory=[80, 80],
                  llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                  llm_frozen='True',
                  revision="main") # need toadd args

Loading LLaMA...


Downloading shards:   0%|          | 0/30 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 448.00 MiB. GPU 0 has a total capacity of 39.50 GiB of which 142.25 MiB is free. Process 533927 has 448.00 MiB memory in use. Process 2084245 has 2.91 GiB memory in use. Including non-PyTorch memory, this process has 35.99 GiB memory in use. Of the allocated memory 35.58 GiB is allocated by PyTorch, and 1.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# base model
base_graph_llm = GraphLLM(max_text_len=512,
                     max_max_new_tokens=32,
                     max_memory=[80, 80],
                     llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                     llm_frozen='True',
                     revision="main") # args are defaulted in the class

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA!
Finished loading LLaMA!


In [None]:
# path
path = '../checkpoints/test_run/epoch_5_best.pth'

# load model
trained_graph_llm = _reload_best_model(base_graph_llm, path)

Loading checkpoint from ../checkpoints/test_run/epoch_5_best.pth.


  checkpoint = torch.load(path, map_location="cpu")


# Evaluate Node Prediction Accuracy

In [None]:
# options
verbose = True
model = trained_graph_llm
loader = train_loader

# set to eval
model.model.generation_config.pad_token_id = model.tokenizer.pad_token_id
model.eval()

n_correct = 0
# loop through dataloader
for batch in tqdm(loader):
    out = model.inference(batch)

    pred = out['pred']
    actual = out['label']

    # test accuracy
    for p, a in zip(pred, actual):
        p = normalize(p)
        a = normalize(a) + ' '
        if verbose:
            print(p)
            print(a)
            print()
        if a in p:
            n_correct += 1
            if verbose:
                print("Correct!")
                print()
        else:
            if verbose:
                print("Incorrect :(")
                print()

acc = n_correct / len(loader)
print(f"Accuracy: {acc:.2%} | {n_correct}/{len(loader)}")


  0%|          | 0/31 [00:00<?, ?it/s]

yes there is edge between nodes g4 and g3 edge is directed from g4 to g3 edge is part of following
yes 

Correct!

yes there is edge between g10 and g2 as g2 is associated with g10 graph shows that g2 is associated with g
no 

Incorrect :(

yes there is edge between nodes g5 and g1 edge is labeled g1 is associated with g5 node g5 is
yes 

Correct!

yes there is edge between nodes g3 and g1 edge is directed from g3 to g1 meaning that g3 is associated with
yes 

Correct!

there is edge between nodes g1 and g10 edge is bidirectional because there are two edges g1 is associated with g5
no 

Incorrect :(

yes there is edge between g7 and g4 graph shows that g7 is associated with g10 and g4 is associated with g
yes 

Correct!

yes there is edge between g3 and g1 and since g1 is associated with g5 there is also edge between g3 and
no 

Incorrect :(

yes there is edge between nodes g8 and g9 edge is inferred from association g9 is associated with g10 and
no 

Incorrect :(

yes there is edge be