In [80]:
warnings.filterwarnings("ignore")

## Define size variables
print_shapes = False
height = IMG_SIZE
width = IMG_SIZE
channels = 1
num_features = height*width*channels

# Regulization
L2_reg = 1e-6
DROPOUT = True
do_p = 0.05 # do_p for conv 
do_p2 = 0.1 # do_p for linear   NB: Classifier dropout is set manuel to 0.3
batchnorm_eps = 1e-5
batchnorm_momentum = 0.2

# Conv Layers
conv_out_channels = [8, 32, 64]
conv_kernel = [5, 5, 3]
conv_padding = [0, 2, 1]
conv_stride = [1, 1, 1]

# MaxPool Layers
pool_kernel = 3
pool_padding = 0
pool_stride = 3

# Fully connected layers
lin_layer = [1000, 200]

# auxillary parameters
aux_layer = [200, 200]
aux_variables = 0
aux_in = 2 # layer no. where a is included in encoder
aux_decoder_layers = [200,200]

# classifier parameters
classifier_layer = [1000,200]
No_classes = len(classes)

# No. of layes
NUM_CONV = len(conv_out_channels)
NUM_LIN = len(lin_layer)
NUM_AUX = len(aux_layer)
NUM_CLASS = len(classifier_layer)
NUM_AUX_DECODER = len(aux_decoder_layers)

# Calculating the dimensions 
def compute_conv_dim(height, width, kernel_size, padding_size, stride_size):
    height_new = int((height - kernel_size + 2 * padding_size) / stride_size + 1)
    width_new =  int((width  - kernel_size + 2 * padding_size) / stride_size + 1)
    return [height_new, width_new]

def compute_final_dimension(height, width, last_num_channels, num_layers):
    # First conv layer
    CNN_height = height
    CNN_width = width
    for i in range(num_layers):
        # conv layer
        CNN_height, CNN_width = compute_conv_dim(CNN_height, CNN_width, conv_kernel[i], conv_padding[i], conv_stride[i])
        # maxpool layer
        CNN_height, CNN_width = compute_conv_dim(CNN_height, CNN_width, pool_kernel, pool_padding, pool_stride)
    final_dim = CNN_height * CNN_width * last_num_channels
    # print(final_dim,CNN_height,CNN_width)
    return [final_dim, CNN_height, CNN_width]

def normalize(x):
    tmp = x-torch.min( torch.min(x,dim = 2, keepdim = True)[0] ,dim = 3, keepdim = True)[0]
    if torch.sum(torch.isnan(tmp))>0:
        print("nan of tmp",torch.sum(torch.isnan(tmp)))
    return tmp/(torch.max( torch.max(tmp,dim = 2, keepdim = True)[0] ,dim = 3, keepdim = True)[0] + 1e-8)  

def gaussian_sample(mu,log_var, num_samples, latent_features):    
    # Don't propagate gradients through randomness
    with torch.no_grad():
        batch_size = mu.size(0)
        epsilon = torch.randn(batch_size, num_samples, latent_features)
            
    if cuda:
        epsilon = epsilon.cuda()
        
    sigma = torch.exp(log_var/2)
        
    # We will need to unsqueeze to turn
    # (batch_size, latent_dim) -> (batch_size, 1, latent_dim)
    if len(mu.shape) == 2:
        z = mu.unsqueeze(1) + epsilon * sigma.unsqueeze(1)
    else:
        z = mu + epsilon * sigma
    return z

def output_recon(x):
    # Shape of x_mean: [batch_size, num_samples, channel, height, width]
    x_mean, x_log_var = torch.chunk(x, 2, dim=2) # the mean and log_var reconstructions from the decoder
    
    # The original digits are on the scale [0, 1] 
    x_hat = x_mean[:,1,].unsqueeze(1)
    #x_hat = normalize(x_mean[:,1,].unsqueeze(1))# to scale for showing an image
    #x_hat = normalize(x_mean)
    x_log_var = softplus(x_log_var)
    
    # Mean over samples
    #x_hat = torch.mean(x_hat, dim=1)
    x_log_var= torch.mean(x_log_var, dim=1)
    x_mean = torch.mean(x_mean,dim=1) # used for the loss
    
    # Resize x_hat from [batch_size, no_features] to [batch_size, channels, height, width]
    x_hat = x_hat.view( batch_size, 1, height, width)
    x_log_var = x_log_var.view( batch_size, 1, height, width)
    return x_hat, x_log_var, x_mean

######## Image has to be: (num, channels, height, width)!!!! #########
class CNN_VAE(nn.Module):
    def encoder(self,x):
        # Convolutional layers of encoder
        for i in range(0,len(self.Encoder_conv),3):
            x = self.Encoder_conv[i](x) # Convolutional layer
            self.layer_size.append(x.shape[-1])
            x = self.Encoder_conv[i+1](x) # Batchnorm layer
            x = relu(x)
            if DROPOUT:
                x = dropout2d(x, p=do_p)   
            x = self.Encoder_conv[i+2](x) # Maxpool Layer
        x = x.view(batch_size, -1) # Prepare x for linear layers
        
        # Fully connected layers of encoder
        for i in range(0,len(self.Encoder_FC),2): 
            x = self.Encoder_FC[i](x) # Linear layer
            x = self.Encoder_FC[i+1](x) # Batchnorm
            x = relu(x)
            if DROPOUT:
                x = dropout(x, p=do_p2)
        return x

    
    def decoder(self,z,y):
        x = torch.cat([z,y],dim=-1)
        # Fully connected layers of decoder
        for i in range(0,len(self.Decoder_FC),2):
            x = self.Decoder_FC[i](x)
            x = x.permute(0,2,1)
            x = self.Decoder_FC[i+1](x)
            x = x.permute(0,2,1)
            x = relu(x)
            x = dropout(x,p= do_p2)
        x = x.view(-1, self.Decoder_conv[0].in_channels, self.final_dim[1], self.final_dim[2])
        
        # Convolutional layers of decoder
        curr_layer = len(self.Decoder_conv)//2-1
        for i in range(0,len(self.Decoder_conv),2):
            x = interpolate(x,size = [self.layer_size[curr_layer],self.layer_size[curr_layer]],
                                      mode = 'bilinear', 
                                      align_corners = False)
            curr_layer -=1
            x = self.Decoder_conv[i](x) # Convolutional layers
            x = self.Decoder_conv[i+1](x) # BatchNorm
            x = relu(x)
            if DROPOUT:
                x = dropout2d(x, p=do_p)
        return x.view(batch_size,-1,channels*2,height,width)
            
    def encoder_aux(self,a):
        for i in range(0,len(self.Encoder_aux),2):
            a = self.Encoder_aux[i](a)
            a = self.Encoder_aux[i+1](a)
            a = relu(a)
            if DROPOUT:
                a = dropout(a, p=do_p2)
        q_a_mu, q_a_log_var = torch.chunk(a, 2, dim=-1) # divide to mu and sigma
        return q_a_mu, q_a_log_var
    
    def decoder_aux(self,xz,y):
        a = torch.cat([xz,y],dim=-1)
        for i in range(0,len(self.Decoder_aux),2):
            a = self.Decoder_aux[i](a)
            a = a.permute(0,2,1)
            a = self.Decoder_aux[i+1](a)
            a = a.permute(0,2,1)
            a = relu(a)
            if DROPOUT:
                a = dropout(a, p=do_p2)  
        return a
    
    def additional_layer(self,xa,y):
        z = torch.cat([xa,y],dim=-1)
        z = self.Additional_layer[0](z)
        z = z.permute(0,2,1)
        z = self.Additional_layer[1](z)
        z = z.permute(0,2,1)
        z = relu(z)
        if DROPOUT:
            z = dropout(z, p=do_p2) 
        return z
    
    def classifier(self,xa):
        for i in range(0,len(self.Classifier),2):
            xa = self.Classifier[i](xa)
            if i < len(self.Classifier)-1:
                if aux_variables > 0:
                    xa = xa.permute(0,2,1)
                    xa = self.Classifier[i+1](xa)
                    xa = xa.permute(0,2,1)
                else:
                    xa = self.Classifier[i+1](xa)
                xa = relu(xa)
                if DROPOUT:
                    xa = dropout(xa, p=0.3)
        return softmax(xa,dim=-1)
    
    def sample_y(self,batch_size,num_samples,no_classes,i):
        tmp = Variable(torch.zeros(no_classes))
        tmp[i] = 1
        if cuda:
            tmp = tmp.cuda()
        return tmp.repeat(batch_size,num_samples,1)
    
    def sample_from_latent(self,x):
        x_UL = []
        for j in range(No_classes):
            tmp = self.decoder(x.unsqueeze(1).repeat(1,num_samples,1), self.sample_y(batch_size,num_samples,No_classes,j))
            x_UL.append(tmp)
        x_hat, _, _ = output_recon(sum(x_UL))
        return x_hat
  
    
    def __init__(self, latent_features, num_samples):
        super(CNN_VAE, self).__init__()
        
        self.latent_features = latent_features
        self.num_samples = num_samples
        
        # Calculate final size of the CNN
        self.final_dim = compute_final_dimension(height,width,conv_out_channels[-1],NUM_CONV)
        
        ## Convolutional layers of the encoder
        input_channels = channels
        Encoder_conv = nn.ModuleList()
        for i in range(NUM_CONV):
            Encoder_conv.append(Conv2d( in_channels=input_channels,
                                            out_channels=conv_out_channels[i],
                                            kernel_size=conv_kernel[i],
                                            stride=conv_stride[i],
                                            padding=conv_padding[i]))
            Encoder_conv.append(BatchNorm2d(conv_out_channels[i], eps = batchnorm_eps, momentum = batchnorm_momentum))
            Encoder_conv.append(MaxPool2d(  kernel_size=pool_kernel, 
                                        stride=pool_stride,
                                        padding=pool_padding,
                                        return_indices = False))
            input_channels = conv_out_channels[i]
        self.add_module("Encoder_conv",Encoder_conv)
        
        # Fully connected layers of encoder
        Encoder_FC = nn.ModuleList()
        in_weights = self.final_dim[0]
        for i in range(NUM_LIN):
            Encoder_FC.append(Linear(in_features=in_weights, out_features=lin_layer[i]))
            Encoder_FC.append(BatchNorm1d(lin_layer[i], eps = batchnorm_eps, momentum = batchnorm_momentum))
            in_weights = lin_layer[i]
        self.add_module("Encoder_FC",Encoder_FC)
        
        # map to latent space
        Additional_layer = nn.ModuleList()
        Additional_layer.append(Linear(in_features=lin_layer[-1]+aux_variables+No_classes, out_features=latent_features*2))
        Additional_layer.append(BatchNorm1d(latent_features*2, eps = batchnorm_eps, momentum = batchnorm_momentum))
        self.add_module("Additional_layer",Additional_layer)
        
        # Auxillary network
        if aux_variables > 0:
            # Auxillary encoder
            Encoder_aux = nn.ModuleList()
            in_weights = lin_layer[-1]
            for i in range(NUM_AUX):
                Encoder_aux.append(Linear(in_features=in_weights, out_features=aux_layer[i]))
                Encoder_aux.append(BatchNorm1d(aux_layer[i], eps = batchnorm_eps, momentum = batchnorm_momentum))
                in_weights = aux_layer[i]
            Encoder_aux.append(Linear(in_features=aux_layer[-1], out_features=aux_variables*2))
            Encoder_aux.append(BatchNorm1d(aux_variables*2, eps = batchnorm_eps, momentum = batchnorm_momentum))
            self.add_module("Encoder_aux", Encoder_aux)
            # Auxillary decoder
            Decoder_aux = nn.ModuleList()
            for i in range(NUM_AUX_DECODER):
                if i == 0:
                    in_weights = self.latent_features + lin_layer[-1] + No_classes
                else:
                    in_weights = aux_decoder_layers[i-1]
                Decoder_aux.append(Linear(in_features=in_weights, out_features=aux_decoder_layers[i]))
                Decoder_aux.append(BatchNorm1d(aux_decoder_layers[i], eps = batchnorm_eps, momentum = batchnorm_momentum))
            Decoder_aux.append(Linear(in_features=aux_decoder_layers[-1], out_features=aux_variables*2))
            Decoder_aux.append(BatchNorm1d(aux_variables*2, eps = batchnorm_eps, momentum = batchnorm_momentum))
            self.add_module("Decoder_aux", Decoder_aux)    
        
        # Initialize fully connected layers from latent space to convolutional layers
        Decoder_FC = nn.ModuleList()
        Decoder_FC.append(Linear(in_features=latent_features+No_classes, out_features=lin_layer[-1]))
        Decoder_FC.append(BatchNorm1d(lin_layer[-1], eps = batchnorm_eps, momentum = batchnorm_momentum))
        for i in reversed(range(NUM_LIN)):
            if i == 0:
                out_weights = self.final_dim[0]
            else:
                out_weights = lin_layer[i-1]
            Decoder_FC.append(Linear(in_features=lin_layer[i], out_features=out_weights))
            Decoder_FC.append(BatchNorm1d(out_weights, eps = batchnorm_eps, momentum = batchnorm_momentum))
        self.add_module("Decoder_FC",Decoder_FC)
        
        # Convolutional layers of the decoder
        Decoder_conv = nn.ModuleList()
        for i in reversed(range(NUM_CONV)):
            if i == 0:
                output_channels = channels*2
            else:
                output_channels = conv_out_channels[i-1] 
            Decoder_conv.append(ConvTranspose2d(in_channels=conv_out_channels[i],
                                                out_channels=output_channels,
                                                kernel_size=conv_kernel[i],
                                                stride=conv_stride[i],
                                                padding=conv_padding[i]))
            Decoder_conv.append(BatchNorm2d(output_channels, eps = batchnorm_eps, momentum = batchnorm_momentum))
        self.add_module("Decoder_conv",Decoder_conv)

        # Fully connected layers from convolutional layers to classification
        Classifier = nn.ModuleList()
        if aux_variables > 0:
            in_weights = lin_layer[-1]+aux_variables
        else:
            in_weights = lin_layer[-1]
        for i in range(NUM_CLASS):
            Classifier.append(Linear(in_features=in_weights, out_features=classifier_layer[i]))
            Classifier.append(BatchNorm1d(classifier_layer[i], eps = 1e-4, momentum = batchnorm_momentum))
            in_weights = classifier_layer[i]
        Classifier.append(Linear(in_features=classifier_layer[-1], out_features = No_classes))
        self.add_module("Classifier", Classifier)
        
        
        # Initialize weight of layers
        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.xavier_normal(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
                
### Forward ####
    def forward(self, x, y=None):
        outputs = {}
        self.indices = []
        self.layer_size = []
        x = self.encoder(x)
        if aux_variables > 0:
            q_a_mu, q_a_log_var = self.encoder_aux(x)
            q_a = gaussian_sample(q_a_mu,q_a_log_var,num_samples,aux_variables) # sample auxillary variables
            outputs["q_a"] = q_a # Assign to outputs
            xa = torch.cat([x.unsqueeze(1).repeat(1,num_samples,1),q_a],dim=2) # Create combined vector of x and q_a
        else:
            xa = x
        
        # Run trough classifier
        logits = self.classifier(xa)
        
        if aux_variables <= 0:
            logits = logits.unsqueeze(1)
            logits = logits.repeat(1,num_samples,1)
            xa = xa.unsqueeze(1).repeat(1,num_samples,1) 
            
        # Map x, y, a to latent space
        if y is None:
            x_UL = []
            for j in range(No_classes):
                z = self.additional_layer(xa,self.sample_y(batch_size,num_samples,No_classes,j))
                activation = z
                x_UL.append(z)
            lat_in = sum(x_UL)
            del x_UL
        else:
            lat_in = self.additional_layer(xa,y.unsqueeze(1).repeat(1,num_samples,1))
#             activation = lat_in
            
        # Split into mu and log_var
        mu, log_var = torch.chunk(lat_in, 2, dim=-1)
        # Make sure that the log variance is positive
        log_var = softplus(log_var)
        # Sample from latent space
        z = gaussian_sample(mu,log_var,num_samples,latent_features)
                
        # aux. decoder
        if aux_variables > 0:
            xz = torch.cat([z, x.unsqueeze(1).repeat(1,z.shape[1],1)],dim = -1)
            if y is None:
                a_UL = []
                for j in range(No_classes):
                    a_UL.append(self.decoder_aux(xz,self.sample_y(batch_size,num_samples,No_classes,j)))
                a_log_var = []
                a_mean = []
                for j in range(No_classes):
                    tmp1, tmp2 = torch.chunk(a_UL[j], 2, dim=-1) # the mean and log_var reconstructions from the decoder
                    a_mean.append(tmp1)
                    a_log_var.append(softplus(tmp2))
                del a_UL, tmp1, tmp2
            else:
                a = self.decoder_aux(xz,y.unsqueeze(1).repeat(1,num_samples,1))
                a_mean, a_log_var = torch.chunk(a, 2, dim=-1) # the mean and log_var reconstructions from the decoder
                a_log_var = softplus(a_log_var)    
            
        # Decoder     
        if y is None:
            x_UL = []
            for j in range(No_classes):
                tmp = self.decoder(z, self.sample_y(batch_size,num_samples,No_classes,j))
                x_UL.append(tmp)
            x_log_var = []
            x_mean = []
            x_hat = []
            for j in range(No_classes):
                tmp1, tmp2, tmp3 = output_recon(x_UL[j])
                x_hat.append(tmp1)
                x_log_var.append(tmp2)
                x_mean.append(tmp3)
            del x_UL, tmp1, tmp2, tmp3
        else:
            x = self.decoder(z, y.unsqueeze(1).repeat(1,num_samples,1))
            x_hat, x_log_var, x_mean = output_recon(x)
        
        # Assign variables
        outputs["x_hat"] = x_hat # This is used for visulizations only 
        outputs["z"] = z
        outputs["mu"] = mu
        outputs["log_var"] = log_var
        
        # image recontructions (notice they are outputted as matrices)
        outputs["x_mean"] =  x_mean #x_hat  # mean reconstructions (for loss!!!)
        outputs["x_log_var"] = x_log_var #torch.reshape(x_log_var,(-1,height,width)) # log var reconstructions (for loss!!!)
        
        # auxillary outputs
        if aux_variables > 0:            
            outputs["q_a_mu"] = q_a_mu
            outputs["q_a_log_var"] = q_a_log_var
            outputs["p_a_mu"] = a_mean
            outputs["p_a_log_var"] = a_log_var
        
        # classifier outputs 
        outputs["y_hat"] = logits
        
        # Activation of latent features
#         outputs["activation"] = activation

        return outputs

# The number of samples used then initialising the VAE, 
# is number of samples drawn from the distribution
num_samples = 5
latent_features = 32

net = CNN_VAE(latent_features, num_samples)
"""
print(net)
"""

# Transfer model to GPU ifavailable
if cuda:
    net = net.cuda()

# Test
if 1 == 0:
    x = torch.randn(batch_size,1,width, height)
    x = Variable(x)
    if cuda:
        x = x.cuda()
        y = None
    y = net(x)
    print(y['x_hat'][0].shape)

if cuda:
    print('before: ',torch.cuda.memory_allocated(device=0))
    import gc
    #del y,x
    # gc.collect()
    print('after: ',torch.cuda.memory_allocated(device=0))

#for parameter in net.parameters():
#    print(parameter.shape)
#epsilon = torch.randn(batch_size, latent_features).cuda
#samples = torch.sigmoid(net.decoder(net.latent_to_CNN(epsilon))).detach()

64 1 1
CNN_VAE(
  (Encoder_conv): ModuleList(
    (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.2, affine=True, track_running_stats=True)
    (2): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(8, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.2, affine=True, track_running_stats=True)
    (5): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.2, affine=True, track_running_stats=True)
    (8): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  )
  (Encoder_FC): ModuleList(
    (0): Linear(in_features=64, out_features=1000, bias=True)
    (1): BatchNorm1d(1000, eps=1e-05, momentum=0.2, affine=True, track_running_stats=True)
    (2): Linear(in_features=1000, out_feature

