# Setup

In [1]:
# import packages
import sys, os, glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd

from torch.utils.data import DataLoader, random_split 

from DGXutils import GetFileNames, GetLowestGPU
from importlib import reload
from tqdm.auto import tqdm
from IPython.display import clear_output

sys.path.append('../')

from utils.dataset import *
from utils.model import *
from utils.GetLR import get_lr

device = GetLowestGPU()

In [2]:
# load datasets
gxe = GxE_Dataset(split='train')
gxe_train, gxe_val = random_split(gxe, [int(len(gxe)*0.8), len(gxe)-int(len(gxe)*0.8)])

In [3]:
model = GxE_Transformer(config=TransformerConfig).to(device)

# Train

In [4]:
# set up optimizers, loss function, and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_function = nn.MSELoss()

In [5]:
# paths
log_path = '../logs/gxe_model/log.txt'
chckpnt_path = '../checkpoints/gxe_model/checkpoint_{0}.pt'

# dataloaders
# shuffle dataloaders
train_loader = DataLoader(
    gxe_train, 
    batch_size=64,
    num_workers=16, 
    shuffle=True,
    pin_memory=True)

val_loader = DataLoader(
    gxe_val, 
    batch_size=64,
    num_workers=16, 
    shuffle=True,
    pin_memory=True)

# epochs
num_epochs = 100
batches_per_epoch = len(train_loader)
num_iters = num_epochs * batches_per_epoch

# other options
batches_per_eval = len(val_loader)
warmup_iters = batches_per_epoch
lr_decay_iters = num_iters
max_lr = 1e-3
min_lr = 1e-7
max_iters = num_iters
log_interval = 1
eval_interval = batches_per_epoch
early_stop = 10

In [6]:
# non-customizable options
iter_update = 'train loss {1:.4e}, val loss {2:.4e}\r'
best_val_loss = None # initialize best validation loss
last_improved = 0 # start early stopping counter
iter_num = 0 # initialize iteration counter
epoch_num = 0 # initialize epoch counter
t0 = time.time() # start timer

# training loop
# refresh log
with open(log_path, 'w') as f: 
    f.write(f'epoch,train_loss,val_loss\n')

# keep training until break
while True:

    # clear print output
    clear_output(wait=True)

    if best_val_loss is not None:
        print('---------------------------------------\n',
            f'Epoch: {epoch_num} | Best Loss: {best_val_loss:.4e}\n', 
            '---------------------------------------', sep = '')
    else:
        print('-------------\n',
            f'Epoch: {epoch_num}\n', 
            '-------------', sep = '')

    # ----------
    # checkpoint
    # ----------

    # estimate loss
    model.eval()
    with torch.no_grad():
        train_loss, val_loss = 0, 0
        with tqdm(total=batches_per_eval, desc=' Eval') as pbar:
            for (xbt, ybt), (xbv, ybv) in zip(train_loader, val_loader):

                # send to device
                for key, value in xbt.items():
                    xbt[key] = value.to(device)
                ybt = ybt.to(device)
                for key, value in xbv.items():
                    xbv[key] = value.to(device)
                ybv = ybv.to(device)

                train_loss += loss_function(model(xbt), ybt).item()
                val_loss += loss_function(model(xbv), ybv).item()
                pbar.update(1)
                if pbar.n == pbar.total:
                    break
        train_loss /= batches_per_eval
        val_loss /= batches_per_eval
    model.train()

    # update user
    print(iter_update.format(epoch_num, train_loss, val_loss)) 

    # update log
    with open(log_path, 'a') as f: 
        f.write(f'{epoch_num},{train_loss},{val_loss}\n')

    # checkpoint model
    if iter_num > 0:
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch_num,
            'best_val_loss': best_val_loss,
        }
        torch.save(checkpoint, chckpnt_path.format(epoch_num))

    # book keeping
    if best_val_loss is None:
        best_val_loss = val_loss

    if epoch_num > 0:
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            last_improved = 0
            print(f'*** validation loss improved: {best_val_loss:.4e} ***')
        else:
            last_improved += 1
            print(f'validation has not improved in {last_improved} epochs')
        if last_improved > early_stop:
            print()
            print(f'*** no improvement for {early_stop} epochs, stopping ***')
            break

    # --------
    # backprop
    # --------

    # iterate over batches
    with tqdm(total=eval_interval, desc='Train') as pbar:
        for xb, yb in train_loader:

            # update the model
            for key, value in xb.items():
                xb[key] = value.to(device)
            yb = yb.to(device)

            loss = loss_function(model(xb), yb)

            if torch.isnan(loss):
                print('loss is NaN, stopping')
                break
            
            # apply learning rate schedule
            lr = get_lr(it = iter_num,
                        warmup_iters = warmup_iters, 
                        lr_decay_iters = lr_decay_iters, 
                        max_lr = max_lr, 
                        min_lr = min_lr)
            
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            
            loss.backward()

            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            # update book keeping
            pbar.update(1)
            iter_num += 1
            if iter_num % batches_per_epoch == 0:
                epoch_num += 1
            if pbar.n == pbar.total:
                break

    # break once hitting max_iters
    if iter_num > max_iters:
        print(f'maximum epochs reached: {num_epochs}')
        break

-------------
Epoch: 0
-------------


 Eval:   0%|          | 0/451 [00:00<?, ?it/s]

/opt/conda/conda-bld/pytorch_1729647378361/work/aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [158,0,0], thread: [32,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1729647378361/work/aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [158,0,0], thread: [33,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1729647378361/work/aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [158,0,0], thread: [34,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1729647378361/work/aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [158,0,0], thread: [35,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1729647378361/work/aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [158,0,0], thread: [36,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
