In [1]:
import sys
sys.path.insert(0, './../../../Models')
from sphere_points import generate_points

import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt

import torch
torch.manual_seed(0)
import torch.nn as nn
from tqdm import tqdm
from torch.optim import SGD
from torch.nn.functional import normalize, one_hot
# import torch.nn.functional as F
import torchvision
from torchvision import transforms
# from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

In [2]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
def plot_losses(losses):
    losses = np.array(losses)
    n_dims, epochs = losses.shape
    plt.figure(figsize = (12, 5))
    for l in range(n_dims):
        plt.subplot(1, n_dims, l + 1)
        plt.plot(1 + np.arange(epochs), losses[l])
        plt.title(f"Layer {l + 1} Loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
    plt.tight_layout()
    plt.show()

In [4]:
# model_loss = nn.CrossEntropyLoss()
num_classes = 10

In [5]:
def classifier_head_train(inp_embedding, classifier_weights, labels):
    # model_loss = nn.CrossEntropyLoss()
    inp_embedding = normalize(inp_embedding, p=2, dim=-1)
    classifier_output = torch.mm(inp_embedding, classifier_weights)
    # classifier_output = 1 - (torch.acos(classifier_output)/np.pi)
    classifier_output = classifier_output * one_hot(labels, num_classes = num_classes).type(torch.float32)
    # classifier_output = torch.softmax(classifier_output, dim=-1)
    # loss = model_loss(classifier_output, one_hot(labels, num_classes = num_classes).type(torch.float32))
    theta = 1
    loss = torch.mean(torch.log(2 - (theta * torch.sum(classifier_output,1))))
    return loss

In [6]:
def classifier_head(inp_embedding, classifier_weights, labels):
    inp_embedding = normalize(inp_embedding, p=2, dim=-1)
    classifier_output = torch.mm(inp_embedding, classifier_weights)
    # classifier_output = 1 - (torch.acos(classifier_output)/np.pi)
    # classifier_output = torch.softmax(classifier_output, dim=-1)
    # loss = model_loss(classifier_output, one_hot(labels, num_classes = num_classes).type(torch.float32))
    return torch.argmax(classifier_output, dim=1).tolist()

In [7]:
initial = None
# num_classes = 10

# Data dimension
# (num_data, num_features) => no dimension for batch size please
class Layer(nn.Linear):
    def __init__(self, in_features, out_features, bias, device, lr, apply_dropout=False):
        super().__init__(in_features, out_features, bias, device)
        self.out_features = out_features
        self.bias_flag = bias
        self.lr = lr
        self.num_classes = num_classes
        self.dimension = out_features
        # self.activation = DyT(1)
        # self.activation = nn.LeakyReLU(negative_slope=0.001)
        # self.leaky_relu = nn.LeakyReLU(negative_slope=0.001)
        # self.activation = nn.GELU()
        # self.opt = Adam(self.parameters(), lr = self.lr)
        # self.opt = SGD(self.parameters(), lr=self.lr, momentum=0.9)
        # self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.opt, 
        #                                                       milestones=[60], gamma=0.1)
        # nn.init.kaiming_normal_(self.weight, mode='fan_in')
        
        fc1_limit = np.sqrt(6.0 / in_features)
        torch.nn.init.uniform_(self.weight, a=-fc1_limit, b=fc1_limit)
        
        self.dropout = nn.Dropout(0.1)
        self.apply_dropout = apply_dropout
        global initial

        '''
        self.directions = generate_points(self.num_classes, self.dimension, steps = 10000)
        '''
        self.directions = np.random.randn(self.num_classes, self.dimension)
        initial = np.array(self.directions)
        self.directions = [torch.tensor(t, dtype = torch.float32).to(device) for t in self.directions]
        self.direction_weights = torch.zeros((len(self.directions[0]), len(self.directions)), device=device, 
                                             requires_grad=False)
        for i in range(len(self.directions)):
            self.direction_weights[:, i] = normalize(self.directions[i], p = 2, dim=-1)

    def train(self, x, labels):
        if self.apply_dropout:
            x = self.dropout(x)
        # activation = nn.ELU()
        activation = nn.LeakyReLU(negative_slope=0.001)
        # activation = nn.Tanh()
        # activation = self.activation
        # opt = SGD(self.parameters(), lr=self.lr, momentum=0.9)
        opt = SGD(self.parameters(), lr=self.lr)
        # activation = nn.GELU()
        # activation = nn.ReLU()
        # opt = Adam(self.parameters(), lr=self.lr)
        # print(f"gpu used {torch.cuda.max_memory_allocated(device=None)} memory")
        if self.bias_flag:
            y = activation(torch.mm(x, self.weight.T) + self.bias.unsqueeze(0))
        else:
            y = activation(torch.mm(x, self.weight.T))
        # if self.apply_dropout:
        #     x = self.dropout(x)
        # y = self.forward(x) # shape: (num_data, out_features)
        '''
        y = normalize(y, p = 2, dim = 1)
        '''
        # import pdb;pdb.set_trace()
        loss = classifier_head_train(y, self.direction_weights, labels)
        
        '''
        directions = torch.zeros_like(y)
        for i in range(y.shape[0]):
            directions[i, :] = self.directions[label[i]]
        
        loss = loss_layer(y, directions)
        '''
        opt.zero_grad(set_to_none=True)
        loss.backward(retain_graph = False)
        opt.step()
        # self.scheduler.step()
        
        # normalize the directions
        # self.directions.data = normalize(self.directions.data, p = 2, dim = 1)

        return loss.detach().item(), y
    
    def test(self, x, labels):
        with torch.no_grad():
            # activation = nn.ELU()
            # y = self.forward(x)
            activation = nn.LeakyReLU(negative_slope=0.001)
            # activation = nn.ReLU()
            # activation = nn.GELU()
            # activation = nn.Tanh()
            # activation = self.activation
            if self.bias_flag:
                y = activation(torch.mm(x, self.weight.T) + self.bias.unsqueeze(0))
            else:
                y = activation(torch.mm(x, self.weight.T))
        max_idx_list =[]
        max_idx_list = classifier_head(y, self.direction_weights, labels)
        '''
        for dat in range(y.shape[0]):
            max = -np.inf
            max_idx = 0
            for i in range(self.num_classes):
                cos_sim = cos_similarity(y[dat, :].unsqueeze(0), self.directions[i].reshape(1, -1))
                if cos_sim > max:
                    max = cos_sim
                    max_idx = i
            max_idx_list.append(max_idx)
        '''
        return torch.tensor(max_idx_list, device=device), y

In [8]:
class Net(nn.Module):
    def __init__(self, dims_list, bias, epochs, lr, device):
        super(Net, self).__init__()
        self.dims_list = dims_list
        self.bias = bias
        self.epochs = epochs
        self.lr = lr
        self.device = device
        self.layers = []
        # self.sigmoid = nn.Sigmoid()
        global initial
        for d in range(len(self.dims_list) - 1):
            print(f"Initialization {d + 1} / {len(self.dims_list) - 1}")
            self.layers += [Layer(self.dims_list[d], self.dims_list[d + 1], self.bias, self.device, self.lr)]
            print("Complete\n")
        
    def train(self, train_loader, test_loader):
        layer_loss_list = []
        acc_train = []
        acc_test = []
        layer_w = [[] for _ in range(len(self.dims_list) - 1)]
        
        for i in range(len(self.layers)):
            layer_loss_list.append([])
        pbar = tqdm(total = self.epochs * len(train_loader) * len(self.layers), 
                    desc = f"Training", position = 0, leave = True)
        
        # Test the network
        with torch.no_grad():
           
            for i in range(len(self.layers)):
                layer_w[i].append(torch.norm(self.layers[i].weight, p=2).item())
            acc_train.append(net.test(trainloader))
            acc_test.append(net.test(testloader))
        
        for epoch in range(self.epochs):

            if epoch and not (epoch % 10): 
                # learning rate decay
                for i in range(len(self.layers)):
                    self.layers[i].lr = self.layers[i].lr - 0.1
                    print('lr decreased to ', self.layers[i].lr)
            
            loss_agg = [0] * len(self.layers)
            for dat in train_loader:
                x, label = dat
                x = x.to(device)
                label = label.to(device)
                for i in range(len(self.layers)):
                    
                    loss, y = self.layers[i].train(x, label)
                    self.layers[i].zero_grad(set_to_none=True)
                    x = y.detach()
                    loss_agg[i] += loss / len(train_loader)
                    del y
                    pbar.update(1)
            pbar.set_postfix(epoch = epoch + 1, loss = loss_agg)
            for i in range(len(self.layers)):
                layer_loss_list[i].append(loss_agg[i])
                
        # Test the network
        with torch.no_grad():
            for i in range(len(self.layers)):
                layer_w[i].append(torch.norm(self.layers[i].weight, p=2).item())
            acc_train.append(net.test(train_loader))
            acc_test.append(net.test(test_loader))  
                
        pbar.close()
        return [layer_loss_list, acc_train, acc_test, layer_w]
    
    def test(self, data_loader):
        all_accuracy = []
        correct = [0 for _ in range(len(self.layers))]
        total = [0 for _ in range(len(self.layers))]
        for dat in data_loader:
            x = dat[0]
            label = dat[1]
            x = x.to(device)
            label = label.to(device)
            num = label
            preds = []
            
            for i in range(len(self.layers)):
                pred, x = self.layers[i].test(x, label)
                preds.append(pred)
            
            for i in range(len(preds)):
                correct[i] += (preds[i] == num).sum().item()
                total[i] += num.shape[0]
                
        all_accuracy.append(np.array(correct) / total[-1])
        return all_accuracy

In [9]:
flatten_transform = transforms.Lambda(lambda x: x.view(x.size(0), -1).squeeze())
'''
# Define data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    flatten_transform
])
'''

# # Define data transformations
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.,), (0.5,)),
#     flatten_transform
# ])

transform = transforms.Compose([transforms.ToTensor(), flatten_transform]) # this normalizes to [0,1]

trainset = torchvision.datasets.MNIST(root='./../../../Data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./../../../Data', train=False, download=True, transform=transform)
batch_size = 50
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle = False)

# full_data_set = torch.utils.data.ConcatDataset([trainset, testset])

In [10]:
# ratio = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
ratio = [0.2]
test_acc = []
train_acc = []
w_layers = []
loss = []

num_runs = 5

# for r in ratio:
    # print(f"train ratio: {r}")
'''    
test_size = int(r * len(full_data_set))
train_size = len(full_data_set) - test_size
train_dataset, test_dataset = torch.utils.data.random_split(full_data_set, [train_size, test_size])

# Create DataLoader for training set
trainloader = DataLoader(train_dataset, batch_size = 64, shuffle=True)

# Create DataLoader for test set
testloader = DataLoader(test_dataset, batch_size = 64, shuffle=False)
'''
for _ in range(num_runs):
    
    dims_list = [784, 1024, 10]
    # dims_list = [784, 1000, 34]
    bias = True
    epochs = 200
    lr  = 2.5
    num_classes = 10
    net = Net(dims_list, bias, epochs, lr, device)
    
    # Train the network
    layer_loss_list = net.train(trainloader, testloader)
    
    # plot_losses(layer_loss_list[0])
    
    '''
    # Test the network
    acc_train = net.test(trainloader)
    print(f"Train accuracy: {acc_train * 100:.2f}%")
    
    acc_test = net.test(testloader)
    print(f"Test accuracy: {acc_test * 100:.2f}%")
    '''
    loss.append(layer_loss_list[0])
    train_acc.append(layer_loss_list[1])
    test_acc.append(layer_loss_list[2])
    w_layers.append(layer_loss_list[3])

Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24114/480000 [00:31<08:52, 856.64it/s, epoch=10, loss=[0.3561824440956118, 0.23230

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48143/480000 [01:00<08:33, 840.80it/s, epoch=20, loss=[0.3515272478759292, 0.22684

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72165/480000 [01:29<07:40, 885.29it/s, epoch=30, loss=[0.34952623223265006, 0.2125

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96116/480000 [01:57<07:27, 857.54it/s, epoch=40, loss=[0.3483367256075153, 0.20979

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120100/480000 [02:26<07:49, 766.71it/s, epoch=50, loss=[0.3474910945942006, 0.2077

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144137/480000 [02:55<06:44, 830.28it/s, epoch=60, loss=[0.3466495410352947, 0.2061

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168084/480000 [03:23<06:20, 818.99it/s, epoch=70, loss=[0.3461094703276954, 0.2050

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192129/480000 [03:52<05:33, 864.03it/s, epoch=80, loss=[0.34564033615092427, 0.204

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216105/480000 [04:21<05:12, 843.67it/s, epoch=90, loss=[0.34520834505558, 0.203835

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240125/480000 [04:50<04:39, 858.41it/s, epoch=100, loss=[0.3447985035429403, 0.203

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264097/480000 [05:18<04:12, 854.38it/s, epoch=110, loss=[0.3444007125993567, 0.202

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288131/480000 [05:46<03:39, 872.29it/s, epoch=120, loss=[0.3440093382944667, 0.202

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312097/480000 [06:16<03:21, 833.61it/s, epoch=130, loss=[0.34361588155229944, 0.20

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336152/480000 [06:45<02:56, 814.96it/s, epoch=140, loss=[0.3431191883981227, 0.200

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360101/480000 [07:15<02:41, 740.82it/s, epoch=150, loss=[0.34267425681153896, 0.19

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384165/480000 [07:45<01:53, 847.01it/s, epoch=160, loss=[0.3421789552023015, 0.196

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408091/480000 [08:13<01:28, 812.91it/s, epoch=170, loss=[0.34170375108718937, 0.19

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432091/480000 [08:43<00:59, 809.02it/s, epoch=180, loss=[0.34121362557013885, 0.19

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456144/480000 [09:13<00:31, 769.09it/s, epoch=190, loss=[0.3406966129938758, 0.195

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [09:47<00:00, 816.82it/s, epoch=200, loss=[0.34014328057567267, 0.19


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24131/480000 [00:34<09:28, 801.53it/s, epoch=10, loss=[0.3441522765407955, 0.23816

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48154/480000 [01:07<09:22, 767.86it/s, epoch=20, loss=[0.33956024152537184, 0.2104

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72106/480000 [01:39<08:53, 764.27it/s, epoch=30, loss=[0.3372780714929103, 0.20679

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96112/480000 [02:11<08:58, 713.22it/s, epoch=40, loss=[0.33613164422412745, 0.2035

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120089/480000 [02:42<07:51, 762.85it/s, epoch=50, loss=[0.3353217093646527, 0.2026

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144101/480000 [03:13<06:52, 814.50it/s, epoch=60, loss=[0.3346425771216551, 0.1997

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168088/480000 [03:45<07:15, 715.92it/s, epoch=70, loss=[0.33404268259803466, 0.199

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192079/480000 [04:16<06:17, 762.53it/s, epoch=80, loss=[0.33339400798082314, 0.198

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216083/480000 [04:46<05:30, 798.56it/s, epoch=90, loss=[0.3329100214441613, 0.1973

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240121/480000 [05:16<05:12, 766.66it/s, epoch=100, loss=[0.33250738526384044, 0.19

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264083/480000 [05:45<04:26, 810.96it/s, epoch=110, loss=[0.33212064060072066, 0.19

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288139/480000 [06:15<03:55, 814.84it/s, epoch=120, loss=[0.33174044139683295, 0.19

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312123/480000 [06:45<03:32, 790.30it/s, epoch=130, loss=[0.33136263837416946, 0.19

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336144/480000 [07:15<02:57, 808.67it/s, epoch=140, loss=[0.3309768450508511, 0.195

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360154/480000 [07:44<02:23, 832.30it/s, epoch=150, loss=[0.33058120988309353, 0.19

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384127/480000 [08:14<02:03, 776.72it/s, epoch=160, loss=[0.3301698070019486, 0.194

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408111/480000 [08:43<01:34, 763.20it/s, epoch=170, loss=[0.3297380407154556, 0.194

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432091/480000 [09:14<01:07, 714.98it/s, epoch=180, loss=[0.32925293947259543, 0.19

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456140/480000 [09:45<00:31, 756.28it/s, epoch=190, loss=[0.32872267526884863, 0.19

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [10:19<00:00, 775.26it/s, epoch=200, loss=[0.32811354969938644, 0.19


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24113/480000 [00:34<09:51, 771.28it/s, epoch=10, loss=[0.35487181750436636, 0.2872

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48109/480000 [01:04<08:56, 805.09it/s, epoch=20, loss=[0.3503768819570536, 0.28526

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72141/480000 [01:35<09:11, 740.01it/s, epoch=30, loss=[0.3477618865172068, 0.27593

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96133/480000 [02:05<07:39, 835.59it/s, epoch=40, loss=[0.34644333941241073, 0.2572

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120141/480000 [02:36<07:45, 773.34it/s, epoch=50, loss=[0.3456436532487466, 0.2571

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144116/480000 [03:06<07:14, 773.00it/s, epoch=60, loss=[0.34505395464599115, 0.257

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168102/480000 [03:37<06:55, 750.06it/s, epoch=70, loss=[0.3445688636600976, 0.2553

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192089/480000 [04:07<06:20, 756.61it/s, epoch=80, loss=[0.34415216612319166, 0.254

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216085/480000 [04:38<05:41, 773.43it/s, epoch=90, loss=[0.3437537362178167, 0.2511

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240134/480000 [05:09<05:45, 694.78it/s, epoch=100, loss=[0.343397381231188, 0.2504

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264096/480000 [05:44<05:13, 688.66it/s, epoch=110, loss=[0.34306408400336924, 0.25

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288121/480000 [06:17<05:17, 603.84it/s, epoch=120, loss=[0.3426495249321066, 0.250

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312105/480000 [06:51<04:04, 685.31it/s, epoch=130, loss=[0.34232116344074365, 0.24

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336122/480000 [07:26<03:24, 702.38it/s, epoch=140, loss=[0.34199631700913136, 0.24

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360149/480000 [08:01<02:35, 773.19it/s, epoch=150, loss=[0.3416855694105228, 0.248

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384137/480000 [08:34<02:10, 732.55it/s, epoch=160, loss=[0.34137428847452, 0.24659

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408122/480000 [09:07<01:32, 773.22it/s, epoch=170, loss=[0.3410591076314452, 0.246

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432126/480000 [09:40<01:01, 778.80it/s, epoch=180, loss=[0.34073520968357734, 0.24

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456111/480000 [10:12<00:32, 738.51it/s, epoch=190, loss=[0.3404021689295771, 0.246

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [10:49<00:00, 739.30it/s, epoch=200, loss=[0.34006077103316745, 0.24


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24081/480000 [00:36<10:35, 716.97it/s, epoch=10, loss=[0.3547314678629237, 0.27464

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48113/480000 [01:10<09:59, 721.00it/s, epoch=20, loss=[0.34902509873112025, 0.2664

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72119/480000 [01:44<09:08, 744.26it/s, epoch=30, loss=[0.34685662460823824, 0.2492

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96081/480000 [02:17<08:40, 737.63it/s, epoch=40, loss=[0.3456349778920406, 0.24733

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120097/480000 [02:50<08:18, 721.66it/s, epoch=50, loss=[0.34466944443682856, 0.232

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144145/480000 [03:25<07:30, 745.62it/s, epoch=60, loss=[0.3440371569246058, 0.2281

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168136/480000 [04:00<07:30, 692.10it/s, epoch=70, loss=[0.34350567889710293, 0.226

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192099/480000 [04:34<06:57, 688.89it/s, epoch=80, loss=[0.34304589527348706, 0.225

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216121/480000 [05:07<06:18, 696.43it/s, epoch=90, loss=[0.3426215178271137, 0.2238

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240125/480000 [05:41<05:24, 738.80it/s, epoch=100, loss=[0.3422091176360848, 0.222

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264117/480000 [06:15<04:41, 767.18it/s, epoch=110, loss=[0.34176856438318914, 0.22

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288115/480000 [06:48<04:25, 721.58it/s, epoch=120, loss=[0.34137105380495414, 0.22

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312155/480000 [07:22<03:35, 778.72it/s, epoch=130, loss=[0.340850746830305, 0.2215

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336114/480000 [07:56<03:56, 609.63it/s, epoch=140, loss=[0.3404648749033605, 0.220

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360069/480000 [08:31<02:59, 669.03it/s, epoch=150, loss=[0.3400794092814128, 0.220

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384084/480000 [09:05<02:02, 784.68it/s, epoch=160, loss=[0.3396755628287793, 0.219

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408101/480000 [09:39<01:39, 724.13it/s, epoch=170, loss=[0.33920790709555104, 0.21

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432097/480000 [10:13<01:02, 766.80it/s, epoch=180, loss=[0.33876452095806636, 0.21

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456108/480000 [10:45<00:30, 778.45it/s, epoch=190, loss=[0.33828947968780926, 0.21

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [11:22<00:00, 703.55it/s, epoch=200, loss=[0.33776932500302764, 0.21


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24127/480000 [00:37<11:16, 673.98it/s, epoch=10, loss=[0.35416261459390314, 0.2805

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48101/480000 [01:10<09:43, 740.41it/s, epoch=20, loss=[0.34859205807248866, 0.2596

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72080/480000 [01:45<09:43, 699.61it/s, epoch=30, loss=[0.34659073628485243, 0.2505

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96115/480000 [02:20<08:59, 711.94it/s, epoch=40, loss=[0.3453984050701065, 0.24484

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120092/480000 [02:54<08:55, 672.65it/s, epoch=50, loss=[0.3445744137962654, 0.2442

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144133/480000 [03:26<07:10, 780.57it/s, epoch=60, loss=[0.34391107757886225, 0.241

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168128/480000 [03:58<06:56, 749.41it/s, epoch=70, loss=[0.34341941354175454, 0.238

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192104/480000 [04:30<05:54, 812.82it/s, epoch=80, loss=[0.3429883492241302, 0.2367

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216072/480000 [05:03<06:16, 700.24it/s, epoch=90, loss=[0.3425185816735031, 0.2387

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240098/480000 [05:37<05:22, 743.05it/s, epoch=100, loss=[0.34203396526475704, 0.23

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264065/480000 [06:11<04:42, 765.22it/s, epoch=110, loss=[0.3416070530066896, 0.232

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288093/480000 [06:44<04:19, 738.50it/s, epoch=120, loss=[0.34115083744128555, 0.23

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312142/480000 [07:17<03:50, 727.83it/s, epoch=130, loss=[0.3407800513009231, 0.230

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336116/480000 [07:49<03:15, 737.24it/s, epoch=140, loss=[0.34043524414300996, 0.23

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360101/480000 [08:23<03:07, 638.46it/s, epoch=150, loss=[0.3400971301148337, 0.230

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384066/480000 [08:57<02:35, 618.32it/s, epoch=160, loss=[0.3397557555387423, 0.230

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408137/480000 [09:31<01:31, 781.44it/s, epoch=170, loss=[0.33940594022472675, 0.23

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432100/480000 [10:04<01:02, 765.34it/s, epoch=180, loss=[0.3390451100468638, 0.228

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456131/480000 [10:38<00:31, 748.27it/s, epoch=190, loss=[0.338668550848961, 0.2252

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [11:15<00:00, 710.49it/s, epoch=200, loss=[0.33827623739838586, 0.22


In [11]:
print(test_acc)
print(np.max(test_acc))

[[[array([0.0911, 0.0831])], [array([0.9088, 0.9355])]], [[array([0.1112, 0.0903])], [array([0.9092, 0.9343])]], [[array([0.127 , 0.0365])], [array([0.9095, 0.6759])]], [[array([0.0937, 0.0815])], [array([0.9077, 0.88  ])]], [[array([0.133 , 0.0966])], [array([0.9102, 0.9091])]]]
0.9355


In [12]:
print(train_acc)
print(np.max(train_acc))

[[[array([0.08425   , 0.08618333])], [array([0.91215   , 0.94918333])]], [[array([0.11678333, 0.09033333])], [array([0.91266667, 0.94681667])]], [[array([0.12925   , 0.03968333])], [array([0.91336667, 0.68178333])]], [[array([0.0952, 0.0874])], [array([0.91108333, 0.89215   ])]], [[array([0.13366667, 0.09918333])], [array([0.91345   , 0.91698333])]]]
0.9491833333333334


In [13]:
np.save("./../new_data/ablation_gaussiandir_mnist_train_acc.npy", np.array(train_acc))
np.save("./../new_data/ablation_gaussiandir_mnist_test_acc.npy", np.array(test_acc))