In [1]:
import sys
sys.path.insert(0, './../../../Models')
from sphere_points import generate_points

import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt

import torch
torch.manual_seed(0)
import torch.nn as nn
from tqdm import tqdm
from torch.optim import SGD
from torch.nn.functional import normalize, one_hot
# import torch.nn.functional as F
import torchvision
from torchvision import transforms
# from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

In [2]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
def plot_losses(losses):
    losses = np.array(losses)
    n_dims, epochs = losses.shape
    plt.figure(figsize = (12, 5))
    for l in range(n_dims):
        plt.subplot(1, n_dims, l + 1)
        plt.plot(1 + np.arange(epochs), losses[l])
        plt.title(f"Layer {l + 1} Loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
    plt.tight_layout()
    plt.show()

In [4]:
# model_loss = nn.CrossEntropyLoss()
num_classes = 10

In [5]:
def classifier_head_train(inp_embedding, classifier_weights, labels):
    # model_loss = nn.CrossEntropyLoss()
    inp_embedding = normalize(inp_embedding, p=2, dim=-1)
    classifier_output = torch.mm(inp_embedding, classifier_weights)
    # classifier_output = 1 - (torch.acos(classifier_output)/np.pi)
    classifier_output = classifier_output * one_hot(labels, num_classes = num_classes).type(torch.float32)
    # classifier_output = torch.softmax(classifier_output, dim=-1)
    # loss = model_loss(classifier_output, one_hot(labels, num_classes = num_classes).type(torch.float32))
    theta = 1
    loss = torch.mean(torch.log(2 - (theta * torch.sum(classifier_output,1))))
    return loss

In [6]:
def classifier_head(inp_embedding, classifier_weights, labels):
    inp_embedding = normalize(inp_embedding, p=2, dim=-1)
    classifier_output = torch.mm(inp_embedding, classifier_weights)
    # classifier_output = 1 - (torch.acos(classifier_output)/np.pi)
    # classifier_output = torch.softmax(classifier_output, dim=-1)
    # loss = model_loss(classifier_output, one_hot(labels, num_classes = num_classes).type(torch.float32))
    return torch.argmax(classifier_output, dim=1).tolist()

In [7]:
initial = None
# num_classes = 10

# Data dimension
# (num_data, num_features) => no dimension for batch size please
class Layer(nn.Linear):
    def __init__(self, in_features, out_features, bias, device, lr, apply_dropout=False):
        super().__init__(in_features, out_features, bias, device)
        self.out_features = out_features
        self.bias_flag = bias
        self.lr = lr
        self.num_classes = num_classes
        self.dimension = out_features
        # self.activation = DyT(1)
        # self.activation = nn.LeakyReLU(negative_slope=0.001)
        # self.leaky_relu = nn.LeakyReLU(negative_slope=0.001)
        # self.activation = nn.GELU()
        # self.opt = Adam(self.parameters(), lr = self.lr)
        # self.opt = SGD(self.parameters(), lr=self.lr, momentum=0.9)
        # self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.opt, 
        #                                                       milestones=[60], gamma=0.1)
        # nn.init.kaiming_normal_(self.weight, mode='fan_in')
        
        fc1_limit = np.sqrt(6.0 / in_features)
        torch.nn.init.uniform_(self.weight, a=-fc1_limit, b=fc1_limit)
        
        self.dropout = nn.Dropout(0.1)
        self.apply_dropout = apply_dropout
        global initial
        self.directions = generate_points(self.num_classes, self.dimension, steps = 10000)
        initial = np.array(self.directions)
        self.directions = [torch.tensor(t, dtype = torch.float32).to(device) for t in self.directions]
        self.direction_weights = torch.zeros((len(self.directions[0]), len(self.directions)), device=device, 
                                             requires_grad=False)
        for i in range(len(self.directions)):
            self.direction_weights[:, i] = normalize(self.directions[i], p = 2, dim=-1)

    def train(self, x, labels):
        if self.apply_dropout:
            x = self.dropout(x)
        # activation = nn.ELU()
        activation = nn.LeakyReLU(negative_slope=0.001)
        # activation = nn.Tanh()
        # activation = self.activation
        # opt = SGD(self.parameters(), lr=self.lr, momentum=0.9)
        opt = SGD(self.parameters(), lr=self.lr)
        # activation = nn.GELU()
        # activation = nn.ReLU()
        # opt = Adam(self.parameters(), lr=self.lr)
        # print(f"gpu used {torch.cuda.max_memory_allocated(device=None)} memory")
        if self.bias_flag:
            y = activation(torch.mm(x, self.weight.T) + self.bias.unsqueeze(0))
        else:
            y = activation(torch.mm(x, self.weight.T))
        # if self.apply_dropout:
        #     x = self.dropout(x)
        # y = self.forward(x) # shape: (num_data, out_features)
        '''
        y = normalize(y, p = 2, dim = 1)
        '''
        # import pdb;pdb.set_trace()
        loss = classifier_head_train(y, self.direction_weights, labels)
        
        '''
        directions = torch.zeros_like(y)
        for i in range(y.shape[0]):
            directions[i, :] = self.directions[label[i]]
        
        loss = loss_layer(y, directions)
        '''
        opt.zero_grad(set_to_none=True)
        loss.backward(retain_graph = False)
        opt.step()
        # self.scheduler.step()
        
        # normalize the directions
        # self.directions.data = normalize(self.directions.data, p = 2, dim = 1)

        return loss.detach().item(), y
    
    def test(self, x, labels):
        with torch.no_grad():
            # activation = nn.ELU()
            # y = self.forward(x)
            activation = nn.LeakyReLU(negative_slope=0.001)
            # activation = nn.ReLU()
            # activation = nn.GELU()
            # activation = nn.Tanh()
            # activation = self.activation
            if self.bias_flag:
                y = activation(torch.mm(x, self.weight.T) + self.bias.unsqueeze(0))
            else:
                y = activation(torch.mm(x, self.weight.T))
        max_idx_list =[]
        max_idx_list = classifier_head(y, self.direction_weights, labels)
        '''
        for dat in range(y.shape[0]):
            max = -np.inf
            max_idx = 0
            for i in range(self.num_classes):
                cos_sim = cos_similarity(y[dat, :].unsqueeze(0), self.directions[i].reshape(1, -1))
                if cos_sim > max:
                    max = cos_sim
                    max_idx = i
            max_idx_list.append(max_idx)
        '''
        return torch.tensor(max_idx_list, device=device), y

In [8]:
class Net(nn.Module):
    def __init__(self, dims_list, bias, epochs, lr, device):
        super(Net, self).__init__()
        self.dims_list = dims_list
        self.bias = bias
        self.epochs = epochs
        self.lr = lr
        self.device = device
        self.layers = []
        # self.sigmoid = nn.Sigmoid()
        global initial
        for d in range(len(self.dims_list) - 1):
            print(f"Initialization {d + 1} / {len(self.dims_list) - 1}")
            self.layers += [Layer(self.dims_list[d], self.dims_list[d + 1], self.bias, self.device, self.lr)]
            print("Complete\n")
        
    def train(self, train_loader, test_loader):
        layer_loss_list = []
        acc_train = []
        acc_test = []
        layer_w = [[] for _ in range(len(self.dims_list) - 1)]
        
        for i in range(len(self.layers)):
            layer_loss_list.append([])
        pbar = tqdm(total = self.epochs * len(train_loader) * len(self.layers), 
                    desc = f"Training", position = 0, leave = True)
        
        # Test the network
        with torch.no_grad():
           
            for i in range(len(self.layers)):
                layer_w[i].append(torch.norm(self.layers[i].weight, p=2).item())
            acc_train.append(net.test(trainloader))
            acc_test.append(net.test(testloader))
        
        for epoch in range(self.epochs):

            if epoch and not (epoch % 10): 
                # learning rate decay
                for i in range(len(self.layers)):
                    self.layers[i].lr = self.layers[i].lr - 0.1
                    print('lr decreased to ', self.layers[i].lr)
            
            loss_agg = [0] * len(self.layers)
            for dat in train_loader:
                x, label = dat
                x = x.to(device)
                label = label.to(device)
                for i in range(len(self.layers)):
                    
                    loss, y = self.layers[i].train(x, label)
                    self.layers[i].zero_grad(set_to_none=True)
                    x = y.detach()
                    loss_agg[i] += loss / len(train_loader)
                    del y
                    pbar.update(1)
            pbar.set_postfix(epoch = epoch + 1, loss = loss_agg)
            for i in range(len(self.layers)):
                layer_loss_list[i].append(loss_agg[i])

            self.layers[i].weight.data = self.layers[i].weight.data.sign()
                
        # Test the network
        with torch.no_grad():
            for i in range(len(self.layers)):
                layer_w[i].append(torch.norm(self.layers[i].weight, p=2).item())
            acc_train.append(net.test(train_loader))
            acc_test.append(net.test(test_loader))  
                
        pbar.close()
        return [layer_loss_list, acc_train, acc_test, layer_w]
    
    def test(self, data_loader):
        all_accuracy = []
        correct = [0 for _ in range(len(self.layers))]
        total = [0 for _ in range(len(self.layers))]
        for dat in data_loader:
            x = dat[0]
            label = dat[1]
            x = x.to(device)
            label = label.to(device)
            num = label
            preds = []
            
            for i in range(len(self.layers)):
                pred, x = self.layers[i].test(x, label)
                preds.append(pred)
            
            for i in range(len(preds)):
                correct[i] += (preds[i] == num).sum().item()
                total[i] += num.shape[0]
                
        all_accuracy.append(np.array(correct) / total[-1])
        return all_accuracy

In [9]:
flatten_transform = transforms.Lambda(lambda x: x.view(x.size(0), -1).squeeze())
'''
# Define data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    flatten_transform
])
'''

# # Define data transformations
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.,), (0.5,)),
#     flatten_transform
# ])

transform = transforms.Compose([transforms.ToTensor(), flatten_transform]) # this normalizes to [0,1]

trainset = torchvision.datasets.MNIST(root='./../../../Data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./../../../Data', train=False, download=True, transform=transform)
batch_size = 50
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle = False)

# full_data_set = torch.utils.data.ConcatDataset([trainset, testset])

In [10]:
# ratio = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
ratio = [0.2]
test_acc = []
train_acc = []
w_layers = []
loss = []

num_runs = 5

# for r in ratio:
    # print(f"train ratio: {r}")
'''    
test_size = int(r * len(full_data_set))
train_size = len(full_data_set) - test_size
train_dataset, test_dataset = torch.utils.data.random_split(full_data_set, [train_size, test_size])

# Create DataLoader for training set
trainloader = DataLoader(train_dataset, batch_size = 64, shuffle=True)

# Create DataLoader for test set
testloader = DataLoader(test_dataset, batch_size = 64, shuffle=False)
'''
for _ in range(num_runs):
    
    dims_list = [784, 1024, 10]
    # dims_list = [784, 1000, 34]
    bias = True
    epochs = 200
    lr  = 2.5
    num_classes = 10
    net = Net(dims_list, bias, epochs, lr, device)
    
    # Train the network
    layer_loss_list = net.train(trainloader, testloader)
    
    # plot_losses(layer_loss_list[0])
    
    '''
    # Test the network
    acc_train = net.test(trainloader)
    print(f"Train accuracy: {acc_train * 100:.2f}%")
    
    acc_test = net.test(testloader)
    print(f"Test accuracy: {acc_test * 100:.2f}%")
    '''
    loss.append(layer_loss_list[0])
    train_acc.append(layer_loss_list[1])
    test_acc.append(layer_loss_list[2])
    w_layers.append(layer_loss_list[3])

Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24099/480000 [00:37<13:01, 583.13it/s, epoch=10, loss=[0.3572706779589254, 0.45695

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48101/480000 [01:09<10:04, 714.35it/s, epoch=20, loss=[0.351431381429235, 0.455149

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72081/480000 [01:44<10:16, 661.44it/s, epoch=30, loss=[0.3492180775851008, 0.45371

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96145/480000 [02:18<08:57, 713.82it/s, epoch=40, loss=[0.348088416457176, 0.452717

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120111/480000 [02:52<08:38, 693.66it/s, epoch=50, loss=[0.34736941496531165, 0.452

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144057/480000 [03:25<08:00, 699.61it/s, epoch=60, loss=[0.3468642653524877, 0.4522

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168111/480000 [03:58<07:16, 714.91it/s, epoch=70, loss=[0.346474691107869, 0.45255

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192156/480000 [04:31<06:02, 794.49it/s, epoch=80, loss=[0.34608132379750395, 0.452

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216100/480000 [05:02<06:56, 633.40it/s, epoch=90, loss=[0.3458164561539885, 0.4528

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240108/480000 [05:35<05:07, 780.84it/s, epoch=100, loss=[0.34557899400591874, 0.45

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264145/480000 [06:08<04:38, 773.99it/s, epoch=110, loss=[0.3453917266180119, 0.453

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288141/480000 [06:40<04:27, 717.15it/s, epoch=120, loss=[0.3452269441882768, 0.453

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312077/480000 [07:13<03:53, 719.05it/s, epoch=130, loss=[0.34507961392402603, 0.45

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336083/480000 [07:47<03:10, 756.91it/s, epoch=140, loss=[0.34494687519967565, 0.45

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360142/480000 [08:20<02:42, 738.96it/s, epoch=150, loss=[0.3448244584600134, 0.454

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384151/480000 [08:53<02:10, 733.09it/s, epoch=160, loss=[0.3447115044047433, 0.454

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408129/480000 [09:28<01:44, 684.70it/s, epoch=170, loss=[0.3446059469630325, 0.454

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432080/480000 [10:02<01:18, 611.08it/s, epoch=180, loss=[0.34450676751633513, 0.45

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456102/480000 [10:36<00:33, 716.15it/s, epoch=190, loss=[0.3444132030258571, 0.455

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [11:13<00:00, 712.52it/s, epoch=200, loss=[0.34432396873831767, 0.45


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24129/480000 [00:38<11:31, 659.39it/s, epoch=10, loss=[0.3568354274084172, 0.31470

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48078/480000 [01:11<09:33, 753.04it/s, epoch=20, loss=[0.35079385953644904, 0.3138

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72111/480000 [01:45<10:44, 632.93it/s, epoch=30, loss=[0.34855132800837324, 0.3137

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96117/480000 [02:17<08:11, 780.44it/s, epoch=40, loss=[0.34724270336329927, 0.3139

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120127/480000 [02:51<08:02, 745.20it/s, epoch=50, loss=[0.3463428369164469, 0.3141

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144109/480000 [03:25<07:42, 725.77it/s, epoch=60, loss=[0.34570891958971883, 0.314

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168089/480000 [03:59<07:48, 665.81it/s, epoch=70, loss=[0.3452851156642037, 0.3144

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192123/480000 [04:32<06:19, 757.77it/s, epoch=80, loss=[0.3449702477951842, 0.3145

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216082/480000 [05:04<05:49, 755.72it/s, epoch=90, loss=[0.3447084694852431, 0.3147

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240103/480000 [05:37<05:24, 739.78it/s, epoch=100, loss=[0.34449077300727293, 0.31

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264150/480000 [06:10<04:38, 774.09it/s, epoch=110, loss=[0.3443062928318977, 0.315

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288091/480000 [06:43<04:12, 760.27it/s, epoch=120, loss=[0.3441442070404686, 0.315

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312095/480000 [07:17<03:56, 709.84it/s, epoch=130, loss=[0.34399872156480943, 0.31

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336122/480000 [07:50<03:06, 773.22it/s, epoch=140, loss=[0.34386688642203767, 0.31

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360145/480000 [08:24<02:40, 748.62it/s, epoch=150, loss=[0.3436622239400944, 0.316

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384113/480000 [08:57<02:38, 604.19it/s, epoch=160, loss=[0.3435461077590786, 0.316

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408082/480000 [09:30<01:41, 711.26it/s, epoch=170, loss=[0.3434389328459898, 0.317

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432149/480000 [10:02<01:01, 774.91it/s, epoch=180, loss=[0.343339508349697, 0.3179

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456155/480000 [10:34<00:30, 776.89it/s, epoch=190, loss=[0.34324545991917443, 0.31

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [11:11<00:00, 715.10it/s, epoch=200, loss=[0.34315620936453345, 0.31


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24050/480000 [00:36<10:31, 721.98it/s, epoch=10, loss=[0.3590212762604159, 0.34555

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48084/480000 [01:11<10:08, 709.72it/s, epoch=20, loss=[0.3525826326012602, 0.34380

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72131/480000 [01:43<09:38, 704.94it/s, epoch=30, loss=[0.34992095487813124, 0.3435

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96145/480000 [02:16<08:44, 732.28it/s, epoch=40, loss=[0.3486275202284262, 0.34309

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120103/480000 [02:48<08:13, 729.04it/s, epoch=50, loss=[0.34791849878927067, 0.342

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144103/480000 [03:22<07:03, 793.52it/s, epoch=60, loss=[0.3473614178349572, 0.3431

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168084/480000 [03:56<08:09, 637.62it/s, epoch=70, loss=[0.3469825396190087, 0.3431

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192123/480000 [04:30<06:39, 720.54it/s, epoch=80, loss=[0.34668162492414323, 0.343

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216107/480000 [05:04<06:33, 671.03it/s, epoch=90, loss=[0.34624188435574416, 0.343

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240093/480000 [05:35<05:17, 756.33it/s, epoch=100, loss=[0.34602617909510947, 0.34

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264091/480000 [06:09<04:34, 785.37it/s, epoch=110, loss=[0.3458432068427407, 0.343

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288077/480000 [06:38<04:06, 778.11it/s, epoch=120, loss=[0.34568338287373435, 0.34

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312129/480000 [07:05<03:15, 857.05it/s, epoch=130, loss=[0.3455400188763935, 0.343

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336117/480000 [07:33<03:00, 797.62it/s, epoch=140, loss=[0.3454095150778693, 0.344

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360095/480000 [08:01<02:27, 813.07it/s, epoch=150, loss=[0.3452888227750853, 0.344

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384109/480000 [08:29<01:49, 879.35it/s, epoch=160, loss=[0.34517125405371174, 0.34

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408142/480000 [08:57<01:22, 872.02it/s, epoch=170, loss=[0.3450658465673526, 0.344

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432099/480000 [09:25<00:55, 870.05it/s, epoch=180, loss=[0.3449684266497688, 0.344

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456095/480000 [09:53<00:27, 873.30it/s, epoch=190, loss=[0.3448771833380062, 0.345

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [10:23<00:00, 769.54it/s, epoch=200, loss=[0.3447903964171809, 0.345


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24102/480000 [00:30<08:50, 859.12it/s, epoch=10, loss=[0.3572113467752932, 0.31244

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48118/480000 [00:58<07:49, 919.92it/s, epoch=20, loss=[0.3525571124255656, 0.31147

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72103/480000 [01:26<07:53, 860.68it/s, epoch=30, loss=[0.3503121778120597, 0.31120

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96141/480000 [01:54<07:22, 866.53it/s, epoch=40, loss=[0.34917403993507257, 0.3111

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120090/480000 [02:22<06:56, 864.72it/s, epoch=50, loss=[0.3484399341543515, 0.3111

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144112/480000 [02:50<06:41, 836.13it/s, epoch=60, loss=[0.34793372586369536, 0.311

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168160/480000 [03:19<06:17, 826.42it/s, epoch=70, loss=[0.34754016615450345, 0.311

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192091/480000 [03:47<05:28, 876.54it/s, epoch=80, loss=[0.34723994158208327, 0.311

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216156/480000 [04:14<05:00, 878.61it/s, epoch=90, loss=[0.3469962798058985, 0.3116

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240109/480000 [04:41<04:34, 875.25it/s, epoch=100, loss=[0.3467909666150806, 0.311

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264115/480000 [05:09<04:05, 878.67it/s, epoch=110, loss=[0.34661248507599046, 0.31

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288121/480000 [05:38<03:51, 828.89it/s, epoch=120, loss=[0.34645517314473806, 0.31

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312105/480000 [06:06<03:25, 818.37it/s, epoch=130, loss=[0.34631439660986324, 0.31

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336125/480000 [06:34<02:50, 841.84it/s, epoch=140, loss=[0.34618577666580685, 0.31

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360090/480000 [07:02<02:16, 876.93it/s, epoch=150, loss=[0.3460680062572159, 0.313

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384137/480000 [07:30<01:56, 821.91it/s, epoch=160, loss=[0.3459583848714833, 0.313

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408152/480000 [07:58<01:17, 921.14it/s, epoch=170, loss=[0.34585662700235836, 0.31

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432120/480000 [08:27<00:59, 803.81it/s, epoch=180, loss=[0.3457606443266076, 0.314

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456150/480000 [08:55<00:29, 804.23it/s, epoch=190, loss=[0.3456698849052191, 0.315

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [09:27<00:00, 846.29it/s, epoch=200, loss=[0.3455846465627354, 0.315


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24101/480000 [00:32<10:03, 755.79it/s, epoch=10, loss=[0.3577173341810702, 0.32332

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48171/480000 [01:00<08:09, 881.46it/s, epoch=20, loss=[0.35192250495155664, 0.3221

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72148/480000 [01:28<07:53, 860.79it/s, epoch=30, loss=[0.34930564043422474, 0.3217

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96121/480000 [01:56<07:31, 849.55it/s, epoch=40, loss=[0.34793980896472926, 0.3217

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120135/480000 [02:25<07:12, 831.20it/s, epoch=50, loss=[0.34708666026592255, 0.321

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144159/480000 [02:52<06:10, 906.30it/s, epoch=60, loss=[0.3465672795474526, 0.3220

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168123/480000 [03:20<06:22, 815.91it/s, epoch=70, loss=[0.346160881544153, 0.32217

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192153/480000 [03:48<05:31, 867.65it/s, epoch=80, loss=[0.34585445508360924, 0.322

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216111/480000 [04:15<04:48, 913.83it/s, epoch=90, loss=[0.34560592693587117, 0.322

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240122/480000 [04:43<04:28, 894.60it/s, epoch=100, loss=[0.345395428240299, 0.3227

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264128/480000 [05:11<03:58, 904.38it/s, epoch=110, loss=[0.34520269138117665, 0.32

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288134/480000 [05:40<03:42, 860.89it/s, epoch=120, loss=[0.34503526220719055, 0.32

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312089/480000 [06:08<03:14, 864.64it/s, epoch=130, loss=[0.3448890287925805, 0.323

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336157/480000 [06:35<02:47, 859.30it/s, epoch=140, loss=[0.344754379366835, 0.3239

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360143/480000 [07:03<02:17, 868.70it/s, epoch=150, loss=[0.3446322319904957, 0.324

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384119/480000 [07:32<01:49, 874.28it/s, epoch=160, loss=[0.34451902699967235, 0.32

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408097/480000 [08:00<01:24, 855.55it/s, epoch=170, loss=[0.34441345209876795, 0.32

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432077/480000 [08:28<00:57, 840.26it/s, epoch=180, loss=[0.34431351656715076, 0.32

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456137/480000 [08:56<00:27, 881.68it/s, epoch=190, loss=[0.3442184956123432, 0.326

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [09:27<00:00, 846.15it/s, epoch=200, loss=[0.34412681033213943, 0.32


In [11]:
print(test_acc)
print(np.max(test_acc))

[[[array([0.0654, 0.1204])], [array([0.9114, 0.5637])]], [[array([0.1205, 0.0536])], [array([0.9097, 0.8979])]], [[array([0.0677, 0.0805])], [array([0.9112, 0.836 ])]], [[array([0.1125, 0.1094])], [array([0.9091, 0.9106])]], [[array([0.1146, 0.1036])], [array([0.9107, 0.7615])]]]
0.9114


In [12]:
print(train_acc)
print(np.max(train_acc))

[[[array([0.06591667, 0.12481667])], [array([0.91386667, 0.56776667])]], [[array([0.11825   , 0.05016667])], [array([0.9128    , 0.90318333])]], [[array([0.06358333, 0.08213333])], [array([0.9137    , 0.83526667])]], [[array([0.11461667, 0.112     ])], [array([0.91305, 0.91135])]], [[array([0.11808333, 0.10226667])], [array([0.91296667, 0.76473333])]]]
0.9138666666666667


In [13]:
np.save("./../new_data/ablation_bnn_mnist_train_acc.npy", np.array(train_acc))
np.save("./../new_data/ablation_bnn_mnist_test_acc.npy", np.array(test_acc))