In [1]:
"""
testing.ipynb

File for performing testing to implement lottery ticket experiments.

Authors: Jordan Bourdeau, Casey Forey
Date Created: 3/8/24
"""

%load_ext tensorboard
import functools
import numpy as np
import os
import tensorflow as tf

from src.harness import constants as C
from src.harness.dataset import download_data, load_and_process_mnist
from src.harness.experiment import experiment
from src.harness.model import create_model, LeNet300, load_model
from src.harness.pruning import prune_by_percent
from src.harness.training import train
from src.lottery_ticket.foundations import paths

In [2]:
# Create a model
X_train, Y_train, X_test, Y_test = load_and_process_mnist()
model = create_model(0, X_train, Y_train)
initial_weights1: dict[str: np.array] = model.get_current_weights()

In [3]:
# Test Training a model
optimizer = C.OPTIMIZER()
make_dataset: callable = load_and_process_mnist
initial_weights2, final_weights = train(make_dataset, model, 0, optimizer, C.TEST_TRAINING_ITERATIONS)

for i in range(3):
    key: str = f'layer{i}'
    # Sanity check that the initial weights are correct
    assert np.array_equal(initial_weights1[key], initial_weights2[key])
    # Verify the final weights are different from initial weights
    assert not np.array_equal(initial_weights2[key], final_weights[key])


Iteration 1/10, Loss: 2.3580663204193115
Iteration 2/10, Loss: 2.2764296531677246
Iteration 3/10, Loss: 2.2057595252990723
Iteration 4/10, Loss: 2.1410763263702393
Iteration 5/10, Loss: 2.079235076904297
Iteration 6/10, Loss: 2.018141984939575
Iteration 7/10, Loss: 1.9564679861068726
Iteration 8/10, Loss: 1.8935106992721558
Iteration 9/10, Loss: 1.829115867614746
Iteration 10/10, Loss: 1.763397455215454


In [4]:
# Try creating a model from the initial weights as a preset
make_model: callable = functools.partial(LeNet300, 0)
percents: dict[str: float] = {key: 0.5 for key in final_weights}
starting_masks: dict[str: np.ndarray] = {f'layer{i}': np.ones(initial_weights2[f'layer{i}'].shape) for i in range(3)}

# Create pruned masks
masks = prune_by_percent(C.PRUNING_PERCENTS, starting_masks, final_weights)

# This is transforming initial weights into a tensor when it should be Numpy?
# Passing in masks is the issue
model = make_model(X_train, Y_train, presets=initial_weights2, masks=masks)

# Sanity check that the weights are correctly loaded and are masked off accordingly
for i in range(3):
    key = f'layer{i}'
    layer_weights = model.get_current_weights()[key]
    layer_mask = model.masks[key]
    expected_weights: np.ndarray = initial_weights2[key] * layer_mask
    assert np.array_equal(expected_weights, layer_weights), f'Expected {expected_weights} but received {layer_weights}'
    assert np.array_equal(model.masks[key], masks[key])

# Save the tensors storing the actual weight values (these include the masked off weights)
pretrained_weights = model.weights.copy()

# Try doing a a simulated round of pruning
initial_weights3, final_weights2 = train(make_dataset, model, 1, optimizer, C.TEST_TRAINING_ITERATIONS)

# Make sure the masked off weights don't receive any updates in the actual tensorflow tensor
trained_weights = model.weights

# Compare the masked weights before and after training
for i in range(3):
    key = f'layer{i}'
    pretrained_layer_weights = pretrained_weights[key]
    trained_layer_weights = trained_weights[key]
    
    # Invert the mask, to only look at the weights which WERE masked off
    inverted_mask: np.ndarray = 1 - masks[key]
    masked_pretrained_weights = pretrained_layer_weights * inverted_mask
    masked_trained_weights = trained_layer_weights * inverted_mask
    
    # Assert that the masked weights remain unchanged after training
    assert np.array_equal(masked_pretrained_weights, masked_trained_weights), f'Weights changed after training for layer {key}'
    assert not np.array_equal(initial_weights3, final_weights2)


Iteration 1/10, Loss: 2.2923543453216553
Iteration 2/10, Loss: 2.219968795776367
Iteration 3/10, Loss: 2.154776096343994
Iteration 4/10, Loss: 2.0936989784240723
Iteration 5/10, Loss: 2.0346004962921143
Iteration 6/10, Loss: 1.9759339094161987
Iteration 7/10, Loss: 1.9168710708618164
Iteration 8/10, Loss: 1.8570656776428223
Iteration 9/10, Loss: 1.7964187860488892
Iteration 10/10, Loss: 1.7350682020187378


In [5]:
MODEL_INDEX: int = 0
# Get initial weights
dir: str = f'models/model_{MODEL_INDEX}/initial/'
weight_files = [paths.weights(dir) + f'/layer{i}.npy' for i in range(3)]
mask_files = [paths.masks(dir) + f'/layer{i}.npy' for i in range(3)]

layer_weights = {f'layer{i}': np.load(layer) for i, layer in enumerate(weight_files)}
masks = {f'layer{i}': np.load(layer) for i, layer in enumerate(mask_files)}
# Test loading a model
model: LeNet300 = load_model(MODEL_INDEX, 0, True)

for i in range(3):
    key: str = f'layer{i}'
    # Verify all the layer weights match
    assert np.array_equal(model.weights[key], layer_weights[key])
    # Verify all masks are 1s
    assert np.sum(masks[key]) == masks[key].size


In [6]:
# Test pruning
print([(key, layer.shape) for key, layer in layer_weights.items()])
percents: dict[str: float] = {key: 0.5 for key in layer_weights}
new_masks: dict[str, np.array] = prune_by_percent(percents, masks, layer_weights)
for key in new_masks:
    new_mask: np.array = new_masks[key]
    old_mask: np.array = masks[key]
    assert (old_mask.sum() / 2 - new_mask.sum()) <= 1, f'Doesn\'t match for key {key}'

[('layer0', (784, 300)), ('layer1', (300, 100)), ('layer2', (100, 10))]


In [7]:
# Test experiment
make_dataset: callable = load_and_process_mnist
# Make partial function application giving the model its random seed
make_model: callable = functools.partial(LeNet300, 0)
train_model: callable = functools.partial(train, iterations=C.TEST_TRAINING_ITERATIONS)
prune_masks: callable = functools.partial(prune_by_percent, C.PRUNING_PERCENTS)
experiment(make_dataset, make_model, train_model, prune_masks, C.TEST_PRUNING_STEPS)

Pruning Step 0
Iteration 1/10, Loss: 2.3972244262695312
Iteration 2/10, Loss: 2.298975706100464
Iteration 3/10, Loss: 2.2233521938323975
Iteration 4/10, Loss: 2.155620574951172
Iteration 5/10, Loss: 2.090500593185425
Iteration 6/10, Loss: 2.025655746459961
Iteration 7/10, Loss: 1.9601891040802002
Iteration 8/10, Loss: 1.8937861919403076
Iteration 9/10, Loss: 1.8262752294540405
Iteration 10/10, Loss: 1.7576806545257568
Pruning Step 1
Iteration 1/10, Loss: 2.328902006149292
Iteration 2/10, Loss: 2.2418630123138428
Iteration 3/10, Loss: 2.170030117034912
Iteration 4/10, Loss: 2.1039531230926514
Iteration 5/10, Loss: 2.0400655269622803
Iteration 6/10, Loss: 1.9768481254577637
Iteration 7/10, Loss: 1.9134869575500488
Iteration 8/10, Loss: 1.8495373725891113
Iteration 9/10, Loss: 1.784929633140564
Iteration 10/10, Loss: 1.7198681831359863
Pruning Step 2
Iteration 1/10, Loss: 2.2701938152313232
Iteration 2/10, Loss: 2.1930620670318604
Iteration 3/10, Loss: 2.1257569789886475
Iteration 4/10, L