In [1]:
import numpy as np
from numpy import ndarray
import plotly.express as px

from optimizers import Adam
from losses import BinaryCrossentropy
from metrics import accuracy, gradient_stats, nn_params_stats, activations_stats
from cifar_10_dataset_loading import load_cifar_10
from layers import Convolutional, Linear, Relu, BatchNorm, Flatten, Softmax, MaxPool, Layer

## Setup

### Data extraction

In [2]:
x_train, y_train, x_test, y_test = load_cifar_10()

In [3]:
classes = y_train.argmax(axis=1)

In [4]:
# Change range from [0, 255] to [-1, 1]
x_train = x_train / 255  
x_train.dtype

dtype('float64')

In [5]:
IMGS_IDX = [0, 351, 5673, 5494, 32, 55, 66, 776, 564]
x = x_train[IMGS_IDX]
y = y_train[IMGS_IDX]
px.imshow(x, facet_col=0, facet_col_wrap=4)

### Model declaration

In [6]:
def create_nn() -> list[Layer]:
    return [
        Convolutional((10, 5, 5, 3), 0.005),
        BatchNorm(),
        Relu(),
        MaxPool((2, 2)),
        Convolutional((10, 3, 3, 10), 0.005),
        BatchNorm(),
        Relu(),
        MaxPool((2, 2)),
        Flatten(),
        Linear(360, 64),
        Relu(),
        Linear(64, y.shape[1]),
        Softmax(),
    ]

## Training

In [7]:
training_stats = (
    Adam(
        create_nn(),
        x,
        y,
        BinaryCrossentropy(),
        starting_lr=0.025,
        lr_decay=0.0005,
        momentum_weight=0.9,
        ada_grad_weight=0.9,
    )
    .optimize_nn(
        epochs=10,
        batch_size=128,
        plt_x="epoch",
        plt_ys=["loss", "accuracy"]
    )
)

FigureWidget({
    'data': [{'hovertemplate': 'variable=loss<br>epoch=%{x}<br>value=%{y}<extra></extra>',
              'legendgroup': 'loss',
              'marker': {'color': '#636efa', 'symbol': 'circle'},
              'mode': 'markers',
              'name': 'loss',
              'orientation': 'v',
              'showlegend': True,
              'type': 'scatter',
              'uid': '50a6c05c-e678-47ad-877f-3da30bc2f1d4',
              'x': array([0]),
              'xaxis': 'x2',
              'y': array([0.32508328]),
              'yaxis': 'y2'},
             {'hovertemplate': 'variable=accuracy<br>epoch=%{x}<br>value=%{y}<extra></extra>',
              'legendgroup': 'accuracy',
              'marker': {'color': '#EF553B', 'symbol': 'circle'},
              'mode': 'markers',
              'name': 'accuracy',
              'orientation': 'v',
              'showlegend': True,
              'type': 'scatter',
              'uid': '1d323453-a0e1-4f55-a4b3-f08358bd236b',
     

In [8]:
N_SAMPLES_PER_CLASS = 10
NB_CLASSES = 10

class_masks = y_train.argmax(1, keepdims=True) == np.arange(NB_CLASSES)
take_first_n_of_class = lambda data, class_idx: data[class_masks[:, class_idx]][:N_SAMPLES_PER_CLASS]
take_first_n_of_each_class = lambda data: np.concatenate([take_first_n_of_class(data, class_idx) for class_idx in range(NB_CLASSES)], axis=0)
even_x_train = take_first_n_of_each_class(x_train)
even_y_train = take_first_n_of_each_class(y_train)
even_x_train.shape

(100, 32, 32, 3)

In [9]:
bad_nn = create_nn()
bad_optimizer = Adam(
    bad_nn,
    even_x_train,
    even_y_train,
    BinaryCrossentropy(),
    starting_lr=0.015,
    lr_decay=0.0001,
    momentum_weight=0.99,
    ada_grad_weight=0.99,
)

In [10]:
bad_training_stats = (
    bad_optimizer
    .optimize_nn(
        epochs=50,
        batch_size=50,
        metrics=[accuracy, nn_params_stats, activations_stats],
        plt_x="epoch",
        plt_ys=[
            "loss", 
            "accuracy",
            "Convolutional_0_kernels_mean",
            "Convolutional_0_kernels_std",
            # "Convolutional_0_biases_mean",
            # "Convolutional_0_biases_std",
            "Convolutional_0_kernels_l1",
            "Convolutional_0_kernels_l2",
            # "Convolutional_0_activation_l1",
            "learning_rate",
        ],
        height=850,
    )
)

FigureWidget({
    'data': [{'hovertemplate': 'variable=loss<br>epoch=%{x}<br>value=%{y}<extra></extra>',
              'legendgroup': 'loss',
              'marker': {'color': '#636efa', 'symbol': 'circle'},
              'mode': 'markers',
              'name': 'loss',
              'orientation': 'v',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'c31ae49f-624c-446f-a397-fc841a0009dc',
              'x': array([0]),
              'xaxis': 'x7',
              'y': array([0.32508214]),
              'yaxis': 'y7'},
             {'hovertemplate': 'variable=accuracy<br>epoch=%{x}<br>value=%{y}<extra></extra>',
              'legendgroup': 'accuracy',
              'marker': {'color': '#EF553B', 'symbol': 'circle'},
              'mode': 'markers',
              'name': 'accuracy',
              'orientation': 'v',
              'showlegend': True,
              'type': 'scatter',
              'uid': 'ba959d77-407c-42fc-888b-75cd39ce3fce',
     

In [12]:
from copy import deepcopy

best_accuracy = 0
try:
    iteration = 0
    while True:
        starting_nn = create_nn()
        nn = deepcopy(starting_nn)
        optimizer = Adam(
            nn,
            even_x_train,   
            even_y_train,
            BinaryCrossentropy(),
            starting_lr=0.025,
            lr_decay=0.0005,
            momentum_weight=0.9,
            ada_grad_weight=0.9,
        )
        training_stats = (
            optimizer
            .optimize_nn(
                epochs=50,
                batch_size=100,
                metrics=[accuracy, nn_params_stats, activations_stats],
            )
        )
        print(iteration, "accuracy:", training_stats["accuracy"].max())
        if training_stats["accuracy"].max() > best_accuracy:
            best_starting_nn, best_nn, best_optimizer, best_training_stats = starting_nn, nn, optimizer, training_stats
            best_accuracy = training_stats["accuracy"].max()
        if training_stats["accuracy"].max() > 50:
            break
        iteration += 1
except KeyboardInterrupt:
    print()
print("Best accuracy:", best_accuracy)

0 accuracy: 0.6
1 accuracy: 0.78
2 accuracy: 0.91
3 accuracy: 0.98
4 accuracy: 0.98
5 accuracy: 1.0
6 accuracy: 0.71
7 accuracy: 0.91
8 accuracy: 0.92
9 accuracy: 0.93
10 accuracy: 0.93

Best accuracy: 1.0
