In [8]:
import os, sys
import wandb
import tqdm.notebook as tqdm
import transformers
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, Dataset
from importlib import reload
from torch_scatter import scatter
from transformers import pipeline

sys.path.append('../')

from utils import preprocess as pp
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.bio_graphs import BiologicalDataset

# training imports
# from utils.model import load_model, llama_model_path
from utils.evaluate import eval_funcs
from utils.config import parse_args_llama
from utils.ckpt import _save_checkpoint, _reload_best_model
from utils.collate import collate_fn
from utils.seed import seed_everything
from utils.lr_schedule import adjust_learning_rate

# Dataset & Dataloaders

In [9]:
# get dataset
data_path = '../data/DREAM4_gold_standards/connections_node_label'
dataset = BiologicalDataset(data_path)
idx_split = dataset.get_idx_split()

In [10]:
# split datasets on idx
train_dataset = [dataset[i] for i in idx_split["train"]]
val_dataset = [dataset[i] for i in idx_split["val"]]
test_dataset = [dataset[i] for i in idx_split["test"]]

# options
batch_size = 1

# make dataloaders
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

val_loader = DataLoader(val_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

test_loader = DataLoader(test_dataset, 
                          batch_size=batch_size,
                          drop_last=True,
                          pin_memory=True,
                          shuffle=True,
                          collate_fn=collate_fn)

  graph = torch.load(f'{self.path}/graphs/{index}.pt')


# Load LLMs

In [4]:
# vanilla_llm = LLM(max_text_len=512,
#                   max_max_new_tokens=32,
#                   max_memory=[80, 80],
#                   llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
#                 #   llm_model_path='meta-llama/Llama-2-7B-chat-hf',
#                   llm_frozen='True',
#                   revision="main") # need to add args

In [11]:
model = GraphLLM(max_text_len=512,
                     max_max_new_tokens=32,
                     max_memory=[80, 80],
                     llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                     llm_frozen='True',
                     revision="main") # args are defaulted in the class

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 1002.00 MiB. GPU 0 has a total capacity of 39.50 GiB of which 31.38 MiB is free. Including non-PyTorch memory, this process has 39.46 GiB memory in use. Of the allocated memory 39.05 GiB is allocated by PyTorch, and 1.27 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [5]:
batch = next(iter(train_loader))

# Train Graph LLM

In [9]:
# specify needed args
sys.argv = [''] # needed for argparse in notebooks
args = parse_args_llama()

In [10]:
# set up wandb, seed for tracking
seed = 42
# wandb.init(project=f"{args.project}",
#             name=f"{dataset}_{args.model_name}_seed{seed}",
#             config=args)
seed_everything(seed)

In [11]:
# options
num_training_steps = args.num_epochs * len(train_loader)
progress_bar = tqdm.tqdm(range(num_training_steps))
best_val_loss = float('inf')
save_path = '../checkpoints/test_run/'

# set optimizer
params = [p for _, p in model.named_parameters() if p.requires_grad] # only update non-frozen params (graph encoder)
optimizer = torch.optim.AdamW(
    [{'params': params, 'lr': args.lr, 'weight_decay': args.wd}, ],
    betas=(0.9, 0.95)
)

## TRAIN LOOP
for epoch in range(args.num_epochs):

    model.train()
    epoch_loss, accum_loss = 0., 0.

    for step, batch in enumerate(train_loader):

        optimizer.zero_grad()
        loss = model(batch)
        
        # clip gradients so large changes don't occur - super small clipping too
        clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
        
        # grad steps is a hyprparameter
        if (step + 1) % args.grad_steps == 0:
            adjust_learning_rate(optimizer.param_groups[0], args.lr, step / len(train_loader) + epoch, args)
        
        optimizer.step()
        epoch_loss += loss.item()
        accum_loss += loss.item()

        if  (step + 1) % args.grad_steps == 0:
            lr = optimizer.param_groups[0]['lr']
            # wandb.log({'Lr': lr})
            # wandb.log({'Train Loss': accum_loss / args.grad_steps})
            accum_loss = 0.
        
        progress_bar.update(1)
    
    print(f"Epoch {epoch}/{args.num_epochs} | Train Loss (Epoch Mean): {epoch_loss / len(train_loader)}")
    # wandb.log({'Train Loss (Epoch Mean)': epoch_loss / len(train_loader)})

    # validation
    val_loss = 0.
    eval_output = []
    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(val_loader):
            loss = model(batch)
            val_loss += loss.item()
        val_loss /= len(val_loader)
        print(f"Epoch {epoch}/{args.num_epochs} | Validation Loss: {val_loss}")
        # wandb.log({'Validation Loss': val_loss})
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        _save_checkpoint(model, optimizer, epoch, args, save_path, is_best=True)
        best_epoch = epoch
    
    print(f"Epoch {epoch}/{args.num_epochs} | Best Validation Loss: {best_val_loss} at epoch {best_epoch}")

    if epoch - best_epoch >= args.patience:
        print(f"Early stopping at epoch {epoch}")
        break

torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

  0%|          | 0/310 [00:00<?, ?it/s]

Epoch 0/10 | Train Loss (Epoch Mean): 1.380584547596593
Epoch 0/10 | Validation Loss: 1.4246362447738647
Saving checkpoint at epoch 0 to ../checkpoints/test_run/.
Epoch 0/10 | Best Validation Loss: 1.4246362447738647 at epoch 0
Epoch 1/10 | Train Loss (Epoch Mean): 1.3799904700248473
Epoch 1/10 | Validation Loss: 1.4084311723709106
Saving checkpoint at epoch 1 to ../checkpoints/test_run/.
Epoch 1/10 | Best Validation Loss: 1.4084311723709106 at epoch 1
Epoch 2/10 | Train Loss (Epoch Mean): 1.382697170780551
Epoch 2/10 | Validation Loss: 1.3987900376319886
Saving checkpoint at epoch 2 to ../checkpoints/test_run/.
Epoch 2/10 | Best Validation Loss: 1.3987900376319886 at epoch 2
Epoch 3/10 | Train Loss (Epoch Mean): 1.382141690100393
Epoch 3/10 | Validation Loss: 1.3936089634895326
Saving checkpoint at epoch 3 to ../checkpoints/test_run/.
Epoch 3/10 | Best Validation Loss: 1.3936089634895326 at epoch 3
Epoch 4/10 | Train Loss (Epoch Mean): 1.3817901803601174
Epoch 4/10 | Validation Loss: 



# Evaluate After Training

In [18]:
batch = next(iter(test_loader))
model.inference(batch)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


{'id': [37],
 'pred': ['Yes, there is an edge between nodes G1 and G2. The edge is directed from G1 to G2. The edge is part of the graph'],
 'label': ['yes'],
 'question': ['Is there an edge between nodes G1 and G2?'],
 'desc': ["['G1 is associated with G2', 'G1 is associated with G3', 'G1 is associated with G4', 'G1 is associated with G5', 'G2 is associated with G6', 'G2 is associated with G8', 'G3 is associated with G4', 'G3 is associated with G7', 'G3 is associated with G10', 'G4 is associated with G7', 'G4 is associated with G10', 'G6 is associated with G8', 'G9 is associated with G10', 'node_id 0 is G1', 'node_id 1 is G10', 'node_id 2 is G2', 'node_id 3 is G3', 'node_id 4 is G4', 'node_id 5 is G5', 'node_id 6 is G6', 'node_id 7 is G7', 'node_id 8 is G8', 'node_id 9 is G9', 'layer 0 is from coexpression-heart']"]}