In [1]:
import numpy as np
import torch
import torch.nn as nn

from GOLDataset import generateDataset
from GOLCNN import OPNet, train_epoch, test_model
from MinimalSolution import MinNet

device = "cuda"

In [2]:
# Seed everything for reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f9886d55c90>

In [3]:
# Ensure test_model() works on the minimal solution CNN
dataset_size = 1000
dataloader = generateDataset(dataSetSize=dataset_size, size=32, n_steps=3)
min_model = MinNet(3)
min_model.to(device)
criterion = nn.MSELoss(reduction='mean')
acc, epoch_test_loss, num_correct, num_wrong = test_model(min_model, dataloader, 1, criterion)
print(f'Accuracy: {acc}, Test Loss: {epoch_test_loss}, Correct: {num_correct}/{dataset_size}, Incorrect: {num_wrong}/{dataset_size}')

Accuracy: 1.0, Test Loss: 3.185951683311531e-18, Correct: 1000/1000, Incorrect: 0/1000


In [4]:
# Data parameters
dataset_size = 1000
datapoint_size = 32

# Training Parameters
learning_rate = 1e-3
batch_size_param = 64
epochs = 10
checkpoint_rate = 5

m = 8 # Overparameterization Factor
n = 2  # Steps of GOL simulation

model_amber = OPNet(m, n)
model_brian = OPNet(m, n)

criterion = nn.MSELoss(reduction='mean')
optimizer_amber = torch.optim.SGD(model_amber.parameters(), lr=learning_rate)
optimizer_brian = torch.optim.SGD(model_brian.parameters(), lr=learning_rate)

In [5]:
model_amber.to(device)
model_brian.to(device)
print('models loaded to device')

models loaded to device


In [6]:
full_data_amber = []
full_data_brian = []
checkpoint_data_amber = []
checkpoint_data_brian = []

for t in range(1, epochs + 1):
    dataloader = generateDataset(dataSetSize=dataset_size, 
                                 size=datapoint_size, 
                                 n_steps=n, 
                                 batch_size=batch_size_param)
    
    epoch_train_loss_amber = train_epoch(model_amber, optimizer_amber, criterion, dataloader, m)
    full_data_amber.append([t, epoch_train_loss_amber])
    
    epoch_train_loss_brian = train_epoch(model_brian, optimizer_brian, criterion, dataloader, m)
    full_data_brian.append([t, epoch_train_loss_brian])
    
    if t % checkpoint_rate == 0:
        acc_amber, epoch_test_loss_amber, num_correct_amber, num_wrong_amber = test_model(model_amber, dataloader, m, criterion)
        checkpoint_name_amber = f'amber_m{m}_n{n}_checkpoint{t}.pt'
        checkpoint_data_amber.append([t, checkpoint_name_amber, acc_amber, epoch_test_loss_amber, num_correct_amber, num_wrong_amber])
        print(f'Amber: Epoch: {t}/{epochs}, Test Loss: {epoch_test_loss_amber}, Incorrect: {num_wrong_amber}/1000 examples')
        torch.save(model_amber, f'./models/{checkpoint_name_amber}')
        
        acc_brian, epoch_test_loss_brian, num_correct_brian, num_wrong_brian = test_model(model_brian, dataloader, m, criterion)
        checkpoint_name_brian = f'brian_m{m}_n{n}_checkpoint{t}.pt'
        checkpoint_data_brian.append([t, checkpoint_name_brian, acc_brian, epoch_test_loss_brian, num_correct_brian, num_wrong_brian])
        print(f'Brian: Epoch: {t}/{epochs}, Test Loss: {epoch_test_loss_brian}, Incorrect: {num_wrong_brian}/1000 examples')
        torch.save(model_amber, f'./models/{checkpoint_name_brian}')
        
print("END OF ERA 1")

optimizer_amber = torch.optim.SGD(model_amber.parameters(), lr=learning_rate*0.1)
optimizer_brian = torch.optim.SGD(model_brian.parameters(), lr=learning_rate*0.1)

for t in range(epochs + 1, 2 * epochs + 1):
    dataloader = generateDataset(dataSetSize=dataset_size, 
                                 size=datapoint_size, 
                                 n_steps=n, 
                                 batch_size=batch_size_param)
    
    epoch_train_loss_amber = train_epoch(model_amber, optimizer_amber, criterion, dataloader, m)
    full_data_amber.append([t, epoch_train_loss_amber])
    
    epoch_train_loss_brian = train_epoch(model_brian, optimizer_brian, criterion, dataloader, m)
    full_data_brian.append([t, epoch_train_loss_brian])
    
    if t % checkpoint_rate == 0:
        acc_amber, epoch_test_loss_amber, num_correct_amber, num_wrong_amber = test_model(model_amber, dataloader, m, criterion)
        checkpoint_name_amber = f'amber_m{m}_n{n}_checkpoint{t}.pt'
        checkpoint_data_amber.append([t, checkpoint_name_amber, acc_amber, epoch_test_loss_amber, num_correct_amber, num_wrong_amber])
        print(f'Amber: Epoch: {t}/{2*epochs}, Test Loss: {epoch_test_loss_amber}, Incorrect: {num_wrong_amber}/1000 examples')
        torch.save(model_amber, f'./models/{checkpoint_name_amber}')
        
        acc_brian, epoch_test_loss_brian, num_correct_brian, num_wrong_brian = test_model(model_brian, dataloader, m, criterion)
        checkpoint_name_brian = f'brian_m{m}_n{n}_checkpoint{t}.pt'
        checkpoint_data_brian.append([t, checkpoint_name_brian, acc_brian, epoch_test_loss_brian, num_correct_brian, num_wrong_brian])
        print(f'Brian: Epoch: {t}/{2*epochs}, Test Loss: {epoch_test_loss_brian}, Incorrect: {num_wrong_brian}/1000 examples')
        torch.save(model_amber, f'./models/{checkpoint_name_brian}')
        
print("END OF ERA 2")
print("DONE!")



Amber: Epoch: 5/10, Test Loss: 0.24753129482269287, Incorrect: 1000/1000 examples
Brian: Epoch: 5/10, Test Loss: 0.24128539860248566, Incorrect: 1000/1000 examples
Amber: Epoch: 10/10, Test Loss: 0.2460220456123352, Incorrect: 1000/1000 examples
Brian: Epoch: 10/10, Test Loss: 0.24002912640571594, Incorrect: 1000/1000 examples
END OF ERA 1
Amber: Epoch: 15/20, Test Loss: 0.24586406350135803, Incorrect: 1000/1000 examples
Brian: Epoch: 15/20, Test Loss: 0.239878848195076, Incorrect: 1000/1000 examples
Amber: Epoch: 20/20, Test Loss: 0.24569861590862274, Incorrect: 1000/1000 examples
Brian: Epoch: 20/20, Test Loss: 0.2397150844335556, Incorrect: 1000/1000 examples
END OF ERA 2
DONE!


In [7]:
checkpoint_data_amber

[[5, 'amber_m8_n2_checkpoint5.pt', 0.0, 0.2475313, 0, 1000],
 [10, 'amber_m8_n2_checkpoint10.pt', 0.0, 0.24602205, 0, 1000],
 [15, 'amber_m8_n2_checkpoint15.pt', 0.0, 0.24586406, 0, 1000],
 [20, 'amber_m8_n2_checkpoint20.pt', 0.0, 0.24569862, 0, 1000]]

In [8]:
full_data_amber

[[1, 0.24893625],
 [2, 0.24861959],
 [3, 0.24831773],
 [4, 0.24800381],
 [5, 0.24769434],
 [6, 0.24738495],
 [7, 0.24708095],
 [8, 0.24679111],
 [9, 0.24647474],
 [10, 0.24618122],
 [11, 0.24602932],
 [12, 0.24597196],
 [13, 0.24594079],
 [14, 0.24590994],
 [15, 0.24588001],
 [16, 0.2458455],
 [17, 0.24583468],
 [18, 0.24579148],
 [19, 0.24577399],
 [20, 0.24571463]]

In [9]:
# torch.save(model, f'./models/op_m16_n2_model2.pt')