In [1]:
import sys
sys.path.insert(0, './../../../Models')
from sphere_points import generate_points

import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt

import torch
torch.manual_seed(0)
import torch.nn as nn
from tqdm import tqdm
from torch.optim import SGD
from torch.nn.functional import normalize, one_hot
# import torch.nn.functional as F
import torchvision
from torchvision import transforms
# from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

In [2]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
def plot_losses(losses):
    losses = np.array(losses)
    n_dims, epochs = losses.shape
    plt.figure(figsize = (12, 5))
    for l in range(n_dims):
        plt.subplot(1, n_dims, l + 1)
        plt.plot(1 + np.arange(epochs), losses[l])
        plt.title(f"Layer {l + 1} Loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
    plt.tight_layout()
    plt.show()

In [4]:
# model_loss = nn.CrossEntropyLoss()
num_classes = 10

In [5]:
def classifier_head_train(inp_embedding, classifier_weights, labels):
    # model_loss = nn.CrossEntropyLoss()
    inp_embedding = normalize(inp_embedding, p=2, dim=-1)
    classifier_output = torch.mm(inp_embedding, classifier_weights)
    # classifier_output = 1 - (torch.acos(classifier_output)/np.pi)
    classifier_output = classifier_output * one_hot(labels, num_classes = num_classes).type(torch.float32)
    # classifier_output = torch.softmax(classifier_output, dim=-1)
    # loss = model_loss(classifier_output, one_hot(labels, num_classes = num_classes).type(torch.float32))
    theta = 1
    loss = torch.mean(torch.log(2 - (theta * torch.sum(classifier_output,1))))
    return loss

In [6]:
def classifier_head(inp_embedding, classifier_weights, labels):
    inp_embedding = normalize(inp_embedding, p=2, dim=-1)
    classifier_output = torch.mm(inp_embedding, classifier_weights)
    # classifier_output = 1 - (torch.acos(classifier_output)/np.pi)
    # classifier_output = torch.softmax(classifier_output, dim=-1)
    # loss = model_loss(classifier_output, one_hot(labels, num_classes = num_classes).type(torch.float32))
    return torch.argmax(classifier_output, dim=1).tolist()

In [7]:
initial = None
# num_classes = 10

# Data dimension
# (num_data, num_features) => no dimension for batch size please
class Layer(nn.Linear):
    def __init__(self, in_features, out_features, bias, device, lr, apply_dropout=False):
        super().__init__(in_features, out_features, bias, device)
        self.out_features = out_features
        self.bias_flag = bias
        self.lr = lr
        self.num_classes = num_classes
        self.dimension = out_features
        # self.activation = DyT(1)
        # self.activation = nn.LeakyReLU(negative_slope=0.001)
        # self.leaky_relu = nn.LeakyReLU(negative_slope=0.001)
        # self.activation = nn.GELU()
        # self.opt = Adam(self.parameters(), lr = self.lr)
        # self.opt = SGD(self.parameters(), lr=self.lr, momentum=0.9)
        # self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.opt, 
        #                                                       milestones=[60], gamma=0.1)
        # nn.init.kaiming_normal_(self.weight, mode='fan_in')
        
        fc1_limit = np.sqrt(6.0 / in_features)
        torch.nn.init.uniform_(self.weight, a=-fc1_limit, b=fc1_limit)
        
        self.dropout = nn.Dropout(0.1)
        self.apply_dropout = apply_dropout
        global initial

        '''
        self.directions = generate_points(self.num_classes, self.dimension, steps = 10000)
        '''
        self.directions = np.random.uniform(-1, 1, (self.num_classes, self.dimension))
        initial = np.array(self.directions)
        self.directions = [torch.tensor(t, dtype = torch.float32).to(device) for t in self.directions]
        self.direction_weights = torch.zeros((len(self.directions[0]), len(self.directions)), device=device, 
                                             requires_grad=False)
        for i in range(len(self.directions)):
            self.direction_weights[:, i] = normalize(self.directions[i], p = 2, dim=-1)

    def train(self, x, labels):
        if self.apply_dropout:
            x = self.dropout(x)
        # activation = nn.ELU()
        activation = nn.LeakyReLU(negative_slope=0.001)
        # activation = nn.Tanh()
        # activation = self.activation
        # opt = SGD(self.parameters(), lr=self.lr, momentum=0.9)
        opt = SGD(self.parameters(), lr=self.lr)
        # activation = nn.GELU()
        # activation = nn.ReLU()
        # opt = Adam(self.parameters(), lr=self.lr)
        # print(f"gpu used {torch.cuda.max_memory_allocated(device=None)} memory")
        if self.bias_flag:
            y = activation(torch.mm(x, self.weight.T) + self.bias.unsqueeze(0))
        else:
            y = activation(torch.mm(x, self.weight.T))
        # if self.apply_dropout:
        #     x = self.dropout(x)
        # y = self.forward(x) # shape: (num_data, out_features)
        '''
        y = normalize(y, p = 2, dim = 1)
        '''
        # import pdb;pdb.set_trace()
        loss = classifier_head_train(y, self.direction_weights, labels)
        
        '''
        directions = torch.zeros_like(y)
        for i in range(y.shape[0]):
            directions[i, :] = self.directions[label[i]]
        
        loss = loss_layer(y, directions)
        '''
        opt.zero_grad(set_to_none=True)
        loss.backward(retain_graph = False)
        opt.step()
        # self.scheduler.step()
        
        # normalize the directions
        # self.directions.data = normalize(self.directions.data, p = 2, dim = 1)

        return loss.detach().item(), y
    
    def test(self, x, labels):
        with torch.no_grad():
            # activation = nn.ELU()
            # y = self.forward(x)
            activation = nn.LeakyReLU(negative_slope=0.001)
            # activation = nn.ReLU()
            # activation = nn.GELU()
            # activation = nn.Tanh()
            # activation = self.activation
            if self.bias_flag:
                y = activation(torch.mm(x, self.weight.T) + self.bias.unsqueeze(0))
            else:
                y = activation(torch.mm(x, self.weight.T))
        max_idx_list =[]
        max_idx_list = classifier_head(y, self.direction_weights, labels)
        '''
        for dat in range(y.shape[0]):
            max = -np.inf
            max_idx = 0
            for i in range(self.num_classes):
                cos_sim = cos_similarity(y[dat, :].unsqueeze(0), self.directions[i].reshape(1, -1))
                if cos_sim > max:
                    max = cos_sim
                    max_idx = i
            max_idx_list.append(max_idx)
        '''
        return torch.tensor(max_idx_list, device=device), y

In [8]:
class Net(nn.Module):
    def __init__(self, dims_list, bias, epochs, lr, device):
        super(Net, self).__init__()
        self.dims_list = dims_list
        self.bias = bias
        self.epochs = epochs
        self.lr = lr
        self.device = device
        self.layers = []
        # self.sigmoid = nn.Sigmoid()
        global initial
        for d in range(len(self.dims_list) - 1):
            print(f"Initialization {d + 1} / {len(self.dims_list) - 1}")
            self.layers += [Layer(self.dims_list[d], self.dims_list[d + 1], self.bias, self.device, self.lr)]
            print("Complete\n")
        
    def train(self, train_loader, test_loader):
        layer_loss_list = []
        acc_train = []
        acc_test = []
        layer_w = [[] for _ in range(len(self.dims_list) - 1)]
        
        for i in range(len(self.layers)):
            layer_loss_list.append([])
        pbar = tqdm(total = self.epochs * len(train_loader) * len(self.layers), 
                    desc = f"Training", position = 0, leave = True)
        
        # Test the network
        with torch.no_grad():
           
            for i in range(len(self.layers)):
                layer_w[i].append(torch.norm(self.layers[i].weight, p=2).item())
            acc_train.append(net.test(trainloader))
            acc_test.append(net.test(testloader))
        
        for epoch in range(self.epochs):

            if epoch and not (epoch % 10): 
                # learning rate decay
                for i in range(len(self.layers)):
                    self.layers[i].lr = self.layers[i].lr - 0.1
                    print('lr decreased to ', self.layers[i].lr)
            
            loss_agg = [0] * len(self.layers)
            for dat in train_loader:
                x, label = dat
                x = x.to(device)
                label = label.to(device)
                for i in range(len(self.layers)):
                    
                    loss, y = self.layers[i].train(x, label)
                    self.layers[i].zero_grad(set_to_none=True)
                    x = y.detach()
                    loss_agg[i] += loss / len(train_loader)
                    del y
                    pbar.update(1)
            pbar.set_postfix(epoch = epoch + 1, loss = loss_agg)
            for i in range(len(self.layers)):
                layer_loss_list[i].append(loss_agg[i])
                
        # Test the network
        with torch.no_grad():
            for i in range(len(self.layers)):
                layer_w[i].append(torch.norm(self.layers[i].weight, p=2).item())
            acc_train.append(net.test(train_loader))
            acc_test.append(net.test(test_loader))  
                
        pbar.close()
        return [layer_loss_list, acc_train, acc_test, layer_w]
    
    def test(self, data_loader):
        all_accuracy = []
        correct = [0 for _ in range(len(self.layers))]
        total = [0 for _ in range(len(self.layers))]
        for dat in data_loader:
            x = dat[0]
            label = dat[1]
            x = x.to(device)
            label = label.to(device)
            num = label
            preds = []
            
            for i in range(len(self.layers)):
                pred, x = self.layers[i].test(x, label)
                preds.append(pred)
            
            for i in range(len(preds)):
                correct[i] += (preds[i] == num).sum().item()
                total[i] += num.shape[0]
                
        all_accuracy.append(np.array(correct) / total[-1])
        return all_accuracy

In [9]:
flatten_transform = transforms.Lambda(lambda x: x.view(x.size(0), -1).squeeze())
'''
# Define data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    flatten_transform
])
'''

# # Define data transformations
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.,), (0.5,)),
#     flatten_transform
# ])

transform = transforms.Compose([transforms.ToTensor(), flatten_transform]) # this normalizes to [0,1]

trainset = torchvision.datasets.MNIST(root='./../../../Data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./../../../Data', train=False, download=True, transform=transform)
batch_size = 50
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle = False)

# full_data_set = torch.utils.data.ConcatDataset([trainset, testset])

In [10]:
# ratio = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
ratio = [0.2]
test_acc = []
train_acc = []
w_layers = []
loss = []

num_runs = 5

# for r in ratio:
    # print(f"train ratio: {r}")
'''    
test_size = int(r * len(full_data_set))
train_size = len(full_data_set) - test_size
train_dataset, test_dataset = torch.utils.data.random_split(full_data_set, [train_size, test_size])

# Create DataLoader for training set
trainloader = DataLoader(train_dataset, batch_size = 64, shuffle=True)

# Create DataLoader for test set
testloader = DataLoader(test_dataset, batch_size = 64, shuffle=False)
'''
for _ in range(num_runs):
    
    dims_list = [784, 1024, 10]
    # dims_list = [784, 1000, 34]
    bias = True
    epochs = 200
    lr  = 2.5
    num_classes = 10
    net = Net(dims_list, bias, epochs, lr, device)
    
    # Train the network
    layer_loss_list = net.train(trainloader, testloader)
    
    # plot_losses(layer_loss_list[0])
    
    '''
    # Test the network
    acc_train = net.test(trainloader)
    print(f"Train accuracy: {acc_train * 100:.2f}%")
    
    acc_test = net.test(testloader)
    print(f"Test accuracy: {acc_test * 100:.2f}%")
    '''
    loss.append(layer_loss_list[0])
    train_acc.append(layer_loss_list[1])
    test_acc.append(layer_loss_list[2])
    w_layers.append(layer_loss_list[3])

Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24086/480000 [00:33<10:27, 726.28it/s, epoch=10, loss=[0.35788345019022655, 0.3422

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48114/480000 [01:04<08:42, 826.21it/s, epoch=20, loss=[0.3524727308501803, 0.31056

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72044/480000 [01:34<09:24, 722.16it/s, epoch=30, loss=[0.3504557439188163, 0.31287

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96143/480000 [02:07<08:26, 757.95it/s, epoch=40, loss=[0.3491389197111127, 0.31286

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120111/480000 [02:38<07:28, 803.14it/s, epoch=50, loss=[0.34822310619056207, 0.300

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144119/480000 [03:09<07:44, 723.70it/s, epoch=60, loss=[0.3475761376569666, 0.3002

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168157/480000 [03:41<06:17, 825.58it/s, epoch=70, loss=[0.34707773951192755, 0.298

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192117/480000 [04:11<05:48, 826.13it/s, epoch=80, loss=[0.34667658708989674, 0.295

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216149/480000 [04:42<05:39, 776.86it/s, epoch=90, loss=[0.34632001218696407, 0.296

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240123/480000 [05:13<05:04, 787.90it/s, epoch=100, loss=[0.3459960677226385, 0.294

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264149/480000 [05:43<04:31, 796.43it/s, epoch=110, loss=[0.34568647406995273, 0.29

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288089/480000 [06:13<03:48, 839.60it/s, epoch=120, loss=[0.3454021844267851, 0.292

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312085/480000 [06:42<03:45, 744.57it/s, epoch=130, loss=[0.3451333628594881, 0.294

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336164/480000 [07:12<02:53, 829.64it/s, epoch=140, loss=[0.34487355719010104, 0.29

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360133/480000 [07:41<02:27, 810.93it/s, epoch=150, loss=[0.3446144264688095, 0.291

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384087/480000 [08:11<01:54, 840.36it/s, epoch=160, loss=[0.3443574448923269, 0.289

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408137/480000 [08:40<01:25, 840.76it/s, epoch=170, loss=[0.3441100503007577, 0.289

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432104/480000 [09:10<00:57, 832.24it/s, epoch=180, loss=[0.3438661649823189, 0.289

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456073/480000 [09:39<00:29, 799.62it/s, epoch=190, loss=[0.3436238668113943, 0.288

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [10:12<00:00, 784.06it/s, epoch=200, loss=[0.3433839780340595, 0.288


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24083/480000 [00:34<09:53, 768.06it/s, epoch=10, loss=[0.3565834158410635, 0.33499

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48133/480000 [01:04<08:56, 804.68it/s, epoch=20, loss=[0.3512010189642505, 0.33089

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72151/480000 [01:35<08:18, 818.45it/s, epoch=30, loss=[0.34925895114739725, 0.3288

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96152/480000 [02:05<07:54, 809.30it/s, epoch=40, loss=[0.34817848113675925, 0.3278

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120128/480000 [02:35<07:55, 756.36it/s, epoch=50, loss=[0.34744269122679977, 0.326

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144139/480000 [03:06<06:44, 830.64it/s, epoch=60, loss=[0.34688511349260787, 0.324

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168137/480000 [03:36<06:39, 781.13it/s, epoch=70, loss=[0.3464082166055845, 0.3242

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192084/480000 [04:05<05:50, 820.68it/s, epoch=80, loss=[0.3459998751183352, 0.3182

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216152/480000 [04:36<05:16, 833.29it/s, epoch=90, loss=[0.34553507305681713, 0.312

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240157/480000 [05:07<05:11, 769.11it/s, epoch=100, loss=[0.34515724254151176, 0.30

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264078/480000 [05:37<04:33, 789.97it/s, epoch=110, loss=[0.3448592361062762, 0.303

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288116/480000 [06:08<03:52, 825.20it/s, epoch=120, loss=[0.34443923289577216, 0.30

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312122/480000 [06:39<03:51, 724.09it/s, epoch=130, loss=[0.3441334551821151, 0.300

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336118/480000 [07:13<03:12, 748.81it/s, epoch=140, loss=[0.34387163947025984, 0.29

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360068/480000 [07:46<02:39, 750.40it/s, epoch=150, loss=[0.343614449004332, 0.3016

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384106/480000 [08:21<02:16, 704.48it/s, epoch=160, loss=[0.3433667075633999, 0.298

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408143/480000 [08:54<01:34, 761.77it/s, epoch=170, loss=[0.3431278044730425, 0.296

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432099/480000 [09:28<01:07, 709.98it/s, epoch=180, loss=[0.3428915406266851, 0.296

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456101/480000 [10:01<00:30, 771.53it/s, epoch=190, loss=[0.3426565447201331, 0.297

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [10:38<00:00, 751.28it/s, epoch=200, loss=[0.3424243593961003, 0.295


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24063/480000 [00:35<12:11, 623.17it/s, epoch=10, loss=[0.3561248213549452, 0.30807

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48145/480000 [01:07<09:30, 756.52it/s, epoch=20, loss=[0.35120722825328526, 0.2882

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72142/480000 [01:39<09:29, 715.94it/s, epoch=30, loss=[0.3488254377742608, 0.27918

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96084/480000 [02:12<09:59, 640.55it/s, epoch=40, loss=[0.3476698582122725, 0.25840

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120106/480000 [02:46<08:41, 690.31it/s, epoch=50, loss=[0.34690781523784014, 0.251

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144114/480000 [03:19<07:58, 701.61it/s, epoch=60, loss=[0.3461419993390638, 0.2496

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168131/480000 [03:52<06:43, 771.97it/s, epoch=70, loss=[0.3456621850033602, 0.2434

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192113/480000 [04:25<06:07, 783.66it/s, epoch=80, loss=[0.3451101167003314, 0.2416

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216081/480000 [04:59<06:04, 724.06it/s, epoch=90, loss=[0.34474378059307725, 0.242

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240130/480000 [05:33<05:51, 682.65it/s, epoch=100, loss=[0.34443402086695046, 0.23

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264059/480000 [06:07<05:02, 714.96it/s, epoch=110, loss=[0.3441490836938224, 0.238

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288131/480000 [06:42<04:10, 766.83it/s, epoch=120, loss=[0.34388008331259085, 0.23

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312110/480000 [07:15<04:01, 694.53it/s, epoch=130, loss=[0.3436190708974993, 0.236

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336137/480000 [07:49<03:04, 780.74it/s, epoch=140, loss=[0.3433683648457127, 0.235

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360093/480000 [08:22<02:45, 723.48it/s, epoch=150, loss=[0.34312401754160693, 0.23

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384153/480000 [08:55<02:03, 779.09it/s, epoch=160, loss=[0.34288200723628215, 0.23

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408139/480000 [09:29<01:44, 690.07it/s, epoch=170, loss=[0.34264536408086627, 0.23

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432066/480000 [10:04<01:05, 733.70it/s, epoch=180, loss=[0.3424111030002441, 0.232

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456100/480000 [10:37<00:31, 757.77it/s, epoch=190, loss=[0.3421779735138021, 0.229

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [11:15<00:00, 711.08it/s, epoch=200, loss=[0.3419469749182463, 0.229


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24073/480000 [00:37<10:56, 694.92it/s, epoch=10, loss=[0.3529360722005368, 0.28089

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48069/480000 [01:09<09:23, 767.12it/s, epoch=20, loss=[0.3468556572000184, 0.27328

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72147/480000 [01:41<08:56, 760.88it/s, epoch=30, loss=[0.3448295316348472, 0.26955

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96129/480000 [02:14<08:54, 718.49it/s, epoch=40, loss=[0.34360520102083675, 0.2668

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120119/480000 [02:46<08:41, 690.07it/s, epoch=50, loss=[0.3425713525960845, 0.2632

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144079/480000 [03:21<07:50, 714.03it/s, epoch=60, loss=[0.3418348341683542, 0.2627

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168151/480000 [03:54<06:42, 774.81it/s, epoch=70, loss=[0.3411261147260667, 0.2610

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192057/480000 [04:28<06:29, 739.26it/s, epoch=80, loss=[0.34066704906523243, 0.263

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216115/480000 [05:01<06:21, 692.24it/s, epoch=90, loss=[0.34010954163968543, 0.261

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240087/480000 [05:33<05:06, 782.18it/s, epoch=100, loss=[0.33968554399907597, 0.26

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264099/480000 [06:05<04:33, 790.58it/s, epoch=110, loss=[0.3393773475040996, 0.258

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288105/480000 [06:37<04:15, 750.10it/s, epoch=120, loss=[0.3390921676407263, 0.261

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312111/480000 [07:10<03:47, 737.52it/s, epoch=130, loss=[0.3388180109113457, 0.256

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336083/480000 [07:44<03:24, 704.57it/s, epoch=140, loss=[0.33854326749841357, 0.25

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360101/480000 [08:17<02:31, 789.11it/s, epoch=150, loss=[0.33827827066183114, 0.25

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384134/480000 [08:49<02:05, 765.00it/s, epoch=160, loss=[0.3380192431559169, 0.247

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408060/480000 [09:22<01:43, 693.86it/s, epoch=170, loss=[0.33776336995263867, 0.24

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432109/480000 [09:56<01:14, 645.48it/s, epoch=180, loss=[0.3375079656392337, 0.241

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456143/480000 [10:29<00:33, 709.61it/s, epoch=190, loss=[0.3372529566039641, 0.240

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [11:07<00:00, 719.24it/s, epoch=200, loss=[0.3369972040007513, 0.238


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24135/480000 [00:37<10:41, 710.13it/s, epoch=10, loss=[0.3543235517044862, 0.30740

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48097/480000 [01:10<09:24, 764.47it/s, epoch=20, loss=[0.34903745469947683, 0.2978

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72134/480000 [01:43<09:07, 744.86it/s, epoch=30, loss=[0.34710354134440435, 0.2947

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96157/480000 [02:13<07:24, 863.87it/s, epoch=40, loss=[0.3458327526599167, 0.29077

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120143/480000 [02:41<06:54, 868.23it/s, epoch=50, loss=[0.3449544765303535, 0.2876

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144163/480000 [03:09<06:26, 867.91it/s, epoch=60, loss=[0.34430335755149505, 0.287

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168107/480000 [03:36<05:57, 871.83it/s, epoch=70, loss=[0.34379795325299045, 0.285

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192141/480000 [04:04<05:33, 864.27it/s, epoch=80, loss=[0.34323552300532695, 0.283

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216154/480000 [04:32<05:04, 867.56it/s, epoch=90, loss=[0.34284157104790164, 0.265

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240103/480000 [05:00<04:36, 868.58it/s, epoch=100, loss=[0.3424764965474612, 0.265

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264173/480000 [05:28<04:08, 868.30it/s, epoch=110, loss=[0.34210164763033407, 0.26

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288135/480000 [05:55<03:36, 884.57it/s, epoch=120, loss=[0.34177194423973556, 0.26

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312153/480000 [06:23<03:22, 829.69it/s, epoch=130, loss=[0.34144400934378255, 0.26

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336113/480000 [06:51<03:06, 769.52it/s, epoch=140, loss=[0.34111632801592373, 0.26

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360163/480000 [07:19<02:25, 826.33it/s, epoch=150, loss=[0.3407441272834936, 0.258

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384131/480000 [07:47<01:56, 825.73it/s, epoch=160, loss=[0.34036813015739087, 0.25

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408117/480000 [08:15<01:21, 877.25it/s, epoch=170, loss=[0.34000825673341734, 0.25

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432141/480000 [08:43<00:55, 859.24it/s, epoch=180, loss=[0.3396662204215926, 0.256

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456123/480000 [09:12<00:28, 846.30it/s, epoch=190, loss=[0.33931695672372975, 0.25

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [09:43<00:00, 822.13it/s, epoch=200, loss=[0.33895663052797337, 0.25


In [11]:
print(test_acc)
print(np.max(test_acc))

[[[array([0.0894, 0.0981])], [array([0.9033, 0.8394])]], [[array([0.1275, 0.177 ])], [array([0.9051, 0.8074])]], [[array([0.1036, 0.1147])], [array([0.9027, 0.895 ])]], [[array([0.0784, 0.0837])], [array([0.9024, 0.8998])]], [[array([0.0982, 0.1147])], [array([0.9001, 0.7645])]]]
0.9051


In [12]:
print(train_acc)
print(np.max(train_acc))

[[[array([0.07913333, 0.10041667])], [array([0.90551667, 0.85211667])]], [[array([0.12671667, 0.17183333])], [array([0.90648333, 0.81748333])]], [[array([0.1036, 0.1175])], [array([0.90633333, 0.90271667])]], [[array([0.07768333, 0.07968333])], [array([0.90443333, 0.9047    ])]], [[array([0.09081667, 0.11583333])], [array([0.90355   , 0.76553333])]]]
0.9064833333333333


In [13]:
np.save("./../new_data/ablation_uniformdir_mnist_train_acc.npy", np.array(train_acc))
np.save("./../new_data/ablation_uniformdir_mnist_test_acc.npy", np.array(test_acc))