# Pretrained CNN + Encoder-Decoder

## Setup

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from __future__ import print_function, division

import os
import sys
# sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt
import time
import copy
import glob
from PIL import Image

# import pytorch_mnist
# from pytorch_mnist import Net as MNISTNet

plt.ion()   # interactive mode

## Define Network

In [3]:
class CNNEncoderDecoderNet(torch.nn.Module):
    def __init__(self, layer_sizes = [784,200,50], scale = None):
        """
        Creates a pretrained CNN + Encoder-Decoder network.

        layer_sizes -> list containing layer inputs/ouptuts (minimum length = 3)
            example:
                layer_sizes = [784,500,200,50]
                middleLayers -> [torch.nn.Linear(500,200)]
                outputLayer -> torch.nn.Linear(200,50)
        """
        super(Network, self).__init__()
        
        # Load the MNIST CNN model
        self.cnn_model = MNISTNet()
        # Load the pretrained weights
        self.cnn_model.load_state_dict(torch.load('./models/pytorch_mnist.model'))
        # Chop off the last FC layer
        self.cnn_model = nn.Sequential(*list(self.cnn_model.modules())[:-1])
        # Get the output size of the chopped net
        cnn_out_size = self.cnn_model[-1].state_dict()['weight'].size()[0]
        
        # Set up the input layer for encoder-decoder part
        self.inputLayer = torch.nn.Linear(cnn_out_size, layer_sizes[1])
        
        self.middleLayers = []
        for i in range(1, len(layer_sizes) - 2):
            layer = torch.nn.Linear(layer_sizes[i], layer_sizes[i+1])
            self.middleLayers.append(layer)
            self.add_module("middleLayer_" + str(i), layer)
        self.outputLayer = torch.nn.Linear(layer_sizes[-2], layer_sizes[-1])
        self.scale = scale
        self.loss = 0

    def forward(self, x):
        """
        Defines the layers connections

        forward(x) -> result of forward propagation through network
        x -> input to the Network
        """
        #activation_fn = torch.nn.ReLU6()
        activation_fn = torch.nn.Tanh()
        
        # Run the input through the pretrained CNN
        x = self.cnn_model.forward(x)

        x = activation_fn(self.inputLayer(x))
        for layer in self.middleLayers:
            x = activation_fn(layer(x))
        output = self.outputLayer(x)
        return output

    def isCuda(self):
        return self.inputLayer.weight.is_cuda
