In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
from torchsummary import summary

In [49]:
import numpy as np
import matplotlib.pyplot as plt
from functools import partial

In [99]:
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [5]:
## computes output shape for both convolutions and pooling layers
def output_shape(in_dim,stride,padding,kernel,dilation=1):
    out_dim = np.floor((in_dim + 2*padding - dilation*(kernel-1)-1)/stride+1).astype(int)
    return out_dim

def output_shape_transpose(in_dim,stride,padding,kernel,output_padding, dilation=1):
    out_dim = (in_dim-1)*stride-2*padding+dilation*(kernel-1)+output_padding+1
    return out_dim

In [6]:
output_shape(np.arange(5),stride=1,padding=0,kernel=1)

array([0, 1, 2, 3, 4])

In [130]:
def get_dilation(out_dim,in_dim,stride,padding,kernel,output_padding):
    dilation = np.floor((out_dim-(in_dim-1)*stride+2*padding-output_padding-1)/(kernel-1))
    new_dim  = output_shape_transpose(in_dim,stride,padding,kernel,output_padding, dilation)
    if new_dim == out_dim:
        pass
    else:
        output_padding = (out_dim-new_dim)
    return dilation, output_padding

def get_output_padding(in_dim,out_dim, stride,padding,kernel,dilation=1):
    return out_dim-(in_dim-1)*stride+2*padding-dilation*(kernel-1)-1

In [131]:
class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [132]:
class FCEncoder(nn.Module):
    def __init__(self, params, nparams):
        super(FCEncoder, self).__init__()
        if params['dim'] == '1D':
            self.N    = 1
        elif params['dim'] == '2D':
            self.N    = 2
        else:
            raise Exception("Invalid data dimensionality (must be either 1D or 2D).")
            
            
        if nparams['spec_norm']:
            spec_norm = nn.utils.spectral_norm
        else:
            spec_norm = nn.Identity()
            
        self.model = nn.ModuleList()
    
        self.model.append(nn.Flatten())
        
        current_dim = params['input_dim']**self.N*params['input_c']
        
        for ii in range(nparams['n_layers']):
            
            lin = nn.Linear(current_dim, nparams['out_sizes'][ii])
            self.model.append(spec_norm(lin))
            
            current_dim      =  nparams['out_sizes'][ii]
            
            if nparams['layer_norm'][ii]:
                norm = nn.LayerNorm(current_dim,elementwise_affine=nparams['affine'])
                self.model.append(norm)
            
            gate = getattr(nn, nparams['activations'][ii])()
            self.model.append(gate)
            
            dropout = nn.Dropout(nparams['dropout_rate'][ii])
            self.model.append(dropout)
        
        lin = nn.Linear(current_dim,params['latent_dim'])
        self.model.append(spec_norm(lin))
        
    def forward(self, x):
        for i, l in enumerate(self.model):
            x = l(x)
        return x

In [222]:
class ConvEncoder(nn.Module):
    def __init__(self, params, nparams):
        super(ConvEncoder, self).__init__()
        if params['dim'] == '1D':
            self.conv = nn.Conv1d
            self.pool = nn.AdaptiveMaxPool1d#nn.MaxPool1d
            self.N    = 1
        elif params['dim'] == '2D':
            self.conv = nn.Conv2d
            self.pool = nn.AdaptiveMaxPool2d
            self.N    = 2
        else:
            raise Exception("Invalid data dimensionality (must be either 1D or 2D).")
            
        if nparams['spec_norm']:
            spec_norm = nn.utils.spectral_norm
        else:
            spec_norm = nn.Identity()
            
        self.model = nn.ModuleList()
    
        current_channels   = params['input_c']
        current_dim        = params['input_dim']
        self.out_dims      = []
        
        for ii in range(nparams['n_layers']):
            
            conv = self.conv(current_channels, nparams['out_channels'][ii], nparams['kernel_sizes'][ii], nparams['strides'][ii], nparams['paddings'][ii])
            self.out_dims.append(current_dim)
            self.model.append(spec_norm(conv))
            
            current_channels =  nparams['out_channels'][ii]
            current_dim      =  output_shape(current_dim, nparams['strides'][ii], nparams['paddings'][ii],nparams['kernel_sizes'][ii])
            
            if nparams['layer_norm'][ii]:
                norm = nn.LayerNorm([current_channels]+[current_dim]*self.N,elementwise_affine=nparams['affine'])
                self.model.append(norm)
                
            gate = getattr(nn, nparams['activations'][ii])()
            self.model.append(gate)
            
            pool = self.pool([current_dim//nparams['scale_facs'][ii]]*self.N)
            self.model.append(pool)
            
            current_dim = current_dim//nparams['scale_facs'][ii]

        self.final_dim = current_dim
        self.final_c   = current_channels
        
        self.model.append(nn.Flatten())
        current_shape = current_channels*current_dim**self.N
        linear        = nn.Linear(current_shape,params['latent_dim'])
        self.model.append(spec_norm(linear))
            
    def forward(self, x):
        for i, l in enumerate(self.model):

            x = l(x)

        return x

In [223]:
class ConvDecoder(nn.Module):
    def __init__(self, params, nparams):
        super(ConvDecoder, self).__init__()
        
        if params['dim'] == '1D':
            self.conv = nn.ConvTranspose1d
            self.N    = 1
        elif params['dim'] == '2D':
            self.conv = nn.ConvTranspose2d
            self.N    = 2
        else:
            raise Exception("Invalid data dimensionality (must be either 1D or 2D).")
            
        if nparams['spec_norm']:
            spec_norm = nn.utils.spectral_norm
        else:
            spec_norm = nn.Identity()
        
        self.pool   = nn.Upsample
        
        self.model  = nn.ModuleList()
        
        final_shape = nparams['final_c']*nparams['final_dim']**self.N
    
        self.model.append(nn.Flatten())
        lin         = nn.Linear(params['latent_dim'],final_shape)
        self.model.append(spec_norm(lin))

        if params['dim'] == '1D':
            self.model.append(Reshape((-1, nparams['final_c'],nparams['final_dim'])))
        else:
            self.model.append(Reshape((-1, nparams['final_c'],nparams['final_dim'],nparams['final_dim'])))
                              
        current_dim      = nparams['final_dim']
        current_channels = nparams['final_c']
            
        for jj in range(1,nparams['n_layers']+1):
            ii = nparams['n_layers'] - jj 
            gate = getattr(nn, nparams['activations'][ii])()
            self.model.append(gate)
                  
            upsample    = nn.Upsample(scale_factor=nparams['scale_facs'][ii])
            self.model.append(upsample)
            current_dim = current_dim*nparams['scale_facs'][ii]
                              
            output_padding = get_output_padding(current_dim,nparams['out_dims'][ii],nparams['strides'][ii],nparams['paddings'][ii],nparams['kernel_sizes'][ii],dilation=1)
            conv           = self.conv(current_channels, nparams['out_channels'][ii], kernel_size=nparams['kernel_sizes'][ii], stride=nparams['strides'][ii], padding=nparams['paddings'][ii], output_padding=output_padding)
            self.model.append(spec_norm(conv))
            
            current_channels = nparams['out_channels'][ii]
            current_dim      = output_shape_transpose(current_dim, stride=nparams['strides'][ii], padding=nparams['paddings'][ii],kernel=nparams['kernel_sizes'][ii],output_padding=output_padding)
                
            if nparams['layer_norm'][ii]:
                norm = nn.LayerNorm([current_channels]+[current_dim]*self.N,elementwise_affine=nparams['affine'])
                self.model.append(norm)    
                
        
        conv = self.conv(current_channels, 1, kernel_size=1, stride=1)
        self.model.append(spec_norm(conv))
        
        if nparams['image_data']: 
            self.model.append(getattr(nn, 'Sigmoid')())
        
    def forward(self, x):
        for i, l in enumerate(self.model):
            x = l(x)
        return x

In [224]:
class FCDecoder(nn.Module):
    def __init__(self, params, nparams):
        super(FCDecoder, self).__init__()
        if params['dim'] == '1D':
            self.N    = 1
        elif params['dim'] == '2D':
            self.N    = 2
        else:
            raise Exception("Invalid data dimensionality (must be either 1D or 2D).")
            
            
        if nparams['spec_norm']:
            spec_norm = nn.utils.spectral_norm
        else:
            spec_norm = nn.Identity()
            
        self.model = nn.ModuleList()
    
        self.model.append(nn.Flatten())
        
        current_dim = params['latent_dim']
        
        for jj in range(1,nparams['n_layers']+1):
            ii = nparams['n_layers'] - jj 
            
            lin = nn.Linear(current_dim, nparams['out_sizes'][ii])
            self.model.append(spec_norm(lin))
            
            current_dim      =  nparams['out_sizes'][ii]
            
            if nparams['layer_norm'][ii]:
                norm = nn.LayerNorm(current_dim,elementwise_affine=nparams['affine'])
                self.model.append(norm)
                
            gate = getattr(nn, nparams['activations'][ii])()
            self.model.append(gate)
            
            dropout = nn.Dropout(nparams['dropout_rate'][ii])
            self.model.append(dropout)
        
        lin = nn.Linear(current_dim,params['input_dim']**self.N*params['input_c'])
        self.model.append(spec_norm(lin))
        
        if nparams['image_data']: 
            self.model.append(getattr(nn, 'Sigmoid')())
        
        self.model.append(Reshape([-1]+[params['input_c']]+[params['input_dim']]*self.N))
        

        
    def forward(self, x):
        for i, l in enumerate(self.model):
            x = l(x)
        return x



In [225]:
def get_data(data, loc, batchsize):
    
    if data in dir(datasets):
        dataset = getattr(datasets,data)
    
        training_data = dataset(root=loc,train=True,download=True,transform=ToTensor())

        valid_data    = dataset(root=loc,train=False,download=True,transform=ToTensor())
    else:
        pass
    
    
    train_dataloader = DataLoader(training_data, batch_size=batchsize, shuffle=True)
    valid_dataloader = DataLoader(valid_data, batch_size=batchsize, shuffle=True)
    
    return train_dataloader, valid_dataloader


In [236]:
class Autoencoder(nn.Module):
    def __init__(self, params, dparams, nparams_enc, nparams_dec, tparams):
        super(Autoencoder, self).__init__()
        
        if params['encoder_type'] == 'conv':
            self.encoder = ConvEncoder(params, nparams_enc)
            nparams_enc['out_dims']  = self.encoder.out_dims
            nparams_enc['final_dim'] = self.encoder.final_dim
            nparams_enc['final_c']   = self.encoder.final_c
        elif params['encoder_type'] == 'fc':
            self.encoder = FCEncoder(params, nparams_enc)
        else:
            raise Exception('invalid encoder type')
            
        if params['decoder_type'] == 'conv':
            self.decoder = ConvDecoder(params, nparams_dec)
        elif params['decoder_type'] == 'fc':
            self.decoder = FCDecoder(params, n_dec)
        else:
            raise Exception('invalid decoder type')
        
        self.optimizer = getattr(optim, tparams['optimizer'])
        self.optimizer = self.optimizer(self.parameters(),tparams['initial_lr'])
        
        self.scheduler = partial(getattr(torch.optim.lr_scheduler, tparams['scheduler']),self.optimizer)
        self.scheduler = self.scheduler(**tparams['scheduler_params'])
        
        self.criterion = getattr(nn, tparams['criterion'])()
        
        self.train_loader, self.valid_loader = get_data(dparams['dataset'],dparams['loc'],tparams['batchsize'])
        
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
    
    def update_lr(self,lr):
    
        self.optimizer = getattr(optim, tparams['optimizer'])
        self.optimizer  = self.optimizer(self.parameters(),lr)
        
        self.scheduler = partial(getattr(torch.optim.lr_scheduler, tparams['scheduler']),self.optimizer)
        self.scheduler = self.scheduler(**tparams['scheduler_params'])
        
        return True
    
    
    def update_scheduler(self,scheduler, scheduler_params):
        
        self.scheduler = partial(getattr(torch.optim.lr_scheduler, scheduler),self.optimizer)
        self.scheduler = self.scheduler(**scheduler_params)
        
        return True
    
    
    def update_optimizer(self,optimizer):
        
        self.optimizer = getattr(optim, optimizer)
        self.optimizer  = self.optimizer(self.parameters(),tparams['initial_lr'])
        
        self.scheduler = partial(getattr(torch.optim.lr_scheduler, tparams['scheduler']),self.optimizer)
        self.scheduler = self.scheduler(**tparams['scheduler_params'])
        
        return True
    
    def train(self, nepochs):
        losses  = []
        for epoch in range(nepochs):
            for (data, _) in self.train_loader:
                data = data.to(device)
                recon = self.forward(data)
                loss = self.criterion(recon, data)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                losses.append(loss)
            self.scheduler.step()
            print(f'epoch: {epoch:d}, training loss: {loss:.4e}')
        return losses


In [246]:
n_layers     = 3
out_channels = [8,8,8]
kernel_sizes = [4,2,2]
### stride,padding,kernel; [1,0,1] is identity
#pooling_layers = [[1,0,1], [1,0,2]]
scale_facs   = [1,1,1] 
paddings     = [0,0,0]
strides      = [2,1,1]
layer_norm   = [True,True,True]
dropout_rate = [0.,0.,0.]
spec_norm    = True
dim          = '2D'
activations  = ['ReLU', 'ReLU','ReLU']
latent_dim   = 4
input_c      = 1 
input_dim    = 28
encoder_type = 'conv'
decoder_type = 'conv'
affine       = False
out_sizes    = [256,128,64]
image_data   = True

nepochs       = 4
batchsize     = 128
initial_lr    = 1e-3

optimizer     = 'Adam'
criterion     = 'MSELoss'

scheduler     = 'ExponentialLR'
scheduler_params = {'gamma':0.95}

dataset       = 'MNIST'

In [247]:
general_params      = {'input_c': input_c, 'input_dim': input_dim, 'latent_dim': latent_dim, 'encoder_type': encoder_type, 'decoder_type': decoder_type, 'dim': dim}
conv_network_params = {'n_layers': n_layers, 'out_channels': out_channels, 'kernel_sizes': kernel_sizes, 'scale_facs': scale_facs, 'paddings': paddings,\
                       'strides': strides,'activations': activations, 'spec_norm': spec_norm, 'dropout_rate':dropout_rate, 'layer_norm': layer_norm,\
                       'affine': affine,'image_data': image_data}
fc_network_params   = {'n_layers': n_layers, 'out_sizes': out_sizes,'activations': activations, 'spec_norm': spec_norm, 'dropout_rate':dropout_rate, \
                       'layer_norm': layer_norm, 'affine': affine, 'image_data': image_data}

training_params     = {'batchsize': batchsize, 'initial_lr': initial_lr, 'optimizer': optimizer, 'criterion': criterion, \
                       'scheduler': scheduler, 'scheduler_params':scheduler_params}
data_params         = {'dataset':dataset, 'loc': '/global/cscratch1/sd/vboehm/Datasets'}

In [248]:
AE = Autoencoder(general_params,data_params,conv_network_params, conv_network_params, training_params)

In [249]:
train_features, train_labels = next(iter(AE.train_loader))

In [250]:
train_features.shape

torch.Size([128, 1, 28, 28])

In [251]:
AE.to(device)

Autoencoder(
  (encoder): ConvEncoder(
    (model): ModuleList(
      (0): Conv2d(1, 8, kernel_size=(4, 4), stride=(2, 2))
      (1): LayerNorm((8, 13, 13), eps=1e-05, elementwise_affine=False)
      (2): ReLU()
      (3): AdaptiveMaxPool2d(output_size=[13, 13])
      (4): Conv2d(8, 8, kernel_size=(2, 2), stride=(1, 1))
      (5): LayerNorm((8, 12, 12), eps=1e-05, elementwise_affine=False)
      (6): ReLU()
      (7): AdaptiveMaxPool2d(output_size=[12, 12])
      (8): Conv2d(8, 8, kernel_size=(2, 2), stride=(1, 1))
      (9): LayerNorm((8, 11, 11), eps=1e-05, elementwise_affine=False)
      (10): ReLU()
      (11): AdaptiveMaxPool2d(output_size=[11, 11])
      (12): Flatten(start_dim=1, end_dim=-1)
      (13): Linear(in_features=968, out_features=4, bias=True)
    )
  )
  (decoder): ConvDecoder(
    (model): ModuleList(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1): Linear(in_features=4, out_features=968, bias=True)
      (2): Reshape()
      (3): ReLU()
      (4): Upsample(sca

In [252]:
data =  torch.randn(1,1,28,28).to('cuda')


res = AE.forward(train_features.to(device))

In [253]:
summary(AE, data.shape[1::])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 13, 13]             136
         LayerNorm-2            [-1, 8, 13, 13]               0
              ReLU-3            [-1, 8, 13, 13]               0
 AdaptiveMaxPool2d-4            [-1, 8, 13, 13]               0
            Conv2d-5            [-1, 8, 12, 12]             264
         LayerNorm-6            [-1, 8, 12, 12]               0
              ReLU-7            [-1, 8, 12, 12]               0
 AdaptiveMaxPool2d-8            [-1, 8, 12, 12]               0
            Conv2d-9            [-1, 8, 11, 11]             264
        LayerNorm-10            [-1, 8, 11, 11]               0
             ReLU-11            [-1, 8, 11, 11]               0
AdaptiveMaxPool2d-12            [-1, 8, 11, 11]               0
          Flatten-13                  [-1, 968]               0
           Linear-14                   

In [None]:
AE.train(20)

epoch: 0, loss: 5.5544e-02
epoch: 1, loss: 4.6534e-02
epoch: 2, loss: 4.4221e-02
epoch: 3, loss: 4.2146e-02
epoch: 4, loss: 3.5483e-02
epoch: 5, loss: 3.9296e-02
epoch: 6, loss: 3.6239e-02
epoch: 7, loss: 3.8043e-02
epoch: 8, loss: 3.4639e-02
epoch: 9, loss: 3.9481e-02
