In [1]:
"""
main.ipynb

Main file for recreating lottery ticket experiments done in randomly initialized dense neural networks.

Authors: Jordan Bourdeau, Casey Forey
Date Created: 3/8/24
"""

%load_ext tensorboard
import functools
import numpy as np
import os
import tensorflow as tf

from src.harness import constants as C
from src.harness.dataset import download_data, load_and_process_mnist
from src.harness.experiment import experiment
from src.harness.model import create_model, LeNet300, load_model
from src.harness.pruning import prune_by_percent
from src.harness.training import train
from src.lottery_ticket.foundations import paths

In [2]:
def run_experiment():
    """
    Function used to run the full lottery ticket experiment.
    """
    make_dataset: callable = load_and_process_mnist
    train_model: callable = functools.partial(train, iterations=C.TRAINING_ITERATIONS)
    prune_masks: callable = functools.partial(prune_by_percent, C.PRUNING_PERCENTS)

    for i in range(C.NUM_MODELS):
        make_model: callable = functools.partial(LeNet300, i)
        experiment(make_dataset, make_model, train_model, prune_masks, C.TEST_PRUNING_STEPS)

run_experiment()

Pruning Step 0
Iteration 1/1000, Loss: 2.3766071796417236
Iteration 2/1000, Loss: 2.27860426902771
Iteration 3/1000, Loss: 2.2031211853027344
Iteration 4/1000, Loss: 2.135913848876953
Iteration 5/1000, Loss: 2.072058916091919
Iteration 6/1000, Loss: 2.008862018585205
Iteration 7/1000, Loss: 1.9449796676635742
Iteration 8/1000, Loss: 1.8799101114273071
Iteration 9/1000, Loss: 1.8137075901031494
Iteration 10/1000, Loss: 1.7467494010925293
Iteration 11/1000, Loss: 1.6797231435775757
Iteration 12/1000, Loss: 1.6132572889328003
Iteration 13/1000, Loss: 1.5479921102523804
Iteration 14/1000, Loss: 1.4843674898147583
Iteration 15/1000, Loss: 1.4228434562683105
Iteration 16/1000, Loss: 1.3637536764144897
Iteration 17/1000, Loss: 1.3072665929794312
Iteration 18/1000, Loss: 1.2535427808761597
Iteration 19/1000, Loss: 1.2027026414871216
Iteration 20/1000, Loss: 1.1547727584838867
Iteration 21/1000, Loss: 1.1097360849380493
Iteration 22/1000, Loss: 1.0675314664840698
Iteration 23/1000, Loss: 1.0280