In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
from torchsummary import summary

In [3]:
import numpy as np
import matplotlib.pyplot as plt

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [5]:
## computes output shape for both convolutions and pooling layers
def output_shape(in_dim,stride,padding,kernel,dilation=1):
    out_dim = np.floor((in_dim + 2*padding - dilation*(kernel-1)-1)/stride+1).astype(int)
    return out_dim

def output_shape_transpose(in_dim,stride,padding,kernel,output_padding, dilation=1):
    out_dim = (in_dim-1)*stride-2*padding+dilation*(kernel-1)+output_padding+1
    return out_dim

In [6]:
output_shape(np.arange(5),stride=1,padding=0,kernel=1)

array([0, 1, 2, 3, 4])

In [7]:
def get_dilation(out_dim,in_dim,stride,padding,kernel,output_padding):
    dilation = np.floor((out_dim-(in_dim-1)*stride+2*padding-output_padding-1)/(kernel-1))
    new_dim  = output_shape_transpose(in_dim,stride,padding,kernel,output_padding, dilation)
    if new_dim == out_dim:
        pass
    else:
        output_padding = (out_dim-new_dim)
    return dilation, output_padding

def get_output_padding(in_dim,out_dim, stride,padding,kernel,dilation=1):
    return out_dim-(in_dim-1)*stride+2*padding-dilation*(kernel-1)-1

In [8]:
class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [9]:
class FCEncoder(nn.Module):
    def __init__(self, params, nparams):
        super(FCEncoder, self).__init__()
        if params['dim'] == '1D':
            self.N    = 1
        elif params['dim'] == '2D':
            self.N    = 2
        else:
            raise Exception("Invalid data dimensionality (must be either 1D or 2D).")
            
            
        if nparams['spec_norm']:
            spec_norm = nn.utils.spectral_norm
        else:
            spec_norm = nn.Identity()
            
        self.model = nn.ModuleList()
    
        self.model.append(nn.Flatten())
        
        current_dim = params['input_dim']**self.N*params['input_c']
        
        for ii in range(nparams['n_layers']):
            
            lin = nn.Linear(current_dim, nparams['out_sizes'][ii])
            self.model.append(spec_norm(lin))
            
            current_dim      =  nparams['out_sizes'][ii]
            
            if nparams['layer_norm'][ii]:
                norm = nn.LayerNorm(current_dim,elementwise_affine=nparams['affine'])
                self.model.append(norm)
            
            gate = getattr(nn, nparams['activations'][ii])()
            self.model.append(gate)
            
            dropout = nn.Dropout(nparams['dropout_rate'][ii])
            self.model.append(dropout)
        
        lin = nn.Linear(current_dim,params['latent_dim'])
        self.model.append(spec_norm(lin))
        
    def forward(self, x):
        for i, l in enumerate(self.model):
            print(x.shape)
            x = l(x)
            print(x.shape)
        return x

In [10]:
class ConvEncoder(nn.Module):
    def __init__(self, params, nparams):
        super(ConvEncoder, self).__init__()
        if params['dim'] == '1D':
            self.conv = nn.Conv1d
            self.pool = nn.AdaptiveMaxPool1d#nn.MaxPool1d
            self.N    = 1
        elif params['dim'] == '2D':
            self.conv = nn.Conv2d
            self.pool = nn.AdaptiveMaxPool2d
            self.N    = 2
        else:
            raise Exception("Invalid data dimensionality (must be either 1D or 2D).")
            
        if nparams['spec_norm']:
            spec_norm = nn.utils.spectral_norm
        else:
            spec_norm = nn.Identity()
            
        self.model = nn.ModuleList()
    
        current_channels   = params['input_c']
        current_dim        = params['input_dim']
        self.out_dims      = []
        
        for ii in range(nparams['n_layers']):
            
            conv = self.conv(current_channels, nparams['out_channels'][ii], nparams['kernel_sizes'][ii], nparams['strides'][ii], nparams['paddings'][ii])
            self.out_dims.append(current_dim)
            self.model.append(spec_norm(conv))
            
            current_channels =  nparams['out_channels'][ii]
            current_dim      =  output_shape(current_dim, nparams['strides'][ii], nparams['paddings'][ii],nparams['kernel_sizes'][ii])
            
            if nparams['layer_norm'][ii]:
                norm = nn.LayerNorm([current_channels]+[current_dim]*self.N,elementwise_affine=nparams['affine'])
                self.model.append(norm)
                
            gate = getattr(nn, nparams['activations'][ii])()
            self.model.append(gate)
            
            pool = self.pool([current_dim//nparams['scale_facs'][ii]]*self.N)
            self.model.append(pool)
            
            current_dim = current_dim//nparams['scale_facs'][ii]

        self.final_dim = current_dim
        self.final_c   = current_channels
        
        self.model.append(nn.Flatten())
        current_shape = current_channels*current_dim**self.N
        linear        = nn.Linear(current_shape,params['latent_dim'])
        self.model.append(spec_norm(linear))
            
    def forward(self, x):
        for i, l in enumerate(self.model):
            print(x.shape)
            x = l(x)
            print(x.shape)
        return x

In [11]:
class ConvDecoder(nn.Module):
    def __init__(self, params, nparams):
        super(ConvDecoder, self).__init__()
        
        if params['dim'] == '1D':
            self.conv = nn.ConvTranspose1d
            self.N    = 1
        elif params['dim'] == '2D':
            self.conv = nn.ConvTranspose2d
            self.N    = 2
        else:
            raise Exception("Invalid data dimensionality (must be either 1D or 2D).")
            
        if nparams['spec_norm']:
            spec_norm = nn.utils.spectral_norm
        else:
            spec_norm = nn.Identity()
        
        self.pool   = nn.Upsample
        
        self.model  = nn.ModuleList()
        
        final_shape = nparams['final_c']*nparams['final_dim']**self.N
    
        self.model.append(nn.Flatten())
        lin         = nn.Linear(params['latent_dim'],final_shape)
        self.model.append(spec_norm(lin))

        if params['dim'] == '1D':
            self.model.append(Reshape((-1, nparams['final_c'],nparams['final_dim'])))
        else:
            self.model.append(Reshape((-1, nparams['final_c'],nparams['final_dim'],nparams['final_dim'])))
                              
        current_dim      = nparams['final_dim']
        current_channels = nparams['final_c']
            
        for jj in range(1,nparams['n_layers']+1):
            ii = nparams['n_layers'] - jj 
            gate = getattr(nn, nparams['activations'][ii])()
            self.model.append(gate)
                  
            upsample    = nn.Upsample(scale_factor=nparams['scale_facs'][ii])
            self.model.append(upsample)
            current_dim = current_dim*nparams['scale_facs'][ii]
                              
            output_padding = get_output_padding(current_dim,nparams['out_dims'][ii],nparams['strides'][ii],nparams['paddings'][ii],nparams['kernel_sizes'][ii],dilation=1)
            conv           = self.conv(current_channels, nparams['out_channels'][ii], kernel_size=nparams['kernel_sizes'][ii], stride=nparams['strides'][ii], padding=nparams['paddings'][ii], output_padding=output_padding)
            self.model.append(spec_norm(conv))
            
            current_channels = nparams['out_channels'][ii]
            current_dim      = output_shape_transpose(current_dim, stride=nparams['strides'][ii], padding=nparams['paddings'][ii],kernel=nparams['kernel_sizes'][ii],output_padding=output_padding)
                
            if nparams['layer_norm'][ii]:
                norm = nn.LayerNorm([current_channels]+[current_dim]*self.N,elementwise_affine=nparams['affine'])
                self.model.append(norm)    
                
        
        conv = self.conv(current_channels, 1, kernel_size=1, stride=1)
        self.model.append(spec_norm(conv))
        
    def forward(self, x):
        for i, l in enumerate(self.model):
            print(x.shape)
            x = l(x)
            print(x.shape)
        return x

In [12]:
class FCDecoder(nn.Module):
    def __init__(self, params, nparams):
        super(FCDecoder, self).__init__()
        if params['dim'] == '1D':
            self.N    = 1
        elif params['dim'] == '2D':
            self.N    = 2
        else:
            raise Exception("Invalid data dimensionality (must be either 1D or 2D).")
            
            
        if nparams['spec_norm']:
            spec_norm = nn.utils.spectral_norm
        else:
            spec_norm = nn.Identity()
            
        self.model = nn.ModuleList()
    
        self.model.append(nn.Flatten())
        
        current_dim = params['latent_dim']
        
        for jj in range(1,nparams['n_layers']+1):
            ii = nparams['n_layers'] - jj 
            
            lin = nn.Linear(current_dim, nparams['out_sizes'][ii])
            self.model.append(spec_norm(lin))
            
            current_dim      =  nparams['out_sizes'][ii]
            
            if nparams['layer_norm'][ii]:
                norm = nn.LayerNorm(current_dim,elementwise_affine=nparams['affine'])
                self.model.append(norm)
                
            gate = getattr(nn, nparams['activations'][ii])()
            self.model.append(gate)
            
            dropout = nn.Dropout(nparams['dropout_rate'][ii])
            self.model.append(dropout)
        
        lin = nn.Linear(current_dim,params['input_dim']**self.N*params['input_c'])
        self.model.append(spec_norm(lin))
        
        self.model.append(Reshape([-1]+[params['input_c']]+[params['input_dim']]*self.N))
        
    def forward(self, x):
        for i, l in enumerate(self.model):
            x = l(x)
        return x



In [13]:
class Autoencoder(nn.Module):
    def __init__(self, params, network_params_enc, network_params_dec):
        super(Autoencoder, self).__init__()
        
        if params['encoder_type'] == 'conv':
            self.encoder = ConvEncoder(params, network_params_enc)
            network_params_enc['out_dims']  = self.encoder.out_dims
            network_params_enc['final_dim'] = self.encoder.final_dim
            network_params_enc['final_c']   = self.encoder.final_c
        elif params['encoder_type'] == 'fc':
            self.encoder = FCEncoder(params, network_params_enc)
        else:
            raise Exception('invalid encoder type')
            
        if params['decoder_type'] == 'conv':
            self.decoder = ConvDecoder(params, network_params_dec)
        elif params['decoder_type'] == 'fc':
            self.decoder = FCDecoder(params, network_params_dec)
        else:
            raise Exception('invalid decoder type')
        
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [27]:
n_layers     = 3
out_channels = [2,4,8]
kernel_sizes = [2,4,2]
### stride,padding,kernel; [1,0,1] is identity
#pooling_layers = [[1,0,1], [1,0,2]]
scale_facs   = [2,2,1] 
paddings     = [0,0,0]
strides      = [2,4,2]
layer_norm   = [True,True,True]
dropout_rate = [0.,0.,0.]
spec_norm    = True
dim          = '2D'
activations  = ['ReLU', 'ReLU','ReLU']
latent_dim   = 4
input_c      = 1 
input_dim    = 1000
encoder_type = 'conv'
decoder_type = 'conv'
affine       = False
out_sizes    = [256,128,64]

In [31]:
general_params      = {'input_c': input_c, 'input_dim': input_dim, 'latent_dim': latent_dim, 'encoder_type': encoder_type, 'decoder_type': decoder_type, 'dim': dim}
conv_network_params = {'n_layers': n_layers, 'out_channels': out_channels, 'kernel_sizes': kernel_sizes, 'scale_facs': scale_facs, 'paddings': paddings,\
                       'strides': strides,'activations': activations, 'spec_norm': spec_norm, 'dropout_rate':dropout_rate, 'layer_norm': layer_norm, 'affine': affine}
fc_network_params   = {'n_layers': n_layers, 'out_sizes': out_sizes,'activations': activations, 'spec_norm': spec_norm, 'dropout_rate':dropout_rate, \
                       'layer_norm': layer_norm, 'affine': affine}

In [32]:
AE = Autoencoder(general_params, conv_network_params, conv_network_params)

In [33]:
AE.to(device)

Autoencoder(
  (encoder): ConvEncoder(
    (model): ModuleList(
      (0): Conv2d(1, 2, kernel_size=(2, 2), stride=(2, 2))
      (1): LayerNorm((2, 500, 500), eps=1e-05, elementwise_affine=False)
      (2): ReLU()
      (3): AdaptiveMaxPool2d(output_size=[250, 250])
      (4): Conv2d(2, 4, kernel_size=(4, 4), stride=(4, 4))
      (5): LayerNorm((4, 62, 62), eps=1e-05, elementwise_affine=False)
      (6): ReLU()
      (7): AdaptiveMaxPool2d(output_size=[31, 31])
      (8): Conv2d(4, 8, kernel_size=(2, 2), stride=(2, 2))
      (9): LayerNorm((8, 15, 15), eps=1e-05, elementwise_affine=False)
      (10): ReLU()
      (11): AdaptiveMaxPool2d(output_size=[15, 15])
      (12): Flatten(start_dim=1, end_dim=-1)
      (13): Linear(in_features=1800, out_features=4, bias=True)
    )
  )
  (decoder): ConvDecoder(
    (model): ModuleList(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1): Linear(in_features=4, out_features=1800, bias=True)
      (2): Reshape()
      (3): ReLU()
      (4): Upsamp

In [34]:
data =  torch.randn(1,1,1000,1000).to('cuda')


res = AE.forward(data)

torch.Size([1, 1, 1000, 1000])
torch.Size([1, 2, 500, 500])
torch.Size([1, 2, 500, 500])
torch.Size([1, 2, 500, 500])
torch.Size([1, 2, 500, 500])
torch.Size([1, 2, 500, 500])
torch.Size([1, 2, 500, 500])
torch.Size([1, 2, 250, 250])
torch.Size([1, 2, 250, 250])
torch.Size([1, 4, 62, 62])
torch.Size([1, 4, 62, 62])
torch.Size([1, 4, 62, 62])
torch.Size([1, 4, 62, 62])
torch.Size([1, 4, 62, 62])
torch.Size([1, 4, 62, 62])
torch.Size([1, 4, 31, 31])
torch.Size([1, 4, 31, 31])
torch.Size([1, 8, 15, 15])
torch.Size([1, 8, 15, 15])
torch.Size([1, 8, 15, 15])
torch.Size([1, 8, 15, 15])
torch.Size([1, 8, 15, 15])
torch.Size([1, 8, 15, 15])
torch.Size([1, 8, 15, 15])
torch.Size([1, 8, 15, 15])
torch.Size([1, 1800])
torch.Size([1, 1800])
torch.Size([1, 4])
torch.Size([1, 4])
torch.Size([1, 4])
torch.Size([1, 4])
torch.Size([1, 1800])
torch.Size([1, 1800])
torch.Size([1, 8, 15, 15])
torch.Size([1, 8, 15, 15])
torch.Size([1, 8, 15, 15])
torch.Size([1, 8, 15, 15])
torch.Size([1, 8, 15, 15])
torch.

In [35]:
summary(AE, data.shape[1::])

torch.Size([2, 1, 1000, 1000])
torch.Size([2, 2, 500, 500])
torch.Size([2, 2, 500, 500])
torch.Size([2, 2, 500, 500])
torch.Size([2, 2, 500, 500])
torch.Size([2, 2, 500, 500])
torch.Size([2, 2, 500, 500])
torch.Size([2, 2, 250, 250])
torch.Size([2, 2, 250, 250])
torch.Size([2, 4, 62, 62])
torch.Size([2, 4, 62, 62])
torch.Size([2, 4, 62, 62])
torch.Size([2, 4, 62, 62])
torch.Size([2, 4, 62, 62])
torch.Size([2, 4, 62, 62])
torch.Size([2, 4, 31, 31])
torch.Size([2, 4, 31, 31])
torch.Size([2, 8, 15, 15])
torch.Size([2, 8, 15, 15])
torch.Size([2, 8, 15, 15])
torch.Size([2, 8, 15, 15])
torch.Size([2, 8, 15, 15])
torch.Size([2, 8, 15, 15])
torch.Size([2, 8, 15, 15])
torch.Size([2, 8, 15, 15])
torch.Size([2, 1800])
torch.Size([2, 1800])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 1800])
torch.Size([2, 1800])
torch.Size([2, 8, 15, 15])
torch.Size([2, 8, 15, 15])
torch.Size([2, 8, 15, 15])
torch.Size([2, 8, 15, 15])
torch.Size([2, 8, 15, 15])
torch.

In [None]:
res.shape