Contents of main.py

In [9]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from model import CNN_GSGD, GSGDOptimizer
from train import train, test
import os
import torch.nn as nn
from torch.utils.data import random_split

# Data loading, model setup, and main training loop code here

In [10]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [11]:
# Define the path where the data should be stored
data_path = './data'

# Check if the data directory exists
download_data = not os.path.exists(os.path.join(data_path, 'MNIST'))

# Define the transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load the datasets, only downloading if necessary
train_dataset = datasets.MNIST(data_path, train=True, download=download_data, transform=transform)
test_dataset = datasets.MNIST(data_path, train=False, download=download_data, transform=transform)


In [12]:
# Split the train_dataset into training and validation datasets
train_size = int(0.8 * len(train_dataset))  # 80% for training
validation_size = len(train_dataset) - train_size  # Remaining 20% for validation
train_dataset, validation_dataset = random_split(train_dataset, [train_size, validation_size])

In [13]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)  

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN_GSGD().to(device)
optimizer = GSGDOptimizer(model.parameters(), lr=0.01, rho=10, revisit_batch_num=2)



In [15]:
# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Run training for each epoch, passing the dataset instead of DataLoader
for epoch in range(1, 5):
    train(model, device, train_loader.dataset, validation_loader.dataset, optimizer, epoch, loss_fn, verification_set_num=4)
    # train(model, device, train_loader.dataset, optimizer, epoch, loss_fn)  # Pass train_loader.dataset
    test(model, device, test_loader)



Train Epoch: 1 [0/47744] Loss: 2.306463
Train Epoch: 1 [640/47744] Loss: 2.221426
Train Epoch: 1 [1280/47744] Loss: 2.177142
Train Epoch: 1 [1920/47744] Loss: 2.067878
Train Epoch: 1 [2560/47744] Loss: 1.954833
Train Epoch: 1 [3200/47744] Loss: 1.841123
Train Epoch: 1 [3840/47744] Loss: 1.692158
Train Epoch: 1 [4480/47744] Loss: 1.603419
Train Epoch: 1 [5120/47744] Loss: 1.310480
Train Epoch: 1 [5760/47744] Loss: 1.107646
Train Epoch: 1 [6400/47744] Loss: 0.922329
Train Epoch: 1 [7040/47744] Loss: 0.916762
Train Epoch: 1 [7680/47744] Loss: 0.968069
Train Epoch: 1 [8320/47744] Loss: 0.622095
Train Epoch: 1 [8960/47744] Loss: 0.790213
Train Epoch: 1 [9600/47744] Loss: 0.741122
Train Epoch: 1 [10240/47744] Loss: 0.678021
Train Epoch: 1 [10880/47744] Loss: 0.864908
Train Epoch: 1 [11520/47744] Loss: 0.538610
Train Epoch: 1 [12160/47744] Loss: 0.535813
Train Epoch: 1 [12800/47744] Loss: 0.994482
Train Epoch: 1 [13440/47744] Loss: 0.526758
Train Epoch: 1 [14080/47744] Loss: 0.587241
Train Ep