Contents of main.py

In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from model import CNN_GSGD, GSGDOptimizer
from train import train, test
import os
import torch.nn as nn
from torch.utils.data import random_split

# Data loading, model setup, and main training loop code here

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [3]:
# Define the path where the data should be stored
data_path = './data'

# Check if the data directory exists
download_data = not os.path.exists(os.path.join(data_path, 'MNIST'))

# Define the transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load the datasets, only downloading if necessary
train_dataset = datasets.MNIST(data_path, train=True, download=download_data, transform=transform)
test_dataset = datasets.MNIST(data_path, train=False, download=download_data, transform=transform)


In [4]:
# Split the train_dataset into training and validation datasets
train_size = int(0.8 * len(train_dataset))  # 80% for training
validation_size = len(train_dataset) - train_size  # Remaining 20% for validation
train_dataset, validation_dataset = random_split(train_dataset, [train_size, validation_size])

In [5]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)  

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN_GSGD().to(device)
# For Standard Guided SGD
# optimizer = GSGDOptimizer(model.parameters(), lr=0.01, method='sgd')

# For Guided SGD with Momentum
optimizer = GSGDOptimizer(model.parameters(), lr=0.01, method='momentum', momentum=0.9)

# For Guided Adam
# optimizer = GSGDOptimizer(model.parameters(), lr=0.001, method='adam', beta1=0.9, beta2=0.999)



In [10]:
# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Run training for each epoch, passing the dataset instead of DataLoader
for epoch in range(1, 5):
    train(model, device, train_loader.dataset, validation_loader.dataset, optimizer, epoch, loss_fn, 
      verification_set_num=4, rho=10, log_interval=99)
    # train(model, device, train_loader.dataset, optimizer, epoch, loss_fn)  # Pass train_loader.dataset
    test(model, device, test_loader)



Epoch: 1, Iteration: 1, Loss: 0.016474
Epoch: 1, Iteration: 101, Loss: 0.008952
Epoch: 1, Iteration: 201, Loss: 0.009342
Epoch: 1, Iteration: 301, Loss: 0.001544
Epoch: 1, Iteration: 401, Loss: 0.004520
Epoch: 1, Iteration: 501, Loss: 0.005254
Epoch: 1, Iteration: 601, Loss: 0.034630
Epoch: 1, Iteration: 701, Loss: 0.004891

Test set: Average loss: 0.0000, Accuracy: 9910/10000 (99%)

Epoch: 2, Iteration: 1, Loss: 0.002105
Epoch: 2, Iteration: 101, Loss: 0.000875
Epoch: 2, Iteration: 201, Loss: 0.030991
Epoch: 2, Iteration: 301, Loss: 0.101624
Epoch: 2, Iteration: 401, Loss: 0.006409
Epoch: 2, Iteration: 501, Loss: 0.000988
Epoch: 2, Iteration: 601, Loss: 0.009548
Epoch: 2, Iteration: 701, Loss: 0.000765

Test set: Average loss: 0.0000, Accuracy: 9908/10000 (99%)

Epoch: 3, Iteration: 1, Loss: 0.007994
Epoch: 3, Iteration: 101, Loss: 0.001677
Epoch: 3, Iteration: 201, Loss: 0.013523
Epoch: 3, Iteration: 301, Loss: 0.011326
Epoch: 3, Iteration: 401, Loss: 0.054506
Epoch: 3, Iteration: 50