In [1]:
import sys
sys.path.insert(0, './../../../Models')
from sphere_points import generate_points

import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt

import torch
torch.manual_seed(0)
import torch.nn as nn
from tqdm import tqdm
from torch.optim import SGD
from torch.nn.functional import normalize
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

In [2]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
def plot_losses(losses):
    losses = np.array(losses)
    n_dims, epochs = losses.shape
    plt.figure(figsize = (12, 5))
    for l in range(n_dims):
        plt.subplot(1, n_dims, l + 1)
        plt.plot(1 + np.arange(epochs), losses[l])
        plt.title(f"Layer {l + 1} Loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
    plt.tight_layout()
    plt.show()

In [4]:
theta = 1
'''
cos_similarity = nn.CosineSimilarity(dim = 1, eps = 1e-6)
'''
mse_loss = nn.MSELoss()

'''
def loss_layer(vec1, vec2):
    return torch.mean(torch.log(2 - (theta * cos_similarity(vec1, vec2))), dim = 0)
'''
# def loss_layer(vec1, vec2):
    # return torch.mean((theta * (torch.sum(torch.pow(torch.subtract(vec1, vec2), 2), dim=1))), dim = 0)

def loss_layer(vec1, vec2):
    vec1 = normalize(vec1, p = 2, dim = 1)
    vec2 = normalize(vec2, p = 2, dim = 1)
    return torch.mean(torch.log(1 + (theta * torch.sqrt(torch.sum(torch.pow(torch.subtract(vec1, vec2), 2), dim=1)))), dim = 0)

initial = None
num_classes = 10

# Data dimension
# (num_data, num_features) => no dimension for batch size please
class Layer(nn.Linear):
    def __init__(self, in_features, out_features, bias, device, lr, apply_dropout=False):
        super().__init__(in_features, out_features, bias, device)
        self.out_features = out_features
        self.bias_flag = bias
        self.lr = lr
        self.num_classes = num_classes
        self.dimension = out_features
        # self.leaky_relu = nn.LeakyReLU(negative_slope=0.001)
        # self.opt = Adam(self.parameters(), lr = self.lr)
        # self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.opt,
        #                                                       milestones=[60], gamma=0.1)
        # nn.init.kaiming_normal_(self.weight, mode='fan_in')
        fc1_limit = np.sqrt(6.0 / in_features)
        torch.nn.init.uniform_(self.weight, a=-fc1_limit, b=fc1_limit)
        self.dropout = nn.Dropout(0.1)
        self.apply_dropout = apply_dropout
        global initial
        self.directions = generate_points(self.num_classes, self.dimension, steps = 10000)
        initial = np.array(self.directions)
        self.directions = [torch.tensor(t, dtype = torch.float32).to(device) for t in self.directions]

    def train(self, x, label):
        if self.apply_dropout:
            x = self.dropout(x)
        y = self.forward(x) # shape: (num_data, out_features)
        # y = normalize(y, p = 2, dim = 1)
        activation = nn.LeakyReLU(negative_slope=0.001)
        opt = SGD(self.parameters(), lr=self.lr)
        if self.bias_flag:
            y = activation(torch.mm(x, self.weight.T) + self.bias.unsqueeze(0))
        else:
            y = activation(torch.mm(x, self.weight.T))

        directions = torch.zeros_like(y)
        for i in range(y.shape[0]):
            directions[i, :] = self.directions[label[i]]

        loss = loss_layer(y, directions)
        opt.zero_grad()
        loss.backward(retain_graph = True)
        opt.step()
        # self.scheduler.step()

        # normalize the directions
        # self.directions.data = normalize(self.directions.data, p = 2, dim = 1)

        return loss.detach().item(), y

    def test(self, x):
        with torch.no_grad():
            # y = self.forward(x)
            activation = nn.LeakyReLU(negative_slope=0.001)
            if self.bias_flag:
                y = activation(torch.mm(x, self.weight.T) + self.bias.unsqueeze(0))
            else:
                y = activation(torch.mm(x, self.weight.T))
        max_idx_list =[]
        for dat in range(y.shape[0]):
            max = np.inf
            max_idx = 0
            for i in range(self.num_classes):
                # euc_dist = torch.mean(torch.norm(y[dat, :].unsqueeze(0) - self.directions[i].reshape(1, -1), dim=1), dim=0)
                # euc_dist = torch.sum(torch.pow(torch.subtract(y[dat, :].unsqueeze(0), self.directions[i].reshape(1, -1)), 2), dim=1)
                euc_dist = torch.sqrt(torch.sum(torch.pow(torch.subtract(y[dat, :], self.directions[i]), 2)))
                if euc_dist < max:
                    max = euc_dist
                    max_idx = i
                '''
                cos_sim = cos_similarity(y[dat, :].unsqueeze(0), self.directions[i].reshape(1, -1))
                if cos_sim > max:
                    max = cos_sim
                    max_idx = i
                '''
            max_idx_list.append(max_idx)
        return torch.tensor(max_idx_list, device=device), y

In [5]:
class Net(nn.Module):
    def __init__(self, dims_list, bias, epochs, lr, device):
        super(Net, self).__init__()
        self.dims_list = dims_list
        self.bias = bias
        self.epochs = epochs
        self.lr = lr
        self.device = device
        self.layers = []
        # self.sigmoid = nn.Sigmoid()
        global initial
        for d in range(len(self.dims_list) - 1):
            print(f"Initialization {d + 1} / {len(self.dims_list) - 1}")
            self.layers += [Layer(self.dims_list[d], self.dims_list[d + 1], self.bias, self.device, self.lr)]
            print("Complete\n")

    def train(self, train_loader, test_loader):
        layer_loss_list = []
        acc_train = []
        acc_test = []
        layer_w = [[] for _ in range(len(self.dims_list) - 1)]

        for i in range(len(self.layers)):
            layer_loss_list.append([])
        pbar = tqdm(total = self.epochs * len(train_loader) * len(self.layers),
                    desc = f"Training", position = 0, leave = True)
        '''
        # Test the network
        with torch.no_grad():

            for i in range(len(self.layers)):
                layer_w[i].append(torch.norm(self.layers[i].weight, p=2).item())
            acc_train.append(net.test(trainloader))
            acc_test.append(net.test(testloader))
        '''

        for epoch in range(self.epochs):

            if epoch and not (epoch % 10): 
                # learning rate decay
                for i in range(len(self.layers)):
                    self.layers[i].lr = self.layers[i].lr - 0.1
                    print('lr decreased to ', self.layers[i].lr)

            loss_agg = [0] * len(self.layers)
            for dat in train_loader:
                x, label = dat
                x = x.to(device)
                label = label.to(device)
                for i in range(len(self.layers)):
                    loss, y = self.layers[i].train(x, label)
                    self.layers[i].zero_grad(set_to_none=True)
                    x = y.detach()
                    loss_agg[i] += loss / len(train_loader)
                    del y
                    pbar.update(1)
            pbar.set_postfix(epoch = epoch + 1, loss = loss_agg)
            for i in range(len(self.layers)):
                layer_loss_list[i].append(loss_agg[i])

        # Test the network
        with torch.no_grad():
            for i in range(len(self.layers)):
                layer_w[i].append(torch.norm(self.layers[i].weight, p=2).item())
            acc_train.append(net.test(trainloader))
            acc_test.append(net.test(testloader))

        pbar.close()
        return [layer_loss_list, acc_train, acc_test, layer_w]

    def test(self, data_loader):
        all_accuracy = []
        correct = [0 for _ in range(len(self.layers))]
        total = [0 for _ in range(len(self.layers))]
        for dat in data_loader:
            x = dat[0]
            label = dat[1]
            x = x.to(device)
            label = label.to(device)
            num = label
            preds = []
            for i in range(len(self.layers)):
                pred, x = self.layers[i].test(x)
                preds.append(pred)

            for i in range(len(preds)):
                correct[i] += (preds[i] == num).sum().item()
                total[i] += num.shape[0]

        all_accuracy.append(np.array(correct) / total[-1])
        return all_accuracy

In [6]:
flatten_transform = transforms.Lambda(lambda x: x.view(x.size(0), -1).squeeze())

# # Define data transformations
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.,), (0.5,)),
#     flatten_transform
# ])
transform = transforms.Compose([transforms.ToTensor(), flatten_transform]) # this normalizes to [0,1]
trainset = torchvision.datasets.MNIST(root='./../../Data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./../../Data', train=False, download=True, transform=transform)
batch_size = 50
trainloader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, shuffle = True)
testloader = torch.utils.data.DataLoader(testset, batch_size = batch_size, shuffle = False)


# full_data_set = torch.utils.data.ConcatDataset([trainset, testset])

In [7]:
# ratio = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
ratio = [0.2]
test_acc = []
train_acc = []
w_norm = []

num_runs = 5

# for r in ratio:
    # print(f"train ratio: {r}")
'''
test_size = int(r * len(full_data_set))
train_size = len(full_data_set) - test_size
train_dataset, test_dataset = torch.utils.data.random_split(full_data_set, [train_size, test_size])

# Create DataLoader for training set
trainloader = DataLoader(train_dataset, batch_size = 64, shuffle=True)

# Create DataLoader for test set
testloader = DataLoader(test_dataset, batch_size = 64, shuffle=False)
'''
for _ in range(num_runs):

    dims_list = [784, 1024, 10]
    bias = True # False
    epochs = 200
    lr = 2.5
    num_classes = 10
    net = Net(dims_list, bias, epochs, lr, device)

    # Train the network
    layer_loss_list = net.train(trainloader, testloader)
    '''
    plot_losses(layer_loss_list[0])
    '''

    '''
    # Test the network
    acc_train = net.test(trainloader)
    print(f"Train accuracy: {acc_train * 100:.2f}%")

    acc_test = net.test(testloader)
    print(f"Test accuracy: {acc_test * 100:.2f}%")
    '''

    train_acc.append(layer_loss_list[1])
    test_acc.append(layer_loss_list[2])
    w_norm.append(layer_loss_list[3])

Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24012/480000 [03:27<1:06:31, 114.24it/s, epoch=10, loss=[0.6553428988158702, 0.592

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48017/480000 [06:55<1:02:33, 115.08it/s, epoch=20, loss=[0.6493785619735714, 0.585

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72011/480000 [10:23<1:00:43, 111.96it/s, epoch=30, loss=[0.6472577501336736, 0.574

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96018/480000 [13:51<57:38, 111.01it/s, epoch=40, loss=[0.6460361969967682, 0.57173

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120022/480000 [17:20<52:24, 114.48it/s, epoch=50, loss=[0.6452021724482383, 0.5640

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144019/480000 [20:48<48:44, 114.88it/s, epoch=60, loss=[0.6442776533961299, 0.5584

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168021/480000 [24:16<45:27, 114.39it/s, epoch=70, loss=[0.6437238265573979, 0.5615

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192016/480000 [27:45<41:51, 114.66it/s, epoch=80, loss=[0.6434063560764004, 0.5543

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216014/480000 [31:12<39:27, 111.51it/s, epoch=90, loss=[0.6431732052564635, 0.5589

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240019/480000 [34:42<35:03, 114.10it/s, epoch=100, loss=[0.642945532798767, 0.5504

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264012/480000 [38:10<31:41, 113.59it/s, epoch=110, loss=[0.6427400544285776, 0.552

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288016/480000 [41:38<28:22, 112.75it/s, epoch=120, loss=[0.6425910020371273, 0.550

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312018/480000 [45:18<26:08, 107.10it/s, epoch=130, loss=[0.6424371635417144, 0.560

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336021/480000 [49:04<22:44, 105.54it/s, epoch=140, loss=[0.6423011656602224, 0.553

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360016/480000 [52:35<17:30, 114.17it/s, epoch=150, loss=[0.6421739368637396, 0.546

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384011/480000 [56:03<13:39, 117.07it/s, epoch=160, loss=[0.6420605531334873, 0.548

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408013/480000 [59:31<10:38, 112.76it/s, epoch=170, loss=[0.6419316655894117, 0.545

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432012/480000 [1:02:59<07:01, 113.81it/s, epoch=180, loss=[0.6418246992429103, 0.5

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456013/480000 [1:06:26<03:24, 117.15it/s, epoch=190, loss=[0.6417217097679774, 0.5

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [1:13:32<00:00, 108.78it/s, epoch=200, loss=[0.6416391539076965, 0.5


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%|████▌                                                                                     | 24019/480000 [03:27<1:06:25, 114.42it/s, epoch=10, loss=[0.6529848640660443, 0.5909688905874892]]

lr decreased to  2.4
lr decreased to  2.4


Training:  10%|█████████                                                                                  | 48018/480000 [06:55<1:03:34, 113.24it/s, epoch=20, loss=[0.648261278669039, 0.5873679953316843]]

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|█████████████▋                                                                             | 72017/480000 [10:22<1:00:30, 112.38it/s, epoch=30, loss=[0.646106342871983, 0.5742216816047824]]

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|██████████████████▌                                                                          | 96021/480000 [13:50<54:27, 117.50it/s, epoch=40, loss=[0.6447878367702171, 0.600509520967801]]

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|██████████████████████▊                                                                    | 120016/480000 [17:18<51:49, 115.78it/s, epoch=50, loss=[0.6441093440353869, 0.5928481996059408]]

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|███████████████████████████▎                                                               | 144019/480000 [20:47<48:47, 114.76it/s, epoch=60, loss=[0.6435852554937207, 0.5764533331741899]]

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|███████████████████████████████▊                                                           | 168014/480000 [24:15<43:27, 119.63it/s, epoch=70, loss=[0.6432150666415701, 0.5717564151684436]]

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|████████████████████████████████████▍                                                      | 192015/480000 [27:43<42:00, 114.27it/s, epoch=80, loss=[0.6429174472888308, 0.5715342854460087]]

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|████████████████████████████████████████▉                                                  | 216014/480000 [31:11<39:08, 112.42it/s, epoch=90, loss=[0.6426726848880437, 0.5709149112304054]]

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|█████████████████████████████████████████████                                             | 240020/480000 [34:39<34:59, 114.32it/s, epoch=100, loss=[0.6424809518953158, 0.5706737015644706]]

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|█████████████████████████████████████████████████▌                                        | 264022/480000 [38:07<31:55, 112.74it/s, epoch=110, loss=[0.6422978472709654, 0.5704346669216946]]

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|██████████████████████████████████████████████████████                                    | 288021/480000 [41:35<27:35, 115.95it/s, epoch=120, loss=[0.6421419585744542, 0.5702980683743948]]

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|██████████████████████████████████████████████████████████▌                               | 312015/480000 [45:13<26:31, 105.55it/s, epoch=130, loss=[0.6420040280123557, 0.5725580142438411]]

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|███████████████████████████████████████████████████████████████                           | 336013/480000 [48:59<23:01, 104.25it/s, epoch=140, loss=[0.6418614455064137, 0.5700460476179922]]

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|███████████████████████████████████████████████████████████████████▌                      | 360013/480000 [52:31<17:42, 112.92it/s, epoch=150, loss=[0.6417628277341528, 0.5698550741374496]]

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|████████████████████████████████████████████████████████████████████████                  | 384019/480000 [55:59<13:41, 116.80it/s, epoch=160, loss=[0.6415979833900924, 0.5697609798113492]]

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|████████████████████████████████████████████████████████████████████████████▌             | 408018/480000 [59:27<10:31, 114.02it/s, epoch=170, loss=[0.6414624707400799, 0.5699656383693223]]

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|███████████████████████████████████████████████████████████████████████████████▏        | 432015/480000 [1:02:55<06:45, 118.39it/s, epoch=180, loss=[0.6413605441153054, 0.5695695076882834]]

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|███████████████████████████████████████████████████████████████████████████████████▌    | 456013/480000 [1:06:23<03:29, 114.45it/s, epoch=190, loss=[0.6412836704154808, 0.5694021052618815]]

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 480000/480000 [1:13:29<00:00, 108.86it/s, epoch=200, loss=[0.6411921808123584, 0.5693044828871896]]


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%|████▌                                                                                     | 24013/480000 [03:28<1:05:17, 116.39it/s, epoch=10, loss=[0.6540159984926387, 0.6615129662056753]]

lr decreased to  2.4
lr decreased to  2.4


Training:  10%|█████████                                                                                 | 48011/480000 [06:57<1:03:54, 112.65it/s, epoch=20, loss=[0.6488550058007254, 0.6219376094639294]]

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|█████████████▊                                                                              | 72010/480000 [10:24<59:16, 114.72it/s, epoch=30, loss=[0.6464506544172762, 0.6080295737087728]]

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|██████████████████▍                                                                         | 96011/480000 [13:53<55:43, 114.83it/s, epoch=40, loss=[0.6451079874734083, 0.6060807757079589]]

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|██████████████████████▊                                                                    | 120020/480000 [17:21<52:31, 114.24it/s, epoch=50, loss=[0.6443844569226113, 0.6050933768848588]]

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|███████████████████████████▎                                                               | 144019/480000 [20:49<48:43, 114.91it/s, epoch=60, loss=[0.6438825979332135, 0.6044294511278482]]

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|███████████████████████████████▊                                                           | 168018/480000 [24:17<46:23, 112.10it/s, epoch=70, loss=[0.6435328510403641, 0.6041139776508015]]

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|████████████████████████████████████▊                                                       | 192018/480000 [27:46<41:13, 116.43it/s, epoch=80, loss=[0.6432322623829053, 0.603733605146409]]

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|████████████████████████████████████████▉                                                  | 216021/480000 [31:14<38:40, 113.76it/s, epoch=90, loss=[0.6430012953281402, 0.6034432170788456]]

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|█████████████████████████████████████████████                                             | 240013/480000 [34:42<35:46, 111.81it/s, epoch=100, loss=[0.6427562170227368, 0.6031943100194144]]

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|█████████████████████████████████████████████████▌                                        | 264018/480000 [38:10<31:06, 115.74it/s, epoch=110, loss=[0.6425508972009022, 0.6029978342851005]]

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|██████████████████████████████████████████████████████                                    | 288022/480000 [41:38<27:35, 115.96it/s, epoch=120, loss=[0.6424081067740919, 0.6027371471623577]]

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|██████████████████████████████████████████████████████████▌                               | 312019/480000 [45:17<26:30, 105.62it/s, epoch=130, loss=[0.6422431970636039, 0.6055779828131194]]

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|███████████████████████████████████████████████████████████████                           | 336016/480000 [49:05<22:54, 104.72it/s, epoch=140, loss=[0.6421197689572973, 0.6182637095948075]]

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|███████████████████████████████████████████████████████████████████▌                      | 360015/480000 [52:37<17:30, 114.18it/s, epoch=150, loss=[0.6420056859155495, 0.6050882396101962]]

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|████████████████████████████████████████████████████████████████████████                  | 384018/480000 [56:04<13:58, 114.42it/s, epoch=160, loss=[0.6419038177529972, 0.6024557732542347]]

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|█████████████████████████████████████████████████████████████████████████████▎             | 408019/480000 [59:32<10:21, 115.75it/s, epoch=170, loss=[0.6417930161456271, 0.602265350619952]]

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|███████████████████████████████████████████████████████████████████████████████▏        | 432017/480000 [1:03:00<06:58, 114.62it/s, epoch=180, loss=[0.6417097158730027, 0.6028017697731656]]

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|█████████████████████████████████████████████████████████████████████████████████████▌    | 456014/480000 [1:06:28<03:34, 111.76it/s, epoch=190, loss=[0.64161197518309, 0.6020113998651494]]

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████| 480000/480000 [1:13:34<00:00, 108.74it/s, epoch=200, loss=[0.6415318080782894, 0.6018385614454757]]


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%|████▌                                                                                     | 24017/480000 [03:27<1:09:59, 108.59it/s, epoch=10, loss=[0.6521790900826465, 0.6091703770558042]]

lr decreased to  2.4
lr decreased to  2.4


Training:  10%|█████████                                                                                 | 48019/480000 [06:54<1:02:58, 114.32it/s, epoch=20, loss=[0.6468723669648154, 0.6127167159815621]]

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|█████████████▉                                                                               | 72016/480000 [10:22<57:55, 117.39it/s, epoch=30, loss=[0.6445449229578192, 0.606449668953816]]

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|██████████████████▍                                                                         | 96014/480000 [13:51<55:19, 115.67it/s, epoch=40, loss=[0.6434807773927846, 0.6015180480480199]]

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|██████████████████████▊                                                                    | 120016/480000 [17:19<52:38, 113.98it/s, epoch=50, loss=[0.6427953597903253, 0.6007721364001436]]

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|███████████████████████████▎                                                               | 144022/480000 [20:48<47:44, 117.31it/s, epoch=60, loss=[0.6420521697402007, 0.6016592567165684]]

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|███████████████████████████████▊                                                           | 168017/480000 [24:15<47:30, 109.44it/s, epoch=70, loss=[0.6416647867858408, 0.5999247074623898]]

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|████████████████████████████████████▍                                                      | 192017/480000 [27:43<40:17, 119.11it/s, epoch=80, loss=[0.6413789925972617, 0.5998035029570263]]

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|████████████████████████████████████████▉                                                  | 216019/480000 [31:12<37:13, 118.18it/s, epoch=90, loss=[0.6411558256546649, 0.6002838052312539]]

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|█████████████████████████████████████████████▌                                             | 240022/480000 [34:40<34:08, 117.14it/s, epoch=100, loss=[0.640938648084799, 0.6040968875090283]]

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|██████████████████████████████████████████████████                                         | 264023/480000 [38:09<31:15, 115.15it/s, epoch=110, loss=[0.640768302132686, 0.6032051773369314]]

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|██████████████████████████████████████████████████████▌                                    | 288015/480000 [41:36<27:16, 117.31it/s, epoch=120, loss=[0.6406151480972769, 0.607504500746726]]

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|██████████████████████████████████████████████████████████▌                               | 312021/480000 [45:14<26:46, 104.59it/s, epoch=130, loss=[0.6404723316927744, 0.5996934116383383]]

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|███████████████████████████████████████████████████████████████                           | 336073/480000 [48:22<06:13, 385.32it/s, epoch=140, loss=[0.6403457280496753, 0.5987796793381372]]

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|███████████████████████████████████████████████████████████████████▌                      | 360081/480000 [49:00<03:08, 635.64it/s, epoch=150, loss=[0.6402230962117506, 0.6003740801910551]]

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|████████████████████████████████████████████████████████████████████████                  | 384091/480000 [49:37<02:30, 639.14it/s, epoch=160, loss=[0.6401287677884099, 0.5988090102374553]]

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|████████████████████████████████████████████████████████████████████████████▌             | 408084/480000 [50:14<01:56, 617.33it/s, epoch=170, loss=[0.6400138595203549, 0.5995178354283169]]

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|█████████████████████████████████████████████████████████████████████████████████         | 432109/480000 [50:52<01:15, 632.49it/s, epoch=180, loss=[0.6399207880099611, 0.5991475158929837]]

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|██████████████████████████████████████████████████████████████████████████████████████▍    | 456069/480000 [51:30<00:37, 638.12it/s, epoch=190, loss=[0.6398334068059925, 0.613566650748253]]

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 480000/480000 [52:35<00:00, 152.13it/s, epoch=200, loss=[0.6397511166334148, 0.6014960409700872]]


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%|████▌                                                                                       | 24117/480000 [00:37<11:54, 638.29it/s, epoch=10, loss=[0.6548744416236872, 0.5852151828010876]]

lr decreased to  2.4
lr decreased to  2.4


Training:  10%|█████████▏                                                                                  | 48078/480000 [01:15<11:15, 639.24it/s, epoch=20, loss=[0.6495182452102509, 0.5848494683206077]]

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|█████████████▊                                                                              | 72095/480000 [01:52<10:44, 632.84it/s, epoch=30, loss=[0.6472965709865095, 0.5824507440129909]]

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|██████████████████▍                                                                         | 96126/480000 [02:30<09:55, 644.31it/s, epoch=40, loss=[0.6460049837330976, 0.5812475899358599]]

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|██████████████████████▊                                                                    | 120101/480000 [03:07<09:34, 626.69it/s, epoch=50, loss=[0.6450901635984574, 0.5782382149497661]]

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|███████████████████████████▎                                                               | 144105/480000 [03:46<09:15, 604.38it/s, epoch=60, loss=[0.6446051827569809, 0.5774072311321905]]

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|████████████████████████████████▏                                                           | 168067/480000 [04:24<08:20, 622.86it/s, epoch=70, loss=[0.6442340926826003, 0.580077866713207]]

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|████████████████████████████████████▍                                                      | 192074/480000 [05:02<07:35, 632.13it/s, epoch=80, loss=[0.6439343414207298, 0.5848560388882946]]

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|█████████████████████████████████████████▍                                                  | 216064/480000 [05:39<06:55, 635.12it/s, epoch=90, loss=[0.643690565327804, 0.5832836380600939]]

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|█████████████████████████████████████████████                                             | 240069/480000 [06:17<06:14, 640.11it/s, epoch=100, loss=[0.6435094075401617, 0.5906902694205439]]

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|█████████████████████████████████████████████████▌                                        | 264115/480000 [06:55<05:34, 644.88it/s, epoch=110, loss=[0.6433229840298493, 0.5905429908633232]]

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|██████████████████████████████████████████████████████▌                                    | 288123/480000 [07:32<05:01, 637.28it/s, epoch=120, loss=[0.6431795564293861, 0.575420707960925]]

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|██████████████████████████████████████████████████████████▌                               | 312107/480000 [08:10<04:23, 638.02it/s, epoch=130, loss=[0.6430354801813751, 0.5956122536957259]]

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|███████████████████████████████████████████████████████████████▋                           | 336109/480000 [08:47<03:42, 645.99it/s, epoch=140, loss=[0.6429190927247195, 0.575437786678473]]

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|███████████████████████████████████████████████████████████████████▌                      | 360126/480000 [09:25<03:04, 649.54it/s, epoch=150, loss=[0.6428076017896331, 0.5750657939414185]]

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|████████████████████████████████████████████████████████████████████████▊                  | 384124/480000 [10:02<02:28, 646.02it/s, epoch=160, loss=[0.642701413929463, 0.5746634200215337]]

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|████████████████████████████████████████████████████████████████████████████▌             | 408100/480000 [10:39<01:52, 639.98it/s, epoch=170, loss=[0.6426005931198603, 0.5743257495760913]]

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|█████████████████████████████████████████████████████████████████████████████████▉         | 432120/480000 [11:17<01:17, 620.25it/s, epoch=180, loss=[0.6425016787151491, 0.574206484307845]]

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|██████████████████████████████████████████████████████████████████████████████████████▍    | 456117/480000 [11:54<00:38, 622.13it/s, epoch=190, loss=[0.6424186098575598, 0.574779720455408]]

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 480000/480000 [12:59<00:00, 615.72it/s, epoch=200, loss=[0.6423404491941133, 0.5741494803130628]]


In [8]:
print(train_acc, test_acc)

[[[array([0.90558333, 0.83845   ])]], [[array([0.90391667, 0.92308333])]], [[array([0.905     , 0.90326667])]], [[array([0.9049    , 0.93716667])]], [[array([0.90415   , 0.95371667])]]] [[[array([0.9035, 0.8341])]], [[array([0.9007, 0.9078])]], [[array([0.9035, 0.8937])]], [[array([0.9025, 0.9253])]], [[array([0.9019, 0.9394])]]]


In [9]:
np.save("./../new_data/ablation_normeuc_mnist_train_acc.npy", np.array(train_acc))
np.save("./../new_data/ablation_normeuc_mnist_test_acc.npy", np.array(test_acc))
np.save("./../new_data/ablation_normeuc_w_mnist_end.npy", np.array(w_norm))