In [2]:
import os, sys
import wandb
import tqdm.notebook as tqdm
import transformers
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset, Subset
from importlib import reload
from torch_scatter import scatter
from transformers import pipeline

sys.path.append('../')

from utils import preprocess as pp
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.bio_graphs import BiologicalDataset

# training imports
# from utils.model import load_model, llama_model_path
from utils.evaluate import eval_funcs
from utils.config import parse_args_llama
from utils.ckpt import _save_checkpoint, _reload_best_model
from utils.collate import collate_fn
from utils.seed import seed_everything
from utils.lr_schedule import adjust_learning_rate

# Dataset & Dataloaders

In [3]:
# get dataset
data_path = '../data/subgraphs/all'
dataset = BiologicalDataset(data_path)
idx_split = dataset.get_idx_split()

In [4]:
dataset[0]

  graph = torch.load(text['graph'])


{'id': 0,
 'question': 'Is there an edge between nodes 513 and 623?',
 'scope': 'all',
 'label': "['yes']",
 'desc': ' ',
 'graph': Data(x=[1000, 1024], edge_index=[2, 13124], num_nodes=1000)}

In [5]:
# split datasets on idx
train_dataset = Subset(dataset, idx_split['train'])
val_dataset = Subset(dataset, idx_split['val'])
test_dataset = Subset(dataset, idx_split['test'])

In [6]:
# options
batch_size = 8

# make dataloaders
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn,
                          num_workers=16)

val_loader = DataLoader(val_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn,
                          num_workers=16)

test_loader = DataLoader(test_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn,
                          num_workers=16)

In [7]:
batch = next(iter(train_loader))

# Load LLMs

In [8]:
model = GraphLLM(max_text_len=512,
                     max_max_new_tokens=32,
                     max_memory=[80, 80],
                     llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                     llm_frozen='True',
                     revision="main") # args are defaulted in the class

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA!
Finished loading LLaMA!


# Train Graph LLM

In [9]:
# specify needed args
sys.argv = [''] # needed for argparse in notebooks
args = parse_args_llama()

In [10]:
# set up wandb, seed for tracking
seed = 42
# wandb.init(project=f"{args.project}",
#             name=f"{dataset}_{args.model_name}_seed{seed}",
#             config=args)
seed_everything(seed)

In [11]:
# options
num_training_steps = args.num_epochs * len(train_loader)
progress_bar = tqdm.tqdm(range(num_training_steps))
best_val_loss = float('inf')
save_path = '../checkpoints/test_run/'

# set optimizer
params = [p for _, p in model.named_parameters() if p.requires_grad] # only update non-frozen params (graph encoder)
optimizer = torch.optim.AdamW(
    [{'params': params, 'lr': args.lr, 'weight_decay': args.wd}, ],
    betas=(0.9, 0.95)
)

## TRAIN LOOP
for epoch in range(args.num_epochs):

    model.train()
    epoch_loss, accum_loss = 0., 0.

    for step, batch in enumerate(train_loader):

        optimizer.zero_grad()
        loss = model(batch)
        
        # clip gradients so large changes don't occur - super small clipping too
        clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
        
        # grad steps is a hyprparameter
        if (step + 1) % args.grad_steps == 0:
            adjust_learning_rate(optimizer.param_groups[0], args.lr, step / len(train_loader) + epoch, args)
        
        optimizer.step()
        epoch_loss += loss.item()
        accum_loss += loss.item()

        if  (step + 1) % args.grad_steps == 0:
            lr = optimizer.param_groups[0]['lr']
            # wandb.log({'Lr': lr})
            # wandb.log({'Train Loss': accum_loss / args.grad_steps})
            accum_loss = 0.
        
        progress_bar.update(1)
    # wandb.log({'Train Loss (Epoch Mean)': epoch_loss / len(train_loader)})

    # validation
    val_loss = 0.
    eval_output = []
    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(val_loader):
            loss = model(batch)
            val_loss += loss.item()
        val_loss /= len(val_loader)
        # wandb.log({'Validation Loss': val_loss})
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        _save_checkpoint(model, optimizer, epoch, args, save_path, is_best=True)
        best_epoch = epoch
    
    print(f"Epoch {epoch}/{args.num_epochs} | Train Loss (Epoch Mean): {epoch_loss / len(train_loader)} | Validation Loss: {val_loss} | Best Validation Loss: {best_val_loss} at epoch {best_epoch}", end="\r")

    if epoch - best_epoch >= args.patience:
        print(f"Early stopping at epoch {epoch}")
        break

torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

  0%|          | 0/67750 [00:00<?, ?it/s]

  graph = torch.load(text['graph'])


KeyboardInterrupt: 

# Evaluate After Training

In [None]:
batch = next(iter(test_loader))
model.inference(batch)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


{'id': [12],
 'pred': ['Yes</p> | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |'],
 'label': ['yes'],
 'question': ['Is there an edge between nodes G1 and G3?'],
 'desc': ["['G1 is associated with G2', 'G1 is associated with G3', 'G1 is associated with G4', 'G1 is associated with G5', 'G2 is associated with G6', 'G2 is associated with G8', 'G3 is associated with G4', 'G3 is associated with G7', 'G3 is associated with G10', 'G4 is associated with G7', 'G4 is associated with G10', 'G6 is associated with G8', 'G9 is associated with G10', 'node_id 0 is G1', 'node_id 1 is G10', 'node_id 2 is G2', 'node_id 3 is G3', 'node_id 4 is G4', 'node_id 5 is G5', 'node_id 6 is G6', 'node_id 7 is G7', 'node_id 8 is G8', 'node_id 9 is G9', 'layer 0 is from coexpression-heart']"]}