In [1]:
"""
testing.ipynb

File for performing testing to implement lottery ticket experiments.

Authors: Jordan Bourdeau, Casey Forey
Date Created: 3/8/24
"""

%load_ext tensorboard
import functools
import numpy as np
import os
import tensorflow as tf

from src.harness.dataset import download_data, load_and_process_mnist
from src.harness.model import create_model, LeNet300, load_model
from src.harness.pruning import prune_by_percent
from src.harness.training import train
from src.lottery_ticket.foundations import paths

In [2]:
X_train, Y_train, X_test, Y_test = load_and_process_mnist()
model = create_model(0, X_train, Y_train)
initial_weights1: dict[str: np.array] = model.get_current_weights()

optimizer = tf.keras.optimizers.legacy.SGD(.01)
initial_weights2, final_weights = train(load_and_process_mnist, model, optimizer, 3)

# Sanity check that the initial weights are correct
for i in range(3):
    key: str = f'layer{i}'
    assert np.array_equal(initial_weights1[key], initial_weights2[key])

Iteration 1/3, Loss: 2.3580663204193115
Iteration 2/3, Loss: 2.349560499191284
Iteration 3/3, Loss: 2.341191291809082


In [3]:
MODEL_INDEX: int = 0
# Get initial weights
dir: str = f'models/model_{MODEL_INDEX}/initial/'
weight_files = [paths.weights(dir) + f'/layer{i}.npy' for i in range(3)]
mask_files = [paths.masks(dir) + f'/layer{i}.npy' for i in range(3)]

layer_weights = {f'layer{i}': np.load(layer) for i, layer in enumerate(weight_files)}
masks = {f'layer{i}': np.load(layer) for i, layer in enumerate(mask_files)}
# Test loading a model
model: LeNet300 = load_model(MODEL_INDEX, 0, True)

for i in range(3):
    key: str = f'layer{i}'
    # Verify all the layer weights match
    assert np.array_equal(model.weights[key], layer_weights[key])
    # Verify all masks are 1s
    assert np.sum(masks[key]) == masks[key].size


In [4]:
# Test pruning
print([(key, layer.shape) for key, layer in layer_weights.items()])
percents: dict[str: float] = {key: 0.5 for key in layer_weights}
new_masks: dict[str, np.array] = prune_by_percent(percents, masks, layer_weights)
for key in new_masks:
    new_mask: np.array = new_masks[key]
    old_mask: np.array = masks[key]
    assert (old_mask.sum() / 2 - new_mask.sum()) <= 1, f'Doesn\'t match for key {key}'

[('layer0', (784, 300)), ('layer1', (300, 100)), ('layer2', (100, 10))]
