In [None]:
from data_helper import get_dataloaders_and_standarscaler_photons_from_numpy, get_standarized_constrains, get_photons_with_introduced_XY_symmetries
from train_helper import train_vae_mkmmd
from models_architecture_helper import VAE_Linear_2105

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch

In [None]:
# Hyperparameters
RANDOM_SEED = 123
LEARNING_RATE = 0.0005
BATCH_SIZE = 9200
NUM_EPOCHS = 40
LOGGING_INTERVAL=300
RECONSTRUCTION_TERM_WEIGHT=1

PLOT_FRACTION=0.0125
TEST_FRACTION=0.4
VALIDATION_FRACTION=0.0
SAVE_MODEL_FILE='checkpoint.pth'
NUM_WORKERS=0
path='/data1/dose-3d-generative/data/training-data/PHSPs_without_VR/Filtered_E5.6_s0.0.npy'

constrains_min=None
constrains_max=None
# constrains_min=[0, -200, -200, -1.05, -1.05, -0.02]
# constrains_max=[6, 200, 200, 1.05, 1.05, 1.02]

In [None]:
CUDA_DEVICE_NUM=0
DEVICE = torch.device(f'cuda:{CUDA_DEVICE_NUM}' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

In [None]:
print(torch.cuda.memory_allocated(device=DEVICE))
print(torch.cuda.memory_reserved(device=DEVICE))
print(torch.cuda.get_device_name(0))

In [None]:
#ODCZYTANIE DANYCH Z PLIKU 'photons.npy'
photons = np.load(path)
photons.shape

In [None]:
#USUWANIE DANYCH Z dZ MNIEJSZYM NIŻ 0
photons_nodz=np.delete(photons, np.where(photons[:,5]<0),axis=0)
print(photons_nodz.shape)

In [None]:
#WPROWADZENIE SYMETRII X,Y DO ZBIORU FOTONÓW
symmetrized_photons_nodz=get_photons_with_introduced_XY_symmetries(photons=photons_nodz, random_seed=RANDOM_SEED)
print(symmetrized_photons_nodz.shape)

In [None]:
df_data = pd.DataFrame(symmetrized_photons_nodz, columns = ['E', 'X', 'Y', 'dX', 'dY', 'dZ'])
df_data.head()#zawsze warto rzucić okiem na dane

In [None]:
fig, axs = plt.subplots(2, 3, figsize=(16,8))
for key_index, key in enumerate(df_data.columns):
    df_data.hist(column = df_data.columns[key_index], bins = 50, ax=axs.flatten()[key_index])

In [None]:
train_loader, valid_loader, test_loader, stdcs = get_dataloaders_and_standarscaler_photons_from_numpy(tmp_X=symmetrized_photons_nodz,
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS,
    test_fraction=TEST_FRACTION, 
    validation_fraction=VALIDATION_FRACTION)

In [None]:
if constrains_min is not None:
    standarized_constrains_min, standarized_constrains_max=get_standarized_constrains(constrains_list_min=constrains_min, constrains_list_max=constrains_max,stdcs=stdcs, device=DEVICE)
else:
    standarized_constrains_min, standarized_constrains_max= (None, None)

In [None]:
model = VAE_Linear_2105(constrains_std_min=standarized_constrains_min, constrains_std_max=standarized_constrains_max)
model.to(DEVICE)

#criterion = nn.MSELoss()#FUNKCJA STRATY
optimizer = torch.optim.Adam(model.parameters(),
                             lr=LEARNING_RATE, 
                             weight_decay=1e-5)

In [None]:
log_dict=train_vae_mkmmd(num_epochs=NUM_EPOCHS, device=DEVICE, model=model,optimizer=optimizer,train_loader=train_loader,loss_fn=None, test_loader=test_loader, logging_interval=LOGGING_INTERVAL, reconstruction_term_weight=RECONSTRUCTION_TERM_WEIGHT, constrains_std_min=standarized_constrains_min, constrains_std_max=standarized_constrains_max, save_model_file=SAVE_MODEL_FILE)

In [None]:
# plot_training_loss(log_dict['train_reconstruction_loss_per_batch'], NUM_EPOCHS, custom_label=" (reconstruction)")
# plot_training_loss(log_dict['train_kl_loss_per_batch'], NUM_EPOCHS, custom_label=" (KL)")
# plot_training_loss(log_dict['train_combined_loss_per_batch'], NUM_EPOCHS, custom_label=" (combined)")
# plt.show()