In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt


from models.GAN import Lit_GAN
from models.WGAN import Lit_WGAN
from models.WGAN_gp import Lit_WGAN_gp
from helpers.PhotonsDataModule import PhotonsDataModule
from helpers.plot_helper import get_subplot_adjustment

In [None]:
LOAD_CHECKPOINT_PATH='/home/jakmic/Projekty/dose3d-phsp/GAN/Lightning_GANs/results/WGAN_gp/version_1/checkpoints/last.ckpt'
DATA_PATH='/data1/dose-3d-generative/data/training-data/PHSPs_without_VR/Filtered_E5.6_s0.0.npy'
BATCH_SIZE=400000
NUM_WORKERS=0
TEST_FRACTION=0.0
VALIDATION_FRACTION = 0.4
SHUFFLE_TRAIN=False
RANDOM_SEED=123

LANTENT_SPACE_DIM=8
NUM_SUBPLOT_ROWS, NUM_SUBPLOT_COLUMNS = get_subplot_adjustment(LANTENT_SPACE_DIM)
print(NUM_SUBPLOT_ROWS, NUM_SUBPLOT_COLUMNS)

KEYS = ['E','X', 'Y', 'dX', 'dY', 'dZ']

In [None]:
# Remember to change model !!!
dm=PhotonsDataModule(data_path=DATA_PATH,batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,test_fraction=TEST_FRACTION,validation_fraction=VALIDATION_FRACTION,shuffle_train=SHUFFLE_TRAIN,random_seed=RANDOM_SEED)
model=Lit_WGAN_gp.load_from_checkpoint(LOAD_CHECKPOINT_PATH)
dm.setup()

In [None]:
orginal_photons=np.empty((0,6))
generated=np.empty((0,6))

model.eval()
with torch.no_grad():
    for photon_batch in dm.val_dataloader():
        # Generate noise batch from normal distribution
        noise_batch=torch.randn(size=(photon_batch.size(0),LANTENT_SPACE_DIM))

        # Generate photon batch
        generated_batch = model(noise_batch)

        photon_batch=photon_batch.cpu().detach().numpy()
        generated_batch=generated_batch.cpu().detach().numpy()
        # noise_batch=noise_batch.cpu().detach().numpy()
        
        # Standarizer inverse transform
        photon_batch=dm.stdcs.inverse_transform(photon_batch)
        generated_batch=dm.stdcs.inverse_transform(generated_batch)

        # Prepare numpy arrays
        # noise = np.append(noise, noise_batch, axis=0)
        generated = np.append(generated, generated_batch,axis=0)
        orginal_photons = np.append(orginal_photons,photon_batch, axis=0)
    
        print(generated.shape)

In [None]:
fig, axs = plt.subplots(2, 3)
fig.set_size_inches(20, 14)
for i, j in enumerate(KEYS):
    mi = np.minimum(orginal_photons[:, i].min(), generated[:, i].min())
    ma = np.maximum(orginal_photons[:, i].max(), generated[:, i].max())
    if j=='dZ':
        mi=0.8
    bins = np.linspace(mi, ma, 300)
    axs.flatten()[i].hist(orginal_photons[:, i], bins, alpha=.5, label='orginal')
    axs.flatten()[i].hist(generated[:, i], bins, alpha=.5, label='generated')
    axs.flatten()[i].set_title(j)
    axs.flatten()[i].legend()

In [None]:
fig, axs = plt.subplots(2, 3)
fig.set_size_inches(20, 14)
for i, j in enumerate(KEYS):
    mi = np.minimum(orginal_photons[:, i].min(), generated[:, i].min())
    ma = np.maximum(orginal_photons[:, i].max(), generated[:, i].max())
    if j=='E':
        mi=-0.1
    if j=='dZ':
        mi=0.8
    bins = np.linspace(mi, ma, 300)
    axs.flatten()[i].hist(orginal_photons[:, i], bins, alpha=.5, label='orginal', stacked = True, density = True, log = True)
    axs.flatten()[i].hist(generated[:, i], bins, alpha=.5, label='generated', stacked = True, density = True, log = True)
    axs.flatten()[i].set_title(j)
    axs.flatten()[i].legend()