TODO:

In [1]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import seaborn as sns
from tqdm.notebook import tqdm
from tqdm import trange

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import datasets, transforms
from torchvision.datasets import FashionMNIST

from boring_utils.utils import *

%matplotlib inline 
init_graph()
device = get_device()
set_seed(42, strict=True)

In [2]:
DATASET_PATH = "../data"
CHECKPOINT_PATH = "../model/optm_func/"
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

In [3]:
# https://pytorch.org/vision/main/generated/torchvision.transforms.Normalize.html
# transforms.Normalize(mean, std)
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
])

transform_b = transforms.Compose([
                transforms.ToTensor(),
])

train_set = FashionMNIST(
    root=DATASET_PATH, 
    train=True, 
    download=True, 
    transform=transform
)

test_set = FashionMNIST(
    root=DATASET_PATH, 
    train=False, 
    download=True, 
    transform=transform
)

train_set, val_set = torch.utils.data.random_split(train_set, [50000, 10000])

train_loader = DataLoader(
    train_set, batch_size=1024, shuffle=True, drop_last=False)

val_loader = DataLoader(
    val_set, batch_size=1024, shuffle=True, drop_last=False)

test_loader = DataLoader(
    test_set, batch_size=1024, shuffle=True, drop_last=False)

In [4]:
test_set_no_trans = FashionMNIST(
    root=DATASET_PATH, 
    train=False, 
    download=True, 
    transform=transform_b
)

test_loader_no_trans = DataLoader(
    test_set_no_trans, batch_size=1024, shuffle=True, drop_last=False)

Batch Preview

In [5]:
def print_data(dataset, data_loader):
    # raw data, untransformed
    cprint((dataset.data.float() / 255.0).mean().item())
    cprint((dataset.data.float() / 255.0).std().item())

    # transformed data
    imgs, _ = next(iter(data_loader))
    cprint(imgs.mean().item(), imgs.std().item())
    cprint(imgs.max().item(), imgs.min().item())

In [6]:
print_data(test_set, test_loader)

[93mprint_data -> (dataset.data.float() / 255.0).mean().item():[0m
0.2868492901325226
[93mprint_data -> (dataset.data.float() / 255.0).std().item():[0m
0.3524441719055176
[93mprint_data -> imgs.mean().item():[0m
-0.42312583327293396
[93mprint_data -> imgs.std().item():[0m
0.7069889307022095
[93mprint_data -> imgs.max().item():[0m
1.0
[93mprint_data -> imgs.min().item():[0m
-1.0


In [7]:
print_data(test_set_no_trans, test_loader_no_trans)

[93mprint_data -> (dataset.data.float() / 255.0).mean().item():[0m
0.2868492901325226
[93mprint_data -> (dataset.data.float() / 255.0).std().item():[0m
0.3524441719055176
[93mprint_data -> imgs.mean().item():[0m
0.2812195122241974
[93mprint_data -> imgs.std().item():[0m
0.34957683086395264
[93mprint_data -> imgs.max().item():[0m
1.0
[93mprint_data -> imgs.min().item():[0m
0.0


Define activation functions

In [8]:
act_fn_by_name = {}

class Tanh(nn.Module):
    '''
    https://pytorch.org/docs/master/generated/torch.nn.Tanh.html#torch.nn.Tanh
    '''
    def forward(self, x):
        return (torch.exp(x) - torch.exp(-x)) / (torch.exp(x) + torch.exp(-x))

act_fn_by_name['tanh'] = Tanh


class ReLU(nn.Module):
    '''
    https://pytorch.org/docs/master/generated/torch.nn.ReLU.html#torch.nn.ReLU
    '''
    def forward(self, x):
        # return torch.max(0, x)
        return x * (x > 0).float()

act_fn_by_name['relu'] = ReLU


class LeakyReLU(nn.Module):
    '''
    https://pytorch.org/docs/master/generated/torch.nn.LeakyReLU.html#torch.nn.LeakyReLU
    '''
    def __init__(self, negative_slope=0.1):
        super().__init__()
        self.neg_slop = negative_slope
        
    def forward(self, x):
        return torch.where(x > 0, x, self.neg_slop * x)

act_fn_by_name['leakyrelu'] = LeakyReLU


class Identity(nn.Module):
    def forward(self, x):
        return x

act_fn_by_name['identity'] = Identity

# NN

In [13]:
class BaseNN(nn.Module):
    def __init__(self, act_fn, input_size=784, hidden_sizes=[512, 256, 256, 128], num_classes=10):
        super().__init__()
        
        self.act_fn = act_fn
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.num_classes = num_classes

        # Create the network based on the specified hidden sizes
        layers = []
        layer_sizes = [input_size] + hidden_sizes
        for layer_index in range(1, len(layer_sizes)):
            layers += [
                nn.Linear(layer_sizes[layer_index-1], 
                          layer_sizes[layer_index]),
                self.act_fn
                ]
        layers += [nn.Linear(layer_sizes[-1], num_classes)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        # reshape img to flat tensor
        # x = x.view(x.size(0), -1)
        x = x.view(-1, self.input_size)
        x = self.layers(x)
        return x
        


In [14]:
_get_file_name = lambda model_path, model_name, extension='.tar': os.path.join(model_path, model_name + extension)


def load_model(model_path, model_name, act_fn, net=None, **kargs):
    """
    Loads a saved model from disk.
    """
    model_file = _get_file_name(model_path, model_name)
    if net is None:
        net = BaseNN(act_fn=act_fn, **kargs)
    net.load_state_dict(torch.load(model_file, map_location=device))
    return net


def save_model(model, model_path, model_name):
    """
    Given a model, we save the state_dict and hyperparameters.
    
    Inputs:
        model - Network object to save parameters from
        model_path - Path of the checkpoint directory
        model_name - Name of the model (str)
    """
    os.makedirs(model_path, exist_ok=True)
    model_file = _get_file_name(model_path, model_name)
    torch.save(model.state_dict(), model_file)


def test_model(net, data_loader):
    """
    Test a model on a specified dataset.
    
    Inputs:
        net - Trained model of type BaseNetwork
        data_loader - DataLoader object of the dataset to test on (validation or test)
    """
    net.eval()
    true_preds, count = 0., 0
    for imgs, labels in data_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        with torch.no_grad():
            preds = net(imgs).argmax(dim=-1)
            true_preds += (preds == labels).sum().item()
            count += labels.shape[0]
    test_acc = true_preds / count
    return test_acc 