In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import PIL
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim

# from Modules import ConvBN, PoolConvBN, PoolLinearBN, SharpCosSim2d, SharpCosSimLinear, LReLU

from ConvBN import ConvBN as ConvBN_BiasTrick
from LinearBN import LinearBN
torch.manual_seed(0)

<torch._C.Generator at 0x14e1524731f0>

In [2]:
class LReLU(nn.Module):
    def __init__(self):
        super(LReLU, self).__init__()
        self.alpha = nn.Parameter(torch.tensor(5.0)) 
    def forward(self, x):
        return torch.nn.functional.relu(self.alpha*x)

In [3]:
train_transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(p=0.5),
     transforms.RandomAffine(degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.9, 1.1)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size= 2000
num_workers=2
pin_memory=True

dataset = torchvision.datasets.CIFAR10(root='../', train=True, download=True, transform=train_transform)
train_set, val_set = torch.utils.data.random_split(dataset, [49500, 500])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

test_set = torchvision.datasets.CIFAR10(root='../', train=False, download=True, transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:05<00:00, 33.0MB/s] 


Extracting ../cifar-10-python.tar.gz to ../
Files already downloaded and verified


In [4]:
if torch.cuda.is_available():
    print("CUDA is available")
else:
    print("CUDA is not available")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

CUDA is available


In [5]:
class Network(nn.Module):
    def __init__(self, input_channels=3, num_classes=10, image_size = 24):
        super(Network, self).__init__()

        self.image_size = image_size
        self.in_dim = input_channels
        self.conv1_out = 32
        self.conv1_size = 5
        self.conv1_padding = 2


        self.conv2_out = 64
        self.conv2_size = 3
        self.conv2_padding = 1

        self.fc1_out = 512
        self.fc2_out = num_classes

        self.q = 1e-6
        self.bias_trick_par = nn.Parameter(torch.tensor(0.00005))

        # First Convolutional Block

        self.block1 = ConvBN_BiasTrick(in_channels=self.in_dim, out_channels=self.conv1_out,
                             kernel_size=self.conv1_size, padding=self.conv1_padding, std = .1)

        # Second Convolutional Block

        self.block2 = ConvBN_BiasTrick(in_channels=self.conv1_out, out_channels=self.conv2_out,
                 kernel_size=3, stride=1, padding=1, std = .05)
       
        self.block3 = LinearBN(in_features = self.conv2_out * (32//2 //2) * (32//2 //2), 
                                         out_features=self.fc1_out)
        
        
        torch.manual_seed(0)
        self.w2 = nn.Parameter(torch.randn(self.fc1_out, self.fc2_out))
        nn.init.normal_(self.w2, mean=0.0, std=.6)

        self.dropout = nn.Dropout(0.5)
        self.dropout2d = nn.Dropout2d(0.3)

        self.relu = LReLU()




    def forward(self, x):
        
        x = F.max_pool2d(self.relu(self.block1(x)), (2,2), padding=0)
        x = F.max_pool2d(self.relu(self.block2(x)), (2,2), padding=0)
        x = self.dropout2d(x)
        x = x.view(x.size(0), -1)
        
        x = self.relu(self.block3(x))

        x = x + self.bias_trick_par
        x_norm = x / (x.norm(p=2, dim=1, keepdim=True) + self.q)  # Normalize input x
        w2_norm = self.w2 / (self.w2.norm(p=2, dim=1, keepdim=True) + self.q)  # Normalize weights
        x = torch.matmul(x_norm, w2_norm) # Matrix multiplication 

        # Return raw logits (no softmax here, CrossEntropyLoss handles it)
        return x

    def init_hdc(self, nHDC):
        self.block1.init_hdc(nHDC)
        self.block2.init_hdc(nHDC)
        self.block3.init_hdc(nHDC)
        
        self.g = torch.randn(self.w2.size(0), nHDC, device=self.w2.device).to(torch.half)
        self.wg = torch.sign(torch.matmul(self.g.t(), self.w2.to(torch.half)))

    def hdc(self, x):
        x = F.max_pool2d(self.relu(self.block1.hdc(x)), (2,2), padding=0)
        x = F.max_pool2d(self.relu(self.block2.hdc(x)), (2,2), padding=0)

        x = x.view(x.size(0), -1)
        x = self.relu(self.block3.hdc(x))

        x = x + self.bias_trick_par
        x = torch.sign(torch.matmul(x.to(torch.half), self.g))

        return x
        
    def classification_layer(self, x):
        x = x @ self.wg
        return x


In [6]:
import torch.optim as optim
import time  # Import time module

train = True
model = Network(input_channels=3, num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
# Trained 100 Epochs with lr=0.025, 50 epochs with 0.005 and 50 epochs with 0.001
optimizer = optim.Adam(model.parameters(), lr=0.02, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = torch.nn.DataParallel(model)

if train:
    
    loss_hist, acc_hist = [], []
    
    best_train_loss = float('inf')
    
    for epoch in range(500):
        start_time = time.time() 
    
        running_loss = 0.0
        correct = 0
        for data in train_loader:
            batch, labels = data
            batch, labels = batch.to(device), labels.to(device)
    
            optimizer.zero_grad()
            outputs = model(batch)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item()
    
        avg_loss = running_loss / len(train_set)
        avg_acc = correct / len(train_set)
        loss_hist.append(avg_loss)
        acc_hist.append(avg_acc)
    
        if avg_loss < best_train_loss:
            best_train_loss = avg_loss
            torch.save(model.state_dict(), 'best_model_cifar10.pt')
    
        scheduler.step(avg_loss)
    
        elapsed_time = time.time() - start_time  # Calculate the time taken for this epoch
    
        print('[epoch %d] loss: %.5f accuracy: %.4f time: %.2f seconds' %
              (epoch + 1, avg_loss, avg_acc, elapsed_time))

[epoch 1] loss: 0.00102 accuracy: 0.2604 time: 8.22 seconds
[epoch 2] loss: 0.00082 accuracy: 0.4116 time: 7.02 seconds
[epoch 3] loss: 0.00074 accuracy: 0.4746 time: 6.89 seconds
[epoch 4] loss: 0.00069 accuracy: 0.5054 time: 6.88 seconds
[epoch 5] loss: 0.00066 accuracy: 0.5349 time: 6.89 seconds
[epoch 6] loss: 0.00062 accuracy: 0.5602 time: 6.90 seconds
[epoch 7] loss: 0.00060 accuracy: 0.5827 time: 6.91 seconds
[epoch 8] loss: 0.00058 accuracy: 0.5907 time: 6.90 seconds
[epoch 9] loss: 0.00057 accuracy: 0.6040 time: 6.92 seconds
[epoch 10] loss: 0.00056 accuracy: 0.6075 time: 7.25 seconds
[epoch 11] loss: 0.00056 accuracy: 0.6124 time: 7.10 seconds
[epoch 12] loss: 0.00054 accuracy: 0.6230 time: 6.87 seconds
[epoch 13] loss: 0.00054 accuracy: 0.6259 time: 6.96 seconds
[epoch 14] loss: 0.00054 accuracy: 0.6251 time: 6.89 seconds
[epoch 15] loss: 0.00053 accuracy: 0.6279 time: 6.93 seconds
[epoch 16] loss: 0.00053 accuracy: 0.6266 time: 6.91 seconds
[epoch 17] loss: 0.00053 accuracy

[epoch 135] loss: 0.00033 accuracy: 0.7837 time: 6.89 seconds
[epoch 136] loss: 0.00033 accuracy: 0.7829 time: 6.92 seconds
[epoch 137] loss: 0.00033 accuracy: 0.7848 time: 6.89 seconds
[epoch 138] loss: 0.00033 accuracy: 0.7853 time: 6.98 seconds
[epoch 139] loss: 0.00033 accuracy: 0.7856 time: 7.00 seconds
[epoch 140] loss: 0.00033 accuracy: 0.7867 time: 6.94 seconds
[epoch 141] loss: 0.00033 accuracy: 0.7885 time: 6.87 seconds
[epoch 142] loss: 0.00033 accuracy: 0.7860 time: 6.91 seconds
[epoch 143] loss: 0.00033 accuracy: 0.7846 time: 6.88 seconds
[epoch 144] loss: 0.00033 accuracy: 0.7855 time: 6.93 seconds
[epoch 145] loss: 0.00033 accuracy: 0.7879 time: 6.90 seconds
[epoch 146] loss: 0.00033 accuracy: 0.7866 time: 6.94 seconds
[epoch 147] loss: 0.00033 accuracy: 0.7868 time: 6.83 seconds
[epoch 148] loss: 0.00033 accuracy: 0.7867 time: 6.84 seconds
[epoch 149] loss: 0.00033 accuracy: 0.7897 time: 6.88 seconds
[epoch 150] loss: 0.00033 accuracy: 0.7895 time: 6.86 seconds
[epoch 1

[epoch 268] loss: 0.00028 accuracy: 0.8305 time: 6.89 seconds
[epoch 269] loss: 0.00028 accuracy: 0.8289 time: 6.86 seconds
[epoch 270] loss: 0.00028 accuracy: 0.8290 time: 6.86 seconds
[epoch 271] loss: 0.00028 accuracy: 0.8304 time: 6.86 seconds
[epoch 272] loss: 0.00028 accuracy: 0.8316 time: 6.83 seconds
[epoch 273] loss: 0.00028 accuracy: 0.8312 time: 6.83 seconds
[epoch 274] loss: 0.00028 accuracy: 0.8328 time: 6.89 seconds
[epoch 275] loss: 0.00028 accuracy: 0.8296 time: 6.89 seconds
[epoch 276] loss: 0.00028 accuracy: 0.8294 time: 6.92 seconds
[epoch 277] loss: 0.00028 accuracy: 0.8323 time: 6.89 seconds
[epoch 278] loss: 0.00028 accuracy: 0.8287 time: 6.93 seconds
[epoch 279] loss: 0.00028 accuracy: 0.8318 time: 6.88 seconds
[epoch 280] loss: 0.00028 accuracy: 0.8318 time: 6.88 seconds
[epoch 281] loss: 0.00028 accuracy: 0.8315 time: 6.89 seconds
[epoch 282] loss: 0.00027 accuracy: 0.8326 time: 6.86 seconds
[epoch 283] loss: 0.00028 accuracy: 0.8300 time: 6.89 seconds
[epoch 2

[epoch 401] loss: 0.00027 accuracy: 0.8406 time: 7.03 seconds
[epoch 402] loss: 0.00027 accuracy: 0.8399 time: 6.89 seconds
[epoch 403] loss: 0.00027 accuracy: 0.8420 time: 6.87 seconds
[epoch 404] loss: 0.00027 accuracy: 0.8401 time: 6.95 seconds
[epoch 405] loss: 0.00027 accuracy: 0.8413 time: 6.97 seconds
[epoch 406] loss: 0.00027 accuracy: 0.8411 time: 6.99 seconds
[epoch 407] loss: 0.00027 accuracy: 0.8424 time: 6.89 seconds
[epoch 408] loss: 0.00027 accuracy: 0.8399 time: 6.93 seconds
[epoch 409] loss: 0.00026 accuracy: 0.8410 time: 6.88 seconds
[epoch 410] loss: 0.00027 accuracy: 0.8393 time: 6.84 seconds
[epoch 411] loss: 0.00027 accuracy: 0.8414 time: 6.94 seconds
[epoch 412] loss: 0.00027 accuracy: 0.8396 time: 6.89 seconds
[epoch 413] loss: 0.00027 accuracy: 0.8423 time: 6.87 seconds
[epoch 414] loss: 0.00027 accuracy: 0.8392 time: 6.94 seconds
[epoch 415] loss: 0.00027 accuracy: 0.8423 time: 6.89 seconds
[epoch 416] loss: 0.00027 accuracy: 0.8420 time: 6.89 seconds
[epoch 4

In [7]:
# Load the best model saved during training
model.load_state_dict(torch.load('best_model_cifar10.pt', weights_only=True))
model.eval()  # Set the model to evaluation mode

test_loss = 0.0
correct_test = 0

test_loader = torch.utils.data.DataLoader(test_set, batch_size=2000, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)


# Evaluate on the test dataset
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        
        # Get predictions and update the correct count
        _, predicted = torch.max(outputs, 1)
        correct_test += (predicted == labels).sum().item()

# Compute average loss and accuracy for the test set
avg_test_loss = test_loss / len(test_set)
avg_test_acc = correct_test / len(test_set)

print(f"Test loss: {avg_test_loss:.5f}, Test accuracy: {avg_test_acc:.4f}")

Test loss: 0.00025, Test accuracy: 0.8485


In [8]:
torch.save(model.state_dict(), 'GNet_Trained_Model_84.59%.pth')