In [14]:
import os, sys
import wandb
import tqdm.notebook as tqdm
import transformers
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset
from importlib import reload
from torch_scatter import scatter
from transformers import pipeline

sys.path.append('../')

from utils import preprocess as pp
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.bio_graphs import BiologicalDataset

# training imports
from utils.evaluate import eval_funcs, normalize
from utils.config import parse_args_llama
from utils.ckpt import _save_checkpoint, _reload_best_model
from utils.collate import collate_fn
from utils.seed import seed_everything
from utils.lr_schedule import adjust_learning_rate

In [2]:
# get dataset
data_path = '../data/DREAM4_gold_standards/connections_node_label'
dataset = BiologicalDataset(data_path)
idx_split = dataset.get_idx_split()

In [28]:
# split datasets on idx
train_dataset = [dataset[i] for i in idx_split["train"]]
val_dataset = [dataset[i] for i in idx_split["val"]]
test_dataset = [dataset[i] for i in idx_split["test"]]

# options
batch_size = 12

# make dataloaders
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

val_loader = DataLoader(val_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

test_loader = DataLoader(test_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

  graph = torch.load(f'{self.path}/graphs/{index}.pt')


In [26]:
model = LLM(max_text_len=512,
                  max_max_new_tokens=32,
                  max_memory=[80, 80],
                  llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                #   llm_model_path='meta-llama/Llama-2-7B-chat-hf',
                  llm_frozen='True',
                  revision="main") # need to add args

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA...
Finished loading LLaMA...


In [29]:
# make prediction
batch = next(iter(train_loader))
out = model.inference(batch)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


In [30]:
    pred = out['pred']
    actual = out['label']

    n_correct = 0
    for p, a in zip(pred, actual):
        p = normalize(p)
        a = normalize(a) + ' '
        print(p)
        print(a)
        print()
        if a in p:
            n_correct += 1
            print("Correct!")
            print()
    acc = n_correct / len(pred)
    print(f"Accuracy: {acc:.2%}")

yes according to graph there is edge between g8 and g10 edge is induced by edge g2 is associated with g
no 

yes according to graph there is edge between g7 and g4 edge is associated with layer coexpressionheart
yes 

Correct!

yes there is edge between nodes g9 and g5 edge is associated with relation g1 is associated with g5
no 

yes there is edge between nodes g5 and g1 because g1 is associated with g5
yes 

Correct!

yes there is edge between nodes g3 and g1 as indicated by statement g1 is associated with g3 this is und
yes 

Correct!

yes there is edge between g1 and g5 which is also between g1 and g4 and g4 and g5 so
no 

there are no edges between g5 and g9 edges are defined in following format nodeid x is gy where x is
no 

Correct!

there is edge between nodes g8 and g3 edge is result of association g2 is associated with g8 and g
no 

yes there is edge between nodes g6 and g3 edge is associated with relationship g3 is associated with g6
no 

yes there is edge between nodes g3 a