In [1]:
import sys
sys.path.insert(0, './../../../Models')
from sphere_points import generate_points

import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt

import torch
torch.manual_seed(0)
import torch.nn as nn
from tqdm import tqdm
from torch.optim import SGD
from torch.nn.functional import normalize, one_hot
# import torch.nn.functional as F
import torchvision
from torchvision import transforms
# from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

In [2]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
def plot_losses(losses):
    losses = np.array(losses)
    n_dims, epochs = losses.shape
    plt.figure(figsize = (12, 5))
    for l in range(n_dims):
        plt.subplot(1, n_dims, l + 1)
        plt.plot(1 + np.arange(epochs), losses[l])
        plt.title(f"Layer {l + 1} Loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
    plt.tight_layout()
    plt.show()

In [4]:
# model_loss = nn.CrossEntropyLoss()
num_classes = 10

In [5]:
def classifier_head_train(inp_embedding, classifier_weights, labels):
    # model_loss = nn.CrossEntropyLoss()
    inp_embedding = normalize(inp_embedding, p=2, dim=-1)
    classifier_output = torch.mm(inp_embedding, classifier_weights)
    classifier_output = 1 - torch.acos(classifier_output)/np.pi
    classifier_output = classifier_output * one_hot(labels, num_classes = num_classes).type(torch.float32)
    # classifier_output = torch.softmax(classifier_output, dim=-1)
    # loss = model_loss(classifier_output, one_hot(labels, num_classes = num_classes).type(torch.float32))
    theta = 1
    loss = torch.mean(torch.log(2 - (theta * torch.sum(classifier_output,1))))
    return loss

In [6]:
def classifier_head(inp_embedding, classifier_weights, labels):
    inp_embedding = normalize(inp_embedding, p=2, dim=-1)
    classifier_output = torch.mm(inp_embedding, classifier_weights)
    classifier_output = 1 - torch.acos(classifier_output)/np.pi
    # classifier_output = torch.softmax(classifier_output, dim=-1)
    # loss = model_loss(classifier_output, one_hot(labels, num_classes = num_classes).type(torch.float32))
    return torch.argmax(classifier_output, dim=1).tolist()

In [7]:
initial = None
# num_classes = 10

# Data dimension
# (num_data, num_features) => no dimension for batch size please
class Layer(nn.Linear):
    def __init__(self, in_features, out_features, bias, device, lr, apply_dropout=False):
        super().__init__(in_features, out_features, bias, device)
        self.out_features = out_features
        self.bias_flag = bias
        self.lr = lr
        self.num_classes = num_classes
        self.dimension = out_features
        # self.activation = DyT(1)
        # self.activation = nn.LeakyReLU(negative_slope=0.001)
        # self.leaky_relu = nn.LeakyReLU(negative_slope=0.001)
        # self.activation = nn.GELU()
        # self.opt = Adam(self.parameters(), lr = self.lr)
        # self.opt = SGD(self.parameters(), lr=self.lr, momentum=0.9)
        # self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.opt, 
        #                                                       milestones=[60], gamma=0.1)
        # nn.init.kaiming_normal_(self.weight, mode='fan_in')
        
        fc1_limit = np.sqrt(6.0 / in_features)
        torch.nn.init.uniform_(self.weight, a=-fc1_limit, b=fc1_limit)
        
        self.dropout = nn.Dropout(0.1)
        self.apply_dropout = apply_dropout
        global initial
        self.directions = generate_points(self.num_classes, self.dimension, steps = 10000)
        initial = np.array(self.directions)
        self.directions = [torch.tensor(t, dtype = torch.float32).to(device) for t in self.directions]
        self.direction_weights = torch.zeros((len(self.directions[0]), len(self.directions)), device=device, 
                                             requires_grad=False)
        for i in range(len(self.directions)):
            self.direction_weights[:, i] = normalize(self.directions[i], p = 2, dim=-1)

    def train(self, x, labels):
        if self.apply_dropout:
            x = self.dropout(x)
        # activation = nn.ELU()
        activation = nn.LeakyReLU(negative_slope=0.001)
        # activation = nn.Tanh()
        # activation = self.activation
        # opt = SGD(self.parameters(), lr=self.lr, momentum=0.9)
        opt = SGD(self.parameters(), lr=self.lr)
        # activation = nn.GELU()
        # activation = nn.ReLU()
        # opt = Adam(self.parameters(), lr=self.lr)
        # print(f"gpu used {torch.cuda.max_memory_allocated(device=None)} memory")
        if self.bias_flag:
            y = activation(torch.mm(x, self.weight.T) + self.bias.unsqueeze(0))
        else:
            y = activation(torch.mm(x, self.weight.T))
        # if self.apply_dropout:
        #     x = self.dropout(x)
        # y = self.forward(x) # shape: (num_data, out_features)
        '''
        y = normalize(y, p = 2, dim = 1)
        '''
        # import pdb;pdb.set_trace()
        loss = classifier_head_train(y, self.direction_weights, labels)
        
        '''
        directions = torch.zeros_like(y)
        for i in range(y.shape[0]):
            directions[i, :] = self.directions[label[i]]
        
        loss = loss_layer(y, directions)
        '''
        opt.zero_grad(set_to_none=True)
        loss.backward(retain_graph = False)
        opt.step()
        # self.scheduler.step()
        
        # normalize the directions
        # self.directions.data = normalize(self.directions.data, p = 2, dim = 1)

        return loss.detach().item(), y
    
    def test(self, x, labels):
        with torch.no_grad():
            # activation = nn.ELU()
            # y = self.forward(x)
            activation = nn.LeakyReLU(negative_slope=0.001)
            # activation = nn.ReLU()
            # activation = nn.GELU()
            # activation = nn.Tanh()
            # activation = self.activation
            if self.bias_flag:
                y = activation(torch.mm(x, self.weight.T) + self.bias.unsqueeze(0))
            else:
                y = activation(torch.mm(x, self.weight.T))
        max_idx_list =[]
        max_idx_list = classifier_head(y, self.direction_weights, labels)
        '''
        for dat in range(y.shape[0]):
            max = -np.inf
            max_idx = 0
            for i in range(self.num_classes):
                cos_sim = cos_similarity(y[dat, :].unsqueeze(0), self.directions[i].reshape(1, -1))
                if cos_sim > max:
                    max = cos_sim
                    max_idx = i
            max_idx_list.append(max_idx)
        '''
        return torch.tensor(max_idx_list, device=device), y

In [8]:
class Net(nn.Module):
    def __init__(self, dims_list, bias, epochs, lr, device):
        super(Net, self).__init__()
        self.dims_list = dims_list
        self.bias = bias
        self.epochs = epochs
        self.lr = lr
        self.device = device
        self.layers = []
        # self.sigmoid = nn.Sigmoid()
        global initial
        for d in range(len(self.dims_list) - 1):
            print(f"Initialization {d + 1} / {len(self.dims_list) - 1}")
            self.layers += [Layer(self.dims_list[d], self.dims_list[d + 1], self.bias, self.device, self.lr)]
            print("Complete\n")
        
    def train(self, train_loader, test_loader):
        layer_loss_list = []
        acc_train = []
        acc_test = []
        layer_w = [[] for _ in range(len(self.dims_list) - 1)]
        
        for i in range(len(self.layers)):
            layer_loss_list.append([])
        pbar = tqdm(total = self.epochs * len(train_loader) * len(self.layers), 
                    desc = f"Training", position = 0, leave = True)
        
        # Test the network
        with torch.no_grad():
           
            for i in range(len(self.layers)):
                layer_w[i].append(torch.norm(self.layers[i].weight, p=2).item())
            acc_train.append(net.test(trainloader))
            acc_test.append(net.test(testloader))
        
        for epoch in range(self.epochs):

            if epoch and not (epoch % 10): 
                # learning rate decay
                for i in range(len(self.layers)):
                    self.layers[i].lr = self.layers[i].lr - 0.1
                    print('lr decreased to ', self.layers[i].lr)
            
            loss_agg = [0] * len(self.layers)
            for dat in train_loader:
                x, label = dat
                x = x.to(device)
                label = label.to(device)
                for i in range(len(self.layers)):
                    
                    loss, y = self.layers[i].train(x, label)
                    self.layers[i].zero_grad(set_to_none=True)
                    x = y.detach()
                    loss_agg[i] += loss / len(train_loader)
                    del y
                    pbar.update(1)
            pbar.set_postfix(epoch = epoch + 1, loss = loss_agg)
            for i in range(len(self.layers)):
                layer_loss_list[i].append(loss_agg[i])
                
        # Test the network
        with torch.no_grad():
            for i in range(len(self.layers)):
                layer_w[i].append(torch.norm(self.layers[i].weight, p=2).item())
            acc_train.append(net.test(train_loader))
            acc_test.append(net.test(test_loader))  
                    
        pbar.close()
        return [layer_loss_list, acc_train, acc_test, layer_w]
    
    def test(self, data_loader):
        all_accuracy = []
        correct = [0 for _ in range(len(self.layers))]
        total = [0 for _ in range(len(self.layers))]
        for dat in data_loader:
            x = dat[0]
            label = dat[1]
            x = x.to(device)
            label = label.to(device)
            num = label
            preds = []
            
            for i in range(len(self.layers)):
                pred, x = self.layers[i].test(x, label)
                preds.append(pred)
            
            for i in range(len(preds)):
                correct[i] += (preds[i] == num).sum().item()
                total[i] += num.shape[0]
                
        all_accuracy.append(np.array(correct) / total[-1])
        return all_accuracy

In [9]:
flatten_transform = transforms.Lambda(lambda x: x.view(x.size(0), -1).squeeze())
'''
# Define data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    flatten_transform
])
'''

# # Define data transformations
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.,), (0.5,)),
#     flatten_transform
# ])

transform = transforms.Compose([transforms.ToTensor(), flatten_transform]) # this normalizes to [0,1]

trainset = torchvision.datasets.MNIST(root='./../../../Data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./../../../Data', train=False, download=True, transform=transform)
batch_size = 50
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle = False)

# full_data_set = torch.utils.data.ConcatDataset([trainset, testset])

In [10]:
# ratio = [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
ratio = [0.2]
test_acc = []
train_acc = []
w_layers = []
loss = []

num_runs = 5

# for r in ratio:
    # print(f"train ratio: {r}")
'''    
test_size = int(r * len(full_data_set))
train_size = len(full_data_set) - test_size
train_dataset, test_dataset = torch.utils.data.random_split(full_data_set, [train_size, test_size])

# Create DataLoader for training set
trainloader = DataLoader(train_dataset, batch_size = 64, shuffle=True)

# Create DataLoader for test set
testloader = DataLoader(test_dataset, batch_size = 64, shuffle=False)
'''
for _ in range(num_runs):
    
    dims_list = [784, 1024, 10]
    # dims_list = [784, 1000, 34]
    bias = True
    epochs = 200
    lr  = 2.5
    num_classes = 10
    net = Net(dims_list, bias, epochs, lr, device)
    
    # Train the network
    layer_loss_list = net.train(trainloader, testloader)
    
    # plot_losses(layer_loss_list[0])
    
    '''
    # Test the network
    acc_train = net.test(trainloader)
    print(f"Train accuracy: {acc_train * 100:.2f}%")
    
    acc_test = net.test(testloader)
    print(f"Test accuracy: {acc_test * 100:.2f}%")
    '''
    loss.append(layer_loss_list[0])
    train_acc.append(layer_loss_list[1])
    test_acc.append(layer_loss_list[2])
    w_layers.append(layer_loss_list[3])

Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24122/480000 [00:31<09:03, 839.38it/s, epoch=10, loss=[0.2719540886705119, 0.24306

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48128/480000 [00:59<09:35, 749.92it/s, epoch=20, loss=[0.2680148441344498, 0.24131

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72094/480000 [01:28<08:15, 822.65it/s, epoch=30, loss=[0.26575603881229964, 0.2477

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96166/480000 [01:58<07:41, 831.49it/s, epoch=40, loss=[0.2646499524265529, 0.24005

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120157/480000 [02:28<07:19, 818.45it/s, epoch=50, loss=[0.264011628019313, 0.23962

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144123/480000 [02:57<06:49, 820.93it/s, epoch=60, loss=[0.26360758413871127, 0.239

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168087/480000 [03:27<06:37, 784.44it/s, epoch=70, loss=[0.2632534280667703, 0.2391

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192121/480000 [03:57<05:47, 829.19it/s, epoch=80, loss=[0.26295103808244097, 0.238

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216097/480000 [04:27<05:41, 772.83it/s, epoch=90, loss=[0.2627611873795593, 0.2411

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240094/480000 [04:57<05:00, 799.30it/s, epoch=100, loss=[0.2626074877381325, 0.238

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264157/480000 [05:27<04:25, 811.68it/s, epoch=110, loss=[0.2624794700990123, 0.238

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288173/480000 [05:57<03:45, 850.22it/s, epoch=120, loss=[0.2623700154696902, 0.238

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312116/480000 [06:26<03:21, 832.02it/s, epoch=130, loss=[0.26227501614640153, 0.23

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336125/480000 [06:56<02:57, 811.12it/s, epoch=140, loss=[0.262115303017199, 0.2382

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360159/480000 [07:26<02:32, 784.25it/s, epoch=150, loss=[0.26201385067154964, 0.23

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384145/480000 [07:56<01:57, 818.81it/s, epoch=160, loss=[0.2619472159569466, 0.238

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408067/480000 [08:28<01:34, 763.36it/s, epoch=170, loss=[0.26189071151117443, 0.23

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432151/480000 [08:58<00:58, 821.79it/s, epoch=180, loss=[0.2618344849099714, 0.237

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456090/480000 [09:28<00:29, 808.64it/s, epoch=190, loss=[0.2617880406603219, 0.237

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [10:02<00:00, 796.28it/s, epoch=200, loss=[0.26174864434947565, 0.23


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24087/480000 [00:34<11:35, 655.90it/s, epoch=10, loss=[0.2712905178591609, 0.25158

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48158/480000 [01:07<09:07, 788.47it/s, epoch=20, loss=[0.26767924144864147, 0.2320

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72115/480000 [01:39<09:27, 719.18it/s, epoch=30, loss=[0.2662043203537666, 0.22706

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96129/480000 [02:13<08:16, 773.12it/s, epoch=40, loss=[0.26520479060709495, 0.2258

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120099/480000 [02:46<08:03, 745.10it/s, epoch=50, loss=[0.2645714812974137, 0.2253

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144106/480000 [03:19<07:30, 745.37it/s, epoch=60, loss=[0.26403879585365425, 0.224

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168155/480000 [03:52<06:39, 780.30it/s, epoch=70, loss=[0.2636387916157643, 0.2244

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192101/480000 [04:24<06:33, 732.34it/s, epoch=80, loss=[0.26335622556507593, 0.224

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216111/480000 [04:57<05:47, 759.61it/s, epoch=90, loss=[0.2631378780926264, 0.2246

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240092/480000 [05:28<05:51, 682.26it/s, epoch=100, loss=[0.26290081934382525, 0.22

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264106/480000 [05:59<04:18, 836.56it/s, epoch=110, loss=[0.2627458857993281, 0.223

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288142/480000 [06:29<04:01, 793.36it/s, epoch=120, loss=[0.2626272910088307, 0.224

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312109/480000 [07:00<03:41, 756.51it/s, epoch=130, loss=[0.2625315154592193, 0.222

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336151/480000 [07:31<03:04, 781.21it/s, epoch=140, loss=[0.2624483958631752, 0.222

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360083/480000 [08:01<02:28, 805.35it/s, epoch=150, loss=[0.2623754036550721, 0.222

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384082/480000 [08:32<02:07, 752.08it/s, epoch=160, loss=[0.26230491490413704, 0.22

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408106/480000 [09:03<01:39, 722.27it/s, epoch=170, loss=[0.2622464815278854, 0.222

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432083/480000 [09:34<01:00, 786.80it/s, epoch=180, loss=[0.2621961635351182, 0.222

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456071/480000 [10:05<00:30, 772.73it/s, epoch=190, loss=[0.2621510114272432, 0.222

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [10:40<00:00, 748.90it/s, epoch=200, loss=[0.2620964052031439, 0.222


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24148/480000 [00:35<09:46, 777.89it/s, epoch=10, loss=[0.2708904414375626, 0.25494

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48142/480000 [01:06<09:09, 785.57it/s, epoch=20, loss=[0.2669670740266638, 0.24920

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72084/480000 [01:38<08:37, 788.76it/s, epoch=30, loss=[0.2651521207640562, 0.23699

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96132/480000 [02:10<09:17, 688.47it/s, epoch=40, loss=[0.2640984834109743, 0.23587

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120121/480000 [02:41<07:31, 797.22it/s, epoch=50, loss=[0.2635538043330119, 0.2353

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144108/480000 [03:12<07:13, 774.06it/s, epoch=60, loss=[0.26295361225803676, 0.234

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168099/480000 [03:44<06:47, 765.38it/s, epoch=70, loss=[0.26254471416274705, 0.244

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192144/480000 [04:15<06:21, 754.03it/s, epoch=80, loss=[0.2621819250533979, 0.2345

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216097/480000 [04:47<05:26, 808.39it/s, epoch=90, loss=[0.2618553873648245, 0.2343

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240137/480000 [05:19<05:06, 782.96it/s, epoch=100, loss=[0.26170006344715774, 0.23

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264127/480000 [05:51<05:10, 695.82it/s, epoch=110, loss=[0.2615731998905536, 0.239

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288102/480000 [06:27<04:15, 751.12it/s, epoch=120, loss=[0.26146562165270254, 0.23

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312110/480000 [07:01<04:17, 651.08it/s, epoch=130, loss=[0.2613659557575977, 0.233

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336080/480000 [07:38<03:27, 695.11it/s, epoch=140, loss=[0.2612810571988426, 0.234

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360131/480000 [08:13<02:59, 666.40it/s, epoch=150, loss=[0.2612048596764606, 0.233

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384149/480000 [08:49<02:09, 741.22it/s, epoch=160, loss=[0.2611367277180154, 0.242

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408083/480000 [09:22<01:49, 658.40it/s, epoch=170, loss=[0.2610815660407145, 0.233

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432097/480000 [09:57<01:14, 645.27it/s, epoch=180, loss=[0.2610330693796276, 0.234

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456069/480000 [10:31<00:33, 705.79it/s, epoch=190, loss=[0.2609900898113847, 0.233

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [11:09<00:00, 717.27it/s, epoch=200, loss=[0.26095190111547717, 0.23


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24090/480000 [00:38<10:35, 717.54it/s, epoch=10, loss=[0.27020601443946374, 0.2330

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48105/480000 [01:12<10:31, 683.90it/s, epoch=20, loss=[0.26629322999467464, 0.2281

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72117/480000 [01:47<09:54, 686.54it/s, epoch=30, loss=[0.264578169174492, 0.227145

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96121/480000 [02:22<08:48, 726.02it/s, epoch=40, loss=[0.2637188201025127, 0.22652

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120053/480000 [02:58<09:32, 628.20it/s, epoch=50, loss=[0.26295784983783954, 0.228

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144103/480000 [03:33<09:29, 589.80it/s, epoch=60, loss=[0.26239996482928585, 0.227

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168117/480000 [04:08<07:22, 705.47it/s, epoch=70, loss=[0.26208010622610606, 0.225

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192087/480000 [04:45<07:08, 671.52it/s, epoch=80, loss=[0.26182542927563185, 0.229

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216085/480000 [05:21<06:21, 692.25it/s, epoch=90, loss=[0.26163175442566494, 0.225

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240079/480000 [05:56<05:47, 689.74it/s, epoch=100, loss=[0.2614751886824766, 0.232

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264068/480000 [06:30<05:40, 633.46it/s, epoch=110, loss=[0.26132206230113914, 0.22

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288127/480000 [07:05<04:43, 676.23it/s, epoch=120, loss=[0.2611729248240587, 0.225

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312105/480000 [07:40<03:59, 700.89it/s, epoch=130, loss=[0.2610750535751384, 0.214

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336060/480000 [08:15<04:11, 572.76it/s, epoch=140, loss=[0.2609923766305048, 0.213

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360079/480000 [08:50<02:49, 706.34it/s, epoch=150, loss=[0.2609208938976129, 0.213

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384077/480000 [09:25<02:12, 723.33it/s, epoch=160, loss=[0.2608582346762218, 0.213

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408067/480000 [10:01<01:35, 755.00it/s, epoch=170, loss=[0.26080056409041036, 0.21

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432132/480000 [10:36<01:04, 742.38it/s, epoch=180, loss=[0.2607499104241528, 0.211

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456123/480000 [11:10<00:33, 711.09it/s, epoch=190, loss=[0.26070443862428283, 0.21

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [11:48<00:00, 677.31it/s, epoch=200, loss=[0.2606645544866724, 0.210


Initialization 1 / 2
Complete

Initialization 2 / 2
Complete



Training:   5%| | 24100/480000 [00:38<11:52, 640.26it/s, epoch=10, loss=[0.26980146636565505, 0.2389

lr decreased to  2.4
lr decreased to  2.4


Training:  10%| | 48121/480000 [01:14<10:33, 681.74it/s, epoch=20, loss=[0.2658895528813201, 0.23749

lr decreased to  2.3
lr decreased to  2.3


Training:  15%|▏| 72077/480000 [01:49<09:06, 746.70it/s, epoch=30, loss=[0.26450082338104647, 0.2341

lr decreased to  2.1999999999999997
lr decreased to  2.1999999999999997


Training:  20%|▏| 96091/480000 [02:24<08:55, 717.15it/s, epoch=40, loss=[0.263695835135877, 0.233899

lr decreased to  2.0999999999999996
lr decreased to  2.0999999999999996


Training:  25%|▎| 120093/480000 [02:59<08:18, 721.94it/s, epoch=50, loss=[0.26319366617749135, 0.230

lr decreased to  1.9999999999999996
lr decreased to  1.9999999999999996


Training:  30%|▎| 144145/480000 [03:33<07:37, 733.48it/s, epoch=60, loss=[0.26276788475612767, 0.214

lr decreased to  1.8999999999999995
lr decreased to  1.8999999999999995


Training:  35%|▎| 168117/480000 [04:06<07:50, 662.23it/s, epoch=70, loss=[0.26244727114836397, 0.211

lr decreased to  1.7999999999999994
lr decreased to  1.7999999999999994


Training:  40%|▍| 192142/480000 [04:39<06:35, 727.63it/s, epoch=80, loss=[0.26222173523157816, 0.211

lr decreased to  1.6999999999999993
lr decreased to  1.6999999999999993


Training:  45%|▍| 216127/480000 [05:14<06:15, 703.30it/s, epoch=90, loss=[0.26194669820368327, 0.210

lr decreased to  1.5999999999999992
lr decreased to  1.5999999999999992


Training:  50%|▌| 240096/480000 [05:49<06:00, 665.48it/s, epoch=100, loss=[0.2617562402163943, 0.210

lr decreased to  1.4999999999999991
lr decreased to  1.4999999999999991


Training:  55%|▌| 264088/480000 [06:23<04:43, 761.12it/s, epoch=110, loss=[0.261624083357553, 0.2113

lr decreased to  1.399999999999999
lr decreased to  1.399999999999999


Training:  60%|▌| 288136/480000 [06:57<04:20, 735.58it/s, epoch=120, loss=[0.26148129843175416, 0.21

lr decreased to  1.299999999999999
lr decreased to  1.299999999999999


Training:  65%|▋| 312078/480000 [07:32<03:54, 717.19it/s, epoch=130, loss=[0.26136433947831406, 0.20

lr decreased to  1.1999999999999988
lr decreased to  1.1999999999999988


Training:  70%|▋| 336075/480000 [08:06<03:42, 646.09it/s, epoch=140, loss=[0.26127520954857253, 0.21

lr decreased to  1.0999999999999988
lr decreased to  1.0999999999999988


Training:  75%|▊| 360067/480000 [08:42<03:04, 651.45it/s, epoch=150, loss=[0.2612026208018268, 0.207

lr decreased to  0.9999999999999988
lr decreased to  0.9999999999999988


Training:  80%|▊| 384103/480000 [09:18<02:29, 640.12it/s, epoch=160, loss=[0.2611395369221763, 0.207

lr decreased to  0.8999999999999988
lr decreased to  0.8999999999999988


Training:  85%|▊| 408075/480000 [09:53<01:45, 684.40it/s, epoch=170, loss=[0.2610837454100449, 0.207

lr decreased to  0.7999999999999988
lr decreased to  0.7999999999999988


Training:  90%|▉| 432076/480000 [10:28<01:09, 694.09it/s, epoch=180, loss=[0.2610342938080434, 0.208

lr decreased to  0.6999999999999988
lr decreased to  0.6999999999999988


Training:  95%|▉| 456077/480000 [11:03<00:36, 662.02it/s, epoch=190, loss=[0.26098817143589204, 0.20

lr decreased to  0.5999999999999989
lr decreased to  0.5999999999999989


Training: 100%|█| 480000/480000 [11:37<00:00, 687.72it/s, epoch=200, loss=[0.26094933015604793, 0.20


In [11]:
print(test_acc)
print(np.max(test_acc))

[[[array([0.1157, 0.1119])], [array([0.9093, 0.9514])]], [[array([0.0879, 0.0695])], [array([0.9103, 0.9477])]], [[array([0.0735, 0.1125])], [array([0.9111, 0.9507])]], [[array([0.1146, 0.1355])], [array([0.9105, 0.9428])]], [[array([0.0973, 0.1051])], [array([0.9103, 0.9267])]]]
0.9514


In [12]:
print(train_acc)
print(np.max(train_acc))

[[[array([0.11608333, 0.10806667])], [array([0.9114    , 0.96366667])]], [[array([0.0875    , 0.06763333])], [array([0.91121667, 0.96086667])]], [[array([0.0707    , 0.11006667])], [array([0.91185   , 0.96133333])]], [[array([0.11181667, 0.13035   ])], [array([0.91201667, 0.95295   ])]], [[array([0.09515, 0.10715])], [array([0.91195   , 0.93618333])]]]
0.9636666666666667


In [13]:
np.save("./../new_data/ablation_acos_mnist_train_acc.npy", np.array(train_acc))
np.save("./../new_data/ablation_acos_mnist_test_acc.npy", np.array(test_acc))
# np.save("./../new_data/w_mnist_bestconfig_cos.npy", np.array(w_layers))
# np.save("./../new_data/loss_mnist_bestconfig_cos.npy", np.array(loss))