## Import libraries

In [1]:
# !pip install tensorly

import tensorly as tl
from tensorly.decomposition import parafac
from tensorly.decomposition import tucker
from tensorly.decomposition import tensor_train
from math import ceil
from tensorly import tt_to_tensor
from tensorly.decomposition import matrix_product_state
import torch.nn.init as init

In [2]:
"""
5 runs of 50 epochs, seed = 10, 20, 30, 40, 50;
validation accuracies: 0.9492, 0.9457, 0.9463, 0.9439, 0.9455
"""
#from __future__ import print_function, division
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
import os
import copy
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import balanced_accuracy_score

print("PyTorch Version:", torch.__version__)
print("Torchvision Version:", torchvision.__version__)
print("GPU is available?", torch.cuda.is_available())

dtype = torch.float64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

PyTorch Version: 1.13.1
Torchvision Version: 0.14.1
GPU is available? True


## train data processing

In [3]:
df_train_1 = pd.read_csv('/home/c/cl237/Datasets/Flare_LSTM_dataset/M/normalized_training.csv')
df_train_2 = pd.read_csv('/home/c/cl237/Datasets/Flare_LSTM_dataset/M/normalized_validation.csv')
# concat train and validation into train. we don't need validation set
df_train = df_train_1.append(df_train_2).reset_index(drop=True)

# undersampling for "negative" samples
negative_df = df_train[df_train['label'] == 'Negative']
positive_df = df_train[df_train['label'] == 'Positive']

# Get number of "positive" samples
num_positive = len(positive_df)

# take the same number of "negative" samples as there are "positive" samples
balanced_negative_df = negative_df.sample(n=num_positive, random_state=10086)

df_train_balanced = pd.concat([positive_df, balanced_negative_df])
print(df_train_balanced['label'].value_counts())

Positive    4057
Negative    4057
Name: label, dtype: int64


In [4]:
pd_X_train = df_train_balanced.iloc[:, 5:]
pd_y_train = df_train_balanced.iloc[:, 0]

N = len(pd_X_train)
K = 2

pd_X_train = pd_X_train.values
X_train = torch.tensor(pd_X_train, dtype=dtype, device=device)
X_train = torch.t(X_train)

# scaler = MinMaxScaler()
# x = pd_X_train.values
# x_scaled = scaler.fit_transform(x)
# X_train = torch.tensor(x_scaled, dtype=dtype, device=device)
# X_train = torch.t(X_train)

# Initialize the LabelEncoder
encoder = LabelEncoder()
# Fit and transform the y values
y = encoder.fit_transform(pd_y_train.values)
y_train=torch.tensor(y, dtype=torch.long, device=device)
y_train = torch.flatten(y_train)

y_one_hot = torch.zeros(N, K, device=device).scatter_(1, y_train.unsqueeze(1), 1)
y_one_hot = torch.t(y_one_hot).to(device=device)

print(list(encoder.classes_))

['Negative', 'Positive']


## test data processing

In [5]:
df_test = pd.read_csv('/home/c/cl237/Datasets/Flare_LSTM_dataset/M/normalized_testing.csv')

# Undersampling for "negative" samples
#negative_df = df_test[df_test['label'] == 'Negative']
#positive_df = df_test[df_test['label'] == 'Positive']

# Get number of "positive" samples
#num_positive = len(positive_df)

# Take the same number of "negative" samples as there are "positive" samples
#balanced_negative_df = negative_df.sample(n=num_positive, random_state=10086)

#df_test_balanced = pd.concat([positive_df, balanced_negative_df])
#print(df_test_balanced['label'].value_counts())

In [6]:
pd_X_test = df_test.iloc[:, 5:]
pd_y_test = df_test.iloc[:, 0]

N_test = len(pd_X_test)
K = 2

pd_X_test = pd_X_test.values
X_test = torch.tensor(pd_X_test, dtype=dtype, device=device)
X_test = torch.t(X_test)
# x = pd_X_test.values
# x_scaled = scaler.transform(x)  # only transform x, don't fit the scaler again
# X_test = torch.tensor(x_scaled, dtype=dtype, device=device)
# X_test = torch.t(X_test)


# Initialize the LabelEncoder
encoder = LabelEncoder()
# Fit and transform the y values
y = encoder.fit_transform(pd_y_test.values)
y_test=torch.tensor(y, dtype=torch.long, device=device)
y_test = torch.flatten(y_test)

y_test_one_hot = torch.zeros(N_test, K, device=device).scatter_(1, y_test.unsqueeze(1), 1)
y_test_one_hot = torch.t(y_test_one_hot).to(device=device)

In [7]:
print(df_test['label'].value_counts())

Negative    43411
Positive     1278
Name: label, dtype: int64


## Main algorithm

### Define functions for updating blocks

In [8]:
# def updateMask(W, sparsity):
#     torch.dist(V1,nn.ReLU()(U1),2),2).cpu().numpy()
#     Mask
#     Wsparse
#     return Mask, Wsparse

# Wsquare = torch.square(W3)
# Threshold = torch.quantile(torch.reshape(Wsquare, (-1,)), 0.5, interpolation='linear')
# Wsparse = W3
# Wsparse[Wsquare < Threshold] =  0

# Wsparse

In [9]:
def updateV(U1,U2,W,b,rho,gamma):
    _, d = W.size()
    I = torch.eye(d, device=device)
    U1 = nn.ReLU()(U1)
    _, col_U2 = U2.size()
    Vstar = torch.mm(torch.inverse(rho*(torch.mm(torch.t(W),W))+gamma*I), rho*torch.mm(torch.t(W),U2-b.repeat(1,col_U2))+gamma*U1)
    return Vstar

In [10]:
def updateWb_org(U, V, W, b, alpha, rho):
    d,N = V.size()
    I = torch.eye(d, device=device)
    _, col_U = U.size()
    Wstar = torch.mm(alpha*W+rho*torch.mm(U-b.repeat(1,col_U),torch.t(V)),torch.inverse(alpha*I+rho*(torch.mm(V,torch.t(V)))))
    bstar = 0*(alpha*b+rho*torch.sum(U-torch.mm(Wstar,V), dim=1).reshape(b.size()))/(rho*N+alpha)
    return Wstar, bstar

In [11]:
def updateWb(U, V, W, b, W_tensor_rec, alpha, rho,tau):
    W_tensor_rec = torch.as_tensor(W_tensor_rec,device=device).float()
    W_tensor2matrix = W_tensor_rec.reshape(W.shape)
    d,N = V.size()
    I = torch.eye(d, device=device)
    _, col_U = U.size()
    Wstar = torch.mm(alpha*W+tau*W_tensor2matrix+rho*torch.mm(U-b.repeat(1,col_U),torch.t(V)),torch.inverse((alpha+tau)*I+rho*(torch.mm(V,torch.t(V)))))
    bstar = 0*(alpha*b+rho*torch.sum(U-torch.mm(Wstar,V), dim=1).reshape(b.size()))/(rho*N+alpha)
    return Wstar, bstar

In [12]:
def updateWbsparse(U, V, W, b,  W_tensor2matrix, alpha, rho,tau):
    d,N = V.size()
    I = torch.eye(d, device=device)
    _, col_U = U.size()
    Wstar = torch.mm(alpha*W+tau*W_tensor2matrix+rho*torch.mm(U-b.repeat(1,col_U),torch.t(V)),torch.inverse((alpha+tau)*I+rho*(torch.mm(V,torch.t(V)))))
    bstar = 0*(alpha*b+rho*torch.sum(U-torch.mm(Wstar,V), dim=1).reshape(b.size()))/(rho*N+alpha)
    return Wstar, bstar

### Define the proximal operator of the ReLU activation function

In [13]:
def relu_prox(a, b, gamma, d, N):
    val = torch.empty(d,N, device=device)
    x = (a+gamma*b)/(1+gamma)
    y = torch.min(b,torch.zeros(d,N, device=device))

    val = torch.where(a+gamma*b < 0, y, torch.zeros(d,N, device=device))
    val = torch.where(((a+gamma*b >= 0) & (b >=0)) | ((a*(gamma-np.sqrt(gamma*(gamma+1))) <= gamma*b) & (b < 0)), x, val)
    val = torch.where((-a <= gamma*b) & (gamma*b <= a*(gamma-np.sqrt(gamma*(gamma+1)))), b, val)
    return val

### Effective Sparsity

In [14]:
def process_weights(W1, W2, W3):
    # Clone the tensors to keep the originals unchanged
    W1, W2, W3 = W1.clone(), W2.clone(), W3.clone()

    while True:
        # Store a copy of the current weights
        old_W1, old_W2, old_W3 = W1.clone(), W2.clone(), W3.clone()

        # Check all rows of W1
        # If all values in a row of W1 are 0, set corresponding column in W2 to 0
        zero_rows_W1 = torch.all(W1 == 0, dim=1)
        W2[:, zero_rows_W1] = 0

        # Check all columns of W2
        # If all values in a column of W2 are 0, set corresponding row in W1 to 0
        zero_cols_W2 = torch.all(W2 == 0, dim=0)
        W1[zero_cols_W2, :] = 0

        # Check all rows of W2
        # If all values in a row of W2 are 0, set corresponding column in W3 to 0
        zero_rows_W2 = torch.all(W2 == 0, dim=1)
        W3[:, zero_rows_W2] = 0

        # Check if matrices are unchanged
        if torch.equal(W1, old_W1) and torch.equal(W2, old_W2) and torch.equal(W3, old_W3):
            break

    return W1, W2, W3

### Parameter initialization

In [20]:
#df = pd.DataFrame()
#df.to_csv('/home/c/cl237/TenBCD/Sparse/LeNet300_100/Flare(BCD method)/differernt gamma rho/result.csv')

niter = 1000
sparsity = 0.9

tau = 200
alpha = 1
rho = 1
gamma = 100

        

# Do something with these parameters
print(sparsity, tau, gamma, rho, alpha)

loss1 = np.empty(niter)
loss2 = np.empty(niter)
accuracy_train = np.empty(niter)
accuracy_test = np.empty(niter)
time1 = np.empty(niter)
bacc_train = np.empty(niter)
bacc_test = np.empty(niter)
true_sparsity = np.empty(niter)
effective_sparsity = np.empty(niter)

results = torch.zeros(1, 9, niter)

for Out_iter in range(1):
    rank_initial = 700
    seed = 10 + 10*Out_iter
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.manual_seed(seed)

    d0 = 40
    d1 =  300
    d2 =  100
    d3 = K # Layers: input + 2 hidden + output

    W1 = 0.2*init.orthogonal_(torch.empty(d1, d0, device=device,dtype=dtype), gain=1.0)
    b1 = 0*torch.ones(d1, 1, dtype=torch.float64,device=device)
    W1square = torch.square(W1)
    Threshold = torch.quantile(torch.reshape(W1square, (-1,)), sparsity, interpolation='linear')
    W1sparse = W1
    W1sparse[W1square < Threshold] =  0

    W2 = 0.2*init.orthogonal_(torch.empty(d2, d1, device=device,dtype=dtype), gain=1.0)
    b2 = 0*torch.ones(d2, 1, dtype=torch.float64,device=device)
    W2square = torch.square(W2)
    Threshold = torch.quantile(torch.reshape(W2square, (-1,)), sparsity, interpolation='linear')
    W2sparse = W2
    W2sparse[W2square < Threshold] =  0


    W3 = 0.2*init.orthogonal_(torch.empty(d3, d2, device=device,dtype=dtype), gain=1.0)
    b3 = 0*torch.ones(d3, 1, dtype=torch.float64,device=device)
    # W3square = torch.square(W3)
    # Threshold = torch.quantile(torch.reshape(W3square, (-1,)), sparsity, interpolation='linear')
    # W3sparse = W3
    # W3sparse[W3square < Threshold] =  0


    U1 = torch.addmm(b1.repeat(1, N), W1, X_train)
    V1 = nn.ReLU()(U1)
    U2 = torch.addmm(b2.repeat(1, N), W2, V1)
    V2 = nn.ReLU()(U2)
    U3 = torch.addmm(b3.repeat(1, N), W3, V2)
    V3 = U3
    # U4 = torch.addmm(b4.repeat(1, N), W4, V3)
    # V4 = U4


    print('Train on', N, 'samples, validate on', N_test, 'samples')

    for k in range(niter):

        start = time.time()

        # update V4
        V3 = (y_one_hot + gamma*U3 + alpha*V3)/(1 + gamma + alpha)

        # update U4
        U3 = (gamma*V3 + rho*(torch.mm(W3,V2) + b3.repeat(1,N)))/(gamma + rho)

        # update W4 and b4
        W3, b3 = updateWb_org(U3,V2,W3,b3, alpha,rho)
        # W3square = torch.square(W3)
        # Threshold = torch.quantile(torch.reshape(W3square, (-1,)), sparsity, interpolation='linear')
        # W3sparse = W3
        # W3sparse[W3square < Threshold] =  0
        # update V2
        V2 = updateV(U2,U3,W3,b3,rho,gamma)

        # update U2
        U2 = relu_prox(V2,(rho*torch.addmm(b2.repeat(1,N), W2, V1) + alpha*U2)/(rho + alpha),(rho + alpha)/gamma,d2,N)

        # update W2 and b2

        W2, b2 = updateWbsparse(U2,V1,W2,b2,W2sparse, alpha,rho,tau)
        W2square = torch.square(W2)
        Threshold = torch.quantile(torch.reshape(W2square, (-1,)), sparsity, interpolation='linear')
        W2sparse = W2
        W2sparse[W2square < Threshold] =  0

        # update V1
        V1 = updateV(U1,U2,W2,b2,rho,gamma)

        # update U1
        U1 = relu_prox(V1,(rho*torch.addmm(b1.repeat(1,N), W1, X_train) + alpha*U1)/(rho + alpha),(rho + alpha)/gamma,d1,N)

        # update W1 and b1
        W1, b1 = updateWbsparse(U1,X_train,W1,b1,W1sparse, alpha,rho,tau)
        W1square = torch.square(W1)
        Threshold = torch.quantile(torch.reshape(W1square, (-1,)), sparsity, interpolation='linear')
        W1sparse = W1
        W1sparse[W1square < Threshold] =  0

        a1_train = nn.ReLU()(torch.addmm(b1.repeat(1, N), W1sparse, X_train)).double()
        a2_train = nn.ReLU()(torch.addmm(b2.repeat(1, N),  W2sparse, a1_train)).double()
        pred = torch.argmax(torch.addmm(b3.repeat(1, N), W3, a2_train), dim=0)

        a1_test = nn.ReLU()(torch.addmm(b1.repeat(1, N_test), W1sparse, X_test)).double()
        a2_test = nn.ReLU()(torch.addmm(b2.repeat(1, N_test),  W2sparse, a1_test)).double()
        pred_test = torch.argmax(torch.addmm(b3.repeat(1, N_test), W3, a2_test), dim=0)


        loss1[k] = gamma/2*torch.pow(torch.dist(V3,y_one_hot,2),2).cpu().numpy()
        loss2[k] = loss1[k] + rho/2*torch.pow(torch.dist(torch.addmm(b1.repeat(1,N), W1sparse, X_train),U1,2),2).cpu().numpy() \
        +rho/2*torch.pow(torch.dist(torch.addmm(b2.repeat(1,N),  W2sparse, V1),U2,2),2).cpu().numpy() \
        +rho/2*torch.pow(torch.dist(torch.addmm(b3.repeat(1,N), W3, V2),U3,2),2).cpu().numpy() \
        + gamma/2*torch.pow(torch.dist(V1,nn.ReLU()(U1),2),2).cpu().numpy() \
        + gamma/2*torch.pow(torch.dist(V2,nn.ReLU()(U2),2),2).cpu().numpy() \
        + gamma/2*torch.pow(torch.dist(V3,U3,2),2).cpu().numpy() \
        + tau/2*torch.pow(torch.dist(W1,W1sparse,2),2).cpu().numpy()\
        + tau/2*torch.pow(torch.dist(W2,W2sparse,2),2).cpu().numpy()\
        # + tau/2*torch.pow(torch.dist(W3,W3sparse,2),2).cpu().numpy()

        # compute training accuracy
        correct_train = pred == y_train
        accuracy_train[k] = np.mean(correct_train.cpu().numpy())

        # compute validation accuracy
        correct_test = pred_test == y_test
        accuracy_test[k] = np.mean(correct_test.cpu().numpy())


        pred_train_np = pred.cpu().numpy()
        y_train_np = y_train.cpu().numpy()
        pred_test_np = pred_test.cpu().numpy()
        y_test_np = y_test.cpu().numpy()

        bacc_train[k] = balanced_accuracy_score(y_train_np, pred_train_np)
        bacc_test[k] = balanced_accuracy_score(y_test_np, pred_test_np)

        # compute training time
        stop = time.time()
        duration = stop - start
        time1[k] = duration

        # sparsity
        num_zeros_W1 = torch.sum(W1 == 0).item()
        num_zeros_W2 = torch.sum(W2 == 0).item()
        num_zeros_W3 = torch.sum(W3 == 0).item()
        total_zeros_old = num_zeros_W1 + num_zeros_W2 + num_zeros_W3
        total_weights = d0*d1+d1*d2+d2*d3
        true_sparsity[k] = total_zeros_old / total_weights

        new_W1, new_W2, new_W3 = process_weights(W1, W2, W3)
        num_zeros_W1_new = torch.sum(new_W1 == 0).item()
        num_zeros_W2_new = torch.sum(new_W2 == 0).item()
        num_zeros_W3_new = torch.sum(new_W3 == 0).item()
        total_zeros = num_zeros_W1_new + num_zeros_W2_new + num_zeros_W3_new
        total_weights = d0*d1+d1*d2+d2*d3
        effective_sparsity[k] = total_zeros / total_weights


        # print results
        print('Epoch', k + 1, '/', niter, '\n',
              '-', 'time:', time1[k], '-', 'sq_loss:', loss1[k], '-', 'tot_loss:', loss2[k],
              '-', 'acc:', accuracy_train[k], '-', 'val_acc:', accuracy_test[k],
              '-', 'bacc_train:', bacc_train[k], '-', 'bacc_test:', bacc_test[k],'-', 'true_sparsity:', true_sparsity[k],
             '-', 'effective_sparsity:', effective_sparsity[k])


    results[Out_iter,0,:] = torch.tensor(loss1)
    results[Out_iter,1,:] = torch.tensor(loss2)
    results[Out_iter,2,:] = torch.tensor(accuracy_train)
    results[Out_iter,3,:] = torch.tensor(accuracy_test)
    results[Out_iter,4,:] = torch.tensor(time1)
    results[Out_iter,5,:] = torch.tensor(bacc_train)
    results[Out_iter,6,:] = torch.tensor(bacc_test)
    results[Out_iter,7,:] = torch.tensor(true_sparsity)
    results[Out_iter,8,:] = torch.tensor(effective_sparsity)

#             df=pd.read_csv('/home/c/cl237/TenBCD/Sparse/LeNet300_100/Flare(BCD method)/differernt gamma rho/result.csv')
#             new_row = {
#                         'tau': tau, 
#                         'gamma': gamma, 
#                         'rho': rho, 
#                         'alpha': alpha,
#                         'loss1': loss1[niter-1], 
#                         'loss2': loss2[niter-1], 
#                         'accuracy_train': accuracy_train[niter-1],
#                         'accuracy_test': accuracy_test[niter-1], 
#                         'max_accuracy_train': max(accuracy_train),
#                         'max_accuracy_test': max(accuracy_test),
#                         'time': time1[niter-1], 
#                         'BACC_train': bacc_train[niter-1],
#                         'BACC_test': bacc_test[niter-1],
#                         'max_BACC_train': max(bacc_train),
#                         'max_BACC_test': max(bacc_test),
#                         'Sparsity': sparsity,
#                         'seed' : seed
#                     }

#             df=df.append(new_row,ignore_index=True)
#             df.to_csv('/home/c/cl237/TenBCD/Sparse/LeNet300_100/Flare(BCD method)/differernt gamma rho/result.csv',index=False)

filename="Orthogonal_" + "niter_"+ str(niter) + "Sparsity_" + str(sparsity) +"tau_" + str(tau) + "gamma_" + str(gamma) + \
                "rho_" + str(rho) + "alpha_" + str(alpha)+ ".mat"
from scipy.io import savemat
%cd /home/c/cl237/TenBCD/Sparse/LeNet300_100/Flare(BCD method)/different init/
savemat (filename, {'results': torch.Tensor.numpy(results)})

0.9 200 100 1 1
Train on 8114 samples, validate on 44689 samples
Epoch 1 / 1000 
 - time: 0.029474496841430664 - sq_loss: 397717.6989335023 - tot_loss: 397717.90800162096 - acc: 0.8009613014542766 - val_acc: 0.8072903846584171 - bacc_train: 0.8009613014542765 - bacc_test: 0.860558391029671 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 2 / 1000 
 - time: 0.028175830841064453 - sq_loss: 390033.2415263759 - tot_loss: 390033.98336951394 - acc: 0.8009613014542766 - val_acc: 0.8064848172928462 - bacc_train: 0.8009613014542765 - bacc_test: 0.865839526909318 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 3 / 1000 
 - time: 0.028130769729614258 - sq_loss: 382543.9146260285 - tot_loss: 382545.42082533694 - acc: 0.8008380576780872 - val_acc: 0.817494237955649 - bacc_train: 0.8008380576780872 - bacc_test: 0.8665699530444333 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 4 / 1000 
 - time: 0

Epoch 31 / 1000 
 - time: 0.02983546257019043 - sq_loss: 222811.96556541303 - tot_loss: 222812.5386312971 - acc: 0.8175992112398324 - val_acc: 0.8372977690259348 - bacc_train: 0.8175992112398325 - bacc_test: 0.8756240665655622 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 32 / 1000 
 - time: 0.026848316192626953 - sq_loss: 218517.41748029375 - tot_loss: 218517.9962220624 - acc: 0.8179689425684002 - val_acc: 0.8376781758374544 - bacc_train: 0.8179689425684002 - bacc_test: 0.8765793064499889 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 33 / 1000 
 - time: 0.0268857479095459 - sq_loss: 214305.32614237408 - tot_loss: 214305.91061321268 - acc: 0.8180921863445896 - val_acc: 0.8379466982926447 - bacc_train: 0.8180921863445896 - bacc_test: 0.8767175202667634 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 34 / 1000 
 - time: 0.026849985122680664 - sq_loss: 210174.1379847413 - tot_loss:

Epoch 63 / 1000 
 - time: 0.026990652084350586 - sq_loss: 119475.02902210492 - tot_loss: 119476.4694086449 - acc: 0.8233916687207296 - val_acc: 0.8454653270379735 - bacc_train: 0.8233916687207297 - bacc_test: 0.8821063810911097 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 64 / 1000 
 - time: 0.026847362518310547 - sq_loss: 117169.7840027949 - tot_loss: 117171.21092257627 - acc: 0.8236381562731082 - val_acc: 0.8456219651368345 - bacc_train: 0.8236381562731082 - bacc_test: 0.8821870058175616 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 65 / 1000 
 - time: 0.02685260772705078 - sq_loss: 114908.99076620433 - tot_loss: 114910.40459357877 - acc: 0.8245008627064333 - val_acc: 0.8462708944035445 - bacc_train: 0.8245008627064333 - bacc_test: 0.8829007410300982 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 66 / 1000 
 - time: 0.02690911293029785 - sq_loss: 112691.79406232333 - tot_los

Epoch 95 / 1000 
 - time: 0.026818275451660156 - sq_loss: 64043.261421728115 - tot_loss: 64044.775844139294 - acc: 0.8273354695587872 - val_acc: 0.8534314932086196 - bacc_train: 0.8273354695587873 - bacc_test: 0.8865864428107528 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 96 / 1000 
 - time: 0.02685713768005371 - sq_loss: 62807.297176100496 - tot_loss: 62808.901890582965 - acc: 0.8277052008873552 - val_acc: 0.8537223925350758 - bacc_train: 0.8277052008873552 - bacc_test: 0.8867361744455919 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 97 / 1000 
 - time: 0.02677774429321289 - sq_loss: 61595.183020311786 - tot_loss: 61596.90657930393 - acc: 0.8279516884397338 - val_acc: 0.8538790306339368 - bacc_train: 0.8279516884397338 - bacc_test: 0.886437080683379 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 98 / 1000 
 - time: 0.026860952377319336 - sq_loss: 60406.45906249481 - tot_loss

Epoch 127 / 1000 
 - time: 0.026845216751098633 - sq_loss: 34327.69352522964 - tot_loss: 34330.40738216035 - acc: 0.8315257579492236 - val_acc: 0.859741770905592 - bacc_train: 0.8315257579492236 - bacc_test: 0.8898344675049548 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 128 / 1000 
 - time: 0.026801347732543945 - sq_loss: 33665.21591592495 - tot_loss: 33667.917112551164 - acc: 0.8315257579492236 - val_acc: 0.8599655396182506 - bacc_train: 0.8315257579492236 - bacc_test: 0.8903293641742649 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 129 / 1000 
 - time: 0.026855945587158203 - sq_loss: 33015.52526759509 - tot_loss: 33018.19057585327 - acc: 0.8317722455016021 - val_acc: 0.860055047103314 - bacc_train: 0.8317722455016021 - bacc_test: 0.8903754354465231 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 130 / 1000 
 - time: 0.026912927627563477 - sq_loss: 32378.374762106734 - tot_lo

Epoch 159 / 1000 
 - time: 0.026880979537963867 - sq_loss: 18400.838083376824 - tot_loss: 18403.04362532595 - acc: 0.8367019965491743 - val_acc: 0.8625836335563561 - bacc_train: 0.8367019965491742 - bacc_test: 0.8920566673764815 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 160 / 1000 
 - time: 0.026969194412231445 - sq_loss: 18045.774773985257 - tot_loss: 18048.030757951106 - acc: 0.8368252403253635 - val_acc: 0.8625388798138244 - bacc_train: 0.8368252403253635 - bacc_test: 0.8920336317403523 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 161 / 1000 
 - time: 0.027102947235107422 - sq_loss: 17697.564848390593 - tot_loss: 17699.8841454397 - acc: 0.8370717278777422 - val_acc: 0.8626507641701537 - bacc_train: 0.8370717278777422 - bacc_test: 0.8920912208306752 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 162 / 1000 
 - time: 0.02687859535217285 - sq_loss: 17356.07646687038 - tot_

Epoch 191 / 1000 
 - time: 0.026845932006835938 - sq_loss: 9864.67721312632 - tot_loss: 9866.853179555883 - acc: 0.838180921863446 - val_acc: 0.8627850253977489 - bacc_train: 0.838180921863446 - bacc_test: 0.8959575126257096 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 192 / 1000 
 - time: 0.02696061134338379 - sq_loss: 9674.37595595212 - tot_loss: 9676.546883249346 - acc: 0.8383041656396352 - val_acc: 0.8627178947839513 - bacc_train: 0.8383041656396352 - bacc_test: 0.895922959171516 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 193 / 1000 
 - time: 0.026866912841796875 - sq_loss: 9487.747541081857 - tot_loss: 9489.929644654225 - acc: 0.8384274094158245 - val_acc: 0.8628074022690148 - bacc_train: 0.8384274094158245 - bacc_test: 0.8959690304437742 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8959241706161137
Epoch 194 / 1000 
 - time: 0.026819229125976562 - sq_loss: 9304.721100416655 - tot_loss: 93

Epoch 223 / 1000 
 - time: 0.027048349380493164 - sq_loss: 5289.4253078979245 - tot_loss: 5291.846339910345 - acc: 0.8401528222824748 - val_acc: 0.8659849179887668 - bacc_train: 0.8401528222824748 - bacc_test: 0.8995031530522632 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 224 / 1000 
 - time: 0.026843786239624023 - sq_loss: 5187.4208660847435 - tot_loss: 5189.91980598397 - acc: 0.8400295785062855 - val_acc: 0.8661863098301595 - bacc_train: 0.8400295785062855 - bacc_test: 0.899606813414844 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 225 / 1000 
 - time: 0.0268862247467041 - sq_loss: 5087.384790785045 - tot_loss: 5089.985381441134 - acc: 0.8400295785062855 - val_acc: 0.8662981941864888 - bacc_train: 0.8400295785062855 - bacc_test: 0.8996644025051669 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 226 / 1000 
 - time: 0.02691340446472168 - sq_loss: 4989.27913122849 - tot_loss: 

Epoch 255 / 1000 
 - time: 0.026926040649414062 - sq_loss: 2836.9168993151116 - tot_loss: 2839.837947585448 - acc: 0.8410155287157999 - val_acc: 0.8657163955335765 - bacc_train: 0.8410155287157999 - bacc_test: 0.9008838131901475 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 256 / 1000 
 - time: 0.026953935623168945 - sq_loss: 2782.2349523366233 - tot_loss: 2785.2207998708836 - acc: 0.8410155287157999 - val_acc: 0.8656940186623107 - bacc_train: 0.8410155287157999 - bacc_test: 0.900872295372083 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 257 / 1000 
 - time: 0.02696824073791504 - sq_loss: 2728.6079540403375 - tot_loss: 2731.677088424846 - acc: 0.8412620162681784 - val_acc: 0.8657611492761083 - bacc_train: 0.8412620162681785 - bacc_test: 0.9009068488262766 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 258 / 1000 
 - time: 0.02684783935546875 - sq_loss: 2676.0155473686054 - tot_

Epoch 287 / 1000 
 - time: 0.027126789093017578 - sq_loss: 1522.0900676410372 - tot_loss: 1528.3915056228034 - acc: 0.844343110672911 - val_acc: 0.8675065452348453 - bacc_train: 0.844343110672911 - bacc_test: 0.9025646756126406 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 288 / 1000 
 - time: 0.026894569396972656 - sq_loss: 1492.7711783806737 - tot_loss: 1497.935079818066 - acc: 0.8442198668967217 - val_acc: 0.8674394146210477 - bacc_train: 0.8442198668967217 - bacc_test: 0.902530122158447 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 289 / 1000 
 - time: 0.02684164047241211 - sq_loss: 1464.0177422409315 - tot_loss: 1468.7996038806139 - acc: 0.8440966231205325 - val_acc: 0.8673275302647184 - bacc_train: 0.8440966231205325 - bacc_test: 0.9024725330681242 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 290 / 1000 
 - time: 0.026834964752197266 - sq_loss: 1435.8188529114977 - tot_

Epoch 319 / 1000 
 - time: 0.02687549591064453 - sq_loss: 817.0451480750186 - tot_loss: 820.3611658394586 - acc: 0.8450825733300469 - val_acc: 0.8690057956096578 - bacc_train: 0.8450825733300469 - bacc_test: 0.9037160879116299 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 320 / 1000 
 - time: 0.026897430419921875 - sq_loss: 801.3211690953992 - tot_loss: 804.5716798083173 - acc: 0.8450825733300469 - val_acc: 0.8690953030947213 - bacc_train: 0.8450825733300469 - bacc_test: 0.9037621591838881 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 321 / 1000 
 - time: 0.026836872100830078 - sq_loss: 785.9003052366484 - tot_loss: 789.0914471710652 - acc: 0.8450825733300469 - val_acc: 0.8691848105797847 - bacc_train: 0.8450825733300469 - bacc_test: 0.9038082304561463 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 322 / 1000 
 - time: 0.027075767517089844 - sq_loss: 770.7767016792642 - tot_los

Epoch 351 / 1000 
 - time: 0.026959896087646484 - sq_loss: 438.869052962359 - tot_loss: 442.5190497990956 - acc: 0.8463150110919399 - val_acc: 0.8701470160442167 - bacc_train: 0.8463150110919399 - bacc_test: 0.9050629336102513 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 352 / 1000 
 - time: 0.026897192001342773 - sq_loss: 430.43325500240013 - tot_loss: 434.0482647107666 - acc: 0.8464382548681292 - val_acc: 0.870102262301685 - bacc_train: 0.8464382548681292 - bacc_test: 0.9050398979741221 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 353 / 1000 
 - time: 0.026871919631958008 - sq_loss: 422.1599698445205 - tot_loss: 425.7686995885846 - acc: 0.8464382548681292 - val_acc: 0.870102262301685 - bacc_train: 0.8464382548681292 - bacc_test: 0.9050398979741221 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 354 / 1000 
 - time: 0.026995420455932617 - sq_loss: 414.0460625769832 - tot_loss

Epoch 383 / 1000 
 - time: 0.026950597763061523 - sq_loss: 235.93955924145266 - tot_loss: 240.07585090255603 - acc: 0.8473009613014543 - val_acc: 0.8717357739040927 - bacc_train: 0.8473009613014543 - bacc_test: 0.905880698692834 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 384 / 1000 
 - time: 0.026891469955444336 - sq_loss: 231.4116033398531 - tot_loss: 235.39241124663724 - acc: 0.8474242050776436 - val_acc: 0.8717805276466245 - bacc_train: 0.8474242050776436 - bacc_test: 0.9059037343289631 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 385 / 1000 
 - time: 0.02687668800354004 - sq_loss: 226.97080449393297 - tot_loss: 230.8485042895757 - acc: 0.8475474488538329 - val_acc: 0.8717357739040927 - bacc_train: 0.8475474488538328 - bacc_test: 0.905880698692834 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 386 / 1000 
 - time: 0.02693939208984375 - sq_loss: 222.6154817921617 - tot_lo

Epoch 415 / 1000 
 - time: 0.0268557071685791 - sq_loss: 126.99021530208185 - tot_loss: 130.7750299519154 - acc: 0.848410155287158 - val_acc: 0.87218331132941 - bacc_train: 0.848410155287158 - bacc_test: 0.9064907735427896 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 416 / 1000 
 - time: 0.026877164840698242 - sq_loss: 124.55834270482782 - tot_loss: 128.30355560153095 - acc: 0.848410155287158 - val_acc: 0.872160934458144 - bacc_train: 0.848410155287158 - bacc_test: 0.906479255724725 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 417 / 1000 
 - time: 0.02688145637512207 - sq_loss: 122.17322669689503 - tot_loss: 125.87906021158042 - acc: 0.848410155287158 - val_acc: 0.87218331132941 - bacc_train: 0.848410155287158 - bacc_test: 0.9064907735427896 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 418 / 1000 
 - time: 0.02689814567565918 - sq_loss: 119.8339688731924 - tot_loss: 123.501

Epoch 447 / 1000 
 - time: 0.027066469192504883 - sq_loss: 68.45654429799166 - tot_loss: 71.52799524776283 - acc: 0.8491496179442938 - val_acc: 0.8737273154467542 - bacc_train: 0.8491496179442939 - bacc_test: 0.907665221477908 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 448 / 1000 
 - time: 0.026951313018798828 - sq_loss: 67.1494319382077 - tot_loss: 70.20024288095107 - acc: 0.8491496179442938 - val_acc: 0.8738168229318177 - bacc_train: 0.8491496179442939 - bacc_test: 0.9077112927501662 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 449 / 1000 
 - time: 0.026988744735717773 - sq_loss: 65.86741508985253 - tot_loss: 68.89819941011905 - acc: 0.8492728617204831 - val_acc: 0.8739063304168811 - bacc_train: 0.8492728617204831 - bacc_test: 0.9077573640224244 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 450 / 1000 
 - time: 0.02681732177734375 - sq_loss: 64.61001036940655 - tot_loss:

Epoch 479 / 1000 
 - time: 0.027109622955322266 - sq_loss: 36.98193117428581 - tot_loss: 39.838028249761564 - acc: 0.8505052994823762 - val_acc: 0.8742643603571348 - bacc_train: 0.8505052994823761 - bacc_test: 0.9079416491114571 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 480 / 1000 
 - time: 0.027037620544433594 - sq_loss: 36.278625255485494 - tot_loss: 39.19493108773528 - acc: 0.8505052994823762 - val_acc: 0.874241983485869 - bacc_train: 0.8505052994823761 - bacc_test: 0.9079301312933925 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 481 / 1000 
 - time: 0.02687549591064453 - sq_loss: 35.58879674613469 - tot_loss: 38.57465765765141 - acc: 0.8502588119299975 - val_acc: 0.8742643603571348 - bacc_train: 0.8502588119299975 - bacc_test: 0.9079416491114571 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 482 / 1000 
 - time: 0.02704143524169922 - sq_loss: 34.912185657322034 - tot_lo

Epoch 511 / 1000 
 - time: 0.026921987533569336 - sq_loss: 20.03667485363298 - tot_loss: 23.921280708509638 - acc: 0.850012324377619 - val_acc: 0.8735483004766273 - bacc_train: 0.850012324377619 - bacc_test: 0.9071933604447269 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 512 / 1000 
 - time: 0.027278423309326172 - sq_loss: 19.657708880418276 - tot_loss: 23.394689558445084 - acc: 0.850012324377619 - val_acc: 0.8735483004766273 - bacc_train: 0.850012324377619 - bacc_test: 0.9071933604447269 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 513 / 1000 
 - time: 0.026868104934692383 - sq_loss: 19.285985454999558 - tot_loss: 22.8729352092388 - acc: 0.8501355681538082 - val_acc: 0.8735483004766273 - bacc_train: 0.8501355681538082 - bacc_test: 0.9071933604447269 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 514 / 1000 
 - time: 0.02689671516418457 - sq_loss: 18.921365457168715 - tot_los

Epoch 543 / 1000 
 - time: 0.026909828186035156 - sq_loss: 10.898620658352451 - tot_loss: 13.070603468089223 - acc: 0.8511215183633226 - val_acc: 0.8744209984559959 - bacc_train: 0.8511215183633227 - bacc_test: 0.9076425553492442 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 544 / 1000 
 - time: 0.026891708374023438 - sq_loss: 10.69402274877163 - tot_loss: 12.817253737779309 - acc: 0.851244762139512 - val_acc: 0.87439862158473 - bacc_train: 0.851244762139512 - bacc_test: 0.9076310375311796 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 545 / 1000 
 - time: 0.02687811851501465 - sq_loss: 10.493321093889833 - tot_loss: 12.57442252430368 - acc: 0.851244762139512 - val_acc: 0.8744433753272618 - bacc_train: 0.851244762139512 - bacc_test: 0.9076540731673087 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 546 / 1000 
 - time: 0.026856422424316406 - sq_loss: 10.29644088105952 - tot_loss:

Epoch 575 / 1000 
 - time: 0.026988983154296875 - sq_loss: 5.960269182848879 - tot_loss: 7.3601853700469295 - acc: 0.8513680059157013 - val_acc: 0.8752489426928327 - bacc_train: 0.8513680059157013 - bacc_test: 0.9084484331062972 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 576 / 1000 
 - time: 0.02692556381225586 - sq_loss: 5.849546475742573 - tot_loss: 7.239219850199742 - acc: 0.8513680059157013 - val_acc: 0.8752489426928327 - bacc_train: 0.8513680059157013 - bacc_test: 0.9084484331062972 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 577 / 1000 
 - time: 0.027158498764038086 - sq_loss: 5.740922787563467 - tot_loss: 7.121076470602315 - acc: 0.8513680059157013 - val_acc: 0.8752489426928327 - bacc_train: 0.8513680059157013 - bacc_test: 0.9084484331062972 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 578 / 1000 
 - time: 0.026833534240722656 - sq_loss: 5.634357998308755 - tot_lo

Epoch 607 / 1000 
 - time: 0.02688908576965332 - sq_loss: 3.284203417936153 - tot_loss: 4.449851477217679 - acc: 0.8514912496918906 - val_acc: 0.8759426257020744 - bacc_train: 0.8514912496918905 - bacc_test: 0.907666330000304 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 608 / 1000 
 - time: 0.026888608932495117 - sq_loss: 3.224086522328963 - tot_loss: 4.367895219303589 - acc: 0.8514912496918906 - val_acc: 0.8759202488308084 - bacc_train: 0.8514912496918905 - bacc_test: 0.9076548121822394 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 609 / 1000 
 - time: 0.026869773864746094 - sq_loss: 3.1651025043076286 - tot_loss: 4.271419589877799 - acc: 0.8514912496918906 - val_acc: 0.8758978719595426 - bacc_train: 0.8514912496918905 - bacc_test: 0.9076432943641748 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 610 / 1000 
 - time: 0.026801347732543945 - sq_loss: 3.1072297383534186 - tot_lo

Epoch 639 / 1000 
 - time: 0.027169227600097656 - sq_loss: 1.8286890678604808 - tot_loss: 2.804290131964573 - acc: 0.8514912496918906 - val_acc: 0.876323032513594 - bacc_train: 0.8514912496918905 - bacc_test: 0.9078621329074013 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 640 / 1000 
 - time: 0.026872634887695312 - sq_loss: 1.795909947247978 - tot_loss: 2.7692167240462866 - acc: 0.8514912496918906 - val_acc: 0.8764349168699233 - bacc_train: 0.8514912496918905 - bacc_test: 0.907919721997724 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 641 / 1000 
 - time: 0.026838064193725586 - sq_loss: 1.7637433723038245 - tot_loss: 2.7350861053847493 - acc: 0.8513680059157013 - val_acc: 0.8764572937411891 - bacc_train: 0.8513680059157013 - bacc_test: 0.9079312398157885 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 642 / 1000 
 - time: 0.02682805061340332 - sq_loss: 1.7321778118549154 - tot_

Epoch 671 / 1000 
 - time: 0.026841402053833008 - sq_loss: 1.0331861331325518 - tot_loss: 2.018612477534916 - acc: 0.852477199901405 - val_acc: 0.8773971223343552 - bacc_train: 0.8524771999014049 - bacc_test: 0.9080352696858347 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 672 / 1000 
 - time: 0.027080774307250977 - sq_loss: 1.015210228660431 - tot_loss: 2.0048967252884937 - acc: 0.8523539561252157 - val_acc: 0.877441876076887 - bacc_train: 0.8523539561252156 - bacc_test: 0.9080583053219637 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 673 / 1000 
 - time: 0.026847362518310547 - sq_loss: 0.9975665824170548 - tot_loss: 1.9924311884982926 - acc: 0.8523539561252157 - val_acc: 0.8774194992056211 - bacc_train: 0.8523539561252156 - bacc_test: 0.9076670690152344 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 674 / 1000 
 - time: 0.026801347732543945 - sq_loss: 0.980248985191687 - tot_

Epoch 703 / 1000 
 - time: 0.026855945587158203 - sq_loss: 0.5956101776854729 - tot_loss: 2.0275144476749833 - acc: 0.8523539561252157 - val_acc: 0.8798585781735998 - bacc_train: 0.8523539561252156 - bacc_test: 0.9077833557182758 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 704 / 1000 
 - time: 0.02713179588317871 - sq_loss: 0.5856798807930589 - tot_loss: 2.048113448650848 - acc: 0.852477199901405 - val_acc: 0.879970462529929 - bacc_train: 0.8524771999014049 - bacc_test: 0.9078409448085986 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 705 / 1000 
 - time: 0.026858091354370117 - sq_loss: 0.5759305460501664 - tot_loss: 2.071618165225964 - acc: 0.852477199901405 - val_acc: 0.8801047237575242 - bacc_train: 0.8524771999014049 - bacc_test: 0.9079100517169858 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 706 / 1000 
 - time: 0.027116060256958008 - sq_loss: 0.5663587847407893 - tot_l

Epoch 735 / 1000 
 - time: 0.027099132537841797 - sq_loss: 0.3529036504462594 - tot_loss: 1.564195287712384 - acc: 0.851244762139512 - val_acc: 0.875360827049162 - bacc_train: 0.851244762139512 - bacc_test: 0.9062277112646315 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 736 / 1000 
 - time: 0.0269167423248291 - sq_loss: 0.34736384216638605 - tot_loss: 1.5645246612605117 - acc: 0.851244762139512 - val_acc: 0.8753832039204279 - bacc_train: 0.851244762139512 - bacc_test: 0.9062392290826962 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 737 / 1000 
 - time: 0.02685379981994629 - sq_loss: 0.34192316227026925 - tot_loss: 1.5654801000338434 - acc: 0.851244762139512 - val_acc: 0.8753384501778961 - bacc_train: 0.851244762139512 - bacc_test: 0.906216193446567 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 738 / 1000 
 - time: 0.026835203170776367 - sq_loss: 0.3365797561568719 - tot_loss:

Epoch 767 / 1000 
 - time: 0.026886940002441406 - sq_loss: 0.2168142613997059 - tot_loss: 1.2214190858082143 - acc: 0.852477199901405 - val_acc: 0.8768377005527087 - bacc_train: 0.8524771999014049 - bacc_test: 0.9062284502795621 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 768 / 1000 
 - time: 0.026831626892089844 - sq_loss: 0.21368615474194513 - tot_loss: 1.213191464298925 - acc: 0.852477199901405 - val_acc: 0.8768377005527087 - bacc_train: 0.8524771999014049 - bacc_test: 0.9062284502795621 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 769 / 1000 
 - time: 0.026896238327026367 - sq_loss: 0.21061266705360038 - tot_loss: 1.2053119253741242 - acc: 0.852477199901405 - val_acc: 0.8768153236814429 - bacc_train: 0.8524771999014049 - bacc_test: 0.9062169324614976 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 770 / 1000 
 - time: 0.026801347732543945 - sq_loss: 0.20759282233662105 - 

Epoch 799 / 1000 
 - time: 0.02683734893798828 - sq_loss: 0.1394647571328427 - tot_loss: 1.026206819690992 - acc: 0.8528469312299729 - val_acc: 0.8768824542952405 - bacc_train: 0.8528469312299729 - bacc_test: 0.9058717674270265 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 800 / 1000 
 - time: 0.026889324188232422 - sq_loss: 0.13767058310524927 - tot_loss: 1.0215022920469674 - acc: 0.8528469312299729 - val_acc: 0.8769048311665063 - bacc_train: 0.8528469312299729 - bacc_test: 0.9058832852450911 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 801 / 1000 
 - time: 0.02690863609313965 - sq_loss: 0.13590676513091782 - tot_loss: 1.0165472069476613 - acc: 0.8528469312299729 - val_acc: 0.8769495849090381 - bacc_train: 0.8528469312299729 - bacc_test: 0.9059063208812201 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 802 / 1000 
 - time: 0.026865720748901367 - sq_loss: 0.13417278724500106 -

Epoch 831 / 1000 
 - time: 0.02691507339477539 - sq_loss: 0.09473871002974617 - tot_loss: 0.89359191768106 - acc: 0.8534631501109194 - val_acc: 0.8770838461366331 - bacc_train: 0.8534631501109193 - bacc_test: 0.9055957093009427 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 832 / 1000 
 - time: 0.027011632919311523 - sq_loss: 0.09368972736569052 - tot_loss: 0.8901049282926056 - acc: 0.8534631501109194 - val_acc: 0.8770614692653673 - bacc_train: 0.8534631501109193 - bacc_test: 0.9055841914828782 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 833 / 1000 
 - time: 0.02686929702758789 - sq_loss: 0.09265782642930627 - tot_loss: 0.8866975646234662 - acc: 0.8535863938871087 - val_acc: 0.8770390923941015 - bacc_train: 0.8535863938871087 - bacc_test: 0.9055726736648135 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 834 / 1000 
 - time: 0.02682948112487793 - sq_loss: 0.09164269463658983 - 

Epoch 863 / 1000 
 - time: 0.02748703956604004 - sq_loss: 0.06832654272222596 - tot_loss: 0.8009689588237296 - acc: 0.8535863938871087 - val_acc: 0.8775090066906845 - bacc_train: 0.8535863938871087 - bacc_test: 0.905814547844169 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 864 / 1000 
 - time: 0.02703547477722168 - sq_loss: 0.0676985279724186 - tot_loss: 0.798145823736542 - acc: 0.8535863938871087 - val_acc: 0.8775090066906845 - bacc_train: 0.8535863938871087 - bacc_test: 0.905814547844169 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 865 / 1000 
 - time: 0.027098417282104492 - sq_loss: 0.06708026482603294 - tot_loss: 0.7955721333453715 - acc: 0.8535863938871087 - val_acc: 0.8775313835619504 - bacc_train: 0.8535863938871087 - bacc_test: 0.9058260656622337 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 866 / 1000 
 - time: 0.027193069458007812 - sq_loss: 0.06647156499604942 - t

Epoch 895 / 1000 
 - time: 0.02686619758605957 - sq_loss: 0.05232364705596696 - tot_loss: 0.7337863161960845 - acc: 0.8539561252156767 - val_acc: 0.878068428472331 - bacc_train: 0.8539561252156767 - bacc_test: 0.9057227748071179 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 896 / 1000 
 - time: 0.026891469955444336 - sq_loss: 0.05193717962311518 - tot_loss: 0.7318750468998532 - acc: 0.8539561252156767 - val_acc: 0.878068428472331 - bacc_train: 0.8539561252156767 - bacc_test: 0.9057227748071179 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 897 / 1000 
 - time: 0.026996374130249023 - sq_loss: 0.051556339647290575 - tot_loss: 0.7300115913915783 - acc: 0.8539561252156767 - val_acc: 0.8781355590861286 - bacc_train: 0.8539561252156767 - bacc_test: 0.9057573282613116 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 898 / 1000 
 - time: 0.026875734329223633 - sq_loss: 0.05118106024128444

Epoch 927 / 1000 
 - time: 0.026890993118286133 - sq_loss: 0.042340744374382934 - tot_loss: 0.6841484664053481 - acc: 0.8539561252156767 - val_acc: 0.8784935890263823 - bacc_train: 0.8539561252156767 - bacc_test: 0.9059416133503444 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 928 / 1000 
 - time: 0.02683711051940918 - sq_loss: 0.04209540306861786 - tot_loss: 0.682771268947521 - acc: 0.8539561252156767 - val_acc: 0.8785607196401799 - bacc_train: 0.8539561252156767 - bacc_test: 0.905976166804538 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 929 / 1000 
 - time: 0.027025938034057617 - sq_loss: 0.04185339317210379 - tot_loss: 0.6813032773670802 - acc: 0.8539561252156767 - val_acc: 0.8785383427689141 - bacc_train: 0.8539561252156767 - bacc_test: 0.9059646489864734 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 930 / 1000 
 - time: 0.026871204376220703 - sq_loss: 0.04161464695492841

Epoch 959 / 1000 
 - time: 0.026848316192626953 - sq_loss: 0.03590306817479912 - tot_loss: 0.6435692205720964 - acc: 0.8543258565442445 - val_acc: 0.8790530108080288 - bacc_train: 0.8543258565442444 - bacc_test: 0.906229558801958 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 960 / 1000 
 - time: 0.02691364288330078 - sq_loss: 0.035741569985372444 - tot_loss: 0.6424012737699041 - acc: 0.8543258565442445 - val_acc: 0.8790753876792947 - bacc_train: 0.8543258565442444 - bacc_test: 0.9062410766200226 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 961 / 1000 
 - time: 0.026858806610107422 - sq_loss: 0.03558206415802435 - tot_loss: 0.6413259672370918 - acc: 0.8542026127680552 - val_acc: 0.8791425182930923 - bacc_train: 0.8542026127680552 - bacc_test: 0.9062756300742162 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 962 / 1000 
 - time: 0.02723383903503418 - sq_loss: 0.03542450802413671

Epoch 991 / 1000 
 - time: 0.026862382888793945 - sq_loss: 0.031587363782408284 - tot_loss: 0.6102353916876535 - acc: 0.8540793689918659 - val_acc: 0.8795900557184094 - bacc_train: 0.854079368991866 - bacc_test: 0.9057465494581778 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 992 / 1000 
 - time: 0.026906967163085938 - sq_loss: 0.031476666004825904 - tot_loss: 0.6092348442275307 - acc: 0.8540793689918659 - val_acc: 0.8796795632034728 - bacc_train: 0.854079368991866 - bacc_test: 0.9057926207304359 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 993 / 1000 
 - time: 0.0269315242767334 - sq_loss: 0.03136719962006792 - tot_loss: 0.6083092368966566 - acc: 0.8540793689918659 - val_acc: 0.8796795632034728 - bacc_train: 0.854079368991866 - bacc_test: 0.9057926207304359 - true_sparsity: 0.8957345971563981 - effective_sparsity: 0.8961374407582938
Epoch 994 / 1000 
 - time: 0.026932716369628906 - sq_loss: 0.031258940173284254

### Visualization of training results

In [None]:
plt.figure()
plt.plot(np.arange(1,niter+1), loss2)
plt.yscale('log',base=2)
plt.title('training loss')

plt.figure()
plt.plot(np.arange(1,niter+1), accuracy_train)
plt.plot(np.arange(1,niter+1), accuracy_test)
plt.title('accuracy')