In [1]:
import numpy as np
import torch
import torch.nn as nn

from GOLDataset import generateDataset
from GOLCNN import OPNet, train_epoch, test_model
from MinimalSolution import MinNet

device = "cuda"

In [2]:
# Ensure test_model() works on the minimal solution CNN
dataset_size = 1000
dataloader = generateDataset(dataSetSize=dataset_size, size=32, n_steps=3)
min_model = MinNet(3)
min_model.to(device)
criterion = nn.MSELoss(reduction='mean')
acc, epoch_test_loss, num_correct, num_wrong = test_model(min_model, dataloader, 1, criterion)
print(f'Accuracy: {acc}, Test Loss: {epoch_test_loss}, Correct: {num_correct}/{dataset_size}, Incorrect: {num_wrong}/{dataset_size}')

Accuracy: 1.0, Test Loss: 3.1864204879236953e-18, Correct: 1000/1000, Incorrect: 0/1000


In [4]:
# Training Parameters
learning_rate = 1e-3
batch_size_param = 64
epochs = 10
checkpoint_rate = 5

m = 16 # Overparameterization Factor
n = 3  # Steps of GOL simulation
model = OPNet(m, n)
criterion = nn.MSELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [5]:
loss = []

model.to(device)

for t in range(1, epochs + 1):
    dataloader = generateDataset(dataSetSize=1000, size=32, n_steps=n, batch_size=batch_size_param)
    epoch_train_loss = train_epoch(model, optimizer, criterion, dataloader, m)
    
    if t % checkpoint_rate == 0:
        acc, epoch_test_loss, num_correct, num_wrong = test_model(model, dataloader, m, criterion)
        loss.append(epoch_test_loss)
        print(f'Epoch: {t}/{epochs}, Test Loss: {epoch_test_loss}, Incorrect: {num_wrong}/1000 examples')
        
print("END OF ERA 1")

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate*0.1)

for t in range(1, epochs + 1):
    dataloader = generateDataset(dataSetSize=1000, size=32, n_steps=n, batch_size=batch_size_param)
    epoch_train_loss = train_epoch(model, optimizer, criterion, dataloader, m)
    
    if t % checkpoint_rate == 0:
        acc, epoch_test_loss, num_correct, num_wrong = test_model(model, dataloader, m, criterion)
        loss.append(epoch_test_loss)
        print(f'Epoch: {t}/{epochs}, Test Loss: {epoch_test_loss}, Incorrect: {num_wrong}/1000 examples')
        
print("END OF ERA 2")
print("DONE!")

Epoch: 6/10, Test Loss: 0.24100607633590698, Incorrect: 1000/1000 examples
Epoch: 11/10, Test Loss: 0.23955798149108887, Incorrect: 1000/1000 examples
END OF ERA 1
Epoch: 6/10, Test Loss: 0.23940441012382507, Incorrect: 1000/1000 examples
Epoch: 11/10, Test Loss: 0.23923274874687195, Incorrect: 1000/1000 examples
END OF ERA 2
DONE!


In [None]:
torch.save(model, f'./models/op_m16_n2_model2.pt')