In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import PIL
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import random

from ConvBN import ConvBN as ConvBN
from LinearBN import LinearBN

In [2]:
random.seed(43)
np.random.seed(43)
torch.manual_seed(43)
torch.cuda.manual_seed_all(43)

In [3]:
class LReLU(nn.Module):
    def __init__(self):
        super(LReLU, self).__init__()
        self.alpha = nn.Parameter(torch.tensor(5.0)) 
    def forward(self, x):
        return torch.nn.functional.relu(self.alpha*x)

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # Normalize with mean 0.5 and std 0.5
])

batch_size= 1500
num_workers=2
pin_memory=True

g = torch.Generator()
g.manual_seed(42)

dataset = torchvision.datasets.MNIST(root='../Data', train=True, download=True, transform=transform)
train_set, val_set = torch.utils.data.random_split(dataset, [58000, 2000])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, generator=g)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

test_set = torchvision.datasets.MNIST(root='../Data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)


In [5]:
if torch.cuda.is_available():
    print("CUDA is available")
else:
    print("CUDA is not available")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

CUDA is available


In [6]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()

        self.conv1_out = 32
        self.conv1_size = 3
        self.conv1_padding = 1


        self.conv2_out = 64
        self.conv2_size = 5
        self.conv2_padding = 2

        self.fc1_out = 512
        self.fc2_out = 10

        self.q = 1e-6
        self.bias_trick_par = nn.Parameter(torch.tensor(0.00005))

        # First Convolutional Block

        self.block1 = ConvBN(in_channels=1, out_channels=self.conv1_out, kernel_size=self.conv1_size, padding=self.conv1_padding, std = .05)
        self.block2 = ConvBN(in_channels=self.conv1_out, out_channels=self.conv2_out, kernel_size=self.conv2_size, padding=self.conv2_padding, std = .05)

        # Second Convolutional Block
       
        self.block3 = LinearBN(in_features = self.conv2_out * (28//2//2) * (28//2//2), out_features=self.fc1_out, std=.3)
        
        
        torch.manual_seed(0)
        self.w2 = nn.Parameter(torch.randn(self.fc1_out, self.fc2_out))
        nn.init.normal_(self.w2, mean=0.0, std=.6)

        self.dropout = nn.Dropout(0.5)

        self.relu = LReLU()




    def forward(self, x):
        
        x = F.max_pool2d(self.relu(self.block1(x)), (2,2), padding=0)
        x = F.max_pool2d(self.relu(self.block2(x)), (2,2), padding=0)
        
        x = x.view(x.size(0), -1)
        
        x = self.relu(self.block3(x))
        x = self.dropout(x)

        x = x + self.bias_trick_par
        x_norm = x / (x.norm(p=2, dim=1, keepdim=True) + self.q)  # Normalize input x
        w2_norm = self.w2 / (self.w2.norm(p=2, dim=1, keepdim=True) + self.q)  # Normalize weights
        x = torch.matmul(x_norm, w2_norm) # Matrix multiplication 

        # Return raw logits (no softmax here, CrossEntropyLoss handles it)
        return x

In [7]:
import torch.optim as optim
import time  # Import time module

train = True
model = Network().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.025, weight_decay=0.00001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = torch.nn.DataParallel(model)

if train:
    
    loss_hist, acc_hist = [], []
    loss_hist_val, acc_hist_val = [], []
    
    best_val_acc = -float('inf')  
    
    for epoch in range(80):
        start_time = time.time()  
    
        running_loss = 0.0
        correct = 0
        for data in train_loader:
            batch, labels = data
            batch, labels = batch.to(device), labels.to(device)
    
            optimizer.zero_grad()
            outputs = model(batch)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            # Compute training statistics
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item()
    
        avg_loss = running_loss / len(train_set)
        avg_acc = correct / len(train_set)
        loss_hist.append(avg_loss)
        acc_hist.append(avg_acc)
    
        # Validation statistics
        model.eval()
        with torch.no_grad():
            loss_val = 0.0
            correct_val = 0
            for data in val_loader:
                batch, labels = data
                batch, labels = batch.to(device), labels.to(device)
                outputs = model(batch)
                loss = criterion(outputs, labels)
                _, predicted = torch.max(outputs, 1)
                correct_val += (predicted == labels).sum().item()
                loss_val += loss.item()
            avg_loss_val = loss_val / len(val_set)
            avg_acc_val = correct_val / len(val_set)
            loss_hist_val.append(avg_loss_val)
            acc_hist_val.append(avg_acc_val)
        model.train()
    
        scheduler.step(avg_loss_val)
    
        if avg_acc_val > best_val_acc:
            best_val_acc = avg_acc_val
            if torch.cuda.device_count() > 1:
                torch.save(model.module.state_dict(), 'best_model_mnist.pt')
            else:
                torch.save(model.state_dict(), 'best_model_mnist.pt')
    
        elapsed_time = time.time() - start_time 
    
        print('[epoch %d] loss: %.5f accuracy: %.4f val loss: %.5f val accuracy: %.4f time: %.2f seconds' %
              (epoch + 1, avg_loss, avg_acc, avg_loss_val, avg_acc_val, elapsed_time))


[epoch 1] loss: 0.00069 accuracy: 0.7182 val loss: 0.00017 val accuracy: 0.9550 time: 6.04 seconds
[epoch 2] loss: 0.00013 accuracy: 0.9723 val loss: 0.00008 val accuracy: 0.9805 time: 4.62 seconds
[epoch 3] loss: 0.00008 accuracy: 0.9863 val loss: 0.00005 val accuracy: 0.9880 time: 4.85 seconds
[epoch 4] loss: 0.00007 accuracy: 0.9893 val loss: 0.00005 val accuracy: 0.9875 time: 4.76 seconds
[epoch 5] loss: 0.00006 accuracy: 0.9912 val loss: 0.00004 val accuracy: 0.9865 time: 4.82 seconds
[epoch 6] loss: 0.00007 accuracy: 0.9900 val loss: 0.00006 val accuracy: 0.9850 time: 4.82 seconds
[epoch 7] loss: 0.00006 accuracy: 0.9920 val loss: 0.00005 val accuracy: 0.9885 time: 4.79 seconds
[epoch 8] loss: 0.00005 accuracy: 0.9934 val loss: 0.00005 val accuracy: 0.9880 time: 4.91 seconds
[epoch 9] loss: 0.00005 accuracy: 0.9929 val loss: 0.00005 val accuracy: 0.9890 time: 4.80 seconds
[epoch 10] loss: 0.00008 accuracy: 0.9828 val loss: 0.00012 val accuracy: 0.9690 time: 4.96 seconds
[epoch 11

In [8]:
model.eval()  # Set the model to evaluation mode

test_loss = 0.0
correct_test = 0

test_loader = torch.utils.data.DataLoader(test_set, batch_size=2000, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)


# Evaluate on the test dataset
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        
        # Get predictions and update the correct count
        _, predicted = torch.max(outputs, 1)
        correct_test += (predicted == labels).sum().item()

# Compute average loss and accuracy for the test set
avg_test_loss = test_loss / len(test_set)
avg_test_acc = correct_test / len(test_set)

print(f"Test loss: {avg_test_loss:.5f}, Test accuracy: {avg_test_acc:.4f}")

Test loss: 0.00002, Test accuracy: 0.9930


In [9]:
torch.save(model.state_dict(), 'MNIST_GNet_Training_99.30.pth')