In [1]:
import os, sys
import transformers
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from transformers import pipeline

sys.path.append('../')

from utils import preprocess as pp
# from utils.llm import llm
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.GetFileNames import GetFileNames
from utils.GetLowestGPU import GetLowestGPU
from utils.bio_graphs import BiologicalDataset

# Step 1: Load Networks as `Multiplex` object

In [2]:
flist_name = '../data/DREAM4_gold_standards/flist.tsv'
mp = Multiplex(flist_name)

# Step 2: Textualize Graphs (Ken's Code)

In [3]:
textualize = load_textualizer['edges']
graph_text = textualize(mp)

# view first 10 items
for i in range(10):
    print(graph_text[i])

G1 is associated with G2
G1 is associated with G3
G1 is associated with G4
G1 is associated with G5
G1 is associated with G6
G1 is associated with G7
G1 is associated with G8
G1 is associated with G9
G1 is associated with G10
G2 is associated with G6


# Step 3: Make Dataloader
* dataloader returns dict with keys `["ids"]`, `["desc"]`,`["question"]`,`["label"]`

In [4]:
bio_data = BiologicalDataset()

In [5]:
bio_dataloader = DataLoader(bio_data, batch_size=1, shuffle=True)

# Step 4: Load In Encoder + LLM

In [7]:
vanilla_llm = LLM(max_text_len=512,
                  max_max_new_tokens=32,
                  max_memory=[80, 80],
                  llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                  llm_frozen='True',
                  revision="main") # need to add args

graph_llm = GraphLLM(max_text_len=512,
                     max_max_new_tokens=32,
                     max_memory=[80, 80],
                     llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                     llm_frozen='True',
                     revision="main") # args are defaulted in the class

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA...
Finished loading LLaMA...
Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA!
Finished loading LLaMA!


# Step 5: Perform Initial Untrained Inference

In [9]:
batch = next(iter(bio_dataloader))
out = vanilla_llm.inference(batch)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


In [10]:
print(out['pred'], out['label'])

['s>\nYes, there is a path between nodes G9 and G5. The path is: G1 -> G2 -> G5 or G1 ->'] ['yes']


# Step 6: Train Model

In [None]:
# options
num_training_steps = num_epochs * len(train_loader)
progress_bar = tqdm(range(num_training_steps))
best_val_loss = float('inf')

## TRAIN LOOP
for epoch in range(num_epochs):

    model.train()
    epoch_loss, accum_loss = 0., 0.

    for step, batch in enumerate(train_loader):

        optimizer.zero_grad()
        loss = model(batch)
        
        # clip gradients so large changes don't occur
        clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
        
        # grad steps is a hyprparameter
        if (step + 1) % grad_steps == 0:
            adjust_learning_rate(optimizer.param_groups[0], lr, step / len(train_loader) + epoch)
        
        optimizer.step()
        epoch_loss, accum_loss += loss.item(), loss.item()

        if  (step + 1) % grad_steps == 0:
            lr = optimizer.param_groups[0]['lr']
            wandb.log({'Lr': lr})
            wandb.log({'Train Loss': accum_loss / grad_steps})
            accum_loss = 0.
        
        progress_bar.update(1)
    
    print(f"Epoch {epoch}/{num_epochs} | Train Loss (Epoch Mean): {epoch_loss / len(train_loader)}")
    wandb.log({'Train Loss (Epoch Mean)': epoch_loss / len(train_loader)})

    # validation
    val_loss = 0.
    eval_output = []
    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(val_loader):
            loss = model(batch)
            val_loss += loss.item()
        val_loss /= len(val_loader)
        print(f"Epoch {epoch}/{num_epochs} | Validation Loss: {val_loss}")
        wandb.log({'Validation Loss': val_loss})
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        _save_checkpoint(model, optimizer, epoch, args, is_best=True
        best_epoch = epoch
    
    print(f"Epoch {epoch}/{num_epochs} | Best Validation Loss: {best_val_loss} at epoch {best_epoch}")

    if epoch - best_epoch >= patience:
        print(f"Early stopping at epoch {epoch}")
        break

torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

# Step 7: Evaluate After Training

In [None]:
# eval
model = _reload_best_model(model, args)
model.eval()

progress_bar_test = tqdm(range(len(test_loader)))
with open(path, "w") as f:
    for step, batch in enumerate(test_loader):
        with torch.no_grad():
            output = model.inference(batch)
            df = pd.DataFrame(output)
            for _, row in df.iterrows():
                f.write(json.dumps(dict(row)) + "\n")
        progress_bar_test.update(1)

# post process + compute metrics
acc = eval_funcs[dataset](path)
print(f'Test Acc: {acc}')
wandb.log({'Test Acc': acc})