In [2]:
import torch
from torch import nn, optim 
from torch.autograd import Variable
from torch.nn import functional as F 
import numpy as np
import os

In [432]:
class n_convs(nn.Module):
    def __init__(self, n_layer=2, in_channels=1, out_channels=64, imsize=64, kernelsize=2, stride=1, padding=0):
        
        super(n_convs, self).__init__()
        
        self.n_layer = n_layer 
        self.layer_list = []
        self.layer_list.append(nn.Conv2d(in_channels, out_channels, kernelsize, stride, padding))
        
        for i in range(n_layer-1):
            self.layer_list.append(nn.LeakyReLU(0.2, inplace=True))
            self.layer_list.append(nn.Conv2d(out_channels, out_channels, kernelsize, stride, padding))
        self.layer_list.append(nn.LeakyReLU(0.2, inplace=True))
        
    def forward(self, x):
        
        for i in range(self.n_layer):
            x = self.layer_list[i](x)
        
        return x

In [433]:
double_conv = n_convs()
a = torch.rand(1, 1, 64, 64)
x = torch.autograd.Variable(a)
# double_conv(x)
print(double_conv.layer_list)

[Conv2d(1, 64, kernel_size=(2, 2), stride=(1, 1)), LeakyReLU(0.2, inplace), Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1)), LeakyReLU(0.2, inplace)]


In [441]:
class ConvSet(nn.Module):
    def __init__(self, in_channels, out_channels, kernersize, stride, padding, bias=False, 
                 batchNorm=True, activation='ReLU', name='Default_net'):
        
        super(ConvSet, self).__init__()
        
        assert activation in ['ReLU', 'LeakyReLU'], "Activation methods not implemented!!!"
        self.convSet = nn.Sequential()
        self.nameList = []
        
        conv_name = name + '.{}-{}.conv'.format(in_channels, out_channels)
        activ_name = name + '.{}.' + activation
        activ_name = activ_name.format(out_channels)
        if batchNorm:
            batch_name = name + '.{}.batchnorm'.format(out_channels)
            
        self.convSet.add_module(conv_name,
                               nn.Conv2d(in_channels, out_channels, kernersize, stride, padding, bias=bias))
        self.nameList.append(conv_name)
        
        if batchNorm:
            self.convSet.add_module(batch_name,
                                   nn.BatchNorm2d(out_channels))
            self.nameList.append(batch_name)
        
        if activation == 'ReLU':
            self.convSet.add_module(activ_name,
                                   nn.ReLU(inplace=True))
        elif activation == 'LeakyReLU':
            self.convSet.add_module(activ_name,
                                   nn.LeakyReLU(0.2, inplace=True))
        self.nameList.append(activ_name)

    def forward(self, x):
        x = self.convSet(x)
        return x
    
    def getEntriesAndNames(self):
        moduleList = list(self.convSet)
        return moduleList, self.nameList

In [448]:
conv = ConvSet(1, 64, 4, 2, 1)
a = torch.rand(2, 1, 64, 64)
x = torch.autograd.Variable(a)
y = conv(x)
print(y.shape)

torch.Size([2, 64, 32, 32])


In [465]:
from collections import OrderedDict

class GeomEncoder(nn.Module):
    def __init__(self, encode_dim=20, in_channels=1, imsize=64, nfeat=64,  extra_layer=0, batchNorm=True, activation='ReLU'):
        
        super(GeomEncoder, self).__init__()  
        # initial input
        self.initial = ConvSet(in_channels, nfeat, 4, 2, 1, False, 
                               batchNorm=batchNorm, activation=activation, name='initial')
        
        # pyramid structure
        encoder_list = []
        name_list = []
        c_imsize, c_feat = imsize / 2, nfeat
        
        ind = 0       
        while c_imsize >= 4:
            in_feat = c_feat
            out_feat = c_feat * 2
            
            ind += 1
            layer_name = 'pyramid_' + str(ind)
            convnet = ConvSet(in_feat, out_feat, 4, 2, 1, bias=False, 
                                            batchNorm=batchNorm, activation=activation, name=layer_name)
            
            entries, names = convnet.getEntriesAndNames()
            encoder_list.extend(entries)
            name_list.extend(names)
        
            c_feat *= 2
            c_imsize = c_imsize / 4
        # Tensor[None, 256, 8, 8]
 
        # final convolutional layer, out_feat=20 is to be changed
        final_conv = ConvSet(c_feat, 20, 4, 2, 1, bias=False, 
                                            batchNorm=batchNorm, activation=activation, name='final_conv')
        entries, names = final_conv.getEntriesAndNames()
        encoder_list.extend(entries)
        name_list.extend(names)
        
        encoder_dict = OrderedDict(zip(name_list, encoder_list))
        self.encoder =  nn.Sequential(encoder_dict)
        # Tensor[None, 20, 4, 4] -> 320
        
        self.final_layer = nn.Sequential()
        self.final_layer.add_module('encoded',
                                    nn.Linear(320, encode_dim))
        
    def forward(self, x):
        x = self.initial(x)
        x = self.encoder(x)
        x = self.final_layer(x.view(2, 320))
        return x
    
    def get_num_params(self):
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        return params

In [466]:
encoder = GeomEncoder(encode_dim=50, activation='LeakyReLU', batchNorm=False)
        
print(encoder)
a = torch.rand(2, 1, 64, 64)
x = torch.autograd.Variable(a)
y = encoder(x)
print(y.shape)
# for p in encoder.parameters():
#     print(p)in_channels
print(encoder.get_num_params())

GeomEncoder(
  (initial): ConvSet(
    (convSet): Sequential(
      (initial.1-64.conv): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (initial.64.LeakyReLU): LeakyReLU(0.2, inplace)
    )
  )
  (encoder): Sequential(
    (pyramid_1.64-128.conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid_1.128.LeakyReLU): LeakyReLU(0.2, inplace)
    (pyramid_2.128-256.conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid_2.256.LeakyReLU): LeakyReLU(0.2, inplace)
    (final_conv.256-20.conv): Conv2d(256, 20, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (final_conv.20.LeakyReLU): LeakyReLU(0.2, inplace)
  )
  (final_layer): Sequential(
    (encoded): Linear(in_features=320, out_features=50, bias=True)
  )
)
torch.Size([2, 50])
754354


In [537]:
class TConvSet(nn.Module):
    def __init__(self, in_channels, out_channels, kernersize, stride, padding, bias=False, 
                 batchNorm=True, activation='ReLU', name='Default_net'):
        
        super(TConvSet, self).__init__()
        
        assert activation in ['ReLU', 'LeakyReLU'], "Activation methods not implemented!!!"
        self.convSet = nn.Sequential()
        self.nameList = []
        
        conv_name = name + '.{}-{}.transconv'.format(in_channels, out_channels)
        activ_name = name + '.{}.' + activation
        activ_name = activ_name.format(out_channels)
        if batchNorm:
            batch_name = name + '.{}.batchnorm'.format(out_channels)
            
        self.convSet.add_module(conv_name,
                               nn.ConvTranspose2d(in_channels, out_channels, kernersize, stride, padding, bias=bias))
        self.nameList.append(conv_name)
        
        if batchNorm:
            self.convSet.add_module(batch_name,
                                   nn.BatchNorm2d(out_channels))
            self.nameList.append(batch_name)
        
        if activation == 'ReLU':
            self.convSet.add_module(activ_name,
                                   nn.ReLU(inplace=True))
        elif activation == 'LeakyReLU':
            self.convSet.add_module(activ_name,
                                   nn.LeakyReLU(0.2, inplace=True))
        self.nameList.append(activ_name)

    def forward(self, x):
        x = self.convSet(x)
        return x
    
    def getEntriesAndNames(self):
        moduleList = list(self.convSet)
        return moduleList, self.nameList

In [538]:
tconv = TConvSet(1, 64, 4, 2, 1)
a = torch.rand(2, 1, 64, 64)
x = torch.autograd.Variable(a)
y = tconv(x)
print(y.shape)

torch.Size([2, 64, 128, 128])


In [539]:
tconv = TConvSet(20, 128, 4, 1, 0)
a = torch.rand(2, 20, 1, 1)
x = torch.autograd.Variable(a)
y = tconv(x)
print(y.shape)

torch.Size([2, 128, 4, 4])


In [592]:
class GeomDecoder(nn.Module):
    def __init__(self, encode_dim=20, noise_dim=20, out_channels=1, imsize=64, nfeat=64, extra_layer=0, batchNorm=True, activation='ReLU'):
        
        super(GeomDecoder, self).__init__() 
        
        self.noise_dim = noise_dim
        self.encode_dim = encode_dim
        
        cngf, tisize = nfeat//2, 4
        while tisize != imsize:
            cngf = cngf * 2
            tisize = tisize * 2
            
        nz = noise_dim + encode_dim
        
        # initial input
        self.initial = TConvSet(nz, cngf, 4, 1, 0, False, 
                               batchNorm=batchNorm, activation=activation, name='initial')
        
        # pyramid structure
        decoder_list = []
        name_list = []        
        c_imsize, c_feat = 4, cngf
        
        ind = 0
        while c_imsize < imsize // 2:
            ind += 1
            layer_name = 'pyramid_' + str(ind)
            tconvnet = TConvSet(c_feat, c_feat//2, 4, 2, 1, bias=False,
                               batchNorm=batchNorm, activation=activation, name=layer_name)
            
            entries, names = tconvnet.getEntriesAndNames()
            decoder_list.extend(entries)
            name_list.extend(names)
            
            c_feat = c_feat // 2
            c_imsize = c_imsize * 2
        
        # extra layer
        ind = 0
        for i in range(extra_layer):
            ind += 1
            layer_name = 'extra_layer_' + str(ind)
            extra_tconvnet = ConvSet(c_feat, c_feat, 3, 1, 1, bias=False, name=layer_name)
            
            entries, names = extra_tconvnet.getEntriesAndNames()
            decoder_list.extend(entries)
            name_list.extend(names)       
        
        # final output layer
        final_tconvnet = TConvSet(c_feat, out_channels, 4, 2, 1, bias=False,
                   batchNorm=batchNorm, activation=activation, name='final_layer')

        entries, names = final_tconvnet.getEntriesAndNames()
        decoder_list.extend(entries)
        name_list.extend(names)
        
        decoder_dict = OrderedDict(zip(name_list, decoder_list))
        self.decoder =  nn.Sequential(decoder_dict)   
        
        
    def forward(self, *args):
        if len(args) == 2:
            x, noise = args[0], args[1]
            assert noise.shape[1] == self.noise_dim and x.shape[1] == self.encode_dim, "Dimension of noise or encoded vector does not match"
            x = torch.cat((x, noise), dim=1)
        else:
            x = args[0]
            assert self.noise_dim == 0, "Model needs noise with dimension dim={} as input".format(self.noise_dim)
            assert x.shape[1] == self.encode_dim, "Dimension of encoded vector does not match"
        
        x = x.view((x.shape[0], x.shape[1], 1, 1))
        x = self.initial(x)        
        x = self.decoder(x)
        return x
    
    def get_num_params(self):
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        return params


In [593]:
decoder = GeomDecoder(20, 0, 1, 64, 64, 4, True)
print(decoder)
a = torch.ones(6, 20)
noise = torch.rand(6, 10)
x = torch.autograd.Variable(a)
noise = torch.autograd.Variable(noise)
y = decoder(x)
print(y.shape)

GeomDecoder(
  (initial): TConvSet(
    (convSet): Sequential(
      (initial.20-512.transconv): ConvTranspose2d(20, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
      (initial.512.batchnorm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (initial.512.ReLU): ReLU(inplace)
    )
  )
  (decoder): Sequential(
    (pyramid_1.512-256.transconv): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid_1.256.batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (pyramid_1.256.ReLU): ReLU(inplace)
    (pyramid_2.256-128.transconv): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid_2.128.batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (pyramid_2.128.ReLU): ReLU(inplace)
    (pyramid_3.128-64.transconv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid_3.64.batchnorm): BatchNorm2d(64, ep