In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

In [2]:
from torchsummary import summary

In [3]:
import numpy as np

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [5]:
## computes output shape for both convolutions and pooling layers
def output_shape(in_dim,stride,padding,kernel,dilation=1):
    out_dim = np.floor((in_dim + 2*padding - dilation*(kernel-1)-1)/stride+1).astype(int)
    return out_dim

def output_shape_transpose(in_dim,stride,padding,kernel,output_padding, dilation=1):
    out_dim = (in_dim-1)*stride-2*padding+dilation*(kernel-1)+output_padding+1
    return out_dim

In [6]:
output_shape(np.arange(5),stride=1,padding=0,kernel=1)

array([0, 1, 2, 3, 4])

In [7]:
def get_dilation(out_dim,in_dim,stride,padding,kernel,output_padding):
    dilation = np.floor((out_dim-(in_dim-1)*stride+2*padding-output_padding-1)/(kernel-1))
    new_dim  = output_shape_transpose(in_dim,stride,padding,kernel,output_padding, dilation)
    if new_dim == out_dim:
        pass
    else:
        output_padding = (out_dim-new_dim)
    return dilation, output_padding

def get_output_padding(in_dim,out_dim, stride,padding,kernel,dilation=1):
    return out_dim-(in_dim-1)*stride+2*padding-dilation*(kernel-1)-1

In [8]:
class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [125]:
class Encoder(nn.Module):
    def __init__(self, n_layers, out_channels, kernel_sizes, scale_facs, paddings, strides, dim, activations, input_c, input_dim, latent_dim):
        super(Encoder, self).__init__()
        if dim == '1D':
            self.conv = nn.Conv1d
            self.pool = nn.AdaptiveMaxPool1d#nn.MaxPool1d
            self.N    = 1
        elif dim == '2D':
            self.conv = nn.Conv2d
            self.pool = nn.AdaptiveMaxPool2d
            self.N    = 2
        else:
            raise Exception("Invalid data dimensionality (must be either 1D or 2D).")
            
        self.model = nn.ModuleList()
    
        current_channels   = input_c
        current_dim        = input_dim
        self.out_dims      = []
        
        for ii in range(n_layers):
            
            conv = self.conv(current_channels, out_channels[ii], kernel_sizes[ii], strides[ii], paddings[ii])
            self.out_dims.append(current_dim)
            self.model.append(conv)
            
            current_channels =  out_channels[ii]
            current_dim      =  output_shape(current_dim, strides[ii], paddings[ii],kernel_sizes[ii])
            
            gate = getattr(nn, activations[ii])()
            self.model.append(gate)
            
            pool = self.pool([current_dim//scale_facs[ii]]*self.N)
            self.model.append(pool)
            
            current_dim = current_dim//scale_facs[ii]

        self.final_dim = current_dim
        self.final_c   = current_channels
        
        self.model.append(nn.Flatten())
        current_shape = current_channels*current_dim**self.N
        self.model.append(nn.Linear(current_shape,latent_dim))

            
    def forward(self, x):
        for i, l in enumerate(self.model):
            print(l, x.shape)
            x = l(x)
            print(x.shape)
        return x

In [241]:
class Decoder(nn.Module):
    def __init__(self, n_layers, out_channels, kernel_sizes, scale_facs, paddings, strides, dim, activations, latent_dim, final_dim, final_c, out_dims):
        super(Decoder, self).__init__()
        
        if dim == '1D':
            self.conv = nn.ConvTranspose1d
            self.N    = 1
        elif dim == '2D':
            self.conv = nn.ConvTranspose2d
            self.N    = 2
        else:
            raise Exception("Invalid data dimensionality (must be either 1D or 2D).")
        
        self.pool  = nn.Upsample
        
        self.model = nn.ModuleList()
        
        final_shape = final_c*final_dim**self.N
    
        self.model.append(nn.Flatten())
        self.model.append(nn.Linear(latent_dim,final_shape))

        if dim == '1D':
            self.model.append(Reshape((-1, final_c,final_dim)))
        else:
            self.model.append(Reshape((-1, final_c,final_dim,final_dim)))
                              
        current_dim      = final_dim
        current_channels = final_c
            
        for jj in range(1,n_layers+1):
            ii = n_layers - jj 
            gate = getattr(nn, activations[ii])()
            self.model.append(gate)
                  
            upsample    = nn.Upsample(scale_factor=scale_facs[ii])
            self.model.append(upsample)
            current_dim = current_dim*scale_facs[ii]
                              
            output_padding = get_output_padding(current_dim,out_dims[ii],strides[ii],paddings[ii],kernel_sizes[ii],dilation=1)
            conv           = self.conv(current_channels, out_channels[ii], kernel_size=kernel_sizes[ii], stride=strides[ii], padding=paddings[ii], output_padding=output_padding)
            self.model.append(conv)
            current_channels = out_channels[ii]
            current_dim      = output_shape_transpose(current_dim, stride=strides[ii], padding=paddings[ii],kernel=kernel_sizes[ii],output_padding=output_padding)
                
        conv = self.conv(current_channels, 1, kernel_size=1, stride=1)
        self.model.append(conv)
        
    def forward(self, x):
        print(x.shape)
        for i, l in enumerate(self.model):
            print(x.shape, l)
            x = l(x)
            print(x.shape)
        return x

In [231]:
class Autoencoder(nn.Module):
    def __init__(self, n_layers, out_channels, kernel_sizes, scale_facs, paddings, strides, dim, activations, input_c, input_dim, latent_dim):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(n_layers, out_channels, kernel_sizes, scale_facs, paddings, strides, dim, activations, input_c, input_dim, latent_dim)
        self.decoder = Decoder(n_layers, out_channels, kernel_sizes, scale_facs, paddings, strides, dim, activations, latent_dim, self.encoder.final_dim, self.encoder.final_c, self.encoder.out_dims)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [232]:
n_layers     = 2
out_channels = [2,4]
kernel_sizes = [2,4]
scale_facs   = [1,4]
### stride,padding,kernel; [1,0,1] is identity
#pooling_layers = [[1,0,1], [1,0,2]]
scale_facs   = [2,1] 
paddings     = [0,0]
strides      = [2,4]
dim          = '2D'
activations  = ['ReLU', 'ReLU']
latent_dim   = 4
input_c      = 1 
input_dim    = 1000

In [233]:
AE = Autoencoder(n_layers, out_channels, kernel_sizes, scale_facs, paddings, strides, dim, activations, input_c, input_dim, latent_dim)

In [234]:
Enc = Encoder(n_layers, out_channels, kernel_sizes, scale_facs, paddings, strides, dim, activations, input_c, input_dim, latent_dim)

In [235]:
Enc.to(device)

Encoder(
  (model): ModuleList(
    (0): Conv2d(1, 2, kernel_size=(2, 2), stride=(2, 2))
    (1): ReLU()
    (2): AdaptiveMaxPool2d(output_size=[250, 250])
    (3): Conv2d(2, 4, kernel_size=(4, 4), stride=(4, 4))
    (4): ReLU()
    (5): AdaptiveMaxPool2d(output_size=[62, 62])
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=15376, out_features=4, bias=True)
  )
)

In [236]:
data =  torch.randn(1,1,1000,1000).to('cuda')

In [237]:


res = Enc.forward(data)

Conv2d(1, 2, kernel_size=(2, 2), stride=(2, 2)) torch.Size([1, 1, 1000, 1000])
torch.Size([1, 2, 500, 500])
ReLU() torch.Size([1, 2, 500, 500])
torch.Size([1, 2, 500, 500])
AdaptiveMaxPool2d(output_size=[250, 250]) torch.Size([1, 2, 500, 500])
torch.Size([1, 2, 250, 250])
Conv2d(2, 4, kernel_size=(4, 4), stride=(4, 4)) torch.Size([1, 2, 250, 250])
torch.Size([1, 4, 62, 62])
ReLU() torch.Size([1, 4, 62, 62])
torch.Size([1, 4, 62, 62])
AdaptiveMaxPool2d(output_size=[62, 62]) torch.Size([1, 4, 62, 62])
torch.Size([1, 4, 62, 62])
Flatten(start_dim=1, end_dim=-1) torch.Size([1, 4, 62, 62])
torch.Size([1, 15376])
Linear(in_features=15376, out_features=4, bias=True) torch.Size([1, 15376])
torch.Size([1, 4])


In [238]:
Dec =  Decoder(n_layers, out_channels, kernel_sizes, scale_facs, paddings, strides, dim, activations, latent_dim, Enc.final_dim, Enc.final_c, Enc.out_dims)

In [239]:
Dec.to(device)

Decoder(
  (model): ModuleList(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=4, out_features=15376, bias=True)
    (2): Reshape()
    (3): ReLU()
    (4): Upsample(scale_factor=1.0, mode=nearest)
    (5): ConvTranspose2d(4, 4, kernel_size=(4, 4), stride=(4, 4), output_padding=(2, 2))
    (6): ReLU()
    (7): Upsample(scale_factor=2.0, mode=nearest)
    (8): ConvTranspose2d(4, 2, kernel_size=(2, 2), stride=(2, 2))
    (9): ConvTranspose2d(2, 1, kernel_size=(2, 2), stride=(0, 0))
  )
)

In [240]:
Dec.forward(res)

torch.Size([1, 4])
torch.Size([1, 4]) Flatten(start_dim=1, end_dim=-1)
torch.Size([1, 4])
torch.Size([1, 4]) Linear(in_features=4, out_features=15376, bias=True)
torch.Size([1, 15376])
torch.Size([1, 15376]) Reshape()
torch.Size([1, 4, 62, 62])
torch.Size([1, 4, 62, 62]) ReLU()
torch.Size([1, 4, 62, 62])
torch.Size([1, 4, 62, 62]) Upsample(scale_factor=1.0, mode=nearest)
torch.Size([1, 4, 62, 62])
torch.Size([1, 4, 62, 62]) ConvTranspose2d(4, 4, kernel_size=(4, 4), stride=(4, 4), output_padding=(2, 2))
torch.Size([1, 4, 250, 250])
torch.Size([1, 4, 250, 250]) ReLU()
torch.Size([1, 4, 250, 250])
torch.Size([1, 4, 250, 250]) Upsample(scale_factor=2.0, mode=nearest)
torch.Size([1, 4, 500, 500])
torch.Size([1, 4, 500, 500]) ConvTranspose2d(4, 2, kernel_size=(2, 2), stride=(2, 2))
torch.Size([1, 2, 1000, 1000])
torch.Size([1, 2, 1000, 1000]) ConvTranspose2d(2, 1, kernel_size=(2, 2), stride=(0, 0))


RuntimeError: non-positive stride is not supported

In [217]:
summary(Dec, (4,1))


torch.Size([2, 4, 1])
torch.Size([2, 4, 1]) Flatten(start_dim=1, end_dim=-1)
torch.Size([2, 4])
torch.Size([2, 4]) Linear(in_features=4, out_features=15376, bias=True)
torch.Size([2, 15376])
torch.Size([2, 15376]) Reshape()
torch.Size([2, 4, 62, 62])
torch.Size([2, 4, 62, 62]) ReLU()
torch.Size([2, 4, 62, 62])
torch.Size([2, 4, 62, 62]) Upsample(scale_factor=1.0, mode=nearest)
torch.Size([2, 4, 62, 62])
torch.Size([2, 4, 62, 62]) ConvTranspose2d(4, 4, kernel_size=(4, 4), stride=(4, 4), output_padding=(2, 2))
torch.Size([2, 4, 250, 250])
torch.Size([2, 4, 250, 250]) ReLU()
torch.Size([2, 4, 250, 250])
torch.Size([2, 4, 250, 250]) Upsample(scale_factor=2.0, mode=nearest)
torch.Size([2, 4, 500, 500])
torch.Size([2, 4, 500, 500]) ConvTranspose2d(4, 2, kernel_size=(2, 2), stride=(2, 2))
torch.Size([2, 2, 1000, 1000])
torch.Size([2, 2, 1000, 1000]) ConvTranspose2d(2, 1, kernel_size=(1, 1), stride=(1, 1))
torch.Size([2, 1, 1000, 1000])
---------------------------------------------------------