In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1, 2

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1, 2


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
from tqdm.auto import tqdm

In [3]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
batch_size=100

class lstm(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, n_layers, batch_size):
        super(lstm, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.n_layers = n_layers
        self.embed = nn.Linear(input_size, hidden_size)
        self.lstm = nn.ModuleList([nn.LSTMCell(hidden_size, hidden_size) for i in range(self.n_layers)])
        self.output = nn.Sequential(
                nn.Linear(hidden_size, output_size),
                #nn.BatchNorm1d(output_size),
                nn.Tanh())
        self.hidden = self.init_hidden(self.batch_size)

    def init_hidden(self, batch_size=1):
        hidden = []
        for i in range(self.n_layers):
            hidden.append((Variable(torch.zeros(batch_size, self.hidden_size).to(device)),
                           Variable(torch.zeros(batch_size, self.hidden_size).to(device))))
        self.hidden = hidden
        return hidden

    def init_hidden_(self, batch_size):
        hidden = []
        for i in range(self.n_layers):
            hidden.append((Variable(torch.zeros(batch_size, self.hidden_size).to(device)),
                           Variable(torch.zeros(batch_size, self.hidden_size).to(device))))
        self.hidden = hidden
        #return hidden

    def forward(self, input):
        embedded = self.embed(input.view(-1, self.input_size))
        h_in = embedded
        for i in range(self.n_layers):
            self.hidden[i] = self.lstm[i](h_in, self.hidden[i])
            h_in = self.hidden[i][0]

        return self.output(h_in)

class gaussian_lstm(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, n_layers, batch_size):
        super(gaussian_lstm, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.batch_size = batch_size
        self.embed = nn.Linear(input_size, hidden_size)
        self.lstm = nn.ModuleList([nn.LSTMCell(hidden_size, hidden_size) for i in range(self.n_layers)])
        self.mu_net = nn.Linear(hidden_size, output_size)
        self.logvar_net = nn.Linear(hidden_size, output_size)
        self.hidden = self.init_hidden()

    def init_hidden(self, batch_size=1):
        hidden = []
        for i in range(self.n_layers):
            hidden.append((Variable(torch.zeros(batch_size, self.hidden_size).to(device)),
                           Variable(torch.zeros(batch_size, self.hidden_size).to(device))))
        self.hidden = hidden
        return hidden

    def init_hidden_(self):
        hidden = []
        for i in range(self.n_layers):
            hidden.append((Variable(torch.zeros(self.batch_size, self.hidden_size).to(device)),
                           Variable(torch.zeros(self.batch_size, self.hidden_size).to(device))))
        self.hidden = hidden
        #return hidden

    def reparameterize(self, mu, logvar):
        logvar = logvar.mul(0.5).exp_()
        eps = Variable(logvar.data.new(logvar.size()).normal_())
        #return eps.add_(mu)
        #return eps.mul(logvar)
        return eps.mul(logvar).add_(mu)

    def forward(self, input):
        #import pdb
        #pdb.set_trace()
        embedded = self.embed(input.view(-1, self.input_size))
        h_in = embedded
        for i in range(self.n_layers):
            self.hidden[i] = self.lstm[i](h_in, self.hidden[i])
            h_in = self.hidden[i][0]
        mu = self.mu_net(h_in)
        logvar = self.logvar_net(h_in)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

In [4]:
class dcgan_conv(nn.Module):
    def __init__(self, nin, nout):
        super(dcgan_conv, self).__init__()
        self.main = nn.Sequential(
                nn.Conv2d(nin, nout, 4, 2, 1),
                nn.BatchNorm2d(nout),
                nn.LeakyReLU(0.2),
                )

    def forward(self, input):
        return self.main(input)

class dcgan_upconv(nn.Module):
    def __init__(self, nin, nout):
        super(dcgan_upconv, self).__init__()
        self.main = nn.Sequential(
                nn.ConvTranspose2d(nin, nout, 4, 2, 1),
                nn.BatchNorm2d(nout),
                nn.LeakyReLU(0.2),
                )

    def forward(self, input):
        return self.main(input)

class encoder(nn.Module):
    def __init__(self, dim, nc=1):
        super(encoder, self).__init__()
        self.dim = dim
        nf = 64
        # input is (nc) x 64 x 64
        self.c1 = dcgan_conv(nc, nf)
        # state size. (nf) x 32 x 32
        self.c2 = dcgan_conv(nf, nf * 2)
        # state size. (nf*2) x 16 x 16
        self.c3 = dcgan_conv(nf * 2, nf * 4)
        # state size. (nf*4) x 8 x 8
        self.c4 = dcgan_conv(nf * 4, nf * 8)
        # state size. (nf*8) x 4 x 4
        self.c5 = nn.Sequential(
                nn.Conv2d(nf * 8, dim, 4, 1, 0),
                nn.BatchNorm2d(dim),
                nn.Tanh()
                )

    def forward(self, input):
        h1 = self.c1(input)
        h2 = self.c2(h1)
        h3 = self.c3(h2)
        h4 = self.c4(h3)
        h5 = self.c5(h4)
        return h5.view(-1, self.dim), [h1, h2, h3, h4]


class decoder(nn.Module):
    def __init__(self, dim, nc=1):
        super(decoder, self).__init__()
        self.dim = dim
        nf = 64
        self.upc1 = nn.Sequential(
                # input is Z, going into a convolution
                nn.ConvTranspose2d(dim, nf * 8, 4, 1, 0),
                nn.BatchNorm2d(nf * 8),
                nn.LeakyReLU(0.2)
                )
        # state size. (nf*8) x 4 x 4
        self.upc2 = dcgan_upconv(nf * 8 * 2, nf * 4)
        # state size. (nf*4) x 8 x 8
        self.upc3 = dcgan_upconv(nf * 4 * 2, nf * 2)
        # state size. (nf*2) x 16 x 16
        self.upc4 = dcgan_upconv(nf * 2 * 2, nf)
        # state size. (nf) x 32 x 32
        self.upc5 = nn.Sequential(
                nn.ConvTranspose2d(nf * 2, nc, 4, 2, 1),
                nn.Sigmoid()
                # state size. (nc) x 64 x 64
                )

    def forward(self, input):
        vec, skip = input 
        d1 = self.upc1(vec.view(-1, self.dim, 1, 1))
        d2 = self.upc2(torch.cat([d1, skip[3]], 1))
        d3 = self.upc3(torch.cat([d2, skip[2]], 1))
        d4 = self.upc4(torch.cat([d3, skip[1]], 1))
        output = self.upc5(torch.cat([d4, skip[0]], 1))
        return output

In [5]:
class KLCriterion(nn.Module):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def forward(self, mu1, logvar1, mu2, logvar2):
        """KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))"""
        sigma1 = logvar1.mul(0.5).exp() 
        sigma2 = logvar2.mul(0.5).exp() 
        kld = torch.log(sigma2/sigma1) + (torch.exp(logvar1) + (mu1 - mu2)**2)/(2*torch.exp(logvar2)) - 1/2
        return kld.sum() / self.batch_size

In [6]:
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1 or classname.find("Linear") != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find("BatchNorm") != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [7]:
class P2PModel(nn.Module):
    def __init__(
        self,
        batch_size=100,
        channels=1,
        g_dim=128,
        z_dim=10,
        rnn_size=256,
        prior_rnn_layers=1,
        posterior_rnn_layers=1,
        predictor_rnn_layers=2,
        learning_rate=0.002,
        skip_prob: float = 0.1,
        n_past: int = 1,
        last_frame_skip: bool = False,
        beta: float = 0.0001,
        weight_align: float = 0.1,
        weight_cpc: float = 100,
    ):

        super().__init__()
        self.batch_size = batch_size
        self.channels = channels
        self.g_dim = g_dim
        self.z_dim = z_dim
        self.rnn_size = rnn_size
        self.prior_rnn_layers = prior_rnn_layers
        self.posterior_rnn_layers = posterior_rnn_layers
        self.predictor_rnn_layers = predictor_rnn_layers
        
        # Training parameters
        self.learning_rate = learning_rate
        self.skip_prob = skip_prob
        self.n_past = n_past
        self.last_frame_skip = last_frame_skip
        self.beta = beta
        self.weight_align = weight_align
        self.weight_cpc = weight_cpc

        # LSTMs
        self.frame_predictor = lstm(
            self.g_dim + self.z_dim + 1 + 1,
            self.g_dim,
            self.rnn_size,
            self.predictor_rnn_layers,
            self.batch_size,
        )
        self.posterior = gaussian_lstm(
            self.g_dim + self.g_dim + 1 + 1,
            self.z_dim,
            self.rnn_size,
            self.posterior_rnn_layers,
            self.batch_size,
        )
        self.prior = gaussian_lstm(
            self.g_dim + self.g_dim + 1 + 1,
            self.z_dim,
            self.rnn_size,
            self.prior_rnn_layers,
            self.batch_size,
        )

        # encoder & decoder
        self.encoder = encoder(self.g_dim, self.channels)
        self.decoder = decoder(self.g_dim, self.channels)

        # optimizer
        self.optimizer = optim.Adam

        # criterions
        self.mse_criterion = nn.MSELoss()  # recon and cpc
        self.kl_criterion = KLCriterion(batch_size=batch_size)
        self.align_criterion = nn.MSELoss()

        self.init_weight()
        self.init_optimizer()

    def init_optimizer(self):
        self.frame_predictor_optimizer = self.optimizer(
            self.frame_predictor.parameters(), lr=self.learning_rate, betas=(self.beta, 0.999)
        )
        self.posterior_optimizer = self.optimizer(
            self.posterior.parameters(), lr=self.learning_rate, betas=(self.beta, 0.999)
        )
        self.prior_optimizer = self.optimizer(
            self.prior.parameters(), lr=self.learning_rate, betas=(self.beta, 0.999)
        )
        self.encoder_optimizer = self.optimizer(
            self.encoder.parameters(), lr=self.learning_rate, betas=(self.beta, 0.999)
        )
        self.decoder_optimizer = self.optimizer(
            self.decoder.parameters(), lr=self.learning_rate, betas=(self.beta, 0.999)
        )

    def init_hidden(self, batch_size=1):
        self.frame_predictor.hidden = self.frame_predictor.init_hidden(
            batch_size=batch_size
        )
        self.posterior.hidden = self.posterior.init_hidden(batch_size=batch_size)
        self.prior.hidden = self.prior.init_hidden(batch_size=batch_size)

    def init_weight(self):
        self.frame_predictor.apply(init_weights)
        self.posterior.apply(init_weights)
        self.prior.apply(init_weights)
        self.encoder.apply(init_weights)
        self.decoder.apply(init_weights)

    def get_global_descriptor(self, x, start_ix=0, cp_ix=None):
        """Get the global descriptor based on x, start_ix, cp_ix."""
        if cp_ix is None:
            cp_ix = len(x) - 1
        x_cp = x[cp_ix]
        h_cp = self.encoder(x_cp)[0]  # 1 is input for skip-connection

        return x_cp, h_cp

    def p2p_generate(
        self,
        x,
        len_output,
        eval_cp_ix,
        start_ix=0,
        cp_ix=-1,
        model_mode="full",
        skip_frame=False,
        init_hidden=True,
    ):
        """Point-to-Point Generation given input sequence. Generate *1* sample for each input sequence.

        params:
            x: input sequence
            len_output: length of the generated sequence
            eval_cp_ix: cp_ix of the output sequence. usually it is len_output-1
            model_mode:
                - full:      post then prior
                - posterior: all use posterior
                - prior:     all use prior

        """

        if type(x) == tuple:
            # h36m
            (pose_2d, pose_3d, camera_view) = x
            # T, bs = len(pose_3d), pose_3d[0].shape[0]
            # x = pose_3d.view(T, bs, -1)
            x = pose_3d
            batch_size, coor, n_dim = x[0].shape
            dim_shape = (coor, n_dim)
        else:
            batch_size, channels, h, w = x[0].shape
            dim_shape = (channels, h, w)

        # gen_seq will collect the generated frames
        gen_seq = [x[0]]
        x_in = x[0]

        # NOTE: for visualization
        # init lstm
        if init_hidden:
            self.init_hidden(batch_size=batch_size)

        # get global descriptor
        seq_len = len(x)
        cp_ix = seq_len - 1
        x_cp, global_z = self.get_global_descriptor(
            x, cp_ix=cp_ix
        )  # here global_z is h_cp

        ###### time skipping
        skip_prob = self.skip_prob

        prev_i = 0
        max_skip_count = seq_len * skip_prob
        skip_count = 0
        probs = np.random.uniform(0, 1, len_output - 1)

        # for each sample, generate *n_eval* frames
        for i in range(1, len_output):
            # if np.random.uniform(0, 1) <= skip_prob and i > 1 and skip_count < max_skip_count and i != cp_ix:
            if (
                probs[i - 1] <= skip_prob
                and i >= self.n_past
                and skip_count < max_skip_count
                and i != 1
                and i != (len_output - 1)
                and skip_frame
            ):
                skip_count += 1
                gen_seq.append(torch.zeros_like(x_in))
                continue

            time_until_cp = (
                torch.zeros(batch_size, 1)
                .fill_((eval_cp_ix - i + 1) / eval_cp_ix)
                .to(x_cp)
            )
            delta_time = (
                torch.zeros(batch_size, 1).fill_((i - prev_i) / eval_cp_ix).to(x_cp)
            )

            prev_i = i

            h = self.encoder(x_in)
            # if opt.last_frame_skip or i < opt.n_past: # original
            if self.last_frame_skip or i == 1 or i < self.n_past:
                h, skip = h
            else:
                h, _ = h

            h_cpaw = torch.cat([h, global_z, time_until_cp, delta_time], 1).detach()

            if i < self.n_past:
                h_target = self.encoder(x[i])[0]
                h_target_cpaw = torch.cat(
                    [h_target, global_z, time_until_cp, delta_time], 1
                ).detach()
                zt, _, _ = self.posterior(h_target_cpaw)
                zt_p, _, _ = self.prior(h_cpaw)

                if model_mode == "posterior" or model_mode == "full":
                    self.frame_predictor(
                        torch.cat([h, zt, time_until_cp, delta_time], 1)
                    )
                elif model_mode == "prior":
                    self.frame_predictor(
                        torch.cat([h, zt_p, time_until_cp, delta_time], 1)
                    )

                x_in = x[i]
                gen_seq.append(
                    x_in
                )  # NOTE: gen_seq can append the decoded x_in for comparing with gt
            else:
                if i < len(x):  # for posterior
                    h_target = self.encoder(x[i])[0]
                    h_target_cpaw = torch.cat(
                        [h_target, global_z, time_until_cp, delta_time], 1
                    ).detach()
                else:
                    h_target_cpaw = h_cpaw

                zt, _, _ = self.posterior(h_target_cpaw)
                zt_p, _, _ = self.prior(h_cpaw)

                if model_mode == "posterior":
                    h = self.frame_predictor(
                        torch.cat([h, zt, time_until_cp, delta_time], 1)
                    )
                elif model_mode == "prior" or model_mode == "full":
                    h = self.frame_predictor(
                        torch.cat([h, zt_p, time_until_cp, delta_time], 1)
                    )

                x_in = self.decoder([h, skip]).detach()
                gen_seq.append(
                    x_in
                )  # NOTE: gen_seq can append the decoded x_in for comparing with gt
        return gen_seq

    def forward(self, x, start_ix=0, cp_ix=-1):
        """training"""
        if type(x) == tuple:  # h36m # NOTE: TODO
            (pose_2d, pose_3d, camera_view) = x
            x = pose_3d

        batch_size = x[0].shape[0]

        # initialize the hidden state
        self.init_hidden(batch_size=batch_size)

        # losses
        mse_loss = 0
        kld_loss = 0
        cpc_loss = 0
        align_loss = 0

        # get global descriptor
        seq_len = len(x)
        start_ix = 0
        cp_ix = seq_len - 1
        x_cp, global_z = self.get_global_descriptor(
            x, start_ix, cp_ix
        )  # here global_z is h_cp

        # time skipping
        skip_prob = self.skip_prob

        prev_i = 0
        max_skip_count = seq_len * skip_prob
        skip_count = 0
        probs = np.random.uniform(0, 1, seq_len - 1)

        for i in range(1, seq_len):
            # if np.random.uniform(0, 1) <= skip_prob and i > 1 and skip_count < max_skip_count and i != cp_ix:
            # if probs[i-1] <= skip_prob and i >= opt.n_past and skip_count < max_skip_count and i != cp_ix:
            if (
                probs[i - 1] <= skip_prob
                and i >= self.n_past
                and skip_count < max_skip_count
                and i != 1
                and i != cp_ix
            ):
                skip_count += 1
                continue

            if i > 1:
                align_loss += self.align_criterion(h[0], h_pred)

            time_until_cp = (
                torch.zeros(batch_size, 1).fill_((cp_ix - i + 1) / cp_ix).to(x_cp)
            )
            delta_time = torch.zeros(batch_size, 1).fill_((i - prev_i) / cp_ix).to(x_cp)
            prev_i = i

            h = self.encoder(x[i - 1])
            h_target = self.encoder(x[i])[0]

            # if opt.last_frame_skip or i < opt.n_past: # original
            if self.last_frame_skip or i <= self.n_past:
                h, skip = h
            else:
                h = h[0]

            # cp aware
            h_cpaw = torch.cat([h, global_z, time_until_cp, delta_time], 1)
            h_target_cpaw = torch.cat(
                [h_target, global_z, time_until_cp, delta_time], 1
            )

            zt, mu, logvar = self.posterior(h_target_cpaw)
            zt_p, mu_p, logvar_p = self.prior(h_cpaw)

            frame_predictor_input = torch.cat([h, zt, time_until_cp, delta_time], 1)

            h_pred = self.frame_predictor( frame_predictor_input)
            x_pred = self.decoder([h_pred, skip])

            # loss
            if i == (cp_ix):  # the gen-cp-frame should be exactly as x_cp
                h_pred_p = self.frame_predictor(
                    torch.cat([h, zt_p, time_until_cp, delta_time], 1)
                )
                x_pred_p = self.decoder([h_pred_p, skip])
                cpc_loss = self.mse_criterion(x_pred_p, x_cp)

            mse_loss += self.mse_criterion(x_pred, x[i])
            kld_loss += self.kl_criterion(mu, logvar, mu_p, logvar_p)

        # backward
        # update model without prior
        # loss = torch.tensor(
        #     [mse_loss + kld_loss * opt.beta + align_loss * opt.weight_align],
        #     requires_grad=True,
        # )
        loss = mse_loss + kld_loss * self.beta + align_loss * self.weight_align + cpc_loss * self.weight_cpc

        loss.backward()
        #self.update_model_without_prior()
        self.update_model()
        

        self.prior.zero_grad()
        self.posterior.zero_grad()
        self.frame_predictor.zero_grad()
        self.encoder.zero_grad()
        self.decoder.zero_grad()
        # update model with prior due to loss_on_prior
        #self.prior.zero_grad()
        # prior_loss = torch.tensor(
        #     [kld_loss + cpc_loss * opt.weight_cpc], requires_grad=True
        # )
        #prior_loss = kld_loss + cpc_loss * self.weight_cpc
        #prior_loss.backward()
        #self.update_prior()

        # mse_loss = torch.tensor(mse_loss)
        # kld_loss = torch.tensor(kld_loss)
        # cpc_loss = torch.tensor(cpc_loss)
        # align_loss = torch.tensor(align_loss)

        return (
            mse_loss.data.cpu().numpy() / seq_len,
            kld_loss.data.cpu().numpy() / seq_len,
            cpc_loss.data.cpu().numpy() / seq_len,
            align_loss.data.cpu().numpy() / seq_len,
        )

    def update_prior(self):
        self.prior_optimizer.step()

    def update_model_without_prior(self):
        self.frame_predictor_optimizer.step()
        self.posterior_optimizer.step()
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()

    def update_model(self):
        self.frame_predictor_optimizer.step()
        self.posterior_optimizer.step()
        self.prior_optimizer.step()
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()

In [8]:
from ganime.data.base import load_dataset
train_ds, test_ds, input_shape = load_dataset("moving_mnist_vae", "../../data", batch_size=100)# * strategy.num_replicas_in_sync)

2022-05-03 03:08:45.972070: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-05-03 03:08:49.074346: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13,460 MB memory:  -> device: 0, name: NVIDIA RTX A4000, pci bus id: 0000:25:00.0, compute capability: 8.6
2022-05-03 03:08:49.075255: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13,460 MB memory:  -> device: 1, name: NVIDIA RTX A4000, pci bus id: 0000:41:00.0, compute capability: 8.6


In [9]:
model = P2PModel().to(device)

In [10]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f216418e6d0>

In [11]:
iterator = train_ds.as_numpy_iterator()
for epoch in tqdm(range(200)):
    epoch_mse = 0
    epoch_kld = 0
    epoch_align = 0
    epoch_cpc = 0

    for x in train_ds.as_numpy_iterator():
        x = np.moveaxis(x[0], 0, 1)
        x = np.moveaxis(x, -1, 2)
        x = torch.tensor(x).to(device)
        
        start_ix = 0
        cp_ix = -1
        cp_ix = x.shape[1] - 1

        mse, kld, cpc, align = model(x, start_ix, cp_ix)
        epoch_mse += mse
        epoch_kld += kld
        epoch_cpc += cpc
        epoch_align += align
    print("EPOCH", epoch)
    print("epoch mse", epoch_mse)
    print("epoch kld", epoch_kld)
    print("epoch cpc", epoch_cpc)
    print("epoch align", epoch_align)

  0%|          | 0/200 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


EPOCH 0
epoch mse 3.582608070969583
epoch kld 24.33021411448717
epoch cpc 0.20322215035557753
epoch align 25.216130965948096
EPOCH 1
epoch mse 2.5908395290374737
epoch kld 39.14953933954239
epoch cpc 0.14576730309054256
epoch align 11.182117357850075
EPOCH 2
epoch mse 2.3822577044367783
epoch kld 64.45099029541015
epoch cpc 0.1331605749204755
epoch align 7.76171635836363
EPOCH 3
epoch mse 2.2770199179649353
epoch kld 86.81971297264099
epoch cpc 0.12619281327351925
epoch align 6.274997822940354
EPOCH 4
epoch mse 2.230066674947739
epoch kld 96.96366109848024
epoch cpc 0.12227157680317764
epoch align 5.353533518314361
EPOCH 5
epoch mse 2.167253895103931
epoch kld 96.01606941223146
epoch cpc 0.11924541797488929
epoch align 4.539233151078223
EPOCH 6
epoch mse 2.1367987036705007
epoch kld 95.5141518592834
epoch cpc 0.11690673390403389
epoch align 3.987957359105349
EPOCH 7
epoch mse 2.1204962849617015
epoch kld 101.11894369125365
epoch cpc 0.11568514769896864
epoch align 3.356449429690838
EPO

KeyboardInterrupt: 

In [12]:
generated = model.p2p_generate(x, 20, 19)
generated = torch.stack(generated, axis=1).cpu()
generated = torch.moveaxis(generated, 2, -1)

In [13]:
from ganime.visualization.videos import display_videos
display_videos(generated, 1, 3)