In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import torch
from hydra.utils import instantiate
from hydra import initialize, compose
import hydra

import wandb

from data.dataManager import DataManager
from model.modelCreator import ModelCreator
from omegaconf import OmegaConf
from scripts.run import setup_model, load_model_instance

from utils.plots import vae_plots
from utils.rbm_plots import plot_rbm_histogram, plot_rbm_params

from scripts.run import set_device

[1m[06:25:15.230][0m [1;95mINFO [1;0m  [1mCaloQuVAE                                         [0mLoading configuration.


In [2]:
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path="config")
config=compose(config_name="config.yaml")
wandb.init(tags = [config.data.dataset_name], project=config.wandb.project, entity=config.wandb.entity, config=OmegaConf.to_container(config, resolve=True), mode='disabled')

In [3]:
new_model = True
if new_model:
    self = setup_model(config)
    # self.model = self.model.double()  # sets all model parameters to float64
else:
    self = load_model_instance(config)
    # self.model = self.model.double()


[1m[06:25:21.782][0m [1;95mINFO [1;0m  [1mdata.dataManager                                  [0mLoading other dataset: CaloChallenge2
[1m[06:25:21.787][0m [1;95mINFO [1;0m  [1mdata.dataManager                                  [0mKeys: ['incident_energies', 'showers']
[1m[06:25:26.940][0m [1;95mINFO [1;0m  [1mdata.dataManager                                  [0mdict_keys(['incident_energies', 'showers'])
[1m[06:25:26.943][0m [1;95mINFO [1;0m  [1mdata.dataManager                                  [0m<torch.utils.data.dataloader.DataLoader object at 0x7fd8154f73e0>: 79999 events, 157 batches
[1m[06:25:26.943][0m [1;95mINFO [1;0m  [1mdata.dataManager                                  [0m<torch.utils.data.dataloader.DataLoader object at 0x7fd812ff0080>: 10001 events, 10 batches
[1m[06:25:26.944][0m [1;95mINFO [1;0m  [1mdata.dataManager                                  [0m<torch.utils.data.dataloader.DataLoader object at 0x7fd8136f3470>: 9999 events, 10 batch

cuda:4
encoder._networks.0.seq1.0.conv.weight True
encoder._networks.0.seq1.0.conv.bias True
encoder._networks.0.seq1.1.weight True
encoder._networks.0.seq1.1.bias True
encoder._networks.0.seq1.2.weight True
encoder._networks.0.seq1.3.conv.weight True
encoder._networks.0.seq1.3.conv.bias True
encoder._networks.0.seq1.4.weight True
encoder._networks.0.seq1.4.bias True
encoder._networks.0.seq1.5.weight True
encoder._networks.0.seq1.6.conv.weight True
encoder._networks.0.seq1.6.conv.bias True
encoder._networks.0.seq1.7.weight True
encoder._networks.0.seq1.7.bias True
encoder._networks.0.seq1.8.weight True
encoder._networks.0.seq2.0.conv.weight True
encoder._networks.0.seq2.0.conv.bias True
encoder._networks.0.seq2.1.weight True
encoder._networks.0.seq2.1.bias True
encoder._networks.0.seq2.2.weight True
encoder._networks.0.seq2.3.conv.weight True
encoder._networks.0.seq2.3.conv.bias True
encoder._networks.0.seq2.4.weight True
encoder._networks.1.seq1.0.conv.weight True
encoder._networks.1.

In [7]:
# from model.decoder.transformer import Multiheadv2
from model.decoder.decoderhierarchy0 import PeriodicConvTranspose3d
import torch.nn as nn
from einops import rearrange

In [179]:
#D1, D2, D3
class DecoderAtt(nn.Module):
    def __init__(self, cfg, input_size):
        super(DecoderAtt, self).__init__()
        self._config = cfg

        self.n_latent_hierarchy_lvls=self._config.rbm.partitions

        self.n_latent_nodes=self._config.rbm.latent_nodes_per_p * self._config.rbm.partitions

        self.z = self._config.data.z
        self.r = self._config.data.r
        self.phi = self._config.data.phi

        # output_size_z = int( output_size / (self.r * self.phi))

        self._layers =  nn.Sequential(
                   nn.Unflatten(1, (input_size, 1, 1, 1)),

                   PeriodicConvTranspose3d(input_size, 512, (3,3,2), (1,1,1), 0),
                   nn.BatchNorm3d(512),
                   nn.PReLU(512, 0.02),
                   

                   PeriodicConvTranspose3d(512, 128, (3,3,2), (1,1,1), 0),
                   nn.BatchNorm3d(128),
                   nn.PReLU(128, 0.02),
                                   )
        
        self._layers2 = nn.Sequential(
                   PeriodicConvTranspose3d(129, 64, (3,3,2), (1,1,1), 1),
                   nn.BatchNorm3d(64),
                   nn.PReLU(64, 0.02),

                   PeriodicConvTranspose3d(64, 32, (2,2,2), (1,1,1), 1),
                   nn.BatchNorm3d(32),
                   nn.PReLU(32, 1.0),

                #    PeriodicConvTranspose3d(32, 1, (5,3,3), (1,1,1), 0),
                #    PeriodicConv3d(1, 1, (self.z - output_size_z + 1, 1, 1), (1,1,1), 0),
                   nn.PReLU(1, 1.0)
                                   )
        
        self._layers3 = nn.Sequential(
                   PeriodicConvTranspose3d(129, 64, (3,3,2), (1,1,1), 1),
                   nn.BatchNorm3d(64),
                   nn.PReLU(64, 0.02),

                   PeriodicConvTranspose3d(64, 32, (2,2,2), (1,1,1), 1),
                   nn.BatchNorm3d(32),
                   nn.PReLU(32, 0.02),

                #    PeriodicConvTranspose3d(32, 1, (5,3,3), (1,1,1), 0),
                #    PeriodicConv3d(1, 1, (self.z - output_size_z + 1, 1, 1), (1,1,1), 0),
                   nn.PReLU(1, 0.02),
                                   )
        
    def forward(self, x, x0):
                
        x = self._layers(x)
        x0 = self.trans_energy(x0)
        xx0 = torch.cat((x, x0.unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1,1,torch.tensor(x.shape[-3:-2]).item(),torch.tensor(x.shape[-2:-1]).item(), torch.tensor(x.shape[-1:]).item())), 1)
        x1 = self._layers2(xx0) #hits
        x2 = self._layers3(xx0)
        return rearrange(x1, "b c l h w -> b (l h w) c"), rearrange(x1, "b c l h w -> b (l h w) c")
    
    def trans_energy(self, x0, log_e_max=14.0, log_e_min=6.0, s_map = 1.0):
        return ((torch.log(x0) - log_e_min)/(log_e_max - log_e_min)) * s_map
    
class Skip(nn.Module):
    def __init__(self, cfg):
        super(Skip, self).__init__()
        self._config = cfg
        self.head_size = self._config.model.head_size
        self.seq = nn.Sequential(
            nn.Unflatten(1, (self._config.rbm.latent_nodes_per_p*2,1,1,1)),
            PeriodicConvTranspose3d(self._config.rbm.latent_nodes_per_p*2, self.head_size,(3,3,3),(1,1,1),0),
        )
        self.query = nn.Linear(27,self._config.model.skip_output_size, bias=False)
        self.value = nn.Linear(27,self._config.model.skip_output_size, bias=False)
        self.linear = nn.Linear(self.head_size, 1, bias=False)

    def forward(self, x, keys):
        x = self.seq(x)
        x = rearrange(x, "b c l h w -> b c (l h w)")
        x_query = self.query(x).transpose(-2,-1)
        x_value = self.value(x).transpose(-2,-1)

        wei = x_query @ keys.transpose(-2,-1) * self.head_size**-0.5
        wei = F.softmax(wei,dim=-1)
        out = self.linear(wei @ x_value).reshape(-1,self._config.model.skip_output_size)

        
        return out
    
class Decoder(nn.Module):
    def __init__(self, cfg, input_size):
        super(Decoder, self).__init__()
        self._config = cfg

        self.n_latent_hierarchy_lvls=self._config.rbm.partitions

        self.n_latent_nodes=self._config.rbm.latent_nodes_per_p * self._config.rbm.partitions

        self.z = self._config.data.z
        self.r = self._config.data.r
        self.phi = self._config.data.phi
        
        self._layers =  nn.Sequential(
                   nn.Unflatten(1, (input_size, 1, 1, 1)),

                   PeriodicConvTranspose3d(input_size, 512, (3,3,2), (2,1,1), 0),
                   nn.BatchNorm3d(512),
                   nn.PReLU(512, 0.02),
                   

                   PeriodicConvTranspose3d(512, 128, (5,4,2), (2,1,1), 0),
                   nn.BatchNorm3d(128),
                   nn.PReLU(128, 0.02),
                                   )
        
        self._layers2 = nn.Sequential(
                   PeriodicConvTranspose3d(129, 64, (3,3,2), (2,1,1), 1),
                   nn.BatchNorm3d(64),
                   nn.PReLU(64, 0.02),

                   PeriodicConvTranspose3d(64, 32, (5,3,2), (2,1,1), 1),
                   nn.BatchNorm3d(32),
                   nn.PReLU(32, 1.0),

                   PeriodicConvTranspose3d(32, 1, (5,3,3), (1,1,1), 0),
                   nn.PReLU(1, 1.0)
                                   )
        
        self._layers3 = nn.Sequential(
                   PeriodicConvTranspose3d(129, 64, (3,3,2), (2,1,1), 1),
                   nn.BatchNorm3d(64),
                   nn.PReLU(64, 0.02),

                   PeriodicConvTranspose3d(64, 32, (5,3,2), (2,1,1), 1),
                   nn.BatchNorm3d(32),
                   nn.PReLU(32, 0.02),

                   PeriodicConvTranspose3d(32, 1, (5,3,3), (1,1,1), 0),
                   nn.PReLU(1, 0.02),
                                   )
        
    def forward(self, x, x0):
                
        x = self._layers(x)
        x0 = self.trans_energy(x0)
        xx0 = torch.cat((x, x0.unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1,1,torch.tensor(x.shape[-3:-2]).item(),torch.tensor(x.shape[-2:-1]).item(), torch.tensor(x.shape[-1:]).item())), 1)
        x1 = self._layers2(xx0) #hits
        x2 = self._layers3(xx0)
        return x1.reshape(x1.shape[0],self.z*self.r*self.phi), x2.reshape(x1.shape[0],self.z*self.r*self.phi)
    
    def trans_energy(self, x0, log_e_max=16.0, log_e_min=5.0, s_map = 1.0):
        return ((torch.log(x0) - log_e_min)/(log_e_max - log_e_min)) * s_map

In [180]:
d = Decoder(config, config.model.decoder_input[-1])
# sk = Skip(config)
d(torch.rand(2,4832), torch.rand(2,1)).shape

AttributeError: 'tuple' object has no attribute 'shape'

In [129]:
# print(d(torch.rand(2,1208), torch.rand(2,1))[0].shape)
# print(sk(torch.rand(2,604))[0].shape)

In [45]:
xx = d(torch.rand(2,1208), torch.rand(2,1))[0]

In [None]:
import torch.nn.functional as F

class DecoderHierarchyTF(nn.Module):
    def __init__(self, cfg):
        super(DecoderHierarchyTF, self).__init__()
        self._config = cfg
        self.head_size = self._config.model.head_size
        self._create_hierarchy_network()
        self._create_skipcon_decoders()

    def _create_hierarchy_network(self):
        self.latent_nodes = self._config.rbm.latent_nodes_per_p * self._config.rbm.partitions
        self.hierarchical_lvls = self._config.rbm.partitions

        inp_layers = self._config.model.decoder_input

        self.moduleLayers = nn.ModuleList([])
        for i in range(self.hierarchical_lvls-1):
            self.moduleLayers.append(DecoderAtt(self._config, inp_layers[i])) 
        self.moduleLayers.append(Decoder(self._config, inp_layers[-1]))   

    def _create_skipcon_decoders(self):
        self.lnpp = self._config.rbm.latent_nodes_per_p
        self.subdecs = nn.ModuleList([])
        for i in range(self.hierarchical_lvls-1):
            self.subdecs.append(Skip(config))
    
    def forward(self, z, x0):
        z_prime = z
        for i in range(len(self.moduleLayers)-1):
            x1, x2 = self.moduleLayers[i](z_prime, x0)
            keys = x1 * x2
            z_skip = torch.cat((z_prime[:,:self.lnpp], z_prime[:,self.lnpp*(3-i):self.lnpp*(4-i)]), dim=1)

            out = self.subdecs[i](z_skip, keys)
            z_prime = torch.cat((z_prime,z),dim=1)
                
        x1, x2 = self.moduleLayers[-1](z_prime, x0)
        return x1,x2,out,z_prime

In [182]:
# dev = set_device(config)
dh._config.model.decoder_input[-1]

4832

In [183]:
dh = DecoderHierarchyTF(config)#.to(dev)
# dh

In [184]:
x1,x2,out,z_prime = dh(torch.rand(2,1208), torch.rand(2,1))

0 torch.Size([2, 672, 32]) torch.Size([2, 672, 32])
torch.Size([2, 604]) torch.Size([2, 672, 32])
torch.Size([2, 672]) torch.Size([2, 2416])
1 torch.Size([2, 672, 32]) torch.Size([2, 672, 32])
torch.Size([2, 604]) torch.Size([2, 672, 32])
torch.Size([2, 672]) torch.Size([2, 3624])
2 torch.Size([2, 672, 32]) torch.Size([2, 672, 32])
torch.Size([2, 604]) torch.Size([2, 672, 32])
torch.Size([2, 672]) torch.Size([2, 4832])
Final z_prime shape: torch.Size([2, 4832])


In [185]:
x1.shape, x2.shape, out.shape, z_prime.shape

(torch.Size([2, 6480]),
 torch.Size([2, 6480]),
 torch.Size([2, 672]),
 torch.Size([2, 4832]))

In [59]:
nn.Linear(32,1)(out).shape

torch.Size([2, 672, 1])

In [32]:
z[:,0:302*4].shape

torch.Size([2, 1208])