# Training a lot of models initialized via LP method

There will be some amount of code duplication between this and investigation.ipynb, and the code is *rough* in places. For now, just getting it all to run is good enough.

In [12]:
import math
import time

from direct_helpers import *

from scipy.optimize import linprog
import numpy as np

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision
from torchvision.datasets import MNIST


In [13]:
def get_mnist_lp_model():
    """
    Get an MnistReluCountModel with weights initialized by LP method

    This isn't the prettiest, but it'll do.
    """
    loader = DataLoader(MNIST('images/', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
                                   ])),
                              batch_size=1, shuffle=True, pin_memory=False)
    opt_model = MnistReluCountModel()
    in_size = 784
    out_size = 750
    
    c_initials= []
    for _ in range(out_size):
        c_initials.append(torch.flatten(get_avg_MNIST(loader=loader, sample_size=1), 0) * -1) # linprog minimizes, we want to maximize
    opt_model.fc1.weight = linprog_parameter(c_initials, in_size)

    
    in_size = out_size
    out_size = 320
    
    relu = nn.ReLU()
    c2_samples = []
    for _ in range(out_size):
        rand_sample = torch.flatten(get_avg_MNIST(loader=loader, sample_size=1), 0) # get a random sample
        sample = relu(opt_model.fc1(rand_sample)) * -1 # pass through layer 1, then scale by -1 for LP
        c2_samples.append(sample)
    opt_model.fc2.weight = linprog_parameter(c2_samples, in_size)

    in_size = out_size
    out_size = 50
    
    c3_samples = []
    for _ in range(out_size):
        rand_sample = torch.flatten(get_avg_MNIST(loader=loader, sample_size=1), 0) # get a random sample
        sample = relu(opt_model.fc1(rand_sample)) # pass through layer 1
        sample = relu(opt_model.fc2(sample)) * -1 # pass through layer 2, then scale by -1 for LP
        c3_samples.append(sample)
    opt_model.fc3.weight = linprog_parameter(c3_samples, in_size)

    in_size = out_size
    out_size = 10
    
    c4_samples = []
    for _ in range(out_size):
        rand_sample = torch.flatten(get_avg_MNIST(loader=loader, sample_size=1), 0) # get a random sample
        sample = relu(opt_model.fc1(rand_sample)) # pass through layer 1
        sample = relu(opt_model.fc2(sample)) # pass through layer 2
        sample = relu(opt_model.fc3(sample)) * -1 # pass through layer 3, then scale by -1 for LP
        c4_samples.append(sample)
    opt_model.fc4.weight = linprog_parameter(c4_samples, in_size)

    return opt_model

In [14]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    total_loss = 0
    batches = 0
    
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()
        batches += 1
        
    return total_loss / batches


def test_loop(dataloader, model, loss_fn, quiet=False):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0
    correct = 0
    
    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for (X, y) in dataloader:
            out = model(X)
            test_loss += loss_fn(out, y).item()
            pred = out.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).sum().item()

    test_loss /= num_batches
    accuracy = 100.0 * correct / len(dataloader.dataset)
    if not quiet:
        print(f"Avg loss: {test_loss:>8f}")
        print(f"Accuracy: {correct}/{len(dataloader.dataset)} = {accuracy}")
        print()
    return test_loss, accuracy

In [20]:
class TrainingResults():
    def __init__(self, train_losses, test_losses, test_accs, init_zeros, trained_zeros):
        self.train_losses = train_losses
        self.test_losses = test_losses
        self.test_accs = test_accs
        self.initial_zeros = init_zeros
        self.trained_zeros = trained_zeros

    def __repr__(self):
        init_zeros =    f"Initial 0 count:        {self.initial_zeros}\n"
        trained_zeros = f"0 count after training: {self.trained_zeros}\n"
        train = f"Train losses:    {self.train_losses}\n"
        test =  f"Test losses:     {self.test_losses}\n"
        accs =  f"Test Accuracies: {self.test_accs}\n"
        return init_zeros + trained_zeros + train + test + accs


def train_lp_mnist_model(train_loader:DataLoader, test_loader:DataLoader, epochs=5, lr=.01):
    train_losses = []
    test_losses = []
    test_accs = []
    
    model = get_mnist_lp_model()
    initial_zeros = count_relu_0s(test_loader, model)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    loss_function = nn.CrossEntropyLoss()
    for t in range(epochs):
        train_loss = train_loop(train_loader, model, loss_function, optimizer)
        
        test_loss, test_acc = test_loop(test_loader, model, loss_function, quiet=True)

        train_losses.append(train_loss)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        
    trained_zeros = count_relu_0s(test_loader, model)
    results = TrainingResults(train_losses, test_losses, test_accs, 
                              initial_zeros, trained_zeros)
    return results

In [25]:
class CollectedResults():
    def __init__(self, results):
        self.all_results = results
        # should convert all results into ndarrays internally
        # should also allow construction by reading files

    def write_train_losses(self, file_path:str="train_losses.csv"):
        pass

    def write_test_losses(self, file_path:str="test_losses.csv"):
        pass

    def write_test_accs(self, file_path:str="test_accs.csv"):
        pass

    def write_zero_counts(self, file_path:str="zero_counts.csv"):
        pass
        
    def write_to_default_files(self):
        self.write_train_losses()
        self.write_test_losses()
        self.write_test_accs()
        self. write_zero_counts()


def train_many(num_to_train=2, epochs=5, lr=.01):
    train_loader = DataLoader(MNIST('images/', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
                               ])),
                          batch_size=batch_size_train, shuffle=True, pin_memory=False)

    test_loader = DataLoader(MNIST('images/', train=False, download=True,
                                   transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
                                   ])),
                             batch_size=batch_size_test, shuffle=True)
    results = []
    for _ in range(num_to_train):
        results.append(train_lp_mnist_model(train_loader, test_loader, epochs, lr))

    collected_results = CollectedResults(results)
    return collected_results

In [16]:
batch_size_train = 64
batch_size_test = 1000

train_loader = DataLoader(MNIST('images/', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
                               ])),
                          batch_size=batch_size_train, shuffle=True, pin_memory=False)

test_loader = DataLoader(MNIST('images/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
                               ])),
                         batch_size=batch_size_test, shuffle=True)

In [21]:
start = time.time()
results = train_lp_mnist_model(train_loader, test_loader)
stop = time.time()
print(f"Took {stop - start} seconds!")

Took 123.58857035636902 seconds!


In [22]:
print(results)

Initial 0 count:        518531
0 count after training: 1185827
Train losses:    [1.3455590092614769, 0.7725152661487746, 0.7070979224339223, 0.6758331570988779, 0.6539220727646529]
Test losses:     [0.8996824741363525, 0.726496160030365, 0.7187384366989136, 0.67099609375, 0.6722821354866028]
Test Accuracies: [73.91, 81.5, 81.72, 83.96, 84.25]



In [26]:
start = time.time()
cr = train_many()
stop = time.time()
print(f"Took {stop - start} seconds!")

Took 224.89110589027405 seconds!


In [27]:
for result in cr.all_results:
    print(result)
    print()

Initial 0 count:        433476
0 count after training: 1367100
Train losses:    [1.3062069416681588, 1.0145195541478425, 0.8343594954339172, 0.7296054085243994, 0.6997699674957597]
Test losses:     [1.0711086869239808, 0.9713546991348266, 0.7518130004405975, 0.7102375149726867, 0.6880674421787262]
Test Accuracies: [60.44, 62.68, 72.06, 73.18, 73.59]


Initial 0 count:        483937
0 count after training: 1196600
Train losses:    [1.5089067071358533, 1.1702006868462065, 1.1165456987901536, 1.0902361566705236, 1.0724759242936237]
Test losses:     [1.1998318195343018, 1.1066844463348389, 1.1232999205589294, 1.0551402986049652, 1.044065886735916]
Test Accuracies: [64.5, 66.2, 65.95, 67.62, 67.8]


