# Training Workbook for Supercomputer Environment

## Setup

In [None]:
## Setup and configs
# imports
import os
import matplotlib
import matplotlib.pylot as plt
# global variables
global notebook
global axisym,set_cart,axisym,REF_1,REF_2,REF_3,set_cart,D,print_fieldlines
global lowres1,lowres2,lowres3, RAD_M1, RESISTIVE, export_raytracing_GRTRANS, export_raytracing_RAZIEH,r1,r2,r3
global r_min, r_max, theta_min, theta_max, phi_min,phi_max, do_griddata, do_box, check_files, kerr_schild

notebook = 1

harm_directory = '/global/u1/j/jackh/bh/harm2d'
os.chdir(harm_directory)

%run -i setup.py build_ext --inplace
%run -i pp.py build_ext --inplace

# set params
lowres1 = 1 # 
lowres2 = 1 # 
lowres3 = 1 # 
r_min, r_max = 1.0, 100.0
theta_min, theta_max = 0.0, 9
phi_min, phi_max = -1, 9
do_box=0
set_cart=0
set_mpi(0)
axisym=1
print_fieldlines=0
export_raytracing_GRTRANS=0
export_raytracing_RAZIEH=0
kerr_schild=0
DISK_THICKNESS=0.03
check_files=1
notebook=1
interpolate_var=0
AMR = 0 # get all data in grid

print('Imports and setup done.')
%matplotlib inline

## Training

In [None]:
# imports
import numpy as np
import torch
from tqdm import tqdm

# training utilities
from utils.sc_utils import custom_batcher, tensorize_globals
from models.cnn.cnn import CNN_3D

# path to dumps
dumps_path = '/pscratch/sd/l/lalakos/ml_data_rc300/reduced'
os.chdir(dumps_path)
# number of data points
num_dumps = 10 - 1
# batch size
batch_size = 2
# number of epochs
num_epochs = 2
# set model
model = CNN_3D()
# set loss
optim = torch.optim.Adam(params=model.parameters())
loss_fn = torch.nn.MSELoss()

# get indexes for training data
train_indexes, validation_indexes = custom_batcher(
    batch_size=batch_size,
    num_dumps=num_dumps,
    split = 0.8,
    seed=1
)

for epoch in range(num_epochs):
    ## Training
    model.train()
    epoch_train_loss = []

    # shuffle training indexes
    np.random.shuffle(train_indexes)

    # list of average train/validation losses after each epoch
    train_losses, valid_losses = [], []

    prog_bar = tqdm(train_indexes.reshape(-1, batch_size))
    for batch_indexes in prog_bar:
        ## fetch and tensorize data
        # NOTE everything is a global variable so it has to be this way. im sorry
        batch_data, label_data = [], []
        # batch_idx is the dump number
        for batch_idx in batch_indexes:
            ## get data frame
            # get data into global context
            rblock_new(batch_idx)
            rpar_new(batch_idx)
            rgdump_griddata(dumps_path)
            rdump_griddata(dumps_path, batch_idx)
            # format data as tensor
            data_tensor = tensorize_globals(rho=rho, ug=ug, uu=uu, B=B)
            # add to batch
            batch_data.append(data_tensor)

            ## get label frame
            # get data into global context
            rblock_new(batch_idx+1)
            rpar_new(batch_idx+1)
            rgdump_griddata(dumps_path)
            rdump_griddata(dumps_path, batch_idx+1)
            # format data as tensor
            data_tensor = tensorize_globals(rho=rho, ug=ug, uu=uu, B=B)
            # add to batch
            label_data.append(data_tensor)

        # final tensorize
        batch_data = torch.cat(batch_data, dim=0)
        label_data = torch.cat(label_data, dim=0)

        ## train model
        # make prediction
        pred = model.forward(batch_data)
        # compute loss
        loss_value = loss_fn(pred, label_data)
        epoch_train_loss.append(loss_value)
        # backprop
        loss_value.backward()
        # update paramts
        optim.step()

    # training loss tracking
    avg_loss_after_epoch = sum(epoch_train_loss)/len(epoch_train_loss)
    train_losses.append(avg_loss_after_epoch)
    print(f"Train loss value: {avg_loss_after_epoch}")


    ## Validation
    model.eval()
    epoch_valid_loss = []

    prog_bar = tqdm(validation_indexes.reshape(-1, batch_size))
    for batch_indexes in prog_bar:
        ## fetch and tensorize data
        # NOTE everything is a global variable so it has to be this way. im sorry
        batch_data, label_data = [], []
        # batch_idx is the dump number
        for batch_idx in batch_indexes:
            ## get data frame
            # get data into global context
            rblock_new(batch_idx)
            rpar_new(batch_idx)
            rgdump_griddata(dumps_path)
            rdump_griddata(dumps_path, batch_idx)
            # format data as tensor
            data_tensor = tensorize_globals(rho=rho, ug=ug, uu=uu, B=B)
            # add to batch
            batch_data.append(data_tensor)

            ## get label frame
            # get data into global context
            rblock_new(batch_idx+1)
            rpar_new(batch_idx+1)
            rgdump_griddata(dumps_path)
            rdump_griddata(dumps_path, batch_idx+1)
            # format data as tensor
            data_tensor = tensorize_globals(rho=rho, ug=ug, uu=uu, B=B)
            # add to batch
            label_data.append(data_tensor)

        # final tensorize
        batch_data = torch.cat(batch_data, dim=0)
        label_data = torch.cat(label_data, dim=0)

        # make prediction
        pred = model.forward(batch_data)

        # compute loss
        loss_value = loss_fn(pred, label_data)
        epoch_valid_loss.append(loss_value)
        
    avg_vloss_after_epoch = sum(epoch_train_loss)/len(epoch_train_loss)
    valid_losses.append(avg_vloss_after_epoch)
    print(f"Valid loss value: {avg_loss_after_epoch}")

    # plot learning
    plt.plot([i for i in range(len(train_losses))], [loss.item() for loss in train_losses], label='Train Loss')
    # plt.plot([i for i in range(len(train_losses))], [avg_baseline_loss for _ in range(len(train_losses))], label='Predicting Avg Loss', linestyle='dashed')
    plt.plot([i for i in range(len(valid_losses))], [loss.item() for loss in valid_losses], label='Validation Loss')
    plt.title(f'Training and Validation Curve')
    plt.xlabel(f'Number of Batches')
    plt.ylabel(f'Loss (MSE)')
    plt.legend()
    plt.show()


## Visualization

In [None]:
# holy fuck TODO