In [7]:
"""
main.ipynb

Main file for recreating lottery ticket experiments done in randomly initialized dense neural networks.

Authors: Jordan Bourdeau, Casey Forey
Date Created: 3/8/24
"""

%load_ext tensorboard
import functools
import numpy as np
import os
import tensorflow as tf

from src.harness import constants as C
from src.harness.dataset import download_data, load_and_process_mnist
from src.harness.experiment import experiment
from src.harness.model import create_model, LeNet300, load_model
from src.harness.pruning import prune_by_percent
from src.harness.training import train
from src.lottery_ticket.foundations import paths

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [8]:
def run_experiment():
    """
    Function used to run the full lottery ticket experiment.
    """
    make_dataset: callable = load_and_process_mnist
    train_model: callable = functools.partial(train, iterations=C.TEST_TRAINING_ITERATIONS)
    prune_masks: callable = functools.partial(prune_by_percent, C.PRUNING_PERCENTS)

    for i in range(C.NUM_MODELS):
        print(f'Model {i + 1}')
        make_model: callable = functools.partial(LeNet300, i)
        experiment(make_dataset, make_model, train_model, prune_masks, C.TEST_PRUNING_STEPS)

run_experiment()

Model 1
Pruning Step 0
Iteration 1/10, Loss: 2.3550610542297363
Iteration 2/10, Loss: 2.271308422088623
Iteration 3/10, Loss: 2.199246883392334
Iteration 4/10, Loss: 2.1306018829345703
Iteration 5/10, Loss: 2.0620598793029785
Iteration 6/10, Loss: 1.9923481941223145
Iteration 7/10, Loss: 1.9212288856506348
Iteration 8/10, Loss: 1.8489227294921875
Iteration 9/10, Loss: 1.7759305238723755
Iteration 10/10, Loss: 1.7027148008346558
Pruning Step 1
Iteration 1/10, Loss: 2.288278341293335
Iteration 2/10, Loss: 2.2094662189483643
Iteration 3/10, Loss: 2.1392059326171875
Iteration 4/10, Loss: 2.0713226795196533
Iteration 5/10, Loss: 2.003674030303955
Iteration 6/10, Loss: 1.9356462955474854
Iteration 7/10, Loss: 1.8672055006027222
Iteration 8/10, Loss: 1.798560380935669
Iteration 9/10, Loss: 1.7300267219543457
Iteration 10/10, Loss: 1.6620514392852783
Pruning Step 2
Iteration 1/10, Loss: 2.2270307540893555
Iteration 2/10, Loss: 2.152894973754883
Iteration 3/10, Loss: 2.084874153137207
Iteration