In [2]:
import os, sys
import wandb
import tqdm
import transformers
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset
from importlib import reload

from transformers import pipeline

sys.path.append('../')

from utils import preprocess as pp
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.bio_graphs import BiologicalDataset

# training imports
# from utils.model import load_model, llama_model_path
from utils.evaluate import eval_funcs
from utils.config import parse_args_llama
from utils.ckpt import _save_checkpoint, _reload_best_model
from utils.collate import collate_fn
from utils.seed import seed_everything
from utils.lr_schedule import adjust_learning_rate


# Step 1: Load Networks as `Multiplex` object

In [3]:
flist_name = '../data/DREAM4_gold_standards/mono_flist.tsv'
mp = Multiplex(flist_name)

# Step 2: Textualize Graphs (Ken's Code)

In [4]:
textualize = load_textualizer['all']
graph_text = textualize(mp)

# view first 10 items
for i in range(len(graph_text)):
    print(graph_text[i])

G1 is associated with G2
G1 is associated with G3
G1 is associated with G4
G1 is associated with G5
G2 is associated with G6
G2 is associated with G8
G3 is associated with G4
G3 is associated with G7
G3 is associated with G10
G4 is associated with G7
G4 is associated with G10
G6 is associated with G8
G9 is associated with G10
node_id 0 is G1
node_id 1 is G10
node_id 2 is G2
node_id 3 is G3
node_id 4 is G4
node_id 5 is G5
node_id 6 is G6
node_id 7 is G7
node_id 8 is G8
node_id 9 is G9
layer 0 is from coexpression-heart


# Step 3: Make Dataloader
* dataloader returns dict with keys `["ids"]`, `["question"]`, `["label"]`, `["desc"]`, `["graph"]`
* ask ken if our dataset setup is the same

In [5]:
# get dataset
data_path = '../data/DREAM4_gold_standards/connections_node_id'
dataset = BiologicalDataset(data_path)
idx_split = dataset.get_idx_split()

In [6]:
# split datasets on idx
train_dataset = [dataset[i] for i in idx_split["train"]]
val_dataset = [dataset[i] for i in idx_split["val"]]
test_dataset = [dataset[i] for i in idx_split["test"]]

# options
batch_size = 1

# make dataloaders
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

val_loader = DataLoader(val_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

test_loader = DataLoader(test_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

  graph = torch.load(f'{self.path}/graphs/{index}.pt')


# Step 4: Load In Encoder + LLM

In [8]:
vanilla_llm = LLM(max_text_len=512,
                  max_max_new_tokens=32,
                  max_memory=[80, 80],
                  llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                #   llm_model_path='meta-llama/Llama-2-7B-chat-hf',
                  llm_frozen='True',
                  revision="main") # need to add args

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA...
Finished loading LLaMA...


In [9]:
graph_llm = GraphLLM(max_text_len=512,
                     max_max_new_tokens=32,
                     max_memory=[80, 80],
                     llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                     llm_frozen='True',
                     revision="main") # args are defaulted in the class

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA!
Finished loading LLaMA!


# Step 5: Perform Initial Untrained Inference

In [10]:
batch = next(iter(train_loader))
print(batch)

{'id': [1], 'question': ['Is there an edge between nodes 4 and 0?'], 'label': ['yes'], 'desc': ["['G1 is associated with G2', 'G1 is associated with G3', 'G1 is associated with G4', 'G1 is associated with G5', 'G2 is associated with G6', 'G2 is associated with G8', 'G3 is associated with G4', 'G3 is associated with G7', 'G3 is associated with G10', 'G4 is associated with G7', 'G4 is associated with G10', 'G6 is associated with G8', 'G9 is associated with G10', 'node_id 0 is G1', 'node_id 1 is G10', 'node_id 2 is G2', 'node_id 3 is G3', 'node_id 4 is G4', 'node_id 5 is G5', 'node_id 6 is G6', 'node_id 7 is G7', 'node_id 8 is G8', 'node_id 9 is G9', 'layer 0 is from coexpression-heart']"], 'graph': DataBatch(x=[1], edge_index=[2, 13], num_nodes=10, batch=[10], ptr=[2])}


In [11]:
vanilla_llm.inference(batch)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


{'id': [1],
 'pred': [" 1\ns>[INST]['yes', 'no'][INST] 1\ns>[INST]['yes', 'yes'][INST] 1"],
 'label': ['yes'],
 'question': ['Is there an edge between nodes 4 and 0?'],
 'desc': ["['G1 is associated with G2', 'G1 is associated with G3', 'G1 is associated with G4', 'G1 is associated with G5', 'G2 is associated with G6', 'G2 is associated with G8', 'G3 is associated with G4', 'G3 is associated with G7', 'G3 is associated with G10', 'G4 is associated with G7', 'G4 is associated with G10', 'G6 is associated with G8', 'G9 is associated with G10', 'node_id 0 is G1', 'node_id 1 is G10', 'node_id 2 is G2', 'node_id 3 is G3', 'node_id 4 is G4', 'node_id 5 is G5', 'node_id 6 is G6', 'node_id 7 is G7', 'node_id 8 is G8', 'node_id 9 is G9', 'layer 0 is from coexpression-heart']"]}

In [12]:
graph_llm.inference(batch)

TypeError: LlamaModel.get_input_embeddings() takes 1 positional argument but 2 were given

In [10]:
print(out['pred'], out['label'])

['s>\nYes, there is a path between nodes G9 and G5. The path is: G1 -> G2 -> G5 or G1 ->'] ['yes']


# Step 6: Train Model

In [2]:
# specify needed args
args = parse_args_llama()

usage: ipykernel_launcher.py [-h] [--model_name MODEL_NAME]
                             [--project PROJECT] [--seed SEED]
                             [--dataset DATASET] [--lr LR] [--wd WD]
                             [--patience PATIENCE] [--batch_size BATCH_SIZE]
                             [--grad_steps GRAD_STEPS]
                             [--num_epochs NUM_EPOCHS]
                             [--warmup_epochs WARMUP_EPOCHS]
                             [--eval_batch_size EVAL_BATCH_SIZE]
                             [--llm_model_name LLM_MODEL_NAME]
                             [--llm_model_path LLM_MODEL_PATH]
                             [--llm_frozen LLM_FROZEN]
                             [--llm_num_virtual_tokens LLM_NUM_VIRTUAL_TOKENS]
                             [--output_dir OUTPUT_DIR]
                             [--max_txt_len MAX_TXT_LEN]
                             [--max_new_tokens MAX_NEW_TOKENS]
                             [--max_memory MAX_MEMORY]
     

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
# set up wandb, seed for tracking
seed = 42
wandb.init(project=f"{project}",
            name=f"{dataset}_{model_name}_seed{seed}",
            config=args)
seed_everything(seed)

In [None]:
# options
num_training_steps = num_epochs * len(train_loader)
progress_bar = tqdm(range(num_training_steps))
best_val_loss = float('inf')

# set optimizer
params = [p for _, p in model.named_parameters() if p.requires_grad] # only update non-frozen params (graph encoder)
optimizer = torch.optim.AdamW(
    [{'params': params, 'lr': lr, 'weight_decay': wd}, ],
    betas=(0.9, 0.95)
)

## TRAIN LOOP
for epoch in range(num_epochs):

    model.train()
    epoch_loss, accum_loss = 0., 0.

    for step, batch in enumerate(train_loader):

        optimizer.zero_grad()
        loss = model(batch)
        
        # clip gradients so large changes don't occur - super small clipping too
        clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
        
        # grad steps is a hyprparameter
        if (step + 1) % grad_steps == 0:
            adjust_learning_rate(optimizer.param_groups[0], lr, step / len(train_loader) + epoch)
        
        optimizer.step()
        epoch_loss, accum_loss += loss.item(), loss.item()

        if  (step + 1) % grad_steps == 0:
            lr = optimizer.param_groups[0]['lr']
            wandb.log({'Lr': lr})
            wandb.log({'Train Loss': accum_loss / grad_steps})
            accum_loss = 0.
        
        progress_bar.update(1)
    
    print(f"Epoch {epoch}/{num_epochs} | Train Loss (Epoch Mean): {epoch_loss / len(train_loader)}")
    wandb.log({'Train Loss (Epoch Mean)': epoch_loss / len(train_loader)})

    # validation
    val_loss = 0.
    eval_output = []
    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(val_loader):
            loss = model(batch)
            val_loss += loss.item()
        val_loss /= len(val_loader)
        print(f"Epoch {epoch}/{num_epochs} | Validation Loss: {val_loss}")
        wandb.log({'Validation Loss': val_loss})
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        _save_checkpoint(model, optimizer, epoch, args, is_best=True)
        best_epoch = epoch
    
    print(f"Epoch {epoch}/{num_epochs} | Best Validation Loss: {best_val_loss} at epoch {best_epoch}")

    if epoch - best_epoch >= patience:
        print(f"Early stopping at epoch {epoch}")
        break

torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

# Step 7: Evaluate After Training

In [None]:
# eval
model = _reload_best_model(model, args)
model.eval()

progress_bar_test = tqdm(range(len(test_loader)))
with open(path, "w") as f:
    for step, batch in enumerate(test_loader):
        with torch.no_grad():
            output = model.inference(batch)
            df = pd.DataFrame(output)
            for _, row in df.iterrows():
                f.write(json.dumps(dict(row)) + "\n")
        progress_bar_test.update(1)

# post process + compute metrics
acc = eval_funcs[dataset](path)
print(f'Test Acc: {acc}')
wandb.log({'Test Acc': acc})