In [6]:
import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [7]:
class MNISTDataSet:
    '''
    Conveniently manges MNIST data.
    '''
    def __init__(self, data_dir, batch_size=4, num_workers=2) -> None:
        # datasets
        self.trainset = datasets.MNIST(root=data_dir, train=True, download=True, transform=transforms.ToTensor())
        self.testset = datasets.MNIST(root=data_dir, train=False, download=True, transform=transforms.ToTensor())
        self.classes = self.trainset.classes

        # dataloaders
        self.trainloader = torch.utils.data.DataLoader(
            self.trainset, batch_size=batch_size,
            shuffle=True, num_workers=num_workers)

        self.testloader = torch.utils.data.DataLoader(
            self.testset, batch_size=batch_size,
            shuffle=False, num_workers=num_workers)

In [8]:
class MNISTNet(nn.Module):
    '''
    Input -> hidden -> hidden -> Output network
    '''
    def __init__(self, input_size, hidden1_size, hidden2_size, output_size):
        super(MNISTNet, self).__init__()

        def activation_hook(module, args, output):
            if module == self.fc1:
                self.ac1 = output.detach()
            elif module == self.fc2:
                self.ac2 = output.detach()
            elif module == self.fco:
                self.aco = output.detach()

        self.fc1 = nn.Linear(input_size, hidden1_size)
        self.fc2 = nn.Linear(hidden1_size, hidden2_size)
        self.fco = nn.Linear(hidden2_size, output_size)
        
        self.activation = nn.Sigmoid()

        self.fc1.register_forward_hook(activation_hook)
        self.fc2.register_forward_hook(activation_hook)
        self.fco.register_forward_hook(activation_hook)

    def forward(self, x):
        '''
        Receives input data in 1 tensor and returns results.
        '''
        x = x.view(-1, 28 * 28)
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.activation(self.fco(x))
        return x
    
    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))
        self.eval()

In [9]:
class Trainer:
    '''
    Convenience class that contains training routines
    for the given net, criterion and optimizer.
    '''
    def __init__(self, net, criterion, optimizer) -> None:
        self.net = net
        self.criterion = criterion
        self.optimizer = optimizer
    
    def train_step(self, inputs, labels):
        # zero the parameter gradients
        self.optimizer.zero_grad()

        # forward + backward + optimize
        outputs = self.net(inputs)
        loss = self.criterion(outputs, labels)
        loss.backward()
        self.optimizer.step()
        return loss

    def train_epoch(self, trainloader):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            loss = self.train_step(inputs, labels)
            running_loss += loss.item()
        return running_loss / trainloader.batch_size

    def full_train(self, num_epocs, trainloader):
        '''
        Train for the number of epocs.
        '''
        last_loss = -1
        for epoch in range(num_epocs):
            last_loss = self.train_epoch(trainloader)
        return last_loss

In [14]:
INPUT_SIZE = 28 * 28
HIDDEN1_SIZE = 16
HIDDEN2_SIZE = 9
OUTPUT_SIZE = 10

MODEL='data/mnist_89.pth'
DATA_DIR='data'
BATCH_SIZE=4

net = MNISTNet(INPUT_SIZE, HIDDEN1_SIZE, HIDDEN2_SIZE, OUTPUT_SIZE)
net.load(MODEL)
data_set = MNISTDataSet(DATA_DIR, BATCH_SIZE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)

# ---------------------------------------------------
# TODO: Modificar para reconstruir el d√≠gito deseado.
# ---------------------------------------------------

#trainer = Trainer(net, criterion, optimizer)
#last_lost = trainer.full_train(1, data_set.trainloader)
#print("Loss after training is", last_lost)