In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import torch
from hydra.utils import instantiate
from hydra import initialize, compose
import hydra

import wandb

from data.dataManager import DataManager
from model.modelCreator import ModelCreator
from omegaconf import OmegaConf
from scripts.run import setup_model, load_model_instance

from utils.plots import vae_plots
from utils.rbm_plots import plot_rbm_histogram

In [None]:
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path="config")
config=compose(config_name="config.yaml")
wandb.init(tags = [config.data.dataset_name], project=config.wandb.project, entity=config.wandb.entity, config=OmegaConf.to_container(config, resolve=True), mode='disabled')

In [None]:
new_model = True
if new_model:
    self = setup_model(config)
    # self.model = self.model.double()  # sets all model parameters to float64
else:
    self = load_model_instance(config.config_path)
    # self.model = self.model.double()


In [None]:
self.model.double()
self.fit(0)

In [None]:
# self.model.float()
x,x0 = torch.rand(2,6480).to(self._device), torch.rand(2,1).to(self._device)
beta, post_logits, post_samples = self.model.encoder(x,x0, 5.0)
self.model.decode(post_samples, x, x0, beta, 0.05)

In [None]:
# self.fit(0)
import torch.nn as nn
import torch
import torch.nn.functional as F

class DecoderHierarchy0(nn.Module):
    def __init__(self, cfg):
        super(DecoderHierarchy0, self).__init__()
        self._config = cfg
        self._create_hierarchy_network()
        self._create_skipcon_decoders()

    def _create_hierarchy_network(self):
        self.latent_nodes = self._config.rbm.latent_nodes_per_p * 4
        # change these variables for different HE decoder structures
        # FOR THE MIRROR HD, LET 3 SUBDECODERS GENERATE z1', z2', z3', 
        # THEN LAST SUBDECODER GENERATES THE ENTIRE SHOWER
        # self.n_layers_per_subdec = 11
        # self.layer_step = self._config.model.n_layers_per_subdec*144
         # varies depending on if last layer is > or < layer step
        self.hierarchical_lvls = 4

        inp_layers = self._config.model.decoder_input
        out_layers = self._config.model.decoder_output

        self.moduleLayers = nn.ModuleList([])
        for i in range(len(inp_layers)):
            self.moduleLayers.append(Decoder(self._config, inp_layers[i], out_layers[i]))

        

    def _create_skipcon_decoders(self):
        latent_inp = 2 * self._config.rbm.latent_nodes_per_p
        self.subdecs = nn.ModuleList([])
        for i in range(len(self._config.model.decoder_output)-1):
            recon_out = self.latent_nodes + self._config.model.decoder_output[i]
            self.subdecs.append(nn.Conv3d(latent_inp, recon_out, kernel_size=1, stride=1, padding=0))
    
    def forward(self, x, x0):
        x_lat = x
        self.x1, self.x2 = torch.tensor([]).to(x.device), torch.tensor([]).to(x.device) # store hits and activation tensors
        for lvl in range(len(self.moduleLayers)):
            cur_net = self.moduleLayers[lvl]
            output_hits, output_activations = cur_net(x, x0)
            outputs = output_hits * output_activations
            z = outputs
            if lvl == len(self.moduleLayers) - 1:
                self.x1 = output_hits
                self.x2 = output_activations
            else:
                partition_ind_start = (len(self.moduleLayers) - 1 - lvl) * self._config.rbm.latent_nodes_per_p
                partition_ind_end = (len(self.moduleLayers) - lvl) * self._config.rbm.latent_nodes_per_p
                enc_z = torch.cat((x[:,0:self._config.rbm.latent_nodes_per_p], x[:,partition_ind_start:partition_ind_end]), dim=1)
                # enc_z = x[:,partition_ind_start:partition_ind_end]
                enc_z = torch.unflatten(enc_z, 1, (2 * self._config.rbm.latent_nodes_per_p, 1, 1, 1))
                # enc_z = torch.unflatten(enc_z, 1, (self._config.model.n_latent_nodes_per_p, 1, 1, 1))
                enc_z = self.subdecs[lvl](enc_z).view(enc_z.size(0), -1)
                # print(enc_z.shape)
                xz = torch.cat((x_lat, z), dim=1)
                # print(xz.shape)
                x = enc_z + xz
                # print("ins 1: ", x.shape)
        return self.x1, self.x2

class Decoder(nn.Module): #use this one
    def __init__(self, cfg, input_size, output_size):
        super(Decoder, self).__init__()
        self._config = cfg

        self.n_latent_hierarchy_lvls=self._config.rbm.partitions

        self.n_latent_nodes=self._config.rbm.latent_nodes_per_p * self._config.rbm.partitions

        self.z = self._config.data.z
        self.r = self._config.data.r
        self.phi = self._config.data.phi

        output_size_z = int( output_size / ( self.r * self.phi))

        self._layers =  nn.Sequential(
                   nn.Unflatten(1, (input_size, 1, 1, 1)),

                   PeriodicConvTranspose3d(input_size, 512, (3,3,2), (2,1,1), 0),
                   nn.BatchNorm3d(512),
                   nn.PReLU(512, 0.02),
                   

                   PeriodicConvTranspose3d(512, 128, (5,4,2), (2,1,1), 0),
                   nn.BatchNorm3d(128),
                   nn.PReLU(128, 0.02),
                                   )
        
        self._layers2 = nn.Sequential(
                   PeriodicConvTranspose3d(129, 64, (3,3,2), (2,1,1), 1),
                   nn.BatchNorm3d(64),
                   nn.PReLU(64, 0.02),

                   PeriodicConvTranspose3d(64, 32, (5,3,2), (2,1,1), 1),
                   nn.BatchNorm3d(32),
                   nn.PReLU(32, 1.0),

                   PeriodicConvTranspose3d(32, 1, (5,3,3), (1,1,1), 0),
                   PeriodicConv3d(1, 1, (self.z - output_size_z + 1, 1, 1), (1,1,1), 0),
                   nn.PReLU(1, 1.0)
                                   )
        
        self._layers3 = nn.Sequential(
                   PeriodicConvTranspose3d(129, 64, (3,3,2), (2,1,1), 1),
                   nn.BatchNorm3d(64),
                   nn.PReLU(64, 0.02),

                   PeriodicConvTranspose3d(64, 32, (5,3,2), (2,1,1), 1),
                   nn.BatchNorm3d(32),
                   nn.PReLU(32, 0.02),

                   PeriodicConvTranspose3d(32, 1, (5,3,3), (1,1,1), 0),
                   PeriodicConv3d(1, 1, (self.z - output_size_z + 1, 1, 1), (1,1,1), 0),
                   nn.PReLU(1, 0.02),
                                   )
        
    def forward(self, x, x0):
                
        x = self._layers(x)
        x0 = self.trans_energy(x0)
        xx0 = torch.cat((x, x0.unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1,1,torch.tensor(x.shape[-3:-2]).item(),torch.tensor(x.shape[-2:-1]).item(), torch.tensor(x.shape[-1:]).item())), 1)
        x1 = self._layers2(xx0) #hits
        x2 = self._layers3(xx0)
        return x1.reshape(x1.shape[0],-1), x2.reshape(x1.shape[0],-1)
        # return x1, x2
    
    def trans_energy(self, x0, log_e_max=14.0, log_e_min=6.0, s_map = 1.0):
        return ((torch.log(x0) - log_e_min)/(log_e_max - log_e_min)) * s_map
    


class PeriodicConvTranspose3d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(PeriodicConvTranspose3d, self).__init__()
        self.padding = padding
        self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias)

    def forward(self, x):
        # Apply convolution
        x = self.conv(x)
        # Pad input tensor with periodic boundary conditions
        if self.padding == 1:
            mid = x.shape[-2] // 2
            shift = torch.cat((x[..., mid:, [0]], x[..., :mid, [0]]), -2)
            x = torch.cat((shift,x), dim=-1)
            x = F.pad(x, (0, 0, self.padding, self.padding, 0, 0), mode='circular')
        return x
    
class PeriodicConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(PeriodicConv3d, self).__init__()
        self.padding = padding
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias)
    def forward(self, x):
        # Pad input tensor with periodic boundary and circle-center conditions
        if self.padding == 1:
            mid = x.shape[-1] // 2
            shift = torch.cat((x[..., [-1], mid:], x[..., [-1], :mid]), -1)
            x = torch.cat((x, shift), dim=-2)
        x = F.pad(x, (self.padding, self.padding, 0, 0, 0, 0), mode='circular')
        # Apply convolution
        x = self.conv(x)
        return x

In [None]:
config.data.z = 45
config.data.r = 9
config.data.phi = 16
d = Decoder(config, 1208, 2160)
print(d(torch.rand(2,1208), torch.rand(2,1))[0].shape, d(torch.rand(2,1208), torch.rand(2,1))[1].shape)

In [None]:
self.model.decoder(torch.rand(2,1208).to(self.device), torch.rand(2,1).to(self.device))

In [None]:
# config.model.decoder_input
self.model.decoder

In [None]:
x = d(torch.rand(2,1208), torch.rand(2,1))[1]

In [None]:
x.reshape(x.shape[0], -1).shape