# Training Workbook for Supercomputer Environment

## Setup Global Variables and Harm2d Functions

In [2]:
## Setup and configs
# imports
import os
import time
import matplotlib
import matplotlib.pyplot as plt
# global variables
global notebook
global axisym,set_cart,axisym,REF_1,REF_2,REF_3,set_cart,D,print_fieldlines
global lowres1,lowres2,lowres3, RAD_M1, RESISTIVE, export_raytracing_GRTRANS, export_raytracing_RAZIEH,r1,r2,r3
global r_min, r_max, theta_min, theta_max, phi_min,phi_max, do_griddata, do_box, check_files, kerr_schild

global do_train
## NOTE toggle do_train to run training on setup
do_train = False

notebook = 1

# total data is shape (10000, 224, 48, 96)
harm_directory = os.environ['HOME']+f'/bh/harm2d'
os.chdir(harm_directory)

print(f'Running setup scripts...')
start_time = time.time()
%run -i setup.py build_ext --inplace
%run -i pp.py build_ext --inplace
print(f"Execution time: {time.time() - start_time}")

# set params
lowres1 = 1 # 
lowres2 = 1 # 
lowres3 = 1 # 
r_min, r_max = 1.0, 100.0
theta_min, theta_max = 0.0, 9
phi_min, phi_max = -1, 9
do_box=0
set_cart=0
set_mpi(0)
axisym=1
print_fieldlines=0
export_raytracing_GRTRANS=0
export_raytracing_RAZIEH=0
kerr_schild=0
DISK_THICKNESS=0.03
check_files=1
notebook=1
interpolate_var=0
AMR = 0 # get all data in grid

print('Imports and setup done.')
%matplotlib inline

Running setup scripts...
Execution time: 6.298144578933716
Imports and setup done.


## Visualization

In [None]:
import os
import yaml
os.chdir(os.environ['HOME'] + '/bh/harm2d')
from models.cnn.cnn import *
# from utils.anim import make_prediciton_frames

# access device, cuda device if accessible
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load in model
# model = JACK_CNN_3D().to(device)
model = CNN_DEPTH().to(device)
model_path = os.environ['HOME'] + "/bh/harm2d/models/cnn/saves/b3_v0.0.0.pth"
print(model_path)
# global/homes/a/arjuna/bh/harm2d/models/cnn/saves/B3CNN.pth
loaded_temp = torch.load(f=model_path)
model.load_state_dict(loaded_temp)
model.eval()

# load in configs
with open('train_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

first_frame_index = config['start_dump']

dumps_path = '/pscratch/sd/l/lalakos/ml_data_rc300/reduced'
os.chdir(dumps_path)

rblock_new_ml()
rpar_new(first_frame_index)
# get grid data
rgdump_griddata(dumps_path)
rdump_griddata(dumps_path, first_frame_index)

data_tensor = tensorize_globals(rho=np.log10(rho), ug=np.log10(ug), uu=uu, B=B)

save_path = os.environ['HOME']+f'/bh/movies/sc_pred_frames/'
for index in range(first_frame_index, first_frame_index+60):
    data_tensor = data_tensor.to(device)
    pred_time_start = time.time()
    data_tensor = model(data_tensor)
    print(f'Prediction {index} in {time.time()-pred_time_start:.4f} s')

    # print(data_tensor[0][0].unsqueeze(0).detach().numpy().shape)
    # print(rho.shape)
    
    # plot and save
    plot_time_start = time.time()
    plc_cart_ml(
        var=(data_tensor[0][0].cpu().unsqueeze(0).detach().numpy()), 
        min=-2,
        max=2, 
        rmax=100, 
        offset=0, 
        name=save_path+f'pred_rho_{index}', 
        label=r"$\sigma r {\rm sin}\theta$ at %d $r_g/c$" % t
    )
    print(f'Plotted and saved in {time.time()-plot_time_start:.4f} s')

print('Frames saved')