# Setup

In [1]:
# import packages
import sys, os, glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd

from torch.utils.data import DataLoader, Subset

from DGXutils import GetFileNames, GetLowestGPU
from importlib import reload
from tqdm.auto import tqdm
from IPython.display import clear_output

sys.path.append('../')

from utils.dataset import *
from utils.model import *
from utils.GetLR import get_lr

device = GetLowestGPU()

In [2]:
# load datasets
e = E_Dataset(split="train")

# train-test split
train_prop = 0.8
p = np.random.permutation(len(e))
train_idx = p[:int(train_prop*len(e))]
val_idx = p[int(train_prop*len(e)):]

e_train = Subset(e, train_idx)
e_val = Subset(e, val_idx)

In [3]:
# dataloaders
train_loader = DataLoader(e_train, batch_size=32, shuffle=True)
val_loader = DataLoader(e_val, batch_size=32, shuffle=True)

In [4]:
xb, yb = next(iter(train_loader))

ValueError: could not convert string to float: 'ARH2'

In [8]:
model = GxE_Transformer(config=TransformerConfig, g_enc=False).to(device)

# Training

In [9]:
# set up optimizers, loss function, and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
def rmse_loss(y_pred, y_true):
    return torch.sqrt(F.mse_loss(y_pred, y_true))
loss_function = rmse_loss

In [10]:
# paths
log_path = '../logs/e_model/log.txt'
chckpnt_path = '../checkpoints/e_model/checkpoint_{0}.pt'

# dataloaders
# shuffle dataloaders
train_loader = DataLoader(
    e_train, 
    batch_size=64, 
    shuffle=True,
    num_workers=64,
    pin_memory=True)

val_loader = DataLoader(
    e_val, 
    batch_size=64, 
    shuffle=True,
    num_workers=64,
    pin_memory=True)

# epochs
num_epochs = 100
batches_per_epoch = len(train_loader)
num_iters = num_epochs * batches_per_epoch

# other options
batches_per_eval = len(val_loader)
warmup_iters = batches_per_epoch
lr_decay_iters = num_iters
max_lr = 1e-3
min_lr = 1e-7
max_iters = num_iters
log_interval = 1
eval_interval = batches_per_epoch
early_stop = 10

In [11]:
# non-customizable options
iter_update = 'train loss {1:.4e}, val loss {2:.4e}\r'
best_val_loss = None # initialize best validation loss
last_improved = 0 # start early stopping counter
iter_num = 0 # initialize iteration counter
epoch_num = 0 # initialize epoch counter
t0 = time.time() # start timer

# training loop
# refresh log
with open(log_path, 'w') as f: 
    f.write(f'epoch,train_loss,val_loss\n')

# keep training until break
while True:

    # clear print output
    clear_output(wait=True)

    if best_val_loss is not None:
        print('---------------------------------------\n',
            f'Epoch: {epoch_num} | Best Loss: {best_val_loss:.4e}\n', 
            '---------------------------------------', sep = '')
    else:
        print('-------------\n',
            f'Epoch: {epoch_num}\n', 
            '-------------', sep = '')

    # ----------
    # checkpoint
    # ----------

    # estimate loss
    model.eval()
    with torch.no_grad():
        train_loss, val_loss = 0, 0
        with tqdm(total=batches_per_eval, desc=' Eval') as pbar:
            for (xbt, ybt), (xbv, ybv) in zip(train_loader, val_loader):

                # send to device
                for key, value in xbt.items():
                    xbt[key] = value.to(device)
                ybt = ybt.to(device)
                for key, value in xbv.items():
                    xbv[key] = value.to(device)
                ybv = ybv.to(device)

                train_loss += loss_function(model(xbt), ybt).item()
                val_loss += loss_function(model(xbv), ybv).item()
                pbar.update(1)
                if pbar.n == pbar.total:
                    break
        train_loss /= batches_per_eval
        val_loss /= batches_per_eval
    model.train()

    # update user
    print(iter_update.format(epoch_num, train_loss, val_loss)) 

    # update log
    with open(log_path, 'a') as f: 
        f.write(f'{epoch_num},{train_loss},{val_loss}\n')

    # checkpoint model
    if iter_num > 0:
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch_num,
            'best_val_loss': best_val_loss,
        }
        torch.save(checkpoint, chckpnt_path.format(epoch_num))

    # book keeping
    if best_val_loss is None:
        best_val_loss = val_loss

    if epoch_num > 0:
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            last_improved = 0
            print(f'*** validation loss improved: {best_val_loss:.4e} ***')
        else:
            last_improved += 1
            print(f'validation has not improved in {last_improved} epochs')
        if last_improved > (early_stop - 1):
            print()
            print(f'*** no improvement for {early_stop} epochs, stopping ***')
            break

    # --------
    # backprop
    # --------

    # iterate over batches
    with tqdm(total=eval_interval, desc='Train') as pbar:
        for xb, yb in train_loader:

            # update the model
            for key, value in xb.items():
                xb[key] = value.to(device)
            yb = yb.to(device)

            loss = loss_function(model(xb), yb)

            if torch.isnan(loss):
                print('loss is NaN, stopping')
                break
            
            # apply learning rate schedule
            lr = get_lr(it = iter_num,
                        warmup_iters = warmup_iters, 
                        lr_decay_iters = lr_decay_iters, 
                        max_lr = max_lr, 
                        min_lr = min_lr)
            
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            
            loss.backward()

            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            # update book keeping
            pbar.update(1)
            iter_num += 1
            if iter_num % batches_per_epoch == 0:
                epoch_num += 1
            if pbar.n == pbar.total:
                break

    # break once hitting max_iters
    if iter_num > max_iters:
        print(f'maximum epochs reached: {num_epochs}')
        break


-------------
Epoch: 0
-------------


 Eval:   0%|          | 0/373 [00:00<?, ?it/s]

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/pkr/miniconda3/envs/gp-transformer/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/pkr/miniconda3/envs/gp-transformer/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 50, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pkr/miniconda3/envs/gp-transformer/lib/python3.12/site-packages/torch/utils/data/dataset.py", line 420, in __getitems__
    return [self.dataset[self.indices[idx]] for idx in indices]
            ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "/mnt/DGX01/Personal/krusepi/codebase/projects/phenotyping/GP-Transformer/notebooks/../utils/dataset.py", line 180, in __getitem__
    y = torch.tensor(self.y_data.iloc[index].values, dtype=torch.float32)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint64, uint32, uint16, uint8, and bool.
