In [None]:
"""
experiments.ipynb

File for running lottery ticket experiments.

Authors: Jordan Bourdeau, Casey Forey
Date Created: 3/8/24
"""

from importlib import reload
import os
import tensorflow as tf
from tensorflow.python.client import device_lib 

from src.experiment_scripts.lenet_300_100_iterative_magnitude_pruning import get_lenet_300_100_experiment_parameters
from src.harness import constants as C
from src.harness import dataset as ds
from src.harness import experiment
from src.harness import history
from src.harness import model as mod
from src.harness import paths
from src.harness import pruning
from src.harness import rewind
from src.harness import training as train

In [None]:
print("Num CPUs Available: ", len(tf.config.list_physical_devices('CPU')))
print("Num CPU Cores Available: ", os.cpu_count())
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print(device_lib.list_local_devices())

## Run Experiments

In [None]:
reload(ds)
reload(experiment)
reload(paths)
reload(pruning)
reload(rewind)
reload(train)

experiment_directory: str = os.path.join(C.EXPERIMENTS_DIRECTORY, 'testing_experiment')
get_experiment_parameters: callable = get_lenet_300_100_experiment_parameters(ds.Datasets.MNIST, 0.2, 0.05, False)

experiment_summary: history.ExperimentSummary = experiment.run_experiments(
    starting_seed=0,
    num_experiments=8, 
    experiment_directory=experiment_directory,
    experiment=experiment.run_iterative_pruning_experiment,
    get_experiment_parameters=get_experiment_parameters,
    max_processes=os.cpu_count(),
)

In [None]:
# Sanity Check
initial_model = mod.load_model(0, 1, directory=experiment_directory)
for experiment in experiment_summary.experiments.values():
    for round in experiment.pruning_rounds:
        print('Initial Weights:')
        print(initial_model.get_weights()[4][0])
        print('Initial Round Weights:')
        print(round.initial_weights[4][0])
        print('Final Round Weights:')
        print(round.final_weights[4][0])
        print('Round Masks:')
        print(round.masks[4][0])
        print()
        