In [None]:
import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader

from functions import EMA, iterate_ipf, load_data
from score_models import ReluScoreKANConv as ScoreNetworkConv

In [None]:
CUDA = True
device = torch.device("cuda" if CUDA else "cpu")

In [None]:
suffix = '_GFlash_Conv'

num_steps = 20
n = num_steps//2
batch_size = 1024*8
lr = 1e-5
n_iter_glob = 50

In [None]:
gamma_max = 0.001
gamma_min = 0.001
gamma_half = np.linspace(gamma_min, gamma_max, n)
gammas = np.concatenate([gamma_half, np.flip(gamma_half)])
gammas = torch.tensor(gammas).to(device)
T = torch.sum(gammas)

In [None]:
# encoder_layers=[16,16]
# temb_dim=8
# conv_dof=2

normalize_energy = False
# model_version = f"_{encoder_layers[0]}_{temb_dim}_{conv_dof}_"

In [None]:
# abs_path = '/mnt/d/UFRGS/TCC/Dados/'
abs_path = '/media/marcelomd/HDD2/UFRGS/TCC/Dados/'
data_dir_path = abs_path + 'datasets/SB_Refinement/'
models_dir_path = abs_path + 'repos/sb_ref_kan/models/'

file_path_gflash = data_dir_path + 'run_GFlash01_100k_10_100GeV_full.npy'
file_path_g4 = data_dir_path + 'run_Geant_100k_10_100GeV_full.npy'

In [None]:
data = load_data(file_path_gflash, file_path_g4, normalize_energy)

In [None]:
energy_gflash = data["energy_gflash"]
energy_particle_gflash = data["energy_particle_gflash"]
energy_voxel_gflash = data["energy_voxel_gflash"]
energy_g4 = data["energy_g4"]
energy_particle_g4 = data["energy_particle_g4"]
energy_voxel_g4 = data["energy_voxel_g4"]

npar = int(energy_voxel_g4.shape[0])
            
X_init = energy_voxel_gflash
Y_init = np.concatenate((energy_gflash, energy_g4, energy_particle_gflash), 1)
init_sample = torch.tensor(X_init).view(X_init.shape[0], 1, 10, 10)
init_lable = torch.tensor(Y_init)
scaling_factor = 7
#init_sample = (init_sample - init_sample.mean()) / init_sample.std() * scaling_factor
init_ds = TensorDataset(init_sample, init_lable)
init_dl = DataLoader(init_ds, batch_size=batch_size, shuffle=False)
#init_dl = repeater(init_dl)
# print(init_sample.shape)

X_final = energy_voxel_g4
Y_final = np.concatenate((energy_g4, energy_gflash, energy_particle_g4), 1)
scaling_factor = 7.
final_sample = torch.tensor(X_final).view(X_final.shape[0], 1, 10, 10)
final_label = torch.tensor(Y_final)
#final_sample = (final_sample - final_sample.mean()) / final_sample.std() * scaling_factor
final_ds = TensorDataset(final_sample, final_label)
final_dl = DataLoader(final_ds, batch_size=batch_size, shuffle=False)
#final_dl = repeater(final_dl)

#mean_final = torch.tensor(0.)
#var_final = torch.tensor(1.*10**3) #infty like

mean_final = torch.zeros(1, 10, 10).to(device)
var_final = 1.*torch.ones(1, 10, 10).to(device)

# print(final_sample.shape)
# print(mean_final.shape)
# print(var_final.shape)


dls = {'f': init_dl, 'b': final_dl}

In [None]:
from score_models import BottleneckScoreKAGNConv as ScoreNetworkConv

i = 16
encoder_layers=[i,i]
temb_dim=8
conv_dof=2

model_f = ScoreNetworkConv(encoder_layers=encoder_layers,
                           temb_dim=temb_dim,
                           conv_dof=conv_dof,
                           n_cond = init_lable.size(1)).to(device)

model_version = f"_{encoder_layers[0]}_{temb_dim}_{conv_dof}_"

sum(p.numel() for p in model_f.parameters())


In [None]:
model_f = ScoreNetworkConv(encoder_layers=encoder_layers,
                           temb_dim=temb_dim,
                           conv_dof=conv_dof,
                           n_cond = init_lable.size(1)).to(device)

model_b = ScoreNetworkConv(encoder_layers=encoder_layers,
                           temb_dim=temb_dim,
                           conv_dof=conv_dof,
                           n_cond = init_lable.size(1)).to(device)

model_name = str(model_f.__class__)[21:-2]

model_f = torch.nn.DataParallel(model_f)
model_b = torch.nn.DataParallel(model_b)

opt_f = torch.optim.Adam(model_f.parameters(), lr=lr)
opt_b = torch.optim.Adam(model_b.parameters(), lr=lr)

net_f = EMA(model=model_f, decay=0.95).to(device)
net_b = EMA(model=model_b, decay=0.95).to(device)

nets  = {'f': net_f, 'b': net_b, 'iter_loss': [], 'iter_et': [] }
opts  = {'f': opt_f, 'b': opt_b }

nets['f'].train()
nets['b'].train()


d = init_sample[0].shape  # shape of object to diffuse
dy = init_lable[0].shape  # shape of object to diffuse
print(d)
print(dy)

#print(net_f)

In [None]:
# torch.autograd.set_detect_anomaly(True)

f = open(models_dir_path + model_name + model_version + ".txt", 'w', encoding="utf-8")
f.write("loss;elapsed time;iteration\n")

start_iter=0

for i in range(1, 400):
    try:
        nets['f'].load_state_dict(torch.load(models_dir_path + 'Iter{:d}_net_f'.format(i) + suffix + model_name + model_version + '.pth', map_location=device))
        nets['b'].load_state_dict(torch.load(models_dir_path + 'Iter{:d}_net_b'.format(i) + suffix + model_name + model_version + '.pth', map_location=device))
        
        start_iter = i
    except:
        continue

if start_iter == 0:
    iterate_ipf(nets=nets, opts=opts, device=device, dls=dls, gammas=gammas, npar=npar, batch_size=batch_size,
                num_steps=num_steps, d=d, dy=dy, T=T, mean_final=mean_final, var_final=var_final, n_iter=100,
                forward_or_backward='f', forward_or_backward_rev='b', first=True)
    for l, t in zip(nets['iter_loss'],nets['iter_et']):
        f.write(f"{l:.6f};{t:.2f};0\n")
    print('--------------- Done iter 0 ---------------')
    
nets['f'].train()
nets['b'].train()

for i in range(start_iter+1, start_iter+20):

    iterate_ipf(nets=nets, opts=opts, device=device, dls=dls, gammas=gammas, npar=npar, batch_size=batch_size,
                num_steps=num_steps, d=d, dy=dy, T=T, mean_final=mean_final, var_final=var_final, n_iter=n_iter_glob,
                forward_or_backward='b', forward_or_backward_rev='f', first=False)
    for l, t in zip(nets['iter_loss'],nets['iter_et']):
        f.write(f"{l:.6f};{t:.2f};{i}\n")
    print('--------------- Done iter B{:d} ---------------'.format(i))

    iterate_ipf(nets=nets, opts=opts, device=device, dls=dls, gammas=gammas, npar=npar, batch_size=batch_size,
                num_steps=num_steps, d=d, dy=dy, T=T, mean_final=mean_final, var_final=var_final, n_iter=n_iter_glob,
                forward_or_backward='f', forward_or_backward_rev='b', first=False)
    for l, t in zip(nets['iter_loss'],nets['iter_et']):
        f.write(f"{l:.6f};{t:.2f};{i}\n")
    print('--------------- Done iter F{:d} ---------------'.format(i))

    torch.save(net_f.state_dict(), models_dir_path + 'Iter{:d}_net_f'.format(i) + suffix + model_name + model_version + '.pth')
    torch.save(net_b.state_dict(), models_dir_path + 'Iter{:d}_net_b'.format(i) + suffix + model_name + model_version + '.pth')