In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from torch import pca_lowrank
from sklearn.decomposition import PCA

In [None]:
import sys
sys.path.append("../")

In [None]:
from queue import Queue
openPMDBuffer = Queue(8)

In [None]:
from inSituML.ks_producer_openPMD_streaming import StreamLoader
from inSituML.ks_transform_policies import AbsoluteSquare, BoxesAttributesParticles

normalization_values = dict(
    momentum_mean = 1.2091940752668797e-08,
    momentum_std = 0.11923234769525472,
    force_mean = -2.7682006649827533e-09,
    force_std = 7.705477610810592e-05
)


streamLoader_config = dict(
    t0 = 900,
    t1 = 998,
    # t0 =  1800,
    # t1 = 1810,
    streaming_config = None,
    #pathpattern1 = "/bigdata/hplsim/aipp/SC24_PIConGPU-Continual-Learning/03-30_learning-rate-scaling-with-ranks_chamfersdistance_fix-gpu-volume/24-nodes_full-picongpu-data/simOutput/openPMD/simData_%T.bp", # files on hemera
    #pathpattern2 = "/bigdata/hplsim/aipp/SC24_PIConGPU-Continual-Learning/03-30_learning-rate-scaling-with-ranks_chamfersdistance_fix-gpu-volume/24-nodes_full-picongpu-data/simOutput/radiationOpenPMD/e_radAmplitudes%T.bp", # files on hemera
    #pathpattern1 = "/bigdata/hplsim/production/KHI_for_GB_MR/runs/014_KHI_007_noWindowFunction/simOutput/openPMD/simData_%T.bp", # files on hemera
    #pathpattern2 = "/bigdata/hplsim/production/KHI_for_GB_MR/runs/014_KHI_007_noWindowFunction/simOutput/radiationOpenPMD/e_radAmplitudes%T.bp", # files on hemera
    #pathpattern1 = "/bigdata/hplsim/production/KHI_for_GB_MR/runs/015_KHI_009_noWindowFunction/simOutput/openPMD/simData_%T.bp", # files on hemera
    #pathpattern2 = "/bigdata/hplsim/production/KHI_for_GB_MR/runs/015_KHI_009_noWindowFunction/simOutput/radiationOpenPMD/e_radAmplitudes%T.bp", # files on hemera
    particle_pathpattern = "/bigdata/hplsim/aipp/SC24_PIConGPU-Continual-Learning/24-nodes_full-picongpu-data/04-01_1013/simOutput/openPMD/simData_%T.bp5",
    radiation_pathpattern = "/bigdata/hplsim/aipp/SC24_PIConGPU-Continual-Learning/24-nodes_full-picongpu-data/04-01_1013/simOutput/radiationOpenPMD/e_radAmplitudes_%T.bp5",
    
    amplitude_direction=0, # choose single direction along which the radiation signal is observed, max: N_observer-1, where N_observer is defined in PIConGPU's radiation plugin
    phase_space_variables = ["momentum", "force"], # allowed are "position", "momentum", and "force". If "force" is set, "momentum" needs to be set too.
    number_particles_per_gpu = 30000,
    verbose=False,
    ## offline training params
    num_epochs = .01, #.0625
    normalization = normalization_values
)

timeBatchLoader = StreamLoader(openPMDBuffer, 
                                        streamLoader_config,
                                        BoxesAttributesParticles(), AbsoluteSquare())

timeBatchLoader.start()

In [None]:
openPMDBuffer.qsize()

In [None]:
data = []
for i in range(0, 100, 10):
    data.append(openPMDBuffer.get())

In [None]:
import share.configs.model_config as model_config
import share.configs.io_config_hemera as io_config

config = model_config.config

import torch.nn as nn
from torch import optim

from inSituML.ks_models import INNModel

from inSituML.utilities import MMD_multiscale, fit, load_checkpoint
from inSituML.args_transform import MAPPING_TO_LOSS
from inSituML.encoder_decoder import Encoder
from inSituML.encoder_decoder import Encoder
from inSituML.encoder_decoder import Conv3DDecoder, MLPDecoder
from inSituML.loss_functions import EarthMoversLoss
from inSituML.networks import VAE, ConvAutoencoder

world_size = 1

class ModelFinal(nn.Module):
    def __init__(self,
                base_network,
                inner_model,
                loss_function_IM = None,
                weight_AE=1.0,
                weight_IM=1.0):
        super().__init__()

        self.base_network = base_network
        self.inner_model = inner_model
        self.loss_function_IM = loss_function_IM
        self.weight_AE = weight_AE
        self.weight_IM = weight_IM

    def forward(self, x, y):

        loss_AE,loss_ae_reconst,kl_loss, _, encoded = self.base_network(x)

        # Check if the inner model is an instance of INNModel
        if isinstance(self.inner_model, INNModel):
            # Use the compute_losses function of INNModel
            loss_IM, l_fit,l_latent,l_rev = self.inner_model.compute_losses(encoded, y)
            total_loss = loss_AE*self.weight_AE + loss_IM*self.weight_IM

            losses = {
                'total_loss': total_loss,
                'loss_AE': loss_AE*self.weight_AE,
                'loss_IM': loss_IM*self.weight_IM,
                'loss_ae_reconst': loss_ae_reconst,
                'kl_loss': kl_loss,
                'l_fit': l_fit,
                'l_latent': l_latent,
                'l_rev': l_rev,
                    }

            return losses
        else:
            # For other types of models, such as MAF
            loss_IM = self.inner_model(inputs=encoded, context=y)
            total_loss = loss_AE*self.weight_AE + loss_IM * self.weight_IM

            losses = {
                'total_loss': total_loss,
                'loss_AE': loss_AE*self.weight_AE,
                'loss_IM': loss_IM*self.weight_IM,
                'loss_ae_reconst': loss_ae_reconst,
                'kl_loss': kl_loss
                    }

            return losses

    def reconstruct(self,x, y, num_samples = 1):

        if isinstance(self.inner_model, INNModel):
            lat_z_pred = self.inner_model(x, y, rev = True)
            y = self.base_network.decoder(lat_z_pred)
        else:
            lat_z_pred = self.inner_model.sample_pointcloud(num_samples = num_samples, cond=y)
            y = self.base_network.decoder(lat_z_pred)

        return y, lat_z_pred


VAE_encoder_kwargs = {"ae_config":"non_deterministic",
                "z_dim":model_config.latent_space_dims,
                "input_dim":io_config.ps_dims,
                "conv_layer_config":[16, 32, 64, 128, 256, 608],
                "conv_add_bn": False,
                "fc_layer_config":[544]}

VAE_decoder_kwargs = {"z_dim":model_config.latent_space_dims,
                "input_dim":io_config.ps_dims,
                "initial_conv3d_size":[16, 4, 4, 4],
                "add_batch_normalisation":False,
                    "fc_layer_config":[1024]}
def load_objects(rank):

    torch.cuda.set_device(rank)
    torch.cuda.empty_cache()

    loss_fn_for_VAE = MAPPING_TO_LOSS[model_config.config['loss_function']](**model_config.config['loss_kwargs'])

    VAE_obj = VAE(encoder = Encoder,
            encoder_kwargs = VAE_encoder_kwargs,
            decoder = Conv3DDecoder,
            z_dim=model_config.latent_space_dims,
            decoder_kwargs = VAE_decoder_kwargs,
            loss_function = loss_fn_for_VAE,
            property_="momentum_force",
            particles_to_sample = io_config.number_of_particles,
            ae_config="non_deterministic",
            use_encoding_in_decoder=False,
            weight_kl=model_config.config["lambd_kl"],
            device=rank)

    # conv_AE
#     conv_AE_encoder_kwargs = {"ae_config":"simple",
#                     "z_dim":model_config.latent_space_dims,
#                     "input_dim":io_config.ps_dims,
#                     "conv_layer_config":[16, 32, 64, 128, 256, 512],
#                     "conv_add_bn": False}

#     conv_AE_decoder_kwargs = {"z_dim":model_config.latent_space_dims,
#                     "input_dim":io_config.ps_dims,
#                     "add_batch_normalisation":False}

#     conv_AE = ConvAutoencoder(encoder = Encoder,
#                             encoder_kwargs = conv_AE_encoder_kwargs,
#                             decoder = Conv3DDecoder,
#                             decoder_kwargs = conv_AE_decoder_kwargs,
#                             loss_function = EarthMoversLoss(),
#                             )

    # MAF inner model (not used in final runs)
    # inner_model = PC_MAF(dim_condition=config["dim_condition"],
    #                         dim_input=config["dim_input"],
    #                         num_coupling_layers=config["num_coupling_layers"],
    #                         hidden_size=config["hidden_size"],
    #                         device=rank,
    #                         num_blocks_mat = config["num_blocks_mat"],
    #                         activation = config["activation"]
    #                         )

    # INN
    inner_model = INNModel(ndim_tot=config["ndim_tot"],
                    ndim_x=config["ndim_x"],
                    ndim_y=config["ndim_y"],
                    ndim_z=config["ndim_z"],
                    loss_fit=fit,
                    loss_latent=MMD_multiscale,
                    loss_backward=MMD_multiscale,
                    lambd_predict=config["lambd_predict"],
                    lambd_latent=config["lambd_latent"],
                    lambd_rev=config["lambd_rev"],
                    zeros_noise_scale=config["zeros_noise_scale"],
                    y_noise_scale=config["y_noise_scale"],
                    hidden_size=config["hidden_size"],
                    activation=config["activation"],
                    num_coupling_layers=config["num_coupling_layers"],
                    device = rank)

    #model = ModelFinal(VAE_obj, inner_model, EarthMoversLoss())
    #model = ModelFinal(conv_AE, inner_model, EarthMoversLoss())
    model = ModelFinal(VAE_obj,
                       inner_model,
                       EarthMoversLoss(),
                       weight_AE=config["lambd_AE"],
                       weight_IM=config["lambd_IM"])


    #Load a pre-trained model
   
    #map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    
    # updated_state_dict = {key.replace('VAE.', 'base_network.'): value for key, value in original_state_dict.items()}
    updated_state_dict = {key.replace('module.', ''): value for key, value in ckpt["model"].items()}
    model.load_state_dict(updated_state_dict)

    lr = config["lr"]
    bs_factor = io_config.trainBatchBuffer_config["training_bs"] / 2 * world_size
    lr = lr * config["lr_scaling"](bs_factor)
    print("Skaling learning rate from {} to {} due to bs factor {}".format(config["lr"], lr, bs_factor))
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=config["betas"],
                         eps=config["eps"], weight_decay=config["weight_decay"])
    if ( "lr_annealingRate" not in config ) or config["lr_annealingRate"] is None:
        scheduler = None
    else:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=config["lr_annealingRate"])

    return optimizer, scheduler, model



In [None]:
for i in range(0,100,30):
    print(i)

In [None]:
boxlist = []
for i in range(90):
    if data[5][1][i][410] >-15 and data[5][1][i][210] >-5:
        plt.plot(data[5][1][i], label=str(i))
        boxlist.append(i)
plt.legend()
#plt.xlim(400,430)

In [None]:
#ckptfn = "/home/kelling/checkout/FWKT/InSituML/main/ModelHelpers/cINN/slurm-6921762/model_8081"
#ckptfn = "/home/kelling/checkout/FWKT/InSituML/main/ModelHelpers/cINN/slurm-6921753/model_809"
chkptfn = "/home/kelling/checkout/FWKT/InSituML/main/ModelHelpers/cINN/runs_y/slurm-6921781/model_19000"
chkptfn = "/home/kelling/checkout/FWKT/InSituML/main/ModelHelpers/cINN/runs_y/slurm-6921781/model_19000"
#ckptfn = "/home/kelling/checkout/FWKT/InSituML/main/ModelHelpers/cINN/slurm-6921842/model_8081"
#ckptfn = "/home/kelling/checkout/FWKT/InSituML/main/ModelHelpers/cINN/trained_models/inn_vae_latent_544_sim014_859eopan/model_950"
#ckptfn = "/home/kelling/checkout/FWKT/InSituML/main/ModelHelpers/cINN/runs014_30k/slurm-{}/model_{}"
#ckptfn = chkptfn.format(6921756, 809) # 950, contniue at 150k
#ckptfn = chkptfn.format(6921771, 8081) # 950, tuned at 30k
chkptfn = "/home/kelling/checkout/FWKT/InSituML/main/ModelHelpers/cINN/slurm-{}/model_{}"
#chkptfn = chkptfn.format(6923061, 800) # 150, tuned at 30k, continual_bs=4, buffersize=5,training_bs=4, offline
#chkptfn = chkptfn.format(6923062, 800) # 150, tuned at 30k, continual_bs=0, buffersize=4,training_bs=4, offline
#chkptfn = chkptfn.format(6923051, 1600) # 950, tuned at 30k, Y, continual_bs=4, buffersize=5,training_bs=4, offline
#chkptfn = chkptfn.format(6923053, 2000) # 950, tuned at 30k, continual_bs=4, buffersize=5,training_bs=4, offline
#chkptfn = chkptfn.format(6923281, 800) # None, continual_bs=4, buffersize=5,training_bs=4, offline
#chkptfn = chkptfn.format(6923976, 12000) # chamfers streaming, rep 6, lr.001, Y
#chkptfn = chkptfn.format(6923987, 25600) # chamfers streaming, rep 8, lr.0001, Y
#chkptfn = chkptfn.format(6924125, 21600) # chamfers streaming, rep 8, lr.001, Y
#chkptfn = chkptfn.format(6924126, 21600) # chamfers streaming, rep 8, lr.001, Y

#chkptfn = "/bigdata/hplsim/scratch/kelling/chamfers/slurm-{}/model_{}"
#chkptfn = chkptfn.format(6923899, 5600) # chamfers straming, rep 0
#chkptfn = chkptfn.format(6923925, 24000) # chamfers straming, rep 4 "red curve"



#chkptfn = "/bigdata/hplsim/production/steinigk/008-nodes_lr-0.0005_verbose/simOutput/model_376"
#chkptfn = "/bigdata/hplsim/aipp/SC24_PIConGPU-Continual-Learning/03-30_learning-rate-scaling-with-ranks_chamfersdistance_fix-gpu-volume/48-nodes_lr-0.0001_min-tb-16/simOutput/model_1479"
chkptfn = "/bigdata/hplsim/aipp/SC24_PIConGPU-Continual-Learning/03-30_learning-rate-scaling-with-ranks_chamfersdistance_fix-gpu-volume/24-nodes_lr-0.0001_min-tb-16/simOutput/model_1479"
#chkptfn = "/bigdata/hplsim/aipp/SC24_PIConGPU-Continual-Learning/03-30_learning-rate-scaling-with-ranks_chamfersdistance_fix-gpu-volume/{}-nodes_lr-{}_min-tb-{}/simOutput/model_{}"
#chkptfn = chkptfn.format(96,"0.0005", 16, 1400) 
chkptfn = "/bigdata/hplsim/aipp/SC24_PIConGPU-Continual-Learning/04-01_rerun-independent-AE-scaling_chamfersdistance_fix-gpu-volume_scaling/8-nodes_lr-0.0001_min-tb-4_lrAE-20/04-01_1645/simOutput/model_350"
chkptfn = "/bigdata/hplsim/aipp/SC24_PIConGPU-Continual-Learning/04-02_single-gpu-offline-training-from-24-node_hemera/trainingOutput/model_9000"

chkptfn = "/bigdata/hplsim/aipp/SC24_PIConGPU-Continual-Learning/parallelMMD/slurm-{}/model_{}"
#chkptfn = chkptfn.format(7778146, 4000)
#chkptfn = chkptfn.format(7778301, 5600)
#chkptfn = chkptfn.format(7779695, 2400)
chkptfn = chkptfn.format(7779773, 10400)
#chkptfn = chkptfn.format(7779780, 1600)
#chkptfn = chkptfn.format(7784141, 4000)


print(chkptfn)
ckpt = torch.load(chkptfn)

In [None]:
opt, sched, model = load_objects(0)
model.cuda()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")

In [None]:
def plot(boxList, xlim=(-3,3), ylim=(-.4,.4), centerHistSamples=0, lbllist=None):
    plt.rcParams.update({'font.size': 30})
    
    results = []
    
    with torch.inference_mode():
        for it, tt in enumerate(data[-4:]):
            x = tt[0][boxList]
            y = tt[1][boxList]
            dec = model.base_network.forward(x.transpose(1,2).cuda())
            
            zpca = pca_lowrank(dec[4])
            inz = torch.matmul(dec[4], zpca[-1][:, :2])
            inz = inz.detach().cpu()
            
            lat_z = dec[4].detach().cpu()
            decoded = dec[3].detach().cpu()
            pc_pr = []
            lat_z_plot = []
            centerDist = []
            
            for i in range(max(centerHistSamples, 4)):
                p, l = model.reconstruct(x.cuda(), y.cuda())
                center = torch.mean(p[:,:,0], axis=1)
                
                l = torch.matmul(l, zpca[-1][:, :2])
                
                if i < 3 or centerHistSamples == 0:
                    pc_pr.append(p.detach().cpu().numpy())
                if centerHistSamples > 0:
                    centerDist.append(center.detach().cpu().numpy())
                    
                lat_z_plot.append(l.detach().cpu())
                
            lat_z_plot = np.stack(lat_z_plot)
            #print(lat_z_plot.shape)

            if centerHistSamples>0:
                centerDist = np.stack(centerDist)
            for bi in range(len(boxList)):
                
                rgen = []
                
                numPlotCol = 2+len(pc_pr)+(centerHistSamples>0)+1
                fig, ax = plt.subplots(1, numPlotCol, figsize=(10*numPlotCol, 10), squeeze=True)
                for i, p in enumerate(pc_pr):
                    rgen.append(p[bi,:,0])
                    ax[2+i].hexbin(p[bi,:,0], p[bi,:,1], bins="log")
                    ax[2+i].set(xlim=xlim, ylim=ylim, title="INN bw sample")
                    ax[2+i].sharey(ax[0])
                    
                if centerHistSamples>0:
                    chist, binbounds = np.histogram(centerDist[:, bi])
                    ax[-2].bar((binbounds[1:]+binbounds[:-1])/2, chist)
                    ax[-2].set(xlim=xlim, title="Histogram")
                    
                ax[-1].scatter(lat_z_plot[:, bi, 0], lat_z_plot[:, bi, 0], c="gray", s=500)
                ax[-1].scatter(inz[:,0], inz[:,1], c=lbllist, marker="+", s=1000)
                
                ax[0].hexbin(x[bi,0,:].numpy(), x[bi,1,:].numpy(), bins="log")
                ax[0].set(xlim=xlim, ylim=ylim, title="GT@{} box {}".format(it, boxList[bi]))
                ax[1].hexbin(decoded[bi,:,0].numpy(), decoded[bi,:,1].numpy(), bins="log")
                ax[1].set(xlim=xlim, ylim=ylim, title="AE")
                ax[1].sharey(ax[0])
                
                results.append([(it, bi), rgen, x[bi,0,:].numpy()])

                
                fig.show()
    return results

In [None]:
boxlist = [60,61,81,0,1,2,3,4,5,87,88,89]
lbllist = [0,0,0,1,1,1,2,2,2,3,3,3]
#boxlist = [87,88,89]
part = plot(boxList = boxlist, xlim=(-3,3), ylim=(-.1,.1), centerHistSamples=10, lbllist=lbllist)

In [None]:
def center(bins):
    return ( bins[1:]+bins[:-1])/2

def momDenorm(x):
    return x * normalization_values["momentum_std"] + normalization_values["momentum_mean"]

i = -1
gthist, gtbins = np.histogram(momDenorm(part[i][2]))
gtx = center(gtbins)
genhist, genbins = np.histogram(momDenorm(part[i][1][2]))
print(part[i][0])
genx = center(genbins)
plt.plot(gtx, gthist)
plt.plot(genx, gthist)

pickPxGT = momDenorm(part[i][2])
pickPxGen = momDenorm(part[i][1][2])

In [None]:
def plotRad(boxList, xlim=(-3,3), ylim=(-.4,.4), centerHistSamples=0):
    plt.rcParams.update({'font.size': 30})
    
    results = []
    
    with torch.inference_mode():
        for it, tt in enumerate(data[-4:]):
            x = tt[0][boxList]
            y = tt[1][boxList]
            pc_pr = []
            lat_z_pred = []
            centerDist = []
            
            ae = model.base_network.forward(x.transpose(1,2).cuda())
            rad = model.inner_model.forward(ae[4])
            loss_IM, l_fit,l_latent,l_rev = model.inner_model.compute_losses(ae[4], y)
            
            rad = rad.detach().cpu().numpy()

            numPlotCol = rad.shape[0]
            fig, ax = plt.subplots(1, numPlotCol, figsize=(10*numPlotCol, 10), squeeze=True, sharey=True)
            for bi in range(len(boxList)):

                ax[bi].set_title("l_fit={:0.5} b={}".format(l_fit.item(), boxList[bi]))
                results.append([(it, bi), rad[bi, -512:], y[bi]])
                ax[bi].plot(rad[bi, -512:])
                ax[bi].plot(y[bi])
                    
            fig.show()
    
    return results

In [None]:
rad = plotRad(boxList = boxlist)

In [None]:
plt.plot(rad[-1][1])
plt.plot(rad[-1][2])

pickRadGT = rad[-1][1]
pickRadPred = rad[-1][1]

In [None]:
np.savez("picks.npz", pickPxGen=pickPxGen, pickPxGT=pickPxGT, pickRadPred=pickRadPred, pickRadGT=pickRadGT)

In [None]:
#plotRad(boxList = [1,2,5,6,9])

In [None]:
rad014fn = "/bigdata/hplsim/production/KHI_for_GB_MR/runs/014_KHI_007_noWindowFunction/simOutput/totalRad/e_radiation_{}.dat"
radFfn ="/bigdata/hplsim/aipp/SC24_PIConGPU-Continual-Learning/03-30_learning-rate-scaling-with-ranks_chamfersdistance_fix-gpu-volume/{}-nodes_lr-0.001_min-tb-{}/simOutput/totalRad/e_radiation_{}.dat"

In [None]:
t = 900
tF = t
obs = 0

In [None]:
plt.rcParams.update({'font.size': 10})

fn = rad014fn.format(t)
rad014 = np.loadtxt(fn)
plt.plot(np.log(rad014[obs]), linestyle="-", label="014")
for nn in (8,):
    for tb in (4,8,16):
        fn = radFfn.format(nn, tb, tF)
        #print(fn)
        radF = np.loadtxt(fn)
        plt.plot(np.log(radF[obs]), linestyle=":", label="Frontier {} nodes".format(nn))
plt.legend()
plt.show()

In [None]:
r = np.load("/bigdata/hplsim/aipp/SC24_PIConGPU-Continual-Learning/03-30_learning-rate-scaling-with-ranks_chamfersdistance_fix-gpu-volume/96-nodes_lr-0.0001_min-tb-16/simOutput/streamedRadiation/ts_1.npy")

#r = np.load("/home/kelling/checkout/FWKT/InSituML/main/ModelHelpers/cINN/slurm-6924589/ts_1.npy")



In [None]:
#plt.ylim((-1e2,0))
#plt.xlim((280,400))
print(r.shape)
plt.plot(np.log(-r[0,0,:150]))
#plt.plot(r[0,0,:150])