In [1]:
import sys
sys.path.insert(0, './../../../Models')
from sphere_points import generate_points

import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt

import torch
torch.manual_seed(0)
import torch.nn as nn
from tqdm import tqdm
from torch.optim import SGD
from torch.nn.functional import normalize, one_hot
# import torch.nn.functional as F
import torchvision
from torchvision import transforms
# from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

In [2]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
def plot_losses(losses):
    losses = np.array(losses)
    n_dims, epochs = losses.shape
    plt.figure(figsize = (12, 5))
    for l in range(n_dims):
        plt.subplot(1, n_dims, l + 1)
        plt.plot(1 + np.arange(epochs), losses[l])
        plt.title(f"Layer {l + 1} Loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
    plt.tight_layout()
    plt.show()

In [4]:
# model_loss = nn.CrossEntropyLoss()
num_classes = 10

In [5]:
def classifier_head_train(inp_embedding, classifier_weights, labels):
    # model_loss = nn.CrossEntropyLoss()
    inp_embedding = normalize(inp_embedding, p=2, dim=-1)
    classifier_output = torch.mm(inp_embedding, classifier_weights)
    # classifier_output = 1 - (torch.acos(classifier_output)/np.pi)
    classifier_output = classifier_output * one_hot(labels, num_classes = num_classes).type(torch.float32)
    # classifier_output = torch.softmax(classifier_output, dim=-1)
    # loss = model_loss(classifier_output, one_hot(labels, num_classes = num_classes).type(torch.float32))
    theta = 1
    loss = torch.mean(torch.log(2 - (theta * torch.sum(classifier_output,1))))
    return loss

In [6]:
def classifier_head(inp_embedding, classifier_weights, labels):
    inp_embedding = normalize(inp_embedding, p=2, dim=-1)
    classifier_output = torch.mm(inp_embedding, classifier_weights)
    # classifier_output = 1 - (torch.acos(classifier_output)/np.pi)
    # classifier_output = torch.softmax(classifier_output, dim=-1)
    # loss = model_loss(classifier_output, one_hot(labels, num_classes = num_classes).type(torch.float32))
    return torch.argmax(classifier_output, dim=1).tolist()

In [7]:
initial = None
# num_classes = 10

# Data dimension
# (num_data, num_features) => no dimension for batch size please
class Layer(nn.Linear):
    def __init__(self, in_features, out_features, bias, device, lr, apply_dropout=False):
        super().__init__(in_features, out_features, bias, device)
        self.out_features = out_features
        self.bias_flag = bias
        self.lr = lr
        self.num_classes = num_classes
        self.dimension = out_features
        # self.activation = DyT(1)
        # self.activation = nn.LeakyReLU(negative_slope=0.001)
        # self.leaky_relu = nn.LeakyReLU(negative_slope=0.001)
        # self.activation = nn.GELU()
        # self.opt = Adam(self.parameters(), lr = self.lr)
        # self.opt = SGD(self.parameters(), lr=self.lr, momentum=0.9)
        # self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.opt, 
        #                                                       milestones=[60], gamma=0.1)
        # nn.init.kaiming_normal_(self.weight, mode='fan_in')
        
        fc1_limit = np.sqrt(6.0 / in_features)
        torch.nn.init.uniform_(self.weight, a=-fc1_limit, b=fc1_limit)
        
        self.dropout = nn.Dropout(0.1)
        self.apply_dropout = apply_dropout
        global initial
        self.directions = generate_points(self.num_classes, self.dimension, steps = 10000)
        initial = np.array(self.directions)
        self.directions = [torch.tensor(t, dtype = torch.float32).to(device) for t in self.directions]
        self.direction_weights = torch.zeros((len(self.directions[0]), len(self.directions)), device=device, 
                                             requires_grad=False)
        for i in range(len(self.directions)):
            self.direction_weights[:, i] = normalize(self.directions[i], p = 2, dim=-1)

    def train(self, x, labels):
        if self.apply_dropout:
            x = self.dropout(x)
        # activation = nn.ELU()
        activation = nn.LeakyReLU(negative_slope=0.001)
        # activation = nn.Tanh()
        # activation = self.activation
        # opt = SGD(self.parameters(), lr=self.lr, momentum=0.9)
        opt = SGD(self.parameters(), lr=self.lr)
        # activation = nn.GELU()
        # activation = nn.ReLU()
        # opt = Adam(self.parameters(), lr=self.lr)
        # print(f"gpu used {torch.cuda.max_memory_allocated(device=None)} memory")
        if self.bias_flag:
            y = activation(torch.mm(x, self.weight.T) + self.bias.unsqueeze(0))
        else:
            y = activation(torch.mm(x, self.weight.T))
        # if self.apply_dropout:
        #     x = self.dropout(x)
        # y = self.forward(x) # shape: (num_data, out_features)
        '''
        y = normalize(y, p = 2, dim = 1)
        '''
        # import pdb;pdb.set_trace()
        loss = classifier_head_train(y, self.direction_weights, labels)
        
        '''
        directions = torch.zeros_like(y)
        for i in range(y.shape[0]):
            directions[i, :] = self.directions[label[i]]
        
        loss = loss_layer(y, directions)
        '''
        opt.zero_grad(set_to_none=True)
        loss.backward(retain_graph = False)
        opt.step()
        # self.scheduler.step()
        
        # normalize the directions
        # self.directions.data = normalize(self.directions.data, p = 2, dim = 1)

        return loss.detach().item(), y
    
    def test(self, x, labels):
        with torch.no_grad():
            # activation = nn.ELU()
            # y = self.forward(x)
            activation = nn.LeakyReLU(negative_slope=0.001)
            # activation = nn.ReLU()
            # activation = nn.GELU()
            # activation = nn.Tanh()
            # activation = self.activation
            if self.bias_flag:
                y = activation(torch.mm(x, self.weight.T) + self.bias.unsqueeze(0))
            else:
                y = activation(torch.mm(x, self.weight.T))
        max_idx_list =[]
        max_idx_list = classifier_head(y, self.direction_weights, labels)
        '''
        for dat in range(y.shape[0]):
            max = -np.inf
            max_idx = 0
            for i in range(self.num_classes):
                cos_sim = cos_similarity(y[dat, :].unsqueeze(0), self.directions[i].reshape(1, -1))
                if cos_sim > max:
                    max = cos_sim
                    max_idx = i
            max_idx_list.append(max_idx)
        '''
        return torch.tensor(max_idx_list, device=device), y

In [8]:
class Net(nn.Module):
    def __init__(self, dims_list, bias, epochs, lr, device):
        super(Net, self).__init__()
        self.dims_list = dims_list
        self.bias = bias
        self.epochs = epochs
        self.lr = lr
        self.device = device
        self.layers = []
        # self.sigmoid = nn.Sigmoid()
        global initial
        for d in range(len(self.dims_list) - 1):
            print(f"Initialization {d + 1} / {len(self.dims_list) - 1}")
            self.layers += [Layer(self.dims_list[d], self.dims_list[d + 1], self.bias, self.device, self.lr)]
            print("Complete\n")
        
    def train(self, train_loader, test_loader):
        layer_loss_list = []
        acc_train = []
        acc_test = []
        layer_w = [[] for _ in range(len(self.dims_list) - 1)]
        
        for i in range(len(self.layers)):
            layer_loss_list.append([])
        pbar = tqdm(total = self.epochs * len(train_loader) * len(self.layers), 
                    desc = f"Training", position = 0, leave = True)
        
        # Test the network
        with torch.no_grad():
           
            for i in range(len(self.layers)):
                layer_w[i].append(torch.norm(self.layers[i].weight, p=2).item())
            acc_train.append(net.test(trainloader))
            acc_test.append(net.test(testloader))
        
        for epoch in range(self.epochs):

            if epoch and not (epoch % 10): 
                # learning rate decay
                for i in range(len(self.layers)):
                    self.layers[i].lr = self.layers[i].lr - 0.1
                    print('lr decreased to ', self.layers[i].lr)
            
            loss_agg = [0] * len(self.layers)
            for dat in train_loader:
                x, label = dat
                x = x.to(device)
                label = label.to(device)
                for i in range(len(self.layers)):
                    
                    loss, y = self.layers[i].train(x, label)
                    self.layers[i].zero_grad(set_to_none=True)
                    x = y.detach()
                    loss_agg[i] += loss / len(train_loader)
                    del y
                    pbar.update(1)
            pbar.set_postfix(epoch = epoch + 1, loss = loss_agg)
            for i in range(len(self.layers)):
                layer_loss_list[i].append(loss_agg[i])
                
        # Test the network
        with torch.no_grad():
            for i in range(len(self.layers)):
                layer_w[i].append(torch.norm(self.layers[i].weight, p=2).item())
            acc_train.append(net.test(train_loader))
            acc_test.append(net.test(test_loader))  
                
        pbar.close()
        return [layer_loss_list, acc_train, acc_test, layer_w]
    
    def test(self, data_loader):
        all_accuracy = []
        correct = [0 for _ in range(len(self.layers))]
        total = [0 for _ in range(len(self.layers))]
        for dat in data_loader:
            x = dat[0]
            label = dat[1]
            x = x.to(device)
            label = label.to(device)
            num = label
            preds = []
            
            for i in range(len(self.layers)):
                pred, x = self.layers[i].test(x, label)
                preds.append(pred)
            
            for i in range(len(preds)):
                correct[i] += (preds[i] == num).sum().item()
                total[i] += num.shape[0]
                
        all_accuracy.append(np.array(correct) / total[-1])
        return all_accuracy

In [9]:
flatten_transform = transforms.Lambda(lambda x: x.view(x.size(0), -1).squeeze())
'''
# Define data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    flatten_transform
])
'''

# # Define data transformations
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.,), (0.5,)),
#     flatten_transform
# ])

transform = transforms.Compose([transforms.ToTensor(), flatten_transform]) # this normalizes to [0,1]

trainset = torchvision.datasets.MNIST(root='./../../../Data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./../../../Data', train=False, download=True, transform=transform)
batch_size = 50
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle = False)

# full_data_set = torch.utils.data.ConcatDataset([trainset, testset])

In [10]:
# ratio = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
ratio = [0.2]
test_acc = []
train_acc = []
w_layers = []
loss = []

num_runs = 5

# for r in ratio:
    # print(f"train ratio: {r}")
'''    
test_size = int(r * len(full_data_set))
train_size = len(full_data_set) - test_size
train_dataset, test_dataset = torch.utils.data.random_split(full_data_set, [train_size, test_size])

# Create DataLoader for training set
trainloader = DataLoader(train_dataset, batch_size = 64, shuffle=True)

# Create DataLoader for test set
testloader = DataLoader(test_dataset, batch_size = 64, shuffle=False)
'''
for _ in range(num_runs):
    
    dims_list = [784, 1024, 10]
    # dims_list = [784, 1000, 34]
    bias = True
    epochs = 200
    lr  = 2.5
    num_classes = 10
    net = Net(dims_list, bias, epochs, lr, device)
    
    # Train the network
    layer_loss_list = net.train(trainloader, testloader)
    
    # plot_losses(layer_loss_list[0])
    
    '''
    # Test the network
    acc_train = net.test(trainloader)
    print(f"Train accuracy: {acc_train * 100:.2f}%")
    
    acc_test = net.test(testloader)
    print(f"Test accuracy: {acc_test * 100:.2f}%")
    '''
    loss.append(layer_loss_list[0])
    train_acc.append(layer_loss_list[1])
    test_acc.append(layer_loss_list[2])
    w_layers.append(layer_loss_list[3])

Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24113/480000 [00:31<09:33, 795.19it/s, epoch=10, loss=[0.35989250689744895, 0.3110

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48157/480000 [01:00<07:58, 903.09it/s, epoch=20, loss=[0.35386967102686545, 0.3214

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72161/480000 [01:28<08:06, 837.48it/s, epoch=30, loss=[0.3516720188905798, 0.29584

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96137/480000 [01:57<07:38, 836.58it/s, epoch=40, loss=[0.3501994325965645, 0.29454

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120171/480000 [02:25<06:55, 865.49it/s, epoch=50, loss=[0.34942907457550426, 0.294

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144141/480000 [02:53<06:28, 865.30it/s, epoch=60, loss=[0.34875621465345197, 0.293

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168143/480000 [03:21<06:27, 805.10it/s, epoch=70, loss=[0.34833494683106725, 0.294

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192149/480000 [03:50<05:34, 861.50it/s, epoch=80, loss=[0.34803126749893065, 0.295

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216097/480000 [04:19<05:11, 846.56it/s, epoch=90, loss=[0.34777008267740395, 0.292

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240138/480000 [04:48<04:41, 851.24it/s, epoch=100, loss=[0.34755338758230137, 0.29

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264154/480000 [05:16<04:09, 864.05it/s, epoch=110, loss=[0.34737006485462285, 0.29

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288117/480000 [05:43<03:45, 852.55it/s, epoch=120, loss=[0.34720926962792886, 0.29

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312153/480000 [06:12<03:03, 913.33it/s, epoch=130, loss=[0.3470644739021852, 0.291

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336120/480000 [06:40<02:40, 895.52it/s, epoch=140, loss=[0.3469331206132973, 0.293

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360131/480000 [07:08<02:26, 818.09it/s, epoch=150, loss=[0.34681236719091707, 0.29

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384096/480000 [07:36<01:49, 879.68it/s, epoch=160, loss=[0.34669934198260305, 0.29

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408087/480000 [08:03<01:22, 874.96it/s, epoch=170, loss=[0.3465924130628506, 0.290

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432155/480000 [08:32<00:54, 870.14it/s, epoch=180, loss=[0.34649043055872153, 0.29

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456153/480000 [09:00<00:29, 815.03it/s, epoch=190, loss=[0.3463950495173535, 0.298

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [09:31<00:00, 840.46it/s, epoch=200, loss=[0.34630627835790295, 0.29


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24133/480000 [00:30<08:48, 863.06it/s, epoch=10, loss=[0.3586518155038358, 0.24286

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48085/480000 [00:58<08:21, 861.21it/s, epoch=20, loss=[0.3519886412719886, 0.22617

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72094/480000 [01:26<07:21, 924.35it/s, epoch=30, loss=[0.34966737598180747, 0.2163

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96138/480000 [01:53<07:55, 807.97it/s, epoch=40, loss=[0.3485680189728736, 0.21021

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120144/480000 [02:21<07:40, 781.03it/s, epoch=50, loss=[0.34777996818224577, 0.208

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144166/480000 [02:49<06:37, 844.09it/s, epoch=60, loss=[0.3472013551741837, 0.2042

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168105/480000 [03:17<05:48, 894.35it/s, epoch=70, loss=[0.346712753921747, 0.20307

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192115/480000 [03:45<05:33, 862.06it/s, epoch=80, loss=[0.3463983477652072, 0.2076

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216094/480000 [04:13<04:45, 923.69it/s, epoch=90, loss=[0.34614976552625454, 0.203

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240136/480000 [04:40<04:33, 875.85it/s, epoch=100, loss=[0.34594083252052465, 0.20

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264149/480000 [05:08<04:04, 882.65it/s, epoch=110, loss=[0.345760867844025, 0.2089

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288131/480000 [05:35<03:39, 875.76it/s, epoch=120, loss=[0.3456016384065151, 0.199

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312151/480000 [06:02<03:11, 876.77it/s, epoch=130, loss=[0.3454588057845828, 0.198

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336157/480000 [06:30<02:43, 877.58it/s, epoch=140, loss=[0.34532431994875273, 0.19

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360145/480000 [06:57<02:16, 877.16it/s, epoch=150, loss=[0.3452033113439875, 0.196

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384176/480000 [07:24<01:48, 879.44it/s, epoch=160, loss=[0.34508730466167187, 0.19

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408144/480000 [07:51<01:20, 898.12it/s, epoch=170, loss=[0.34497964866459313, 0.19

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432114/480000 [08:19<00:54, 878.82it/s, epoch=180, loss=[0.3447302165627477, 0.195

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456141/480000 [08:46<00:27, 873.78it/s, epoch=190, loss=[0.3446313551813361, 0.194

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [09:16<00:00, 863.00it/s, epoch=200, loss=[0.34454193080465023, 0.19


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24116/480000 [00:29<08:38, 878.75it/s, epoch=10, loss=[0.3601487936079503, 0.30703

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48139/480000 [00:57<08:12, 876.08it/s, epoch=20, loss=[0.3531176572044686, 0.30284

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72165/480000 [01:24<07:50, 866.79it/s, epoch=30, loss=[0.3506968917449309, 0.30569

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96107/480000 [01:51<07:16, 878.82it/s, epoch=40, loss=[0.34922413369019806, 0.3003

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120124/480000 [02:18<06:48, 880.90it/s, epoch=50, loss=[0.3484926300744214, 0.2995

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144103/480000 [02:45<06:23, 876.07it/s, epoch=60, loss=[0.3479938039680323, 0.2987

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168147/480000 [03:13<05:57, 873.22it/s, epoch=70, loss=[0.34762015357613524, 0.298

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192109/480000 [03:40<05:29, 872.72it/s, epoch=80, loss=[0.347322388266524, 0.29851

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216109/480000 [04:07<05:01, 874.25it/s, epoch=90, loss=[0.34707764126360413, 0.299

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240161/480000 [04:35<04:35, 871.32it/s, epoch=100, loss=[0.3466935826092953, 0.299

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264123/480000 [05:02<04:06, 875.56it/s, epoch=110, loss=[0.3465029895802336, 0.297

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288165/480000 [05:30<03:39, 872.75it/s, epoch=120, loss=[0.34634012425939253, 0.29

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312111/480000 [05:57<03:12, 872.46it/s, epoch=130, loss=[0.3461924427747731, 0.296

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336119/480000 [06:25<02:48, 855.93it/s, epoch=140, loss=[0.34605354157586926, 0.29

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360127/480000 [06:52<02:17, 872.43it/s, epoch=150, loss=[0.3459294751038159, 0.296

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384091/480000 [07:19<01:49, 872.37it/s, epoch=160, loss=[0.34581743553280875, 0.29

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408151/480000 [07:47<01:21, 876.68it/s, epoch=170, loss=[0.3456609039505324, 0.296

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432096/480000 [08:14<00:54, 875.03it/s, epoch=180, loss=[0.3455608108143013, 0.296

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456131/480000 [08:41<00:28, 843.99it/s, epoch=190, loss=[0.3454687838256357, 0.295

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [09:11<00:00, 870.79it/s, epoch=200, loss=[0.34533030447860574, 0.29


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24147/480000 [00:29<08:38, 879.90it/s, epoch=10, loss=[0.357384678001205, 0.267694

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48114/480000 [00:57<08:12, 876.53it/s, epoch=20, loss=[0.351270492846767, 0.254521

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72103/480000 [01:24<07:43, 880.11it/s, epoch=30, loss=[0.34889219798147675, 0.2549

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96115/480000 [01:51<07:15, 882.34it/s, epoch=40, loss=[0.3477576161424319, 0.25219

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120127/480000 [02:18<06:50, 876.18it/s, epoch=50, loss=[0.3468764327218139, 0.2449

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144093/480000 [02:46<06:22, 877.50it/s, epoch=60, loss=[0.3463771959642566, 0.2438

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168132/480000 [03:13<05:53, 881.49it/s, epoch=70, loss=[0.3460052681962651, 0.2442

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192130/480000 [03:40<05:27, 879.70it/s, epoch=80, loss=[0.3457066559543216, 0.2427

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216161/480000 [04:08<05:01, 873.81it/s, epoch=90, loss=[0.3454620789239802, 0.2420

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240094/480000 [04:35<04:33, 876.06it/s, epoch=100, loss=[0.3452549100667237, 0.239

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264167/480000 [05:02<03:58, 906.32it/s, epoch=110, loss=[0.34507507349054023, 0.23

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288173/480000 [05:30<03:32, 902.16it/s, epoch=120, loss=[0.3449161860346793, 0.237

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312105/480000 [05:57<03:11, 876.21it/s, epoch=130, loss=[0.34477317355573156, 0.23

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336145/480000 [06:24<02:39, 904.14it/s, epoch=140, loss=[0.34464319926996984, 0.23

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360117/480000 [06:51<02:16, 879.63it/s, epoch=150, loss=[0.34452190573016767, 0.23

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384110/480000 [07:18<01:48, 882.73it/s, epoch=160, loss=[0.3444090247154234, 0.235

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408140/480000 [07:45<01:21, 877.08it/s, epoch=170, loss=[0.34430369849006365, 0.23

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432155/480000 [08:13<00:54, 876.56it/s, epoch=180, loss=[0.34420494894186593, 0.23

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456093/480000 [08:40<00:27, 883.02it/s, epoch=190, loss=[0.34411005529264627, 0.23

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [09:09<00:00, 872.93it/s, epoch=200, loss=[0.34402099631726746, 0.23


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24103/480000 [00:29<08:38, 879.10it/s, epoch=10, loss=[0.35702761933207555, 0.2696

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48153/480000 [00:57<08:16, 869.86it/s, epoch=20, loss=[0.35094152706364834, 0.2578

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72143/480000 [01:24<07:45, 876.87it/s, epoch=30, loss=[0.34873412022988015, 0.2499

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96109/480000 [01:51<07:17, 877.22it/s, epoch=40, loss=[0.34735359959304324, 0.2445

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120146/480000 [02:19<06:51, 874.78it/s, epoch=50, loss=[0.3466178977737829, 0.2375

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144161/480000 [02:46<06:24, 874.05it/s, epoch=60, loss=[0.3461138614267112, 0.2438

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168165/480000 [03:13<05:57, 872.91it/s, epoch=70, loss=[0.3457282591114439, 0.2333

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192168/480000 [03:40<05:31, 868.36it/s, epoch=80, loss=[0.34542061090469356, 0.231

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216147/480000 [04:07<04:59, 880.70it/s, epoch=90, loss=[0.3451669081300501, 0.2332

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240114/480000 [04:35<04:34, 875.33it/s, epoch=100, loss=[0.344954571425915, 0.2324

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264158/480000 [05:02<04:04, 881.23it/s, epoch=110, loss=[0.34476777717471124, 0.22

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288141/480000 [05:29<03:39, 875.15it/s, epoch=120, loss=[0.344602781583866, 0.2284

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312118/480000 [05:57<03:15, 860.03it/s, epoch=130, loss=[0.34440896034240737, 0.22

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336103/480000 [06:24<02:46, 862.83it/s, epoch=140, loss=[0.3442651216934127, 0.227

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360121/480000 [06:52<02:16, 875.72it/s, epoch=150, loss=[0.34414122139414133, 0.22

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384157/480000 [07:19<01:49, 878.10it/s, epoch=160, loss=[0.34402695313096143, 0.22

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408113/480000 [07:46<01:22, 875.49it/s, epoch=170, loss=[0.3439207591364783, 0.226

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432128/480000 [08:14<00:53, 899.45it/s, epoch=180, loss=[0.34382112453381164, 0.22

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456105/480000 [08:41<00:27, 875.29it/s, epoch=190, loss=[0.34372727103531375, 0.22

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [09:10<00:00, 871.63it/s, epoch=200, loss=[0.3436376667519412, 0.223


In [11]:
print(test_acc)
print(np.max(test_acc))

[[[array([0.091 , 0.1028])], [array([0.9091, 0.9473])]], [[array([0.0573, 0.0874])], [array([0.9127, 0.9424])]], [[array([0.1076, 0.0793])], [array([0.9102, 0.9513])]], [[array([0.105 , 0.1056])], [array([0.9113, 0.9367])]], [[array([0.1005, 0.0762])], [array([0.9111, 0.943 ])]]]
0.9513


In [12]:
print(train_acc)
print(np.max(train_acc))

[[[array([0.09076667, 0.0997    ])], [array([0.91275   , 0.96038333])]], [[array([0.05856667, 0.08715   ])], [array([0.91411667, 0.95235   ])]], [[array([0.10648333, 0.0795    ])], [array([0.9134, 0.9642])]], [[array([0.10083333, 0.10225   ])], [array([0.91388333, 0.9476    ])]], [[array([0.10016667, 0.08095   ])], [array([0.91416667, 0.95056667])]]]
0.9642


In [13]:
np.save("./../new_data/ablation_pure_mnist_train_acc.npy", np.array(train_acc))
np.save("./../new_data/ablation_pure_mnist_test_acc.npy", np.array(test_acc))