In [1]:
# default_exp model.rnnvae
# default_cls_lvl 3

In [2]:
# hide
%load_ext autoreload
%autoreload 2

In [3]:
# export

import torch
from torch import nn, optim
import torch.nn.functional as F
from deeptool.architecture import Encoder, Decoder, DownUpConv
from deeptool.abs_model import AbsModel

# RNN VAE

> Structure for an Approach maintained a pseudo space realtion

<img src="img/rnn_vae_arch.png" alt="Drawing" style="width: 600px;"/>

In [4]:
# load some test dataset to confirm architecture:
from deeptool.parameters import get_all_args
from deeptool.dataloader import load_test_batch

args = get_all_args()
args.model_type = "rnnvae"
args.batch_size = 1
args.track = False
batch = load_test_batch(args)
batch["img"].shape

torch.Size([1, 3, 16, 256, 256])

In [5]:
# export
def mod_batch_2d(batch):
    """
    transform the batch to be compatible with the network
    """
    return batch


def mod_batch_3d(batch, key="img"):
    """
    transform the batch to be compatible with the network by permuting
    """
    if len(batch[key].shape) > 4:
        batch[key] = batch[key][0, :, :, :, :]
        batch[key] = batch[key].permute(1, 0, 2, 3)
    return batch

In [6]:
batchmod = mod_batch_3d(batch)
batchmod["img"].shape

torch.Size([16, 3, 256, 256])

In [7]:
# export


class Transition(nn.Module):
    """
    Transition Network with Recurrence / 1D Convolutions / Identity Function
    """

    def __init__(self, args):
        super(Transition, self).__init__()
        self.define_switches(args)
        self.n_z = args.n_z

    def define_switches(self, args):
        """subfunc of init to define switches"""
        # in = Identity mapping
        ident = nn.Sequential()

        # rnn = Recurrence
        rnn = nn.Sequential(nn.GRU(args.n_z, args.n_z, 1),)

        # cnn = 1d Convolution
        cnn = nn.Sequential(
            nn.Conv1d(
                args.n_z, args.n_z, kernel_size=3, dilation=1, padding=1, stride=1
            ),
            nn.ReLU(),
        )

        # switcher for the cnn part
        switcher_part = {
            "ident": ident,
            "rnn": rnn,
            "cnn": cnn,
        }

        # switcher for the functionality
        switcher_func = {
            "ident": self.forward_ident,
            "rnn": self.forward_rnn,
            "cnn": self.forward_cnn,
        }

        self.main_part = switcher_part.get(
            args.rnn_transition, lambda: "Invalid Transition Type"
        )
        self.forward = switcher_func.get(
            args.rnn_transition, lambda: "Invalid Transition Type"
        )

        print("Transition: " + args.rnn_transition)

    def forward_ident(self, x):
        """do not apply anything"""
        return x

    def forward_rnn(self, x):
        """
        take the matrix of encoded input slices and apply the RNN part
        """
        # reshape
        x = x.reshape([1, -1, self.n_z])
        # apply GRU layer
        x, _ = self.main_part(x)
        # reshape
        x = x.reshape([-1, self.n_z])
        # return result
        return x

    def forward_cnn(self, x):
        """apply cnn functionality"""
        # reshape
        x = x.reshape([1, self.n_z, -1])
        # apply cnn layer
        x = self.main_part(x)
        # reshape
        x = x.reshape([-1, self.n_z])
        # return result
        return x

In [8]:
# hide
from deeptool.train_loop import test_one_batch
from deeptool.parameters import get_all_args, compat_args

args = get_all_args()
args.pic_size = 32

## The Simple Autoencoder with Recurrence

In [9]:
# export


class RNNAE(AbsModel):
    def __init__(self, device, args):
        """
        The recurrent autoencoder for compressing 3d data.
        It compresses in 2d while (hopefully) maintaining the spatial relation between layers
        """
        super(RNNAE, self).__init__(args)
        self.device = device

        # 1. create the convolutional Encoder
        self.true_dim = args.dim

        args.dim = 2
        
        self.conv_part_enc = DownUpConv(
            args,
            pic_size=args.pic_size,
            n_fea_in=len(args.perspectives),
            n_fea_next=args.n_fea_up,
            depth=1,
        ).to(self.device)

        # save important features
        max_fea, min_size = self.conv_part_enc.max_fea, self.conv_part_enc.min_size
        self.n_z, self.max_fea, self.min_size = args.n_z, max_fea, min_size

        self.view_arr = [-1, max_fea * min_size ** 2]  # as flat vector
        self.view_conv = [-1, max_fea, min_size, min_size]  # as conv block
        self.view_track = [1, len(args.perspectives), -1, args.pic_size, args.pic_size]

        # 2. Apply FC- Encoder Part
        self.fc_part_enc = nn.Sequential(
            nn.Linear(max_fea * min_size * min_size, max_fea * min_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(max_fea * min_size, max_fea),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(max_fea, args.n_z),
        ).to(self.device)

        # The Transition Part
        self.transition = Transition(args).to(self.device)

        # 4. Apply FC-Decoder Part
        self.fc_part_dec = nn.Sequential(
            nn.Linear(args.n_z, max_fea),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(max_fea, max_fea * min_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(max_fea * min_size, max_fea * min_size * min_size),
        ).to(self.device)

        # 5. create the convolutional Decoder
        self.conv_part_dec = DownUpConv(
            args,
            pic_size=args.pic_size,
            n_fea_in=len(args.perspectives),
            n_fea_next=args.n_fea_down,
            depth=1,
            move="up",
        ).to(self.device)

        # the standard loss
        self.mse_loss = nn.MSELoss(reduction="sum")

        # the optimizer
        self.optimizer = optim.Adam(self.parameters(), lr=args.lr)

        # reset the dimension
        args.dim = self.true_dim
        self.mod_batch = mod_batch_3d if args.dim == 3 else mod_batch_2d

    def encode(self, x):
        x = self.conv_part_enc(x)
        x = x.reshape(self.view_arr)
        x = self.fc_part_enc(x)
        return x

    def decode(self, x):
        # apply transition
        x = self.transition(x)
        # decode
        x = self.fc_part_dec(x)
        x = x.reshape(self.view_conv)
        x = self.conv_part_dec(x)
        return x

    def ae_forward(self, img):
        # encode:
        x = self.encode(img)
        # decode
        x = self.decode(x)
        # calc loss
        loss = self.mse_loss(img, x)
        return loss, x

    def forward(self, batch, update=True):
        """
        calculate the forward pass
        """
        # prepare
        batch = self.mod_batch(batch)
        img = self.prep(batch).to(self.device)

        # autoencoder
        loss, x = self.ae_forward(img)

        if update:
            loss.backward()
            self.optimizer.step()
            return x

        else:
            tr_data = {}
            tr_data["loss"] = loss.item()

        return x, tr_data

In [10]:
trans = nn.Sequential()
a = 5
trans(a) == a

True

In [11]:
args.rnn_transition = "cnn"
x = torch.randn(100, 34)
tran = Transition(args)
print(tran(x).shape)
tran

Transition: cnn
torch.Size([34, 100])


Transition(
  (main_part): Sequential(
    (0): Conv1d(100, 100, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): ReLU()
  )
)

In [12]:
# 3 dim test
args.model_type = "rnnvae"
args.dataset_type = "MRNet"
args.rnn_type = "ae"
args.dim = 3
args = compat_args(args)
test_one_batch(args)

Model-Type: rnnvae
ae
Transition: cnn


## The Variational Autoencoder in RNN Mode

In [13]:
# export


class RNNVAE(RNNAE):
    """
    inherit from RNN_AE and add the variational part
    """

    def __init__(self, device, args):
        super(RNNVAE, self).__init__(device, args)
        # 2. rewrite FC- Encoder Part
        max_fea, min_size = self.max_fea, self.min_size
        self.fc_part_enc = nn.Sequential(
            nn.Linear(max_fea * min_size * min_size, max_fea * min_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(max_fea * min_size, max_fea),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(max_fea, 2 * args.n_z),
            nn.Sigmoid(),
        ).to(self.device)
        # get the kl facor
        self.gamma = args.gamma

        # reset the optimizer
        self.optimizer = optim.Adam(self.parameters(), lr=args.lr)

    def vae_sampling(self, x):
        mu, log_sig2 = x.chunk(2, dim=1)
        # get random matrix
        eps = torch.rand_like(mu, device=self.device)
        # sample together
        z = mu + torch.exp(torch.mul(0.5, log_sig2)) * eps
        return z, mu, log_sig2

    def kl_loss(self, mu, log_sig2):
        return -0.5 * torch.sum(1 - torch.pow(mu, 2) - torch.exp(log_sig2) + log_sig2)

    def forward(self, batch, update=True):
        # prepare
        batch = self.mod_batch(batch)
        img = self.prep(batch).to(self.device)
        # encode
        x = self.encode(img)
        # apply the vae sampling
        x, mu, log_sig2 = self.vae_sampling(x)
        # decode
        x = self.decode(x)

        # get loss
        ae_loss = self.mse_loss(img, x)
        vae_loss = self.kl_loss(mu, log_sig2)
        loss = ae_loss + self.gamma * vae_loss

        if update:
            loss.backward()
            self.optimizer.step()
            return x

        else:
            tr_data = {}
            tr_data["loss"] = loss.item()
            tr_data["ae_loss"] = ae_loss.item()
            tr_data["vae_loss"] = vae_loss.item()

        return x, tr_data

In [14]:
# 3 dim test
args.model_type = "rnnvae"
args.rnn_type = "vae"
args.dim = 3
args = compat_args(args)
test_one_batch(args)

Model-Type: rnnvae
vae
Transition: cnn


## The Intro-VAE with RNN

In [15]:
# export


class RNNINTROVAE(RNNVAE):
    """
    inherit from RNN_VAE and add the GAN part
    """

    def __init__(self, device, args):
        super(RNNINTROVAE, self).__init__(device, args)
        # add extra parameters
        self.alpha = args.alpha
        self.beta = args.beta
        self.gamma = args.gamma
        self.m = args.m
        self.n_pretrain = args.n_pretrain

        # reset the optimizer
        self.optimizer = None

        enc_params = (
            list(self.conv_part_enc.parameters())
            + list(self.fc_part_enc.parameters())
            + list(self.transition.parameters())
        )
        self.optimizerEnc = optim.Adam(enc_params, lr=args.lr)

        dec_params = (
            list(self.conv_part_dec.parameters())
            + list(self.fc_part_dec.parameters())
            + list(self.transition.parameters())
        )
        self.optimizerDec = optim.Adam(dec_params, lr=args.lr)
    
    def forward(self, batch, update=True):
        """
        Get the different relevant outputs for Intro VAE training
        update=True to allow updating, update=False to keep networs constant
        return x_re (reconstructed)
        """
        batch = self.mod_batch(batch)
        x = self.prep(batch).to(self.device)
        
        #=========== Update E ================
        self.optimizerEnc.zero_grad()
        
        # real
        z = self.encode(x)
        z, z_mu, z_log_sig2 = self.vae_sampling(z)
        x_re = self.decode(z)
        
        # fake
        noise = torch.randn_like(z, device=self.device)
        fake = self.decode(noise)
        
        # encode again
        z_re = self.encode(x_re.detach())
        _, z_mu_re, z_log_sig2_re = self.vae_sampling(z_re)
        
        z_fake = self.encode(fake.detach())
        _, z_mu_fake, z_log_sig2_fake = self.vae_sampling(z_fake)
        
        # get losses
        loss_rec = self.mse_loss(x, x_re)
        loss_e_real_kl = self.kl_loss(z_mu, z_log_sig2)
        loss_e_rec_kl = self.kl_loss(z_mu_re, z_log_sig2_re)
        loss_e_fake_kl = self.kl_loss(z_mu_fake, z_log_sig2_fake)
        
        # combine losses
        loss_margin_e = loss_e_real_kl + (F.relu(self.m - loss_e_rec_kl) + F.relu(self.m - loss_e_fake_kl)) * self.alpha
        loss_e = loss_rec * self.beta + loss_margin_e * self.gamma
        
        if update:
            loss_e.backward()
            self.optimizerEnc.step()
        
        #========= Update G ================== 
        self.optimizerDec.zero_grad()
        
        # real
        z = self.encode(x)
        z, z_mu, z_log_sig2 = self.vae_sampling(z)
        x_re = self.decode(z)
        
        # fake
        noise = torch.randn_like(z, device=self.device)
        fake = self.decode(noise)
        
        # encode again
        z_re = self.encode(x_re)
        _, z_mu_re, z_log_sig2_re = self.vae_sampling(z_re)
        
        z_fake = self.encode(fake)
        _, z_mu_fake, z_log_sig2_fake = self.vae_sampling(z_fake)
        
        # get losses
        loss_rec = self.mse_loss(x, x_re)
        loss_g_real_kl = self.kl_loss(z_mu, z_log_sig2)
        loss_g_rec_kl = self.kl_loss(z_mu_re, z_log_sig2_re)
        loss_g_fake_kl = self.kl_loss(z_mu_fake, z_log_sig2_fake)
        
        # combine losses
        loss_margin_g = loss_g_real_kl * (loss_g_rec_kl + loss_g_fake_kl) * self.alpha
        loss_g = loss_rec * self.beta + loss_margin_g * self.gamma
        
        if update:
            loss_g.backward()
            self.optimizerDec.step()
        
        else:
            # setup dictionary for Tracking
            tr_data = {}
            tr_data["loss_rec"] = loss_rec.item()
            tr_data["loss_e_real_kl"] = loss_e_real_kl.item()
            tr_data["loss_margin_e"] = loss_margin_e.item()
            tr_data["loss_margin_g"] = loss_margin_g.item()
            tr_data["loss_e"] = loss_e.item()
            tr_data["loss_g"] = loss_g.item()

            # Return output and tracking data
            return x_re, tr_data

In [16]:
# 3 dim test
args.model_type = "rnnvae"
args.rnn_type = "introvae"
args.dim = 3
args = compat_args(args)
test_one_batch(args)

Model-Type: rnnvae
introvae
Transition: cnn


### Architecture:
<img src="img/arch_biggan.png" alt="Drawing" style="width: 800px;"/>

### Loss:
<img src="img/biggan.png" alt="Drawing" style="width: 800px;"/>


(by: https://arxiv.org/pdf/1907.02544.pdf)

In [17]:
# export


class RNNBIGAN(RNNVAE):
    """
    apply the Bidirectional-GAN part, inherit from the normal autoencoder
    """

    def __init__(self, device, args):
        """
        init the networks and the discriminator
        """
        # init the vae architecture
        super(RNNBIGAN, self).__init__(device, args)
        # we ned a dicriminator!
        # switch to 2 dim for the init:
        # -----------
        args.dim = 2
        self.conv_part_dis = DownUpConv(
            args,
            pic_size=args.pic_size,
            n_fea_in=len(args.perspectives),
            n_fea_next=args.n_fea_up,
            depth=1,
        ).to(self.device)
        args.dim = 3
        # -----------

        # take saved params
        max_fea, min_size, n_z = self.max_fea, self.min_size, args.n_z

        # add the fc part(s)
        self.fc_part_dis_x = nn.Sequential(
            # layer 1
            nn.Linear(max_fea * min_size * min_size, max_fea * min_size),
            nn.LeakyReLU(0.2, inplace=True),
            # layer 2
            nn.Linear(max_fea * min_size, max_fea),
            nn.LeakyReLU(0.2, inplace=True),
            # layer 3
            nn.Linear(max_fea, max_fea),
            nn.LeakyReLU(0.2, inplace=True),
            # layer 4
            nn.Linear(max_fea, 1),
        ).to(self.device)

        self.fc_part_dis_z = nn.Sequential(
            # layer 1
            nn.Linear(n_z, max_fea),
            nn.LeakyReLU(0.2, inplace=True),
            # layer 2
            nn.Linear(max_fea, max_fea),
            nn.LeakyReLU(0.2, inplace=True),
            # layer 3
            nn.Linear(max_fea, 1),
        ).to(self.device)

        self.fc_part_dis_xz = nn.Sequential(
            # layer 1
            nn.Linear(n_z + max_fea * min_size * min_size, max_fea * min_size),
            nn.LeakyReLU(0.2, inplace=True),
            # layer 2
            nn.Linear(max_fea * min_size, max_fea),
            nn.LeakyReLU(0.2, inplace=True),
            # layer 3
            nn.Linear(max_fea, max_fea),
            nn.LeakyReLU(0.2, inplace=True),
            # layer 4
            nn.Linear(max_fea, 1),
        ).to(self.device)

        # reset the optimizer
        self.optimizer = None

        enc_params = list(self.conv_part_enc.parameters()) + list(
            self.fc_part_enc.parameters()
        )
        self.optimizerEnc = optim.Adam(enc_params, lr=args.lr)

        dec_params = (
            list(self.conv_part_dec.parameters())
            + list(self.fc_part_dec.parameters())
            + list(self.transition.parameters())
        )
        self.optimizerDec = optim.Adam(dec_params, lr=args.lr)

        dis_params = (
            list(self.conv_part_dis.parameters())
            + list(self.fc_part_dis_x.parameters())
            + list(self.fc_part_dis_z.parameters())
            + list(self.fc_part_dis_xz.parameters())
        )
        self.optimizerDis = optim.Adam(dis_params, lr=args.lr)

        # parameters
        self.lam = args.lam
        self.bi_ae_scale = args.bi_ae_scale

    def encode_non_d(self, x):
        """non-deterministic encoding from the paper"""
        x = self.encode(x)
        # get mu and sigma
        mu, sig_hat = x.chunk(2, dim=1)
        sig = torch.log(1 + torch.exp(sig_hat))
        # get random matrix
        eps = torch.rand_like(mu, device=self.device)
        # sample together
        z = mu + sig * eps
        return z

    def get_s(self, x, z):
        """apply discriminator and output sx, sz and sxz"""
        # shape inputs
        x = self.conv_part_dis(x)
        x = x.reshape(self.view_arr)
        xz = torch.cat([x, z], dim=1)

        # apply fc decisions to generate out-dis
        s_x = self.fc_part_dis_x(x).view(-1)
        s_z = self.fc_part_dis_z(z).view(-1)
        s_xz = self.fc_part_dis_xz(xz).view(-1)

        return s_x, s_z, s_xz

    def hinge(self, x):
        """the hinge loss: max(0, 1-x)"""
        return F.relu(1 - x)

    def decide(self, x, z, y, ed=False):
        """
        generate dis-loss
        ed -> ENCODE-DECODE Learning
        """
        # get decisions from Discriminator
        s_x, s_z, s_xz = self.get_s(x, z)

        # apply y for encoder-decoder
        if ed:
            return y * (s_x + s_z + s_xz)

        # apply hinge losses for discriminator
        hs_x = self.hinge(y * s_x)
        hs_z = self.hinge(y * s_z)
        hs_xz = self.hinge(y * s_xz)

        return hs_x + hs_z + hs_xz

    def ae_part(self, x, update):
        """simple forward pass of autoencoder"""
        ae_loss, _ = self.ae_forward(x)
        ae_loss *= self.bi_ae_scale

        if update:
            ae_loss.backward()
        return ae_loss.mean().item()

    def forward(self, batch, update=True):
        """main function"""
        # zero all gradients
        self.optimizerEnc.zero_grad()
        self.optimizerDec.zero_grad()
        self.optimizerDis.zero_grad()

        # (0) Train Autoencoder
        # -------------------------------
        ae_loss = 0
        # ae_loss = self.ae_part(x, update)

        # (1) Train Discriminator
        # -------------------------------
        # load batch
        batch = self.mod_batch(batch)
        x = self.prep(batch).to(self.device)
        # generate original z
        z = self.encode_non_d(x)

        # fake
        z_p = torch.randn_like(z, device=self.device)
        # decode
        x_p = self.decode(z_p)

        # real
        errd_real = self.decide(x.detach(), z.detach(), +1).mean()
        if update:
            errd_real.backward()

        # fake
        errd_fake = self.decide(x_p.detach(), z_p.detach(), -1).mean()

        if update:
            errd_fake.backward()
            self.optimizerDis.step()

        errd = (errd_real + errd_fake).mean().item()

        # (2) Train Encoder / Decoder
        # -------------------------------
        # encode
        z = self.encode_non_d(x)

        # decode
        x_p = self.decode(z_p)

        # real
        err_enc = self.decide(x, z, +1, ed=True).mean()

        if update:
            err_enc.backward()

        # fake
        err_dec = self.decide(x_p, z_p, -1, ed=True).mean()

        err_enc_dec = (err_enc + err_dec).mean().item()

        # Update Generator
        if update:
            err_dec.backward()
            self.optimizerEnc.step()
            self.optimizerDec.step()
            return x_p

        else:
            # Track all relevant losses
            tr_data = {}
            tr_data["ae_loss"] = ae_loss
            tr_data["errDis"] = errd
            tr_data["errEncDec"] = err_enc_dec

            tr_data["errD_real"] = errd_real.mean().item()
            tr_data["errD_fake"] = errd_fake.mean().item()

            tr_data["errEnc"] = err_enc.mean().item()
            tr_data["errDec"] = err_dec.mean().item()

            # generate the autoencoder output:
            x_r = self.decode(z)

            # Return losses and reconstruction data
            return x_r, tr_data

In [18]:
# export


def creator_rnn_ae(device, args):
    """
    return an instance of the class depending on the mode set in args
    """
    switcher = {
        "ae": RNNAE,
        "vae": RNNVAE,
        "introvae": RNNINTROVAE,
        "bigan": RNNBIGAN,
    }
    print(args.rnn_type)
    # Get the model_creator
    model_creator = switcher.get(args.rnn_type, lambda: "Invalid Model Type")
    return model_creator(device, args)

In [19]:
device = torch.device(
    "cuda:0" if (torch.cuda.is_available() and args.n_gpu > 0) else "cpu"
)
rnn_bigan = RNNBIGAN(device, args)
data = load_test_batch(args)
x, tr = rnn_bigan(data, update=False)
x.shape

Transition: cnn


torch.Size([16, 3, 32, 32])

In [20]:
# hide
from nbdev.export import *

notebook2script()

Converted 00_dataloader.ipynb.
Converted 01_architecture.ipynb.
Converted 02_utils.ipynb.
Converted 03_parameters.ipynb.
Converted 04_train_loop.ipynb.
Converted 05_abstract_model.ipynb.
Converted 10_diagnosis.ipynb.
Converted 20_dcgan.ipynb.
Converted 21_introvae.ipynb.
Converted 22_vqvae.ipynb.
Converted 23_bigan.ipynb.
Converted 24_mocoae.ipynb.
Converted 33_rnn_vae.ipynb.
Converted 99_index.ipynb.
