# Import

In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from PIL import Image, ImageOps
import csv, os
import numpy as np
import matplotlib.pyplot as plt

# Define model

In [2]:
# Define AutoEncoder model
class AE(nn.Module):
    def __init__(self, hidden_size=50, act_hist={}, **kwargs):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(
            in_features=kwargs["input_shape"], out_features=hidden_size
        )
        self.encoder_output_layer = nn.Linear(
            in_features=hidden_size, out_features=hidden_size
        )
        self.decoder_hidden_layer = nn.Linear(
            in_features=hidden_size, out_features=hidden_size
        )
        self.decoder_output_layer = nn.Linear(
            in_features=hidden_size, out_features=kwargs["input_shape"]
        )
        self.state = {}

    def forward(self, features):
        activation = self.encoder_hidden_layer(features)
        activation = torch.sigmoid(activation)
        code = self.encoder_output_layer(activation)
        code = torch.sigmoid(code)
        activation = self.decoder_hidden_layer(code)
        activation = torch.sigmoid(activation)
        activation = self.decoder_output_layer(activation)
        reconstructed = torch.sigmoid(activation)
        return reconstructed

    # this method runs on every forward pass if attach_hook if attach_hook is called beforehand
    def get_activation(self, module_name):
        def hook(model, input, output):
            self.state[module_name] = output.detach()
        return hook

    def attach_hooks(self, module_names):
        for name in module_names:
            layer = getattr(self, name)
            layer.register_forward_hook(self.get_activation(name))

# Train

In [41]:
FAM = 'f4'
HLS = 20

# Loader for input data
def get_loader(batch_size):
    transform = transforms.Compose([
        # transforms.RandomRotation(20),
        # transforms.RandomResizedCrop(128),
        # transforms.RandomHorizontalFlip(),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor()])
    dataset = datasets.ImageFolder(f'monsters/{FAM}', transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create a model from `AE` autoencoder class and
# load it to the specified device, either gpu or cpu
model = AE(hidden_size=HLS, input_shape=64*64).to(device)
# model.load_state_dict(torch.load(f'model_data/f4/checkpoint0001'))
# model.eval()
model.attach_hooks(['encoder_hidden_layer', 'encoder_output_layer', 'decoder_hidden_layer', 'decoder_output_layer'])

# Create an optimizer object:
# Adam optimizer with learning rate 1e-3
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Define mean-squared error loss
criterion = nn.MSELoss()

# Training params
epochs = 10000
checkpoint = 500
batch_size = 1

# Data management routines
loader = get_loader(batch_size)
loss_hist = []

np.set_printoptions(threshold=sys.maxsize)
for epoch in range(epochs):
    loss = 0
    for i_batch, (batch_features, _) in enumerate(loader):
        # Eeshape mini-batch data to [N, 50*50] matrix
        # Load it to the active device
        batch_features = batch_features.view(-1, 64*64).to(device)
        
        # Reset the gradients back to zero
        # PyTorch accumulates gradients on subsequent backward passes
        optimizer.zero_grad()
        
        # Run forward pass (and run hooks)
        outputs = model(batch_features)
        
        # Compute training reconstruction loss
        train_loss = criterion(outputs, batch_features)
        
        # Compute accumulated gradients
        train_loss.backward()
        
        # Perform parameter update based on current gradients
        optimizer.step()
        
        # Add the mini-batch training loss to epoch loss
        loss += train_loss.item()
    
    # Compute the epoch training loss
    loss = loss / len(loader)
    loss_hist.append(loss)

    # display the epoch training loss
    if (epoch % checkpoint == 0) or (epoch==epochs-1):
        torch.save(model.state_dict(), f'model_data{HLS}/{FAM}/checkpoint{epoch+1:04d}')
        print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))

epoch : 1/10000, loss = 0.176268
epoch : 501/10000, loss = 0.014145
epoch : 1001/10000, loss = 0.002200
epoch : 1501/10000, loss = 0.001875
epoch : 2001/10000, loss = 0.001520
epoch : 2501/10000, loss = 0.001364
epoch : 3001/10000, loss = 0.001255
epoch : 3501/10000, loss = 0.001153
epoch : 4001/10000, loss = 0.001138
epoch : 4501/10000, loss = 0.001071
epoch : 5001/10000, loss = 0.001037
epoch : 5501/10000, loss = 0.001000
epoch : 6001/10000, loss = 0.000953
epoch : 6501/10000, loss = 0.000919
epoch : 7001/10000, loss = 0.000899
epoch : 7501/10000, loss = 0.000889
epoch : 8001/10000, loss = 0.000877
epoch : 8501/10000, loss = 0.000863
epoch : 9001/10000, loss = 0.000957
epoch : 9501/10000, loss = 0.000831
epoch : 10000/10000, loss = 0.000799
