In [10]:
import torch
import torchvision
import torchvision.transforms as transforms
import PIL
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import random
# from Modules import ConvBN, PoolConvBN, PoolLinearBN, SharpCosSim2d, SharpCosSimLinear, LReLU

from ConvBN import ConvBN as ConvBN_BiasTrick
from LinearBN import LinearBN

In [4]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # Normalize with mean 0.5 and std 0.5
])

batch_size= 1500
num_workers=2
pin_memory=True

g = torch.Generator()
g.manual_seed(42)

dataset = torchvision.datasets.MNIST(root='../Data', train=True, download=True, transform=transform)
train_set, val_set = torch.utils.data.random_split(dataset, [58000, 2000])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, generator=g)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

test_set = torchvision.datasets.MNIST(root='../Data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)


In [6]:
if torch.cuda.is_available():
    print("CUDA is available")
else:
    print("CUDA is not available")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

CUDA is available


In [7]:
class TanH(nn.Module):
    def __init__(self):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(10.0)) 
    def forward(self, x):
        return torch.nn.functional.tanh(self.alpha*x)

In [11]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()

        self.conv1_out = 32
        self.conv1_size = 5
        self.conv1_padding = 2


        self.conv2_out = 64
        self.conv2_size = 5
        self.conv2_padding = 2

        self.fc1_out = 512
        self.fc2_out = 10

        self.q = 1e-6
        self.bias_trick_par = nn.Parameter(torch.tensor(0.00005))

        # First Convolutional Block

        self.block1 = ConvBN_BiasTrick(in_channels=1, out_channels=self.conv1_out, kernel_size=self.conv1_size, padding=self.conv1_padding, std = .05, bias_par_init=0.001)
        self.block2 = ConvBN_BiasTrick(in_channels=self.conv1_out, out_channels=self.conv2_out, kernel_size=self.conv2_size, padding=self.conv2_padding, std = .05, bias_par_init=0.01)

        # Second Convolutional Block
       
        self.block3 = LinearBN(in_features = self.conv2_out * (28//2//2) * (28//2//2), out_features=self.fc1_out, std=.3)
        
        
        # torch.manual_seed(0)
        self.w2 = nn.Parameter(torch.randn(self.fc1_out, self.fc2_out))
        nn.init.normal_(self.w2, mean=0.0, std=.6)

        self.dropout = nn.Dropout(0.5)

        self.tanh = TanH()




    def forward(self, x):
        
        x = F.max_pool2d(self.tanh(self.block1(x)), (2,2), padding=0)
        x = F.max_pool2d(self.tanh(self.block2(x)), (2,2), padding=0)
        
        x = x.view(x.size(0), -1)
        
        x = self.tanh(self.block3(x))
        x = self.dropout(x)

        x = x + self.bias_trick_par
        x_norm = x / (x.norm(p=2, dim=1, keepdim=True) + self.q)  # Normalize input x
        w2_norm = self.w2 / (self.w2.norm(p=2, dim=1, keepdim=True) + self.q)  # Normalize weights
        x = torch.matmul(x_norm, w2_norm) # Matrix multiplication 

        # Return raw logits (no softmax here, CrossEntropyLoss handles it)
        return x

    def custom_round(self, n):
        remainder = n % 1000
        base = n - remainder
        if remainder >= 101:
            return base + 1000
        elif remainder <= 100:
            return base
            

    def init_hdc(self, ratio, seed):
        if not isinstance(ratio, (tuple, int)):
            raise TypeError("ratio must be a tuple of size 4 or and integer")

        elif isinstance(ratio, (int)):
            ratio = (ratio, ratio, ratio, ratio)
            
        if not isinstance(seed, (tuple)):
            raise TypeError("seed must be a tuple of size 4")
        
        self.block1.init_hdc(ratio = ratio[0], seed = seed[0])
        self.block2.init_hdc(ratio = ratio[1], seed = seed[1])
        self.block3.init_hdc(ratio = ratio[2], seed = seed[2])
                
        n_last = self.w2.size(0)
        self.nHDC_last = int(self.custom_round(ratio[3] * n_last)) if ratio[3]<1000 else int(ratio[3])
        torch.manual_seed(seed[3])
        self.g = (torch.randn(self.w2.size(0), self.nHDC_last, device=self.w2.device)).to(torch.half)
        self.wg = torch.sign(torch.matmul(self.g.t(), self.w2.to(torch.half)))

    def hdc(self, x):
        x = F.max_pool2d(torch.sign(self.block1.hdc(x)), (2,2), padding=0)
        x = F.max_pool2d(torch.sign(self.block2.hdc(x)), (2,2), padding=0)

        x = x.view(x.size(0), -1)
        x = torch.sign(self.block3.hdc(x))

        x = x + self.bias_trick_par
        x = torch.sign(torch.matmul(x.to(torch.half), self.g))

        return x
        
    def classification_layer(self, x):
        x = x @ self.wg
        return x


In [12]:
from tqdm import tqdm
import time
from torch.nn.parallel import data_parallel
from torch.utils.data import Subset


torch.cuda.empty_cache()
model = Network().to(device)
model.load_state_dict(torch.load('MNIST_GNet_Training_99.15.pth', weights_only = True))



model.to(torch.half).to(device)
model.eval()

n_splits = 20
split_size = len(test_set) // n_splits 

# scale = range
scales = range(1000, 21000, 1000)
accuracies = np.zeros((len(scales), n_splits))
hyperdims = np.zeros((len(scales), 4))
for i, scale in enumerate(scales):
    indices = list(range(len(test_set)))
    # np.random.seed(42)
    # np.random.shuffle(indices)
    for split_idx in tqdm(range(n_splits)):
        start_idx = split_idx * split_size
        end_idx = start_idx + split_size
        split_indices = indices[start_idx:end_idx]
        split_subset = Subset(test_set, split_indices)
        split_loader = torch.utils.data.DataLoader(split_subset, batch_size=3, shuffle=False,
                                                   num_workers=num_workers, pin_memory=pin_memory)
        # ratio = (12, 1.15/6, 3, 18)
        # ratio = tuple(scale * r for r in ratio)
        ratio = scale
        torch.manual_seed(split_idx+4)
        random_seeds = tuple(torch.randint(0, 1000, (1,)).item() for _ in range(4))
        torch.cuda.empty_cache()
        
        model.init_hdc(ratio, random_seeds)
        hyperdims[i] = np.array([model.block1.nHDC, model.block2.nHDC, model.block3.nHDC, model.nHDC_last])
        correct = 0
        total = 0
    
        with torch.no_grad():
            for images, labels in (split_loader):
                images, labels = images.cuda(non_blocking=True), labels.cuda(non_blocking=True)
                output = model.hdc(images.to(torch.half))
                output = model.classification_layer(output.to(torch.half))
                _, predicted = torch.max(output.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
    
        acc = 100 * correct / total
    
        accuracies[i, split_idx] = acc
    
    print(f'Block1: {model.block1.nHDC}, Block2: {model.block2.nHDC}, Block3: {model.block3.nHDC} Classification Layer: {model.nHDC_last}, Average Accuracy: {np.mean(accuracies[i]):.2f}%')


100%|██████████| 20/20 [00:20<00:00,  1.03s/it]


Block1: 1000, Block2: 1000, Block3: 1000 Classification Layer: 1000, Average Accuracy: 47.26%


100%|██████████| 20/20 [00:36<00:00,  1.83s/it]


Block1: 2000, Block2: 2000, Block3: 2000 Classification Layer: 2000, Average Accuracy: 74.18%


100%|██████████| 20/20 [00:53<00:00,  2.66s/it]


Block1: 3000, Block2: 3000, Block3: 3000 Classification Layer: 3000, Average Accuracy: 87.85%


100%|██████████| 20/20 [01:10<00:00,  3.50s/it]


Block1: 4000, Block2: 4000, Block3: 4000 Classification Layer: 4000, Average Accuracy: 92.35%


100%|██████████| 20/20 [01:26<00:00,  4.33s/it]


Block1: 5000, Block2: 5000, Block3: 5000 Classification Layer: 5000, Average Accuracy: 94.22%


100%|██████████| 20/20 [01:43<00:00,  5.17s/it]


Block1: 6000, Block2: 6000, Block3: 6000 Classification Layer: 6000, Average Accuracy: 95.92%


100%|██████████| 20/20 [02:00<00:00,  6.03s/it]


Block1: 7000, Block2: 7000, Block3: 7000 Classification Layer: 7000, Average Accuracy: 96.72%


100%|██████████| 20/20 [02:16<00:00,  6.84s/it]


Block1: 8000, Block2: 8000, Block3: 8000 Classification Layer: 8000, Average Accuracy: 96.91%


100%|██████████| 20/20 [02:34<00:00,  7.70s/it]


Block1: 9000, Block2: 9000, Block3: 9000 Classification Layer: 9000, Average Accuracy: 97.42%


100%|██████████| 20/20 [02:50<00:00,  8.53s/it]


Block1: 10000, Block2: 10000, Block3: 10000 Classification Layer: 10000, Average Accuracy: 97.47%


100%|██████████| 20/20 [03:06<00:00,  9.33s/it]


Block1: 11000, Block2: 11000, Block3: 11000 Classification Layer: 11000, Average Accuracy: 97.59%


100%|██████████| 20/20 [03:23<00:00, 10.19s/it]


Block1: 12000, Block2: 12000, Block3: 12000 Classification Layer: 12000, Average Accuracy: 97.86%


100%|██████████| 20/20 [03:40<00:00, 11.00s/it]


Block1: 13000, Block2: 13000, Block3: 13000 Classification Layer: 13000, Average Accuracy: 97.82%


100%|██████████| 20/20 [03:57<00:00, 11.90s/it]


Block1: 14000, Block2: 14000, Block3: 14000 Classification Layer: 14000, Average Accuracy: 97.94%


100%|██████████| 20/20 [04:14<00:00, 12.72s/it]


Block1: 15000, Block2: 15000, Block3: 15000 Classification Layer: 15000, Average Accuracy: 97.93%


100%|██████████| 20/20 [04:30<00:00, 13.53s/it]


Block1: 16000, Block2: 16000, Block3: 16000 Classification Layer: 16000, Average Accuracy: 98.29%


100%|██████████| 20/20 [04:47<00:00, 14.38s/it]


Block1: 17000, Block2: 17000, Block3: 17000 Classification Layer: 17000, Average Accuracy: 98.12%


100%|██████████| 20/20 [05:03<00:00, 15.17s/it]


Block1: 18000, Block2: 18000, Block3: 18000 Classification Layer: 18000, Average Accuracy: 98.19%


100%|██████████| 20/20 [05:19<00:00, 15.98s/it]


Block1: 19000, Block2: 19000, Block3: 19000 Classification Layer: 19000, Average Accuracy: 98.21%


100%|██████████| 20/20 [05:37<00:00, 16.88s/it]

Block1: 20000, Block2: 20000, Block3: 20000 Classification Layer: 20000, Average Accuracy: 98.33%





In [13]:
print(np.mean(accuracies, axis=1))
print(np.mean(hyperdims, axis=1))

[47.26 74.18 87.85 92.35 94.22 95.92 96.72 96.91 97.42 97.47 97.59 97.86
 97.82 97.94 97.93 98.29 98.12 98.19 98.21 98.33]
[ 1000.  2000.  3000.  4000.  5000.  6000.  7000.  8000.  9000. 10000.
 11000. 12000. 13000. 14000. 15000. 16000. 17000. 18000. 19000. 20000.]


In [14]:
from scipy.io import savemat
savemat('HDCGNet_MNIST.mat', {'HDCGNet_MNIST': accuracies})
savemat('HDCGNet_MNIST_nHDC.mat', {'HDCGNet_MNIST_nHDC': hyperdims})