In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import torch
from hydra.utils import instantiate
from hydra import initialize, compose
import hydra

import wandb

from data.dataManager import DataManager
from model.modelCreator import ModelCreator
from omegaconf import OmegaConf
from scripts.run import setup_model, load_model_instance

from utils.plots import vae_plots
from utils.rbm_plots import plot_rbm_histogram

[1m[14:45:23.422][0m [1;95mINFO [1;0m  [1mCaloQuVAE                                         [0mLoading configuration.


In [79]:
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path="config")
config=compose(config_name="config.yaml")
wandb.init(tags = [config.data.dataset_name], project=config.wandb.project, entity=config.wandb.entity, config=OmegaConf.to_container(config, resolve=True), mode='disabled')

In [80]:
new_model = False
if new_model:
    self = setup_model(config)
    # self.model = self.model.double()  # sets all model parameters to float64
else:
    self = load_model_instance(config.config_path)
    # self.model = self.model.double()


[1m[03:40:38.168][0m [1;95mINFO [1;0m  [1mdata.dataManager                                  [0mLoading other dataset: CaloChallenge2
[1m[03:40:38.170][0m [1;95mINFO [1;0m  [1mdata.dataManager                                  [0mKeys: ['incident_energies', 'showers']
[1m[03:40:42.776][0m [1;95mINFO [1;0m  [1mdata.dataManager                                  [0mdict_keys(['incident_energies', 'showers'])
[1m[03:40:42.777][0m [1;95mINFO [1;0m  [1mdata.dataManager                                  [0m<torch.utils.data.dataloader.DataLoader object at 0x7fa824eaf590>: 79999 events, 157 batches
[1m[03:40:42.778][0m [1;95mINFO [1;0m  [1mdata.dataManager                                  [0m<torch.utils.data.dataloader.DataLoader object at 0x7fa85c6b8920>: 10001 events, 10 batches
[1m[03:40:42.778][0m [1;95mINFO [1;0m  [1mdata.dataManager                                  [0m<torch.utils.data.dataloader.DataLoader object at 0x7fac8229b530>: 9999 events, 10 batch

[1m[03:40:43.686][0m [1;95mINFO [1;0m  [1mdwave.cloud.client.base                           [0mReceived solver data for 7 solver(s).
[1m[03:40:43.752][0m [1;95mINFO [1;0m  [1mdwave.cloud.client.base                           [0mAdding solver StructuredSolver(id='Advantage_system4.1')
[1m[03:40:43.794][0m [1;95mINFO [1;0m  [1mdwave.cloud.client.base                           [0mAdding solver StructuredSolver(id='Advantage_system6.4')
[1m[03:40:43.834][0m [1;95mINFO [1;0m  [1mdwave.cloud.client.base                           [0mAdding solver StructuredSolver(id='Advantage2_system1.3')
[1m[03:40:44.236][0m [1;95mINFO [1;0m  [1mscripts.run                                       [0mRequesting GPUs. GPU list :[1]
[1m[03:40:44.237][0m [1;95mINFO [1;0m  [1mscripts.run                                       [0mMain GPU : cuda:1
[1m[03:40:44.237][0m [1;95mINFO [1;0m  [1mscripts.run                                       [0mCUDA available


cuda:1


[1m[03:40:44.837][0m [1;95mINFO [1;0m  [1mmodel.modelCreator                                [0mLoading state


encoder._networks.0.seq1.0.conv.weight True
encoder._networks.0.seq1.0.conv.bias True
encoder._networks.0.seq1.1.weight True
encoder._networks.0.seq1.1.bias True
encoder._networks.0.seq1.2.weight True
encoder._networks.0.seq1.3.conv.weight True
encoder._networks.0.seq1.3.conv.bias True
encoder._networks.0.seq1.4.weight True
encoder._networks.0.seq1.4.bias True
encoder._networks.0.seq1.5.weight True
encoder._networks.0.seq1.6.conv.weight True
encoder._networks.0.seq1.6.conv.bias True
encoder._networks.0.seq1.7.weight True
encoder._networks.0.seq1.7.bias True
encoder._networks.0.seq1.8.weight True
encoder._networks.0.seq2.0.conv.weight True
encoder._networks.0.seq2.0.conv.bias True
encoder._networks.0.seq2.1.weight True
encoder._networks.0.seq2.1.bias True
encoder._networks.0.seq2.2.weight True
encoder._networks.0.seq2.3.conv.weight True
encoder._networks.0.seq2.3.conv.bias True
encoder._networks.0.seq2.4.weight True
encoder._networks.1.seq1.0.conv.weight True
encoder._networks.1.seq1.0.

[1m[03:40:45.522][0m [1;95mINFO [1;0m  [1mmodel.modelCreator                                [0mLoading weights from file : /fast_scratch_1/caloqvae/jtoledo/wandb/run-20250712_031134-be4t7ksb/files/autoencoderbase_0.pth


Loading weights for module =  _hit_smoothing_dist_mod
Loading weights for module =  _bce_loss
Loading weights for module =  encoder
Loading weights for module =  decoder
Loading weights for module =  prior


In [82]:
# config.model.decoder_input
# self.model.decoder
self.total_loss_dict = {}

In [92]:
with torch.no_grad():
    self.model.eval()
    self.model.training = True
    self.model.encoder.training=False
    self.model.decoder.training = False
    self.model._hit_smoothing_dist_mod.training = False
    for i, (x, x0) in enumerate(self.data_mgr.test_loader):
        # x, x0 = next(iter(self.data_mgr.val_loader))
        x = x.to(self.device)
        x0 = x0.to(self.device)
        x = self._reduce(x, x0)
        # Forward pass
        output = self.model((x, x0), self.beta, 0)
        # Compute loss
        loss_dict = self.model.loss(x, output)
        loss_dict["loss"] = torch.stack([loss_dict[key] * self._config.model.loss_coeff[key]  for key in loss_dict.keys() if "loss" != key]).sum()
        for key in list(loss_dict.keys()):
            loss_dict['val_'+key] = loss_dict[key]
            loss_dict.pop(key)
        self.aggr_loss(self.data_mgr.val_loader, loss_dict)
        print(loss_dict)
        break
        # if torch.isnan(loss_dict['val_ae_loss']):
            # print(i)

{'val_ae_loss': tensor(4064.0273, device='cuda:1'), 'val_kl_loss': tensor(-406.7999, device='cuda:1'), 'val_hit_loss': tensor(2196.3430, device='cuda:1'), 'val_entropy': tensor(-442.5678, device='cuda:1'), 'val_pos_energy': tensor(-246.3113, device='cuda:1'), 'val_neg_energy': tensor(282.0793, device='cuda:1'), 'val_loss': tensor(5853.5703, device='cuda:1')}


In [106]:
6480*np.log(2)

4491.593730028446

In [109]:
1/np.exp(1600/6480)

0.7812082024342466

In [110]:
7951/6480

1.227006172839506

In [111]:
2905/1.22

2381.1475409836066

In [112]:
torch.rand(10,20).mean(dim=0)

tensor([0.5481, 0.4403, 0.4675, 0.3510, 0.5460, 0.4444, 0.4249, 0.5824, 0.5637,
        0.5127, 0.5563, 0.6025, 0.5552, 0.6179, 0.6048, 0.4942, 0.4330, 0.5279,
        0.4374, 0.3947])

In [113]:
torch.rand(10,20)[0]

tensor([0.2538, 0.1370, 0.0428, 0.2076, 0.9547, 0.3663, 0.1745, 0.8925, 0.5019,
        0.1748, 0.3585, 0.2889, 0.9780, 0.4932, 0.8277, 0.4773, 0.4890, 0.1559,
        0.5237, 0.0838])