In [1]:
import time
import os, sys
import wandb
import tqdm.notebook as tqdm
import transformers
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset, Subset
from importlib import reload
from torch_scatter import scatter
from transformers import pipeline

sys.path.append('../')

from utils import preprocess as pp
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.bio_graphs import BiologicalDataset

# training imports
# from utils.model import load_model, llama_model_path
from utils.evaluate import eval_funcs
from utils.config import parse_args_llama
from utils.ckpt import _save_checkpoint, _reload_best_model
from utils.collate import collate_fn
from utils.seed import seed_everything
from utils.lr_schedule import adjust_learning_rate

# Dataset & Dataloaders

In [2]:
# get dataset
data_path = '../data/subgraphs/all'
dataset = BiologicalDataset(data_path)
idx_split = dataset.get_idx_split()

In [4]:
# split datasets on idx
train_dataset = Subset(dataset, idx_split['train'])
val_dataset = Subset(dataset, idx_split['val'])
test_dataset = Subset(dataset, idx_split['test'])

In [5]:
# options
batch_size = 4

# make dataloaders
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn,
                          num_workers=16)

val_loader = DataLoader(val_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn,
                          num_workers=16)

test_loader = DataLoader(test_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn,
                          num_workers=16)

# Load LLMs

In [6]:
model = GraphLLM(max_text_len=512,
                     max_max_new_tokens=32,
                     max_memory=[80, 80],
                     llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                     llm_frozen='True',
                     revision="main") # args are defaulted in the class

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA!
Finished loading LLaMA!


# Train Graph LLM

In [7]:
# specify needed args
sys.argv = [''] # needed for argparse in notebooks
args = parse_args_llama()

In [8]:
# set up seed for tracking
seed = 42
seed_everything(seed)

In [9]:
# options
num_training_steps = args.num_epochs * len(train_loader)
progress_bar = tqdm.tqdm(range(num_training_steps))
best_val_loss = float('inf')
save_path = '../checkpoints/graph_llm_no_text/'

# set optimizer
params = [p for _, p in model.named_parameters() if p.requires_grad] # only update non-frozen params (graph encoder)
optimizer = torch.optim.AdamW(
    [{'params': params, 'lr': args.lr, 'weight_decay': args.wd}, ],
    betas=(0.9, 0.95)
)

## TRAIN LOOP
for epoch in range(args.num_epochs):

    model.train()
    epoch_loss, accum_loss = 0., 0.

    for step, batch in enumerate(train_loader):

        optimizer.zero_grad()
        loss = model(batch)
        
        # clip gradients so large changes don't occur - super small clipping too
        clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
        
        # grad steps is a hyprparameter
        if (step + 1) % args.grad_steps == 0:
            adjust_learning_rate(optimizer.param_groups[0], args.lr, step / len(train_loader) + epoch, args)
        
        optimizer.step()
        epoch_loss += loss.item()
        accum_loss += loss.item()

        if  (step + 1) % args.grad_steps == 0:
            lr = optimizer.param_groups[0]['lr']
            # wandb.log({'Lr': lr})
            # wandb.log({'Train Loss': accum_loss / args.grad_steps})
            accum_loss = 0.
        
        progress_bar.update(1)
    # wandb.log({'Train Loss (Epoch Mean)': epoch_loss / len(train_loader)})

    # validation
    val_loss = 0.
    eval_output = []
    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(val_loader):
            loss = model(batch)
            val_loss += loss.item()
        val_loss /= len(val_loader)
        # wandb.log({'Validation Loss': val_loss})
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        _save_checkpoint(model, optimizer, epoch, args, save_path, is_best=True)
        best_epoch = epoch
    
    print(f"Epoch {epoch}/{args.num_epochs} | Train Loss (Epoch Mean): {epoch_loss / len(train_loader)} | Validation Loss: {val_loss} | Best Validation Loss: {best_val_loss} at epoch {best_epoch}", end="\r")

    if epoch - best_epoch >= args.patience:
        print(f"Early stopping at epoch {epoch}")
        break

torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

  0%|          | 0/135460 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Saving checkpoint at epoch 0 to ../checkpoints/graph_llm_no_text/.
Epoch 0/10 | Train Loss (Epoch Mean): 1.051144796544756 | Validation Loss: 1.0523156256638757 | Best Validation Loss: 1.0523156256638757 at epoch 0

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

KeyboardInterrupt: 

# Evaluate After Training

In [10]:
batch = next(iter(test_loader))
model.inference(batch)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

{'id': [40500, 88015, 90059, 3360],
 'pred': ['Yes, there is an edge between nodes 330 and 77.\n\nExplanation:\n\nThe graph is given in the format of a series of nodes and edges. Each',
  'yes\nFinal Answer: The final answer is yes. I hope it is correct.',
  ' node 195 node 207 edge\nTo solve this problem, you need to analyze the biological graph and find the edge that connects node 195 and node ',
  ' \nHere is the biological graph:\n```\nNode 345  |  Edge  |  Node 864\n          |  ->    |\n          | '],
 'label': ['yes', 'yes', 'no', 'no'],
 'question': ['Is there an edge between nodes 330 and 77?',
  'Is there an edge between nodes 459 and 410?',
  'Is there an edge between nodes 195 and 207?',
  'Is there an edge between nodes 435 and 864?'],
 'desc': ['You will be given a biological graph and a question. You need to determine the answer to the question based on the graph.',
  'You will be given a biological graph and a question. You need to determine the answer to the question 