## Import libraries

In [16]:
# !pip install tensorly

import tensorly as tl
from tensorly.decomposition import parafac
from tensorly.decomposition import tucker
from tensorly.decomposition import tensor_train
from math import ceil
from tensorly import tt_to_tensor
from tensorly.decomposition import matrix_product_state
import torch.nn.init as init

In [2]:
"""
5 runs of 50 epochs, seed = 10, 20, 30, 40, 50;
validation accuracies: 0.9492, 0.9457, 0.9463, 0.9439, 0.9455
"""
#from __future__ import print_function, division
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
import os
import copy
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import balanced_accuracy_score

print("PyTorch Version:", torch.__version__)
print("Torchvision Version:", torchvision.__version__)
print("GPU is available?", torch.cuda.is_available())

dtype = torch.float64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

PyTorch Version: 1.13.1
Torchvision Version: 0.14.1
GPU is available? True


## train data processing

In [3]:
df_train_1 = pd.read_csv('/home/c/cl237/Datasets/Flare_LSTM_dataset/M/normalized_training.csv')
df_train_2 = pd.read_csv('/home/c/cl237/Datasets/Flare_LSTM_dataset/M/normalized_validation.csv')
# concat train and validation into train. we don't need validation set
df_train = df_train_1.append(df_train_2).reset_index(drop=True)

# undersampling for "negative" samples
negative_df = df_train[df_train['label'] == 'Negative']
positive_df = df_train[df_train['label'] == 'Positive']

# Get number of "positive" samples
num_positive = len(positive_df)

# take the same number of "negative" samples as there are "positive" samples
balanced_negative_df = negative_df.sample(n=num_positive, random_state=10086)

df_train_balanced = pd.concat([positive_df, balanced_negative_df])
print(df_train_balanced['label'].value_counts())

Positive    4057
Negative    4057
Name: label, dtype: int64


In [4]:
pd_X_train = df_train_balanced.iloc[:, 5:]
pd_y_train = df_train_balanced.iloc[:, 0]

N = len(pd_X_train)
K = 2

pd_X_train = pd_X_train.values
X_train = torch.tensor(pd_X_train, dtype=dtype, device=device)
X_train = torch.t(X_train)

# scaler = MinMaxScaler()
# x = pd_X_train.values
# x_scaled = scaler.fit_transform(x)
# X_train = torch.tensor(x_scaled, dtype=dtype, device=device)
# X_train = torch.t(X_train)

# Initialize the LabelEncoder
encoder = LabelEncoder()
# Fit and transform the y values
y = encoder.fit_transform(pd_y_train.values)
y_train=torch.tensor(y, dtype=torch.long, device=device)
y_train = torch.flatten(y_train)

y_one_hot = torch.zeros(N, K, device=device).scatter_(1, y_train.unsqueeze(1), 1)
y_one_hot = torch.t(y_one_hot).to(device=device)

print(list(encoder.classes_))

['Negative', 'Positive']


## test data processing

In [5]:
df_test = pd.read_csv('/home/c/cl237/Datasets/Flare_LSTM_dataset/M/normalized_testing.csv')

# Undersampling for "negative" samples
#negative_df = df_test[df_test['label'] == 'Negative']
#positive_df = df_test[df_test['label'] == 'Positive']

# Get number of "positive" samples
#num_positive = len(positive_df)

# Take the same number of "negative" samples as there are "positive" samples
#balanced_negative_df = negative_df.sample(n=num_positive, random_state=10086)

#df_test_balanced = pd.concat([positive_df, balanced_negative_df])
#print(df_test_balanced['label'].value_counts())

In [6]:
pd_X_test = df_test.iloc[:, 5:]
pd_y_test = df_test.iloc[:, 0]

N_test = len(pd_X_test)
K = 2

pd_X_test = pd_X_test.values
X_test = torch.tensor(pd_X_test, dtype=dtype, device=device)
X_test = torch.t(X_test)
# x = pd_X_test.values
# x_scaled = scaler.transform(x)  # only transform x, don't fit the scaler again
# X_test = torch.tensor(x_scaled, dtype=dtype, device=device)
# X_test = torch.t(X_test)


# Initialize the LabelEncoder
encoder = LabelEncoder()
# Fit and transform the y values
y = encoder.fit_transform(pd_y_test.values)
y_test=torch.tensor(y, dtype=torch.long, device=device)
y_test = torch.flatten(y_test)

y_test_one_hot = torch.zeros(N_test, K, device=device).scatter_(1, y_test.unsqueeze(1), 1)
y_test_one_hot = torch.t(y_test_one_hot).to(device=device)

In [7]:
print(df_test['label'].value_counts())

Negative    43411
Positive     1278
Name: label, dtype: int64


## Main algorithm

### Define functions for updating blocks

In [8]:
# def updateMask(W, sparsity):
#     torch.dist(V1,nn.ReLU()(U1),2),2).cpu().numpy()
#     Mask
#     Wsparse
#     return Mask, Wsparse

# Wsquare = torch.square(W3)
# Threshold = torch.quantile(torch.reshape(Wsquare, (-1,)), 0.5, interpolation='linear')
# Wsparse = W3
# Wsparse[Wsquare < Threshold] =  0

# Wsparse

In [9]:
def updateV(U1,U2,W,b,rho,gamma):
    _, d = W.size()
    I = torch.eye(d, device=device)
    U1 = nn.ReLU()(U1)
    _, col_U2 = U2.size()
    Vstar = torch.mm(torch.inverse(rho*(torch.mm(torch.t(W),W))+gamma*I), rho*torch.mm(torch.t(W),U2-b.repeat(1,col_U2))+gamma*U1)
    return Vstar

In [10]:
def updateWb_org(U, V, W, b, alpha, rho):
    d,N = V.size()
    I = torch.eye(d, device=device)
    _, col_U = U.size()
    Wstar = torch.mm(alpha*W+rho*torch.mm(U-b.repeat(1,col_U),torch.t(V)),torch.inverse(alpha*I+rho*(torch.mm(V,torch.t(V)))))
    bstar = 0*(alpha*b+rho*torch.sum(U-torch.mm(Wstar,V), dim=1).reshape(b.size()))/(rho*N+alpha)
    return Wstar, bstar

In [11]:
def updateWb(U, V, W, b, W_tensor_rec, alpha, rho,tau):
    W_tensor_rec = torch.as_tensor(W_tensor_rec,device=device).float()
    W_tensor2matrix = W_tensor_rec.reshape(W.shape)
    d,N = V.size()
    I = torch.eye(d, device=device)
    _, col_U = U.size()
    Wstar = torch.mm(alpha*W+tau*W_tensor2matrix+rho*torch.mm(U-b.repeat(1,col_U),torch.t(V)),torch.inverse((alpha+tau)*I+rho*(torch.mm(V,torch.t(V)))))
    bstar = 0*(alpha*b+rho*torch.sum(U-torch.mm(Wstar,V), dim=1).reshape(b.size()))/(rho*N+alpha)
    return Wstar, bstar

In [12]:
def updateWbsparse(U, V, W, b,  W_tensor2matrix, alpha, rho,tau):
    d,N = V.size()
    I = torch.eye(d, device=device)
    _, col_U = U.size()
    Wstar = torch.mm(alpha*W+tau*W_tensor2matrix+rho*torch.mm(U-b.repeat(1,col_U),torch.t(V)),torch.inverse((alpha+tau)*I+rho*(torch.mm(V,torch.t(V)))))
    bstar = 0*(alpha*b+rho*torch.sum(U-torch.mm(Wstar,V), dim=1).reshape(b.size()))/(rho*N+alpha)
    return Wstar, bstar

### Define the proximal operator of the ReLU activation function

In [13]:
def relu_prox(a, b, gamma, d, N):
    val = torch.empty(d,N, device=device)
    x = (a+gamma*b)/(1+gamma)
    y = torch.min(b,torch.zeros(d,N, device=device))

    val = torch.where(a+gamma*b < 0, y, torch.zeros(d,N, device=device))
    val = torch.where(((a+gamma*b >= 0) & (b >=0)) | ((a*(gamma-np.sqrt(gamma*(gamma+1))) <= gamma*b) & (b < 0)), x, val)
    val = torch.where((-a <= gamma*b) & (gamma*b <= a*(gamma-np.sqrt(gamma*(gamma+1)))), b, val)
    return val

### Effective Sparsity

In [14]:
def process_weights(W1, W2, W3):
    # Clone the tensors to keep the originals unchanged
    W1, W2, W3 = W1.clone(), W2.clone(), W3.clone()

    while True:
        # Store a copy of the current weights
        old_W1, old_W2, old_W3 = W1.clone(), W2.clone(), W3.clone()

        # Check all rows of W1
        # If all values in a row of W1 are 0, set corresponding column in W2 to 0
        zero_rows_W1 = torch.all(W1 == 0, dim=1)
        W2[:, zero_rows_W1] = 0

        # Check all columns of W2
        # If all values in a column of W2 are 0, set corresponding row in W1 to 0
        zero_cols_W2 = torch.all(W2 == 0, dim=0)
        W1[zero_cols_W2, :] = 0

        # Check all rows of W2
        # If all values in a row of W2 are 0, set corresponding column in W3 to 0
        zero_rows_W2 = torch.all(W2 == 0, dim=1)
        W3[:, zero_rows_W2] = 0

        # Check if matrices are unchanged
        if torch.equal(W1, old_W1) and torch.equal(W2, old_W2) and torch.equal(W3, old_W3):
            break

    return W1, W2, W3

### Parameter initialization

In [24]:
#df = pd.DataFrame()
#df.to_csv('/home/c/cl237/TenBCD/Sparse/LeNet300_100/Flare(BCD method)/differernt gamma rho/result.csv')

niter = 1000
sparsity = 0.9

tau = 200
alpha = 1
rho = 1
gamma = 100

        

# Do something with these parameters
print(sparsity, tau, gamma, rho, alpha)

loss1 = np.empty(niter)
loss2 = np.empty(niter)
accuracy_train = np.empty(niter)
accuracy_test = np.empty(niter)
time1 = np.empty(niter)
bacc_train = np.empty(niter)
bacc_test = np.empty(niter)
true_sparsity = np.empty(niter)
effective_sparsity = np.empty(niter)

results = torch.zeros(1, 9, niter)

for Out_iter in range(1):
    rank_initial = 700
    seed = 10 + 10*Out_iter
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.manual_seed(seed)

    d0 = 40
    d1 =  300
    d2 =  100
    d3 = K # Layers: input + 2 hidden + output

    W1 = 0.25*init.kaiming_normal_(torch.empty(d1, d0, device=device, dtype=dtype),a=0, mode='fan_in', nonlinearity='leaky_relu')
    b1 = 0*torch.ones(d1, 1, dtype=torch.float64,device=device)
    W1square = torch.square(W1)
    Threshold = torch.quantile(torch.reshape(W1square, (-1,)), sparsity, interpolation='linear')
    W1sparse = W1
    W1sparse[W1square < Threshold] =  0

    W2 = 0.25*init.kaiming_normal_(torch.empty(d2, d1, device=device, dtype=dtype),a=0, mode='fan_in', nonlinearity='leaky_relu')
    b2 = 0*torch.ones(d2, 1, dtype=torch.float64,device=device)
    W2square = torch.square(W2)
    Threshold = torch.quantile(torch.reshape(W2square, (-1,)), sparsity, interpolation='linear')
    W2sparse = W2
    W2sparse[W2square < Threshold] =  0


    W3 = 0.25*init.kaiming_normal_(torch.empty(d3, d2, device=device, dtype=dtype),a=0, mode='fan_in', nonlinearity='leaky_relu')
    b3 = 0*torch.ones(d3, 1, dtype=torch.float64,device=device)
    # W3square = torch.square(W3)
    # Threshold = torch.quantile(torch.reshape(W3square, (-1,)), sparsity, interpolation='linear')
    # W3sparse = W3
    # W3sparse[W3square < Threshold] =  0


    U1 = torch.addmm(b1.repeat(1, N), W1, X_train)
    V1 = nn.ReLU()(U1)
    U2 = torch.addmm(b2.repeat(1, N), W2, V1)
    V2 = nn.ReLU()(U2)
    U3 = torch.addmm(b3.repeat(1, N), W3, V2)
    V3 = U3
    # U4 = torch.addmm(b4.repeat(1, N), W4, V3)
    # V4 = U4


    print('Train on', N, 'samples, validate on', N_test, 'samples')

    for k in range(niter):

        start = time.time()

        # update V4
        V3 = (y_one_hot + gamma*U3 + alpha*V3)/(1 + gamma + alpha)

        # update U4
        U3 = (gamma*V3 + rho*(torch.mm(W3,V2) + b3.repeat(1,N)))/(gamma + rho)

        # update W4 and b4
        W3, b3 = updateWb_org(U3,V2,W3,b3, alpha,rho)
        # W3square = torch.square(W3)
        # Threshold = torch.quantile(torch.reshape(W3square, (-1,)), sparsity, interpolation='linear')
        # W3sparse = W3
        # W3sparse[W3square < Threshold] =  0
        # update V2
        V2 = updateV(U2,U3,W3,b3,rho,gamma)

        # update U2
        U2 = relu_prox(V2,(rho*torch.addmm(b2.repeat(1,N), W2, V1) + alpha*U2)/(rho + alpha),(rho + alpha)/gamma,d2,N)

        # update W2 and b2

        W2, b2 = updateWbsparse(U2,V1,W2,b2,W2sparse, alpha,rho,tau)
        W2square = torch.square(W2)
        Threshold = torch.quantile(torch.reshape(W2square, (-1,)), sparsity, interpolation='linear')
        W2sparse = W2
        W2sparse[W2square < Threshold] =  0

        # update V1
        V1 = updateV(U1,U2,W2,b2,rho,gamma)

        # update U1
        U1 = relu_prox(V1,(rho*torch.addmm(b1.repeat(1,N), W1, X_train) + alpha*U1)/(rho + alpha),(rho + alpha)/gamma,d1,N)

        # update W1 and b1
        W1, b1 = updateWbsparse(U1,X_train,W1,b1,W1sparse, alpha,rho,tau)
        W1square = torch.square(W1)
        Threshold = torch.quantile(torch.reshape(W1square, (-1,)), sparsity, interpolation='linear')
        W1sparse = W1
        W1sparse[W1square < Threshold] =  0

        a1_train = nn.ReLU()(torch.addmm(b1.repeat(1, N), W1sparse, X_train)).double()
        a2_train = nn.ReLU()(torch.addmm(b2.repeat(1, N),  W2sparse, a1_train)).double()
        pred = torch.argmax(torch.addmm(b3.repeat(1, N), W3, a2_train), dim=0)

        a1_test = nn.ReLU()(torch.addmm(b1.repeat(1, N_test), W1sparse, X_test)).double()
        a2_test = nn.ReLU()(torch.addmm(b2.repeat(1, N_test),  W2sparse, a1_test)).double()
        pred_test = torch.argmax(torch.addmm(b3.repeat(1, N_test), W3, a2_test), dim=0)


        loss1[k] = gamma/2*torch.pow(torch.dist(V3,y_one_hot,2),2).cpu().numpy()
        loss2[k] = loss1[k] + rho/2*torch.pow(torch.dist(torch.addmm(b1.repeat(1,N), W1sparse, X_train),U1,2),2).cpu().numpy() \
        +rho/2*torch.pow(torch.dist(torch.addmm(b2.repeat(1,N),  W2sparse, V1),U2,2),2).cpu().numpy() \
        +rho/2*torch.pow(torch.dist(torch.addmm(b3.repeat(1,N), W3, V2),U3,2),2).cpu().numpy() \
        + gamma/2*torch.pow(torch.dist(V1,nn.ReLU()(U1),2),2).cpu().numpy() \
        + gamma/2*torch.pow(torch.dist(V2,nn.ReLU()(U2),2),2).cpu().numpy() \
        + gamma/2*torch.pow(torch.dist(V3,U3,2),2).cpu().numpy() \
        + tau/2*torch.pow(torch.dist(W1,W1sparse,2),2).cpu().numpy()\
        + tau/2*torch.pow(torch.dist(W2,W2sparse,2),2).cpu().numpy()\
        # + tau/2*torch.pow(torch.dist(W3,W3sparse,2),2).cpu().numpy()

        # compute training accuracy
        correct_train = pred == y_train
        accuracy_train[k] = np.mean(correct_train.cpu().numpy())

        # compute validation accuracy
        correct_test = pred_test == y_test
        accuracy_test[k] = np.mean(correct_test.cpu().numpy())


        pred_train_np = pred.cpu().numpy()
        y_train_np = y_train.cpu().numpy()
        pred_test_np = pred_test.cpu().numpy()
        y_test_np = y_test.cpu().numpy()

        bacc_train[k] = balanced_accuracy_score(y_train_np, pred_train_np)
        bacc_test[k] = balanced_accuracy_score(y_test_np, pred_test_np)

        # compute training time
        stop = time.time()
        duration = stop - start
        time1[k] = duration

        # sparsity
        num_zeros_W1 = torch.sum(W1 == 0).item()
        num_zeros_W2 = torch.sum(W2 == 0).item()
        num_zeros_W3 = torch.sum(W3 == 0).item()
        total_zeros_old = num_zeros_W1 + num_zeros_W2 + num_zeros_W3
        total_weights = d0*d1+d1*d2+d2*d3
        true_sparsity[k] = total_zeros_old / total_weights

        new_W1, new_W2, new_W3 = process_weights(W1, W2, W3)
        num_zeros_W1_new = torch.sum(new_W1 == 0).item()
        num_zeros_W2_new = torch.sum(new_W2 == 0).item()
        num_zeros_W3_new = torch.sum(new_W3 == 0).item()
        total_zeros = num_zeros_W1_new + num_zeros_W2_new + num_zeros_W3_new
        total_weights = d0*d1+d1*d2+d2*d3
        effective_sparsity[k] = total_zeros / total_weights


        # print results
        print('Epoch', k + 1, '/', niter, '\n',
              '-', 'time:', time1[k], '-', 'sq_loss:', loss1[k], '-', 'tot_loss:', loss2[k],
              '-', 'acc:', accuracy_train[k], '-', 'val_acc:', accuracy_test[k],
              '-', 'bacc_train:', bacc_train[k], '-', 'bacc_test:', bacc_test[k],'-', 'true_sparsity:', true_sparsity[k],
             '-', 'effective_sparsity:', effective_sparsity[k])


    results[Out_iter,0,:] = torch.tensor(loss1)
    results[Out_iter,1,:] = torch.tensor(loss2)
    results[Out_iter,2,:] = torch.tensor(accuracy_train)
    results[Out_iter,3,:] = torch.tensor(accuracy_test)
    results[Out_iter,4,:] = torch.tensor(time1)
    results[Out_iter,5,:] = torch.tensor(bacc_train)
    results[Out_iter,6,:] = torch.tensor(bacc_test)
    results[Out_iter,7,:] = torch.tensor(true_sparsity)
    results[Out_iter,8,:] = torch.tensor(effective_sparsity)

#             df=pd.read_csv('/home/c/cl237/TenBCD/Sparse/LeNet300_100/Flare(BCD method)/differernt gamma rho/result.csv')
#             new_row = {
#                         'tau': tau, 
#                         'gamma': gamma, 
#                         'rho': rho, 
#                         'alpha': alpha,
#                         'loss1': loss1[niter-1], 
#                         'loss2': loss2[niter-1], 
#                         'accuracy_train': accuracy_train[niter-1],
#                         'accuracy_test': accuracy_test[niter-1], 
#                         'max_accuracy_train': max(accuracy_train),
#                         'max_accuracy_test': max(accuracy_test),
#                         'time': time1[niter-1], 
#                         'BACC_train': bacc_train[niter-1],
#                         'BACC_test': bacc_test[niter-1],
#                         'max_BACC_train': max(bacc_train),
#                         'max_BACC_test': max(bacc_test),
#                         'Sparsity': sparsity,
#                         'seed' : seed
#                     }

#             df=df.append(new_row,ignore_index=True)
#             df.to_csv('/home/c/cl237/TenBCD/Sparse/LeNet300_100/Flare(BCD method)/differernt gamma rho/result.csv',index=False)

filename="Kaimingnormal_" + "niter_"+ str(niter) + "Sparsity_" + str(sparsity) +"tau_" + str(tau) + "gamma_" + str(gamma) + \
                "rho_" + str(rho) + "alpha_" + str(alpha)+ ".mat"
from scipy.io import savemat
%cd /home/c/cl237/TenBCD/Sparse/LeNet300_100/Flare(BCD method)/different init/
savemat (filename, {'results': torch.Tensor.numpy(results)})

0.9 200 100 1 1
Train on 8114 samples, validate on 44689 samples
Epoch 1 / 1000 
 - time: 0.029567241668701172 - sq_loss: 399452.0118585444 - tot_loss: 399452.1151005697 - acc: 0.5391915208281982 - val_acc: 0.8883617892546264 - bacc_train: 0.5391915208281982 - bacc_test: 0.5483898144420029 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 2 / 1000 
 - time: 0.028216123580932617 - sq_loss: 391734.0451209586 - tot_loss: 391734.4277132784 - acc: 0.6078383041656397 - val_acc: 0.9028172480923717 - bacc_train: 0.6078383041656397 - bacc_test: 0.5930427368008419 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 3 / 1000 
 - time: 0.02836322784423828 - sq_loss: 384187.2621532559 - tot_loss: 384188.0817737659 - acc: 0.6747596746364308 - val_acc: 0.8995502248875562 - bacc_train: 0.6747596746364308 - bacc_test: 0.6452811607538083 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 4 / 1000 
 - time: 0.

Epoch 31 / 1000 
 - time: 0.02673506736755371 - sq_loss: 223556.6673686985 - tot_loss: 223557.54073656417 - acc: 0.8595020951441952 - val_acc: 0.8861017252567746 - bacc_train: 0.8595020951441952 - bacc_test: 0.8430272174872995 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 32 / 1000 
 - time: 0.02673196792602539 - sq_loss: 219246.85629199425 - tot_loss: 219247.71827458558 - acc: 0.8592556075918166 - val_acc: 0.8860569715142429 - bacc_train: 0.8592556075918165 - bacc_test: 0.8433839003398351 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 33 / 1000 
 - time: 0.02682948112487793 - sq_loss: 215019.75933286813 - tot_loss: 215020.61194950185 - acc: 0.859378851368006 - val_acc: 0.8859674640291795 - bacc_train: 0.8593788513680058 - bacc_test: 0.8444769845335711 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 34 / 1000 
 - time: 0.026870250701904297 - sq_loss: 210873.8248725808 - tot_loss:

Epoch 63 / 1000 
 - time: 0.0269162654876709 - sq_loss: 119858.34761160554 - tot_loss: 119859.29532814545 - acc: 0.8646783337441459 - val_acc: 0.8863478708406991 - bacc_train: 0.8646783337441459 - bacc_test: 0.8659370228058926 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 64 / 1000 
 - time: 0.02673196792602539 - sq_loss: 117545.42779183548 - tot_loss: 117546.37995646511 - acc: 0.8648015775203353 - val_acc: 0.8863031170981673 - bacc_train: 0.8648015775203353 - bacc_test: 0.8670531426357577 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 65 / 1000 
 - time: 0.026798009872436523 - sq_loss: 115277.12579179538 - tot_loss: 115278.0823117609 - acc: 0.8648015775203353 - val_acc: 0.8862359864843697 - bacc_train: 0.8648015775203353 - bacc_test: 0.867018589181564 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 66 / 1000 
 - time: 0.026707172393798828 - sq_loss: 113052.5819115699 - tot_loss:

Epoch 95 / 1000 
 - time: 0.026767253875732422 - sq_loss: 64246.82870083919 - tot_loss: 64247.881453217844 - acc: 0.8644318461917673 - val_acc: 0.8866611470384211 - bacc_train: 0.8644318461917673 - bacc_test: 0.8714143311001024 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 96 / 1000 
 - time: 0.027025461196899414 - sq_loss: 63006.95812374341 - tot_loss: 63008.01330276155 - acc: 0.864308602415578 - val_acc: 0.8866163932958894 - bacc_train: 0.8643086024155779 - bacc_test: 0.8713912954639732 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 97 / 1000 
 - time: 0.02695012092590332 - sq_loss: 61791.016135717175 - tot_loss: 61792.073747438844 - acc: 0.8645550899679566 - val_acc: 0.8866387701671552 - bacc_train: 0.8645550899679566 - bacc_test: 0.8714028132820377 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 98 / 1000 
 - time: 0.026760578155517578 - sq_loss: 60598.54120009312 - tot_loss:

Epoch 127 / 1000 
 - time: 0.026737689971923828 - sq_loss: 34438.0264551282 - tot_loss: 34439.5157079375 - acc: 0.8650480650727138 - val_acc: 0.886840162008548 - bacc_train: 0.8650480650727138 - bacc_test: 0.8749239400426012 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 128 / 1000 
 - time: 0.026764631271362305 - sq_loss: 33773.47740726781 - tot_loss: 33775.04959917994 - acc: 0.8650480650727138 - val_acc: 0.8867730313947504 - bacc_train: 0.8650480650727138 - bacc_test: 0.8748893865884075 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 129 / 1000 
 - time: 0.02679157257080078 - sq_loss: 33121.754811298786 - tot_loss: 33123.427855662514 - acc: 0.8651713088489031 - val_acc: 0.8868177851372822 - bacc_train: 0.865171308848903 - bacc_test: 0.8749124222245366 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 130 / 1000 
 - time: 0.026979684829711914 - sq_loss: 32482.611220157713 - tot_loss

Epoch 159 / 1000 
 - time: 0.026841402053833008 - sq_loss: 18461.247158210008 - tot_loss: 18470.064698621063 - acc: 0.8652945526250925 - val_acc: 0.8879590055718409 - bacc_train: 0.8652945526250924 - bacc_test: 0.8789172973438109 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 160 / 1000 
 - time: 0.026898622512817383 - sq_loss: 18105.06664216469 - tot_loss: 18113.553978924807 - acc: 0.8654177964012817 - val_acc: 0.8878694980867775 - bacc_train: 0.8654177964012817 - bacc_test: 0.8792509445602175 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 161 / 1000 
 - time: 0.026715993881225586 - sq_loss: 17755.76051366337 - tot_loss: 17763.984335195964 - acc: 0.8652945526250925 - val_acc: 0.8877799906017141 - bacc_train: 0.8652945526250924 - bacc_test: 0.8792048732879593 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 162 / 1000 
 - time: 0.026759862899780273 - sq_loss: 17413.196152373057 - t

Epoch 191 / 1000 
 - time: 0.026831388473510742 - sq_loss: 9897.985941537212 - tot_loss: 9907.522664939333 - acc: 0.8645550899679566 - val_acc: 0.8873324531763969 - bacc_train: 0.8645550899679566 - bacc_test: 0.882012264835986 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 192 / 1000 
 - time: 0.02673506736755371 - sq_loss: 9707.073871353037 - tot_loss: 9717.519487024605 - acc: 0.8646783337441459 - val_acc: 0.8872205688200676 - bacc_train: 0.8646783337441459 - bacc_test: 0.8819546757456633 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 193 / 1000 
 - time: 0.026723146438598633 - sq_loss: 9519.845977000883 - tot_loss: 9553.172016477172 - acc: 0.8641853586393887 - val_acc: 0.8871758150775358 - bacc_train: 0.8641853586393886 - bacc_test: 0.8815519216208696 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 194 / 1000 
 - time: 0.026761293411254883 - sq_loss: 9336.231323248827 - tot_loss

Epoch 222 / 1000 
 - time: 0.030335187911987305 - sq_loss: 5412.262827328432 - tot_loss: 5419.023000938687 - acc: 0.8641853586393887 - val_acc: 0.8870863075924724 - bacc_train: 0.8641853586393887 - bacc_test: 0.8841638797692644 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 223 / 1000 
 - time: 0.028178930282592773 - sq_loss: 5307.910224815365 - tot_loss: 5314.587020516791 - acc: 0.864308602415578 - val_acc: 0.8871758150775358 - bacc_train: 0.8643086024155779 - bacc_test: 0.8842099510415226 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 224 / 1000 
 - time: 0.028078556060791016 - sq_loss: 5205.571001985343 - tot_loss: 5212.169497244133 - acc: 0.864308602415578 - val_acc: 0.8871310613350042 - bacc_train: 0.8643086024155779 - bacc_test: 0.8841869154053934 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 225 / 1000 
 - time: 0.02816319465637207 - sq_loss: 5105.206349401141 - tot_loss:

Epoch 252 / 1000 
 - time: 0.02681565284729004 - sq_loss: 3018.3833647951938 - tot_loss: 3025.544591894997 - acc: 0.8646783337441459 - val_acc: 0.8871981919488017 - bacc_train: 0.8646783337441459 - bacc_test: 0.8838417503709224 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 253 / 1000 
 - time: 0.026754140853881836 - sq_loss: 2960.2130995974962 - tot_loss: 2967.589545617326 - acc: 0.8646783337441459 - val_acc: 0.8871758150775358 - bacc_train: 0.8646783337441459 - bacc_test: 0.8838302325528579 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 254 / 1000 
 - time: 0.026673555374145508 - sq_loss: 2903.1647842805287 - tot_loss: 2910.7993194907926 - acc: 0.8648015775203353 - val_acc: 0.8872429456913334 - bacc_train: 0.8648015775203353 - bacc_test: 0.8838647860070515 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 255 / 1000 
 - time: 0.02670454978942871 - sq_loss: 2847.216827486584 - tot_

Epoch 284 / 1000 
 - time: 0.026742219924926758 - sq_loss: 1619.6060777510713 - tot_loss: 1637.3971972981785 - acc: 0.8644318461917673 - val_acc: 0.8871310613350042 - bacc_train: 0.8644318461917673 - bacc_test: 0.8845666338940582 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 285 / 1000 
 - time: 0.02672433853149414 - sq_loss: 1588.4131597203764 - tot_loss: 1607.490303540938 - acc: 0.8644318461917673 - val_acc: 0.8872205688200676 - bacc_train: 0.8644318461917673 - bacc_test: 0.8846127051663164 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 286 / 1000 
 - time: 0.026674270629882812 - sq_loss: 1557.8217754773373 - tot_loss: 1578.4308110158877 - acc: 0.8645550899679566 - val_acc: 0.8872653225625993 - bacc_train: 0.8645550899679566 - bacc_test: 0.8846357408024454 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 287 / 1000 
 - time: 0.026739120483398438 - sq_loss: 1527.8202835529844 - t

Epoch 316 / 1000 
 - time: 0.026754379272460938 - sq_loss: 869.5004756806986 - tot_loss: 947.3936772383762 - acc: 0.8651713088489031 - val_acc: 0.8880261361856385 - bacc_train: 0.865171308848903 - bacc_test: 0.8835084726619812 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 317 / 1000 
 - time: 0.026746034622192383 - sq_loss: 852.7711698276913 - tot_loss: 930.2330532174944 - acc: 0.8650480650727138 - val_acc: 0.8879366287005751 - bacc_train: 0.8650480650727138 - bacc_test: 0.883462401389723 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 318 / 1000 
 - time: 0.0266873836517334 - sq_loss: 836.364293980135 - tot_loss: 914.0845262661949 - acc: 0.8649248212965245 - val_acc: 0.8878918749580433 - bacc_train: 0.8649248212965246 - bacc_test: 0.8834393657535939 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 319 / 1000 
 - time: 0.027049541473388672 - sq_loss: 820.273658533423 - tot_loss: 89

Epoch 348 / 1000 
 - time: 0.0267179012298584 - sq_loss: 467.1300443100166 - tot_loss: 641.0388534896189 - acc: 0.8638156273108208 - val_acc: 0.8856541878314574 - bacc_train: 0.8638156273108208 - bacc_test: 0.8830470209244687 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 349 / 1000 
 - time: 0.026746749877929688 - sq_loss: 458.15407230150026 - tot_loss: 627.3417663730014 - acc: 0.8639388710870101 - val_acc: 0.8854527959900647 - bacc_train: 0.8639388710870101 - bacc_test: 0.8829433605618879 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 350 / 1000 
 - time: 0.02673625946044922 - sq_loss: 449.350993832269 - tot_loss: 616.994387161604 - acc: 0.8639388710870101 - val_acc: 0.885408042247533 - bacc_train: 0.8639388710870101 - bacc_test: 0.8833000434144234 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 351 / 1000 
 - time: 0.026782512664794922 - sq_loss: 440.717456782388 - tot_loss: 60

Epoch 380 / 1000 
 - time: 0.026752710342407227 - sq_loss: 251.19130972312385 - tot_loss: 512.2213929474065 - acc: 0.8639388710870101 - val_acc: 0.8854975497325964 - bacc_train: 0.8639388710870102 - bacc_test: 0.882966396198017 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 381 / 1000 
 - time: 0.026746034622192383 - sq_loss: 246.37259836803227 - tot_loss: 521.6249236224747 - acc: 0.8639388710870101 - val_acc: 0.8854975497325964 - bacc_train: 0.8639388710870102 - bacc_test: 0.882966396198017 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 382 / 1000 
 - time: 0.026751995086669922 - sq_loss: 241.64659030930707 - tot_loss: 531.0731156013987 - acc: 0.8639388710870101 - val_acc: 0.885564680346394 - bacc_train: 0.8639388710870102 - bacc_test: 0.8833806681408752 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 383 / 1000 
 - time: 0.026868820190429688 - sq_loss: 237.0114774663354 - tot_lo

Epoch 412 / 1000 
 - time: 0.026749134063720703 - sq_loss: 135.23880378652055 - tot_loss: 429.86603250048586 - acc: 0.8666502341631748 - val_acc: 0.8873324531763969 - bacc_train: 0.8666502341631748 - bacc_test: 0.883531138790645 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 413 / 1000 
 - time: 0.026718616485595703 - sq_loss: 132.65066765097032 - tot_loss: 377.33781575147674 - acc: 0.8666502341631748 - val_acc: 0.8874443375327262 - bacc_train: 0.8666502341631748 - bacc_test: 0.8835887278809678 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 414 / 1000 
 - time: 0.026741981506347656 - sq_loss: 130.1123083065766 - tot_loss: 359.2957335478892 - acc: 0.8666502341631748 - val_acc: 0.8874219606614603 - bacc_train: 0.8666502341631748 - bacc_test: 0.8835772100629031 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 415 / 1000 
 - time: 0.02679896354675293 - sq_loss: 127.62273547499645 - tot

Epoch 444 / 1000 
 - time: 0.026767492294311523 - sq_loss: 72.94171761247996 - tot_loss: 301.2532392570668 - acc: 0.8670199654917427 - val_acc: 0.8852737810199378 - bacc_train: 0.8670199654917427 - bacc_test: 0.8843700919720303 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 445 / 1000 
 - time: 0.02670598030090332 - sq_loss: 71.55030555551272 - tot_loss: 301.6572325322468 - acc: 0.8670199654917427 - val_acc: 0.885251404148672 - bacc_train: 0.8670199654917427 - bacc_test: 0.8843585741539658 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 446 / 1000 
 - time: 0.026690244674682617 - sq_loss: 70.18562005600695 - tot_loss: 302.50562386492487 - acc: 0.867143209267932 - val_acc: 0.8851842735348744 - bacc_train: 0.8671432092679319 - bacc_test: 0.8843240206997722 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 447 / 1000 
 - time: 0.026840686798095703 - sq_loss: 68.84715035250028 - tot_loss

Epoch 476 / 1000 
 - time: 0.026740550994873047 - sq_loss: 39.44399579029347 - tot_loss: 240.2257626977844 - acc: 0.8676361843726892 - val_acc: 0.8852961578912036 - bacc_train: 0.8676361843726892 - bacc_test: 0.8847613282787596 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 477 / 1000 
 - time: 0.026710033416748047 - sq_loss: 38.695755550964186 - tot_loss: 238.7796687271361 - acc: 0.8676361843726892 - val_acc: 0.8853409116337354 - bacc_train: 0.8676361843726892 - bacc_test: 0.8847843639148887 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 478 / 1000 
 - time: 0.02680206298828125 - sq_loss: 37.961864473520805 - tot_loss: 237.6388121842966 - acc: 0.8677594281488785 - val_acc: 0.8853185347624695 - bacc_train: 0.8677594281488785 - bacc_test: 0.8843931276081594 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 479 / 1000 
 - time: 0.026720046997070312 - sq_loss: 37.24204573860388 - tot_l

Epoch 508 / 1000 
 - time: 0.026796579360961914 - sq_loss: 21.4157638073552 - tot_loss: 126.99584238114507 - acc: 0.869608084791718 - val_acc: 0.8924791335675446 - bacc_train: 0.869608084791718 - bacc_test: 0.8812438965928492 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 509 / 1000 
 - time: 0.02673196792602539 - sq_loss: 21.012458053771756 - tot_loss: 124.92839203401714 - acc: 0.869608084791718 - val_acc: 0.8925238873100763 - bacc_train: 0.869608084791718 - bacc_test: 0.8812669322289783 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 510 / 1000 
 - time: 0.026829242706298828 - sq_loss: 20.61685316449287 - tot_loss: 122.90963373332956 - acc: 0.869608084791718 - val_acc: 0.8925462641813422 - bacc_train: 0.869608084791718 - bacc_test: 0.8812784500470427 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 511 / 1000 
 - time: 0.026724576950073242 - sq_loss: 20.228802786331705 - tot_loss:

Epoch 540 / 1000 
 - time: 0.026751279830932617 - sq_loss: 11.689487401941093 - tot_loss: 76.57571878750952 - acc: 0.8692383534631501 - val_acc: 0.8920763498847591 - bacc_train: 0.86923835346315 - bacc_test: 0.8810365758676874 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 541 / 1000 
 - time: 0.027034521102905273 - sq_loss: 11.471658928795373 - tot_loss: 75.39560913976457 - acc: 0.8692383534631501 - val_acc: 0.8920763498847591 - bacc_train: 0.86923835346315 - bacc_test: 0.8810365758676874 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 542 / 1000 
 - time: 0.026721715927124023 - sq_loss: 11.257971554855981 - tot_loss: 74.24615805752588 - acc: 0.8692383534631501 - val_acc: 0.8920539730134932 - bacc_train: 0.86923835346315 - bacc_test: 0.8810250580496228 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 543 / 1000 
 - time: 0.026732683181762695 - sq_loss: 11.048350082751913 - tot_loss

Epoch 572 / 1000 
 - time: 0.026774883270263672 - sq_loss: 6.427690647017238 - tot_loss: 47.43070807757784 - acc: 0.8701010598964752 - val_acc: 0.8910022600639979 - bacc_train: 0.8701010598964751 - bacc_test: 0.8812431575779186 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 573 / 1000 
 - time: 0.02681255340576172 - sq_loss: 6.309593802152789 - tot_loss: 46.71243489648153 - acc: 0.8701010598964752 - val_acc: 0.890979883192732 - bacc_train: 0.8701010598964751 - bacc_test: 0.8812316397598541 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 574 / 1000 
 - time: 0.026721477508544922 - sq_loss: 6.19373386611669 - tot_loss: 46.0181419842649 - acc: 0.8701010598964752 - val_acc: 0.8909575063214661 - bacc_train: 0.8701010598964751 - bacc_test: 0.8812201219417894 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959478672985782
Epoch 575 / 1000 
 - time: 0.026723384857177734 - sq_loss: 6.080068197804852 - tot_loss: 

Epoch 604 / 1000 
 - time: 0.0267336368560791 - sq_loss: 3.572218972355411 - tot_loss: 30.739604087360373 - acc: 0.8701010598964752 - val_acc: 0.8900176777283 - bacc_train: 0.8701010598964751 - bacc_test: 0.8826349660264021 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 605 / 1000 
 - time: 0.02675342559814453 - sq_loss: 3.508042073387209 - tot_loss: 30.367992060443008 - acc: 0.8702243036726646 - val_acc: 0.8899505471145025 - bacc_train: 0.8702243036726645 - bacc_test: 0.8822206940835438 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 606 / 1000 
 - time: 0.026766061782836914 - sq_loss: 3.4450706641827895 - tot_loss: 29.98629238267715 - acc: 0.8702243036726646 - val_acc: 0.8899505471145025 - bacc_train: 0.8702243036726645 - bacc_test: 0.8822206940835438 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 607 / 1000 
 - time: 0.026760101318359375 - sq_loss: 3.3832850825744263 - tot_loss

Epoch 636 / 1000 
 - time: 0.02670884132385254 - sq_loss: 2.017876907746639 - tot_loss: 21.009485695728472 - acc: 0.8707172787774218 - val_acc: 0.8898386627581731 - bacc_train: 0.8707172787774218 - bacc_test: 0.8825428234818857 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 637 / 1000 
 - time: 0.026695966720581055 - sq_loss: 1.982852908881829 - tot_loss: 20.779299200221367 - acc: 0.8707172787774218 - val_acc: 0.8898386627581731 - bacc_train: 0.8707172787774218 - bacc_test: 0.8825428234818857 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 638 / 1000 
 - time: 0.026711702346801758 - sq_loss: 1.9484822492760383 - tot_loss: 20.55922061196733 - acc: 0.8707172787774218 - val_acc: 0.8898162858869073 - bacc_train: 0.8707172787774218 - bacc_test: 0.8825313056638212 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 639 / 1000 
 - time: 0.026706457138061523 - sq_loss: 1.9147532171671595 - tot

Epoch 668 / 1000 
 - time: 0.0267641544342041 - sq_loss: 1.167397539525556 - tot_loss: 15.7327153456357 - acc: 0.870840522553611 - val_acc: 0.8896596477880463 - bacc_train: 0.8708405225536111 - bacc_test: 0.8824506809373693 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 669 / 1000 
 - time: 0.026703357696533203 - sq_loss: 1.1481591029926559 - tot_loss: 15.622958120975143 - acc: 0.870840522553611 - val_acc: 0.8896372709167805 - bacc_train: 0.8708405225536111 - bacc_test: 0.8824391631193049 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 670 / 1000 
 - time: 0.026750564575195312 - sq_loss: 1.1292761600893437 - tot_loss: 15.515831372203252 - acc: 0.870840522553611 - val_acc: 0.8896820246593121 - bacc_train: 0.8708405225536111 - bacc_test: 0.8828419172440987 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 671 / 1000 
 - time: 0.026731252670288086 - sq_loss: 1.1107407791026043 - tot_los

Epoch 700 / 1000 
 - time: 0.026739120483398438 - sq_loss: 0.6985158850630515 - tot_loss: 17.28548891484504 - acc: 0.8710870101059897 - val_acc: 0.8895253865604511 - bacc_train: 0.8710870101059897 - bacc_test: 0.8827612925176469 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 701 / 1000 
 - time: 0.026702404022216797 - sq_loss: 0.6878525191389813 - tot_loss: 14.092229677512012 - acc: 0.8710870101059897 - val_acc: 0.889704401530578 - bacc_train: 0.8710870101059897 - bacc_test: 0.883233153550828 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 702 / 1000 
 - time: 0.026692867279052734 - sq_loss: 0.6773775833323934 - tot_loss: 13.357145836147321 - acc: 0.8713334976583682 - val_acc: 0.8896596477880463 - bacc_train: 0.8713334976583682 - bacc_test: 0.8828303994260341 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 703 / 1000 
 - time: 0.026726961135864258 - sq_loss: 0.6670893452412118 - to

Epoch 732 / 1000 
 - time: 0.027016878128051758 - sq_loss: 0.4364751931416489 - tot_loss: 12.028290054252357 - acc: 0.8719497165393147 - val_acc: 0.8859450871579135 - bacc_train: 0.8719497165393147 - bacc_test: 0.8835764710479725 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 733 / 1000 
 - time: 0.0267181396484375 - sq_loss: 0.4304699426909804 - tot_loss: 12.020508717407006 - acc: 0.8719497165393147 - val_acc: 0.8859227102866477 - bacc_train: 0.8719497165393147 - bacc_test: 0.8835649532299079 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 734 / 1000 
 - time: 0.026847362518310547 - sq_loss: 0.42457100774652196 - tot_loss: 12.013844587111954 - acc: 0.8719497165393147 - val_acc: 0.8859227102866477 - bacc_train: 0.8719497165393147 - bacc_test: 0.8835649532299079 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 735 / 1000 
 - time: 0.02673792839050293 - sq_loss: 0.41877571184089507 - 

Epoch 764 / 1000 
 - time: 0.026718616485595703 - sq_loss: 0.28878165646199083 - tot_loss: 12.901541834237982 - acc: 0.873305398077397 - val_acc: 0.8865492626820918 - bacc_train: 0.873305398077397 - bacc_test: 0.8838874521357153 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 765 / 1000 
 - time: 0.026708126068115234 - sq_loss: 0.28538326736255587 - tot_loss: 12.975191501490446 - acc: 0.8731821543012078 - val_acc: 0.8865716395533576 - bacc_train: 0.8731821543012077 - bacc_test: 0.8838989699537798 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 766 / 1000 
 - time: 0.027031421661376953 - sq_loss: 0.2820438951379556 - tot_loss: 13.043266923261553 - acc: 0.873305398077397 - val_acc: 0.8866387701671552 - bacc_train: 0.873305398077397 - bacc_test: 0.8839335234079735 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 767 / 1000 
 - time: 0.02704930305480957 - sq_loss: 0.278761869841346 - tot

Epoch 796 / 1000 
 - time: 0.026973724365234375 - sq_loss: 0.20432959976860676 - tot_loss: 16.570436552255384 - acc: 0.8730589105250185 - val_acc: 0.886683523909687 - bacc_train: 0.8730589105250185 - bacc_test: 0.8839565590441025 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 797 / 1000 
 - time: 0.026752948760986328 - sq_loss: 0.2023524956704118 - tot_loss: 16.735774244327768 - acc: 0.8731821543012078 - val_acc: 0.886683523909687 - bacc_train: 0.8731821543012077 - bacc_test: 0.8839565590441025 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 798 / 1000 
 - time: 0.026816606521606445 - sq_loss: 0.200407381119063 - tot_loss: 16.90474309520856 - acc: 0.8731821543012078 - val_acc: 0.8867282776522186 - bacc_train: 0.8731821543012077 - bacc_test: 0.8839795946802316 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 799 / 1000 
 - time: 0.02676248550415039 - sq_loss: 0.1984939075274227 - tot

Epoch 828 / 1000 
 - time: 0.026694774627685547 - sq_loss: 0.15450101874262698 - tot_loss: 28.64352449743585 - acc: 0.8735518856297757 - val_acc: 0.8859898409004453 - bacc_train: 0.8735518856297757 - bacc_test: 0.8843589436614311 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 829 / 1000 
 - time: 0.026729345321655273 - sq_loss: 0.15331228826993445 - tot_loss: 29.38088941278987 - acc: 0.8735518856297757 - val_acc: 0.8859227102866477 - bacc_train: 0.8735518856297757 - bacc_test: 0.8843243902072374 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 830 / 1000 
 - time: 0.026689529418945312 - sq_loss: 0.15214175346046513 - tot_loss: 30.1707252686064 - acc: 0.8735518856297757 - val_acc: 0.885877956544116 - bacc_train: 0.8735518856297757 - bacc_test: 0.8843013545711084 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 831 / 1000 
 - time: 0.02672123908996582 - sq_loss: 0.1509890959818027 - to

Epoch 860 / 1000 
 - time: 0.0267486572265625 - sq_loss: 0.12411186116839364 - tot_loss: 24.135108785330907 - acc: 0.8740448607345329 - val_acc: 0.8861464789993063 - bacc_train: 0.8740448607345329 - bacc_test: 0.8851990053652123 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 861 / 1000 
 - time: 0.026897430419921875 - sq_loss: 0.12337429054522209 - tot_loss: 22.67776210993953 - acc: 0.8740448607345329 - val_acc: 0.8861912327418381 - bacc_train: 0.8740448607345329 - bacc_test: 0.8852220410013414 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 862 / 1000 
 - time: 0.026642322540283203 - sq_loss: 0.12264725785931327 - tot_loss: 21.437983846475777 - acc: 0.8740448607345329 - val_acc: 0.8861912327418381 - bacc_train: 0.8740448607345329 - bacc_test: 0.8848423225126767 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 863 / 1000 
 - time: 0.026738882064819336 - sq_loss: 0.12193063828792329 

Epoch 892 / 1000 
 - time: 0.026824235916137695 - sq_loss: 0.10494026089383371 - tot_loss: 10.525479334414305 - acc: 0.8756470298249939 - val_acc: 0.8861241021280405 - bacc_train: 0.8756470298249939 - bacc_test: 0.886326643013142 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 893 / 1000 
 - time: 0.026761293411254883 - sq_loss: 0.10446408797596185 - tot_loss: 10.336479243713322 - acc: 0.8756470298249939 - val_acc: 0.8861241021280405 - bacc_train: 0.8756470298249939 - bacc_test: 0.886326643013142 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 894 / 1000 
 - time: 0.026722431182861328 - sq_loss: 0.10399406519490557 - tot_loss: 10.132220666075382 - acc: 0.8756470298249939 - val_acc: 0.8861688558705721 - bacc_train: 0.8756470298249939 - bacc_test: 0.8859699601606064 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 895 / 1000 
 - time: 0.026701927185058594 - sq_loss: 0.10353012164534066

Epoch 924 / 1000 
 - time: 0.026688814163208008 - sq_loss: 0.09232247714116157 - tot_loss: 6.178462308494213 - acc: 0.8767562238106975 - val_acc: 0.8877128599879165 - bacc_train: 0.8767562238106976 - bacc_test: 0.8852458156524011 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 925 / 1000 
 - time: 0.026730775833129883 - sq_loss: 0.09200274776567084 - tot_loss: 6.083493422980023 - acc: 0.8767562238106975 - val_acc: 0.8876457293741189 - bacc_train: 0.8767562238106976 - bacc_test: 0.885970699175537 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 926 / 1000 
 - time: 0.02689838409423828 - sq_loss: 0.09168680636713397 - tot_loss: 6.004042821911706 - acc: 0.8766329800345083 - val_acc: 0.8875338450177896 - bacc_train: 0.8766329800345083 - bacc_test: 0.8862928285738789 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 927 / 1000 
 - time: 0.02671527862548828 - sq_loss: 0.0913744944000968 - to

Epoch 956 / 1000 
 - time: 0.026693344116210938 - sq_loss: 0.08365144293334591 - tot_loss: 5.211224276234031 - acc: 0.8770027113630762 - val_acc: 0.8879813824431068 - bacc_train: 0.8770027113630762 - bacc_test: 0.8819665630711933 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 957 / 1000 
 - time: 0.02691650390625 - sq_loss: 0.08342456656450999 - tot_loss: 5.299381905153642 - acc: 0.8770027113630762 - val_acc: 0.8880037593143727 - bacc_train: 0.8770027113630762 - bacc_test: 0.8819780808892577 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 958 / 1000 
 - time: 0.026710033416748047 - sq_loss: 0.08319998818716975 - tot_loss: 5.355242509354136 - acc: 0.8770027113630762 - val_acc: 0.8879813824431068 - bacc_train: 0.8770027113630762 - bacc_test: 0.8819665630711933 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 959 / 1000 
 - time: 0.026737451553344727 - sq_loss: 0.0829775564455196 - tot

Epoch 988 / 1000 
 - time: 0.026881933212280273 - sq_loss: 0.07736881317140132 - tot_loss: 5.800689491198796 - acc: 0.8772491989154547 - val_acc: 0.8875114681465237 - bacc_train: 0.8772491989154547 - bacc_test: 0.8862813107558143 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 989 / 1000 
 - time: 0.026762008666992188 - sq_loss: 0.07720029389710614 - tot_loss: 5.623702923969514 - acc: 0.8771259551392655 - val_acc: 0.8876233525028531 - bacc_train: 0.8771259551392654 - bacc_test: 0.8863388998461371 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 990 / 1000 
 - time: 0.026848554611206055 - sq_loss: 0.07703318583013055 - tot_loss: 5.428400705353716 - acc: 0.8770027113630762 - val_acc: 0.8877128599879165 - bacc_train: 0.8770027113630762 - bacc_test: 0.8863849711183953 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8962322274881517
Epoch 991 / 1000 
 - time: 0.026758909225463867 - sq_loss: 0.07686764633136753 

### Visualization of training results

In [None]:
plt.figure()
plt.plot(np.arange(1,niter+1), loss2)
plt.yscale('log',base=2)
plt.title('training loss')

plt.figure()
plt.plot(np.arange(1,niter+1), accuracy_train)
plt.plot(np.arange(1,niter+1), accuracy_test)
plt.title('accuracy')