In [None]:
# default_exp model.rnnvae

In [None]:
# hide
%load_ext autoreload
%autoreload 2

In [None]:
# export

import torch
from torch import nn, optim
from deeptool.architecture import Encoder, Decoder, DownUpConv
from deeptool.utils import Tracker

# RNN VAE

> Structure for an Approach maintained a pseudo space realtion 

In [None]:
# load some test dataset to confirm architecture:
from deeptool.parameters import get_all_args
from deeptool.dataloader import load_test_batch
args = get_all_args()
args.model_type = "rnnvae"
args.batch_size = 1
batch = load_test_batch(args)
batch["img"].shape

torch.Size([1, 3, 16, 256, 256])

In [None]:
# export 

def mod_batch(batch, key="img"):
    """
    transform the batch to be compatible with the network by permuting
    """
    batch[key] = batch[key][0, :, :, :, :]
    batch[key] = batch[key].permute(1, 0, 2, 3)
    return batch

In [None]:
batch = mod_batch(batch)
batch["img"].shape

torch.Size([16, 3, 256, 256])

In [None]:
args.dim = 2
enc_part = DownUpConv(args, pic_size=256, n_fea_in=3, n_fea_next=8, depth=1, )
enc_part.min_size

4

In [None]:
class RNN_VAE(nn.Module):

    def __init__(self, device, args):
        """
        The recurrent autoencoder for compressing 3d data.
        It compresses in 2d while (hopefully) maintaining the spatial relation between layers
        """
        super(RNN_VAE, self).__init__()
        self.device = device

        # 1. create the convolutional Encoder
        self.conv_part_enc = DownUpConv(args, pic_size=args.pic_size, n_fea_in=len(
            args.perspectives), n_fea_next=args.n_fea_up, depth=1).to(self.device)

        # save important features
        max_fea, min_size = self.conv_part_enc.max_fea, self.conv_part_enc.min_size
        self.view_arr = [-1, max_fea * min_size**2]

        # 2. Apply FC- Encoder Part
        self.fc_part_enc = nn.Sequential(
            nn.Linear(max_fea*min_size*min_size, max_fea*min_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(max_fea*min_size, max_fea),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(max_fea, args.n_z),
        ).to(self.device)

        # 3. Transition Layer: GRU
        self.transition = nn.GRU(args.n_z, args.n_z, 1).to(self.device)

        # 4. Apply FC-Decoder Part
        self.fc_part_dec = nn.Sequential(
            nn.Linear(args.n_z, max_fea),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(max_fea, max_fea*min_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(max_fea*min_size, max_fea*min_size*min_size),
        ).to(self.device)

        # 5. create the convolutional Decoder
        self.conv_part_dec = DownUpConv(
            args, pic_size=args.pic_size, n_fea_in=len(
            args.perspectives), n_fea_next=args.n_fea_down, depth=1, move='up').to(self.device)
    
    def rnn_transition(self, x):
        """
        take the matrix of encoded input slices and apply the RNN part
        """
        pass
        

    def forward(self, batch):
        """
        calculate the forward pass
        """
        # move to gpu
        x = batch['img'].to(self.device)
        
        # encode:
        x = self.conv_part_enc(x)
        x = x.reshape(self.view_arr)
        x = self.fc_part_enc(x)
        
        # apply the GRU transition
        
        
        return x

In [None]:
device = torch.device("cuda:0" if (
        torch.cuda.is_available() and args.n_gpu > 0) else "cpu")
rnn_vae = RNN_VAE(device, args)
rnn_vae(batch).shape

torch.Size([16, 100])

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_dataloader.ipynb.
Converted 01_architecture.ipynb.
Converted 02_utils.ipynb.
Converted 03_parameters.ipynb.
Converted 04_train_loop.ipynb.
Converted 10_diagnosis.ipynb.
Converted 20_dcgan.ipynb.
Converted 21_introvae.ipynb.
Converted 22_vqvae.ipynb.
Converted 23_rnn_vae.ipynb.
Converted 99_index.ipynb.


In [None]:
?? DownUpConv