In [8]:
import os, sys
import wandb
import tqdm.notebook as tqdm
import transformers
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset
from importlib import reload
from torch_scatter import scatter
from transformers import pipeline

sys.path.append('../')

from utils import preprocess as pp
from utils.evaluate import eval_funcs
from utils.collate import collate_fn 
from utils.ckpt import _reload_best_model
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.bio_graphs import BiologicalDataset

# Load Datasets

In [9]:
# get dataset
# base_path = '../data/DREAM4_gold_standards/'

# c_node_id_data = BiologicalDataset(base_path + 'connections_node_id')
# c_node_label_data = BiologicalDataset(base_path + 'connections_node_label')
# sp_node_id_data = BiologicalDataset(base_path + 'shortest_path_node_id')
# sp_node_label_data = BiologicalDataset(base_path + 'shortest_path_node_label')
# get dataset
data_path = '../data/DREAM4_gold_standards/connections_node_label'
dataset = BiologicalDataset(data_path)
idx_split = dataset.get_idx_split()

In [10]:
# split datasets on idx
train_dataset = [dataset[i] for i in idx_split["train"]]
val_dataset = [dataset[i] for i in idx_split["val"]]
test_dataset = [dataset[i] for i in idx_split["test"]]

# options
batch_size = 1

# make dataloaders
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

val_loader = DataLoader(val_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

test_loader = DataLoader(test_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

  graph = torch.load(f'{self.path}/graphs/{index}.pt')


# Load Models

In [11]:
vanilla_llm = LLM(max_text_len=512,
                  max_max_new_tokens=32,
                  max_memory=[80, 80],
                  llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                  llm_frozen='True',
                  revision="main") # need to add args

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA...
Finished loading LLaMA...


In [12]:
# base model
base_graph_llm = GraphLLM(max_text_len=512,
                     max_max_new_tokens=32,
                     max_memory=[80, 80],
                     llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                     llm_frozen='True',
                     revision="main") # args are defaulted in the class

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA!
Finished loading LLaMA!


In [13]:
# path
path = '../checkpoints/test_run/epoch_5_best.pth'

# load model
trained_graph_llm = _reload_best_model(base_graph_llm, path)

Loading checkpoint from ../checkpoints/test_run/epoch_5_best.pth.


  checkpoint = torch.load(path, map_location="cpu")


# Evaluate

In [14]:
model.eval()
progress_bar_test = tqdm(range(len(test_loader)))
with open(path, "w") as f:
    for _, batch in enumerate(test_loader):
        with torch.no_grad():
            output = model.inference(batch)
            df = pd.DataFrame(output)
            for _, row in df.iterrows():
                f.write(json.dumps(dict(row)) + "\n")
        progress_bar_test.update(1)

# Step 5. Post-processing & Evaluating
acc = eval_funcs[args.dataset](path)
print(f'Test Acc {acc}')
wandb.log({'Test Acc': acc})

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


{'id': [28],
 'pred': ["Yes, there is an edge between G8 and G10. The edge is associated with the association 'G3 is associated with G10'. The node G"],
 'label': ['no'],
 'question': ['Is there an edge between nodes G8 and G10?'],
 'desc': ["['G1 is associated with G2', 'G1 is associated with G3', 'G1 is associated with G4', 'G1 is associated with G5', 'G2 is associated with G6', 'G2 is associated with G8', 'G3 is associated with G4', 'G3 is associated with G7', 'G3 is associated with G10', 'G4 is associated with G7', 'G4 is associated with G10', 'G6 is associated with G8', 'G9 is associated with G10', 'node_id 0 is G1', 'node_id 1 is G10', 'node_id 2 is G2', 'node_id 3 is G3', 'node_id 4 is G4', 'node_id 5 is G5', 'node_id 6 is G6', 'node_id 7 is G7', 'node_id 8 is G8', 'node_id 9 is G9', 'layer 0 is from coexpression-heart']"]}