## Definitions
   * mu_t is the predicted next frame via an autoreggressive model, modeled by a hidden state variable h
   * This is corrected by y(t)_0 which is an additive correction such that we predict (x(t) - mu(t))/sigma, where sigma is a weighted hyperparameter
   * The diffusion model will then predict the x_t at runtime by taking the autoreggressive mu, and the diffusion models prediction for the next variable.

# Takeaways
   * The forward process is just an autoreggressive model RNN
   * The majority of the training happens in the diffusion model, which will learn to correct the error of the autoregressive flow.
   * From when we generate, we will input these corrections, to create to iteratively correct the time sequence model as we sample forward

In [3]:
import ResBlock_ConvGRU_Downsample
import UpSampleBlock
import MovingMNIST
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import MovingMNIST
from torchsummary import summary


# CRNN

Our first step is to create a model of a CRNN, this will allow us to do an autoregressive generation of our data. This should work fine to model a general idea of the videos flow; however, will fail to do anything much deeper

In [4]:
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        
        self.DownSample1 = ResBlock_ConvGRU_Downsample.ResBlock_ConvGRU_Downsample(pow(2,0), pow(2,1))
        self.DownSample2 = ResBlock_ConvGRU_Downsample.ResBlock_ConvGRU_Downsample(pow(2,1), pow(2,2))
        self.DownSample3 = ResBlock_ConvGRU_Downsample.ResBlock_ConvGRU_Downsample(pow(2,2), pow(2,3))
        self.DownSample4 = ResBlock_ConvGRU_Downsample.ResBlock_ConvGRU_Downsample(pow(2,3), pow(2,4))
        self.UpSample = [0] * 4
            
        self.bottom = nn.ConvTranspose2d(16, 16, kernel_size=2, stride=2)
        
        
        self.UpSample1 = UpSampleBlock.ResBlockUpsample(pow(2,0+1), pow(2,0))
        self.UpSample2 = UpSampleBlock.ResBlockUpsample(pow(2,1+1), pow(2,1))
        self.UpSample3 = UpSampleBlock.ResBlockUpsample(pow(2,2+1), pow(2,2))
        self.UpSample4 = UpSampleBlock.ResBlockUpsample(pow(2,3+1), pow(2,3))
    def forward(self, x, hidden_state = [None]*4):
        
        
        print (x.shape, hidden_state.shape)
        out = x
        
        residual = [0]*4

        out, residual[0] = self.DownSample1(out)
        out, residual[1] = self.DownSample2(out)
        out, residual[2]= self.DownSample3(out)
        out, residual[3] = self.DownSample4(out)
        
        out = self.bottom(out)
        
        out = self.UpSample4(out, residual[3])
        out = self.UpSample3(out, residual[2])
        out = self.UpSample2(out, residual[1])
        out = self.UpSample1(out, residual[0])
            
        return x
 

In [6]:
2^4

6

In [5]:
summary(MyModel(), [(1, 32, 32), (1,32,32)])

torch.Size([2, 8, 4, 4]) torch.Size([2, 16, 4, 4])
torch.Size([2, 4, 8, 8]) torch.Size([2, 8, 8, 8])
torch.Size([2, 2, 16, 16]) torch.Size([2, 4, 16, 16])
torch.Size([2, 1, 32, 32]) torch.Size([2, 2, 32, 32])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 2, 32, 32]              20
       BatchNorm2d-2            [-1, 2, 32, 32]               4
              ReLU-3            [-1, 2, 32, 32]               0
            Conv2d-4            [-1, 2, 32, 32]              38
       BatchNorm2d-5            [-1, 2, 32, 32]               4
              ReLU-6            [-1, 3, 32, 32]               0
            Conv2d-7            [-1, 2, 32, 32]              56
         MaxPool2d-8            [-1, 2, 16, 16]               0
ResBlock_ConvGRU_Downsample-9  [[-1, 2, 16, 16], [-1, 2, 32, 32]]               0
           Conv2d-10            [-1, 4, 16, 16]              76
    