## Import libraries

In [1]:
# !pip install tensorly

import tensorly as tl
from tensorly.decomposition import parafac
from tensorly.decomposition import tucker
from tensorly.decomposition import tensor_train
from math import ceil
from tensorly import tt_to_tensor
from tensorly.decomposition import matrix_product_state


In [2]:
"""
5 runs of 50 epochs, seed = 10, 20, 30, 40, 50;
validation accuracies: 0.9492, 0.9457, 0.9463, 0.9439, 0.9455
"""
#from __future__ import print_function, division
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
import os
import copy
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import balanced_accuracy_score

print("PyTorch Version:", torch.__version__)
print("Torchvision Version:", torchvision.__version__)
print("GPU is available?", torch.cuda.is_available())

dtype = torch.float64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

PyTorch Version: 1.13.1
Torchvision Version: 0.14.1
GPU is available? True


## train data processing

In [3]:
df_train_1 = pd.read_csv('/home/c/cl237/Datasets/Flare_LSTM_dataset/M/normalized_training.csv')
df_train_2 = pd.read_csv('/home/c/cl237/Datasets/Flare_LSTM_dataset/M/normalized_validation.csv')
# concat train and validation into train. we don't need validation set
df_train = df_train_1.append(df_train_2).reset_index(drop=True)

# undersampling for "negative" samples
negative_df = df_train[df_train['label'] == 'Negative']
positive_df = df_train[df_train['label'] == 'Positive']

# Get number of "positive" samples
num_positive = len(positive_df)

# take the same number of "negative" samples as there are "positive" samples
balanced_negative_df = negative_df.sample(n=num_positive, random_state=10086)

df_train_balanced = pd.concat([positive_df, balanced_negative_df])
print(df_train_balanced['label'].value_counts())

Positive    4057
Negative    4057
Name: label, dtype: int64


In [4]:
pd_X_train = df_train_balanced.iloc[:, 5:]
pd_y_train = df_train_balanced.iloc[:, 0]

N = len(pd_X_train)
K = 2

pd_X_train = pd_X_train.values
X_train = torch.tensor(pd_X_train, dtype=dtype, device=device)
X_train = torch.t(X_train)



# Initialize the LabelEncoder
encoder = LabelEncoder()
# Fit and transform the y values
y = encoder.fit_transform(pd_y_train.values)
y_train=torch.tensor(y, dtype=torch.long, device=device)
y_train = torch.flatten(y_train)

y_one_hot = torch.zeros(N, K, device=device).scatter_(1, y_train.unsqueeze(1), 1)
y_one_hot = torch.t(y_one_hot).to(device=device)

print(list(encoder.classes_))

['Negative', 'Positive']


## test data processing

In [5]:
df_test = pd.read_csv('/home/c/cl237/Datasets/Flare_LSTM_dataset/M/normalized_testing.csv')


In [6]:
pd_X_test = df_test.iloc[:, 5:]
pd_y_test = df_test.iloc[:, 0]

N_test = len(pd_X_test)
K = 2

pd_X_test = pd_X_test.values
X_test = torch.tensor(pd_X_test, dtype=dtype, device=device)
X_test = torch.t(X_test)



# Initialize the LabelEncoder
encoder = LabelEncoder()
# Fit and transform the y values
y = encoder.fit_transform(pd_y_test.values)
y_test=torch.tensor(y, dtype=torch.long, device=device)
y_test = torch.flatten(y_test)

y_test_one_hot = torch.zeros(N_test, K, device=device).scatter_(1, y_test.unsqueeze(1), 1)
y_test_one_hot = torch.t(y_test_one_hot).to(device=device)

In [7]:
print(df_test['label'].value_counts())

Negative    43411
Positive     1278
Name: label, dtype: int64


## Main algorithm

### Define functions for updating blocks

In [9]:
def updateV(U1,U2,W,b,rho,gamma):
    _, d = W.size()
    I = torch.eye(d, device=device)
    U1 = nn.ReLU()(U1)
    _, col_U2 = U2.size()
    Vstar = torch.mm(torch.inverse(rho*(torch.mm(torch.t(W),W))+gamma*I), rho*torch.mm(torch.t(W),U2-b.repeat(1,col_U2))+gamma*U1)
    return Vstar

In [10]:
def updateWb_org(U, V, W, b, alpha, rho):
    d,N = V.size()
    I = torch.eye(d, device=device)
    _, col_U = U.size()
    Wstar = torch.mm(alpha*W+rho*torch.mm(U-b.repeat(1,col_U),torch.t(V)),torch.inverse(alpha*I+rho*(torch.mm(V,torch.t(V)))))
    bstar = 0*(alpha*b+rho*torch.sum(U-torch.mm(Wstar,V), dim=1).reshape(b.size()))/(rho*N+alpha)
    return Wstar, bstar

In [11]:
def updateWb(U, V, W, b, W_tensor_rec, alpha, rho,tau):
    W_tensor_rec = torch.as_tensor(W_tensor_rec,device=device).float()
    W_tensor2matrix = W_tensor_rec.reshape(W.shape)
    d,N = V.size()
    I = torch.eye(d, device=device)
    _, col_U = U.size()
    Wstar = torch.mm(alpha*W+tau*W_tensor2matrix+rho*torch.mm(U-b.repeat(1,col_U),torch.t(V)),torch.inverse((alpha+tau)*I+rho*(torch.mm(V,torch.t(V)))))
    bstar = 0*(alpha*b+rho*torch.sum(U-torch.mm(Wstar,V), dim=1).reshape(b.size()))/(rho*N+alpha)
    return Wstar, bstar

In [12]:
def updateWbsparse(U, V, W, b,  W_tensor2matrix, alpha, rho,tau):
    d,N = V.size()
    I = torch.eye(d, device=device)
    _, col_U = U.size()
    Wstar = torch.mm(alpha*W+tau*W_tensor2matrix+rho*torch.mm(U-b.repeat(1,col_U),torch.t(V)),torch.inverse((alpha+tau)*I+rho*(torch.mm(V,torch.t(V)))))
    bstar = 0*(alpha*b+rho*torch.sum(U-torch.mm(Wstar,V), dim=1).reshape(b.size()))/(rho*N+alpha)
    return Wstar, bstar

### Define the proximal operator of the ReLU activation function

In [13]:
def relu_prox(a, b, gamma, d, N):
    val = torch.empty(d,N, device=device)
    x = (a+gamma*b)/(1+gamma)
    y = torch.min(b,torch.zeros(d,N, device=device))

    val = torch.where(a+gamma*b < 0, y, torch.zeros(d,N, device=device))
    val = torch.where(((a+gamma*b >= 0) & (b >=0)) | ((a*(gamma-np.sqrt(gamma*(gamma+1))) <= gamma*b) & (b < 0)), x, val)
    val = torch.where((-a <= gamma*b) & (gamma*b <= a*(gamma-np.sqrt(gamma*(gamma+1)))), b, val)
    return val

### Effective Sparsity

In [14]:
def process_weights(W1, W2, W3):
    # Clone the tensors to keep the originals unchanged
    W1, W2, W3 = W1.clone(), W2.clone(), W3.clone()

    while True:
        # Store a copy of the current weights
        old_W1, old_W2, old_W3 = W1.clone(), W2.clone(), W3.clone()

        # Check all rows of W1
        # If all values in a row of W1 are 0, set corresponding column in W2 to 0
        zero_rows_W1 = torch.all(W1 == 0, dim=1)
        W2[:, zero_rows_W1] = 0

        # Check all columns of W2
        # If all values in a column of W2 are 0, set corresponding row in W1 to 0
        zero_cols_W2 = torch.all(W2 == 0, dim=0)
        W1[zero_cols_W2, :] = 0

        # Check all rows of W2
        # If all values in a row of W2 are 0, set corresponding column in W3 to 0
        zero_rows_W2 = torch.all(W2 == 0, dim=1)
        W3[:, zero_rows_W2] = 0

        # Check if matrices are unchanged
        if torch.equal(W1, old_W1) and torch.equal(W2, old_W2) and torch.equal(W3, old_W3):
            break

    return W1, W2, W3

### Parameter initialization

In [None]:


niter = 1000


setups = [
    {'sparsity': 0.5, 'tau': 50, 'gamma': 0.01, 'rho': 0.1, 'alpha': 0.5},
    {'sparsity': 0.75, 'tau': 300, 'gamma': 0.1, 'rho': 0.5, 'alpha': 1},
    {'sparsity': 0.9, 'tau': 200, 'gamma': 100, 'rho': 0.5, 'alpha': 0.5},
    {'sparsity': 0.95, 'tau': 50, 'gamma': 10, 'rho': 5, 'alpha': 0.5},
    {'sparsity': 0.99, 'tau': 500, 'gamma': 100, 'rho': 100, 'alpha': 1}
    #{'sparsity': 0.995, 'tau': 100, 'gamma': 5, 'rho': 1, 'alpha': 1},
    #{'sparsity': 0.995, 'tau': 5, 'gamma': 0.5, 'rho': 0.1, 'alpha': 0.1},
]

for setup in setups:
    sparsity = setup['sparsity']
    tau = setup['tau']
    gamma = setup['gamma']
    rho = setup['rho']
    alpha = setup['alpha']
    
    # Do something with these parameters
    print(sparsity, tau, gamma, rho, alpha)

    loss1 = np.empty(niter)
    loss2 = np.empty(niter)
    accuracy_train = np.empty(niter)
    accuracy_test = np.empty(niter)
    time1 = np.empty(niter)
    bacc_train = np.empty(niter)
    bacc_test = np.empty(niter)
    true_sparsity = np.empty(niter)
    effective_sparsity = np.empty(niter)

    results = torch.zeros(10, 9, niter)

    for Out_iter in range(10):
        rank_initial = 700
        seed = 10 + 10*Out_iter
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.manual_seed(seed)

        d0 = 40
        d1 =  300
        d2 =  100
        d3 = K # Layers: input + 2 hidden + output

        W1 = 0.01*torch.randn(d1, d0,dtype=torch.float64, device=device)
        b1 = 0*torch.ones(d1, 1, dtype=torch.float64,device=device)
        W1square = torch.square(W1)
        Threshold = torch.quantile(torch.reshape(W1square, (-1,)), sparsity, interpolation='linear')
        W1sparse = W1
        W1sparse[W1square < Threshold] =  0

        W2 = 0.01*torch.randn(d2, d1, dtype=torch.float64,device=device)
        b2 = 0*torch.ones(d2, 1, dtype=torch.float64,device=device)
        W2square = torch.square(W2)
        Threshold = torch.quantile(torch.reshape(W2square, (-1,)), sparsity, interpolation='linear')
        W2sparse = W2
        W2sparse[W2square < Threshold] =  0


        W3 = 0.01*torch.randn(d3, d2, dtype=torch.float64,device=device)
        b3 = 0*torch.ones(d3, 1, dtype=torch.float64,device=device)
        # W3square = torch.square(W3)
        # Threshold = torch.quantile(torch.reshape(W3square, (-1,)), sparsity, interpolation='linear')
        # W3sparse = W3
        # W3sparse[W3square < Threshold] =  0


        U1 = torch.addmm(b1.repeat(1, N), W1, X_train)
        V1 = nn.ReLU()(U1)
        U2 = torch.addmm(b2.repeat(1, N), W2, V1)
        V2 = nn.ReLU()(U2)
        U3 = torch.addmm(b3.repeat(1, N), W3, V2)
        V3 = U3
        # U4 = torch.addmm(b4.repeat(1, N), W4, V3)
        # V4 = U4


        print('Train on', N, 'samples, validate on', N_test, 'samples')

        for k in range(niter):

            start = time.time()

            # update V4
            V3 = (y_one_hot + gamma*U3 + alpha*V3)/(1 + gamma + alpha)

            # update U4
            U3 = (gamma*V3 + rho*(torch.mm(W3,V2) + b3.repeat(1,N)))/(gamma + rho)

            # update W4 and b4
            W3, b3 = updateWb_org(U3,V2,W3,b3, alpha,rho)
            # W3square = torch.square(W3)
            # Threshold = torch.quantile(torch.reshape(W3square, (-1,)), sparsity, interpolation='linear')
            # W3sparse = W3
            # W3sparse[W3square < Threshold] =  0
            # update V2
            V2 = updateV(U2,U3,W3,b3,rho,gamma)

            # update U2
            U2 = relu_prox(V2,(rho*torch.addmm(b2.repeat(1,N), W2, V1) + alpha*U2)/(rho + alpha),(rho + alpha)/gamma,d2,N)

            # update W2 and b2

            W2, b2 = updateWbsparse(U2,V1,W2,b2,W2sparse, alpha,rho,tau)
            W2square = torch.square(W2)
            Threshold = torch.quantile(torch.reshape(W2square, (-1,)), sparsity, interpolation='linear')
            W2sparse = W2
            W2sparse[W2square < Threshold] =  0

            # update V1
            V1 = updateV(U1,U2,W2,b2,rho,gamma)

            # update U1
            U1 = relu_prox(V1,(rho*torch.addmm(b1.repeat(1,N), W1, X_train) + alpha*U1)/(rho + alpha),(rho + alpha)/gamma,d1,N)

            # update W1 and b1
            W1, b1 = updateWbsparse(U1,X_train,W1,b1,W1sparse, alpha,rho,tau)
            W1square = torch.square(W1)
            Threshold = torch.quantile(torch.reshape(W1square, (-1,)), sparsity, interpolation='linear')
            W1sparse = W1
            W1sparse[W1square < Threshold] =  0

            a1_train = nn.ReLU()(torch.addmm(b1.repeat(1, N), W1sparse, X_train)).double()
            a2_train = nn.ReLU()(torch.addmm(b2.repeat(1, N),  W2sparse, a1_train)).double()
            pred = torch.argmax(torch.addmm(b3.repeat(1, N), W3, a2_train), dim=0)

            a1_test = nn.ReLU()(torch.addmm(b1.repeat(1, N_test), W1sparse, X_test)).double()
            a2_test = nn.ReLU()(torch.addmm(b2.repeat(1, N_test),  W2sparse, a1_test)).double()
            pred_test = torch.argmax(torch.addmm(b3.repeat(1, N_test), W3, a2_test), dim=0)


            loss1[k] = gamma/2*torch.pow(torch.dist(V3,y_one_hot,2),2).cpu().numpy()
            loss2[k] = loss1[k] + rho/2*torch.pow(torch.dist(torch.addmm(b1.repeat(1,N), W1sparse, X_train),U1,2),2).cpu().numpy() \
            +rho/2*torch.pow(torch.dist(torch.addmm(b2.repeat(1,N),  W2sparse, V1),U2,2),2).cpu().numpy() \
            +rho/2*torch.pow(torch.dist(torch.addmm(b3.repeat(1,N), W3, V2),U3,2),2).cpu().numpy() \
            + gamma/2*torch.pow(torch.dist(V1,nn.ReLU()(U1),2),2).cpu().numpy() \
            + gamma/2*torch.pow(torch.dist(V2,nn.ReLU()(U2),2),2).cpu().numpy() \
            + gamma/2*torch.pow(torch.dist(V3,U3,2),2).cpu().numpy() \
            + tau/2*torch.pow(torch.dist(W1,W1sparse,2),2).cpu().numpy()\
            + tau/2*torch.pow(torch.dist(W2,W2sparse,2),2).cpu().numpy()\
            # + tau/2*torch.pow(torch.dist(W3,W3sparse,2),2).cpu().numpy()

            # compute training accuracy
            correct_train = pred == y_train
            accuracy_train[k] = np.mean(correct_train.cpu().numpy())

            # compute validation accuracy
            correct_test = pred_test == y_test
            accuracy_test[k] = np.mean(correct_test.cpu().numpy())


            pred_train_np = pred.cpu().numpy()
            y_train_np = y_train.cpu().numpy()
            pred_test_np = pred_test.cpu().numpy()
            y_test_np = y_test.cpu().numpy()

            bacc_train[k] = balanced_accuracy_score(y_train_np, pred_train_np)
            bacc_test[k] = balanced_accuracy_score(y_test_np, pred_test_np)

            # compute training time
            stop = time.time()
            duration = stop - start
            time1[k] = duration
            
            # sparsity
            num_zeros_W1 = torch.sum(W1 == 0).item()
            num_zeros_W2 = torch.sum(W2 == 0).item()
            num_zeros_W3 = torch.sum(W3 == 0).item()
            total_zeros_old = num_zeros_W1 + num_zeros_W2 + num_zeros_W3
            total_weights = d0*d1+d1*d2+d2*d3
            true_sparsity[k] = total_zeros_old / total_weights
            
            new_W1, new_W2, new_W3 = process_weights(W1, W2, W3)
            num_zeros_W1_new = torch.sum(new_W1 == 0).item()
            num_zeros_W2_new = torch.sum(new_W2 == 0).item()
            num_zeros_W3_new = torch.sum(new_W3 == 0).item()
            total_zeros = num_zeros_W1_new + num_zeros_W2_new + num_zeros_W3_new
            total_weights = d0*d1+d1*d2+d2*d3
            effective_sparsity[k] = total_zeros / total_weights

            
            # print results
            print('Epoch', k + 1, '/', niter, '\n',
                  '-', 'time:', time1[k], '-', 'sq_loss:', loss1[k], '-', 'tot_loss:', loss2[k],
                  '-', 'acc:', accuracy_train[k], '-', 'val_acc:', accuracy_test[k],
                  '-', 'bacc_train:', bacc_train[k], '-', 'bacc_test:', bacc_test[k],'-', 'true_sparsity:', true_sparsity[k],
                 '-', 'effective_sparsity:', effective_sparsity[k])


        results[Out_iter,0,:] = torch.tensor(loss1)
        results[Out_iter,1,:] = torch.tensor(loss2)
        results[Out_iter,2,:] = torch.tensor(accuracy_train)
        results[Out_iter,3,:] = torch.tensor(accuracy_test)
        results[Out_iter,4,:] = torch.tensor(time1)
        results[Out_iter,5,:] = torch.tensor(bacc_train)
        results[Out_iter,6,:] = torch.tensor(bacc_test)
        results[Out_iter,7,:] = torch.tensor(true_sparsity)
        results[Out_iter,8,:] = torch.tensor(effective_sparsity)



0.5 50 0.01 0.1 0.5
Train on 8114 samples, validate on 44689 samples
Epoch 1 / 1000 
 - time: 1.104189157485962 - sq_loss: 4.6281087727211565 - tot_loss: 19.646084322847685 - acc: 0.7162928272122258 - val_acc: 0.5763610731947459 - bacc_train: 0.7162928272122258 - bacc_test: 0.7762488910720471 - true_sparsity: 0.4976303317535545 - effective_sparsity: 0.4976303317535545
Epoch 2 / 1000 
 - time: 0.06824183464050293 - sq_loss: 0.5654967182919859 - tot_loss: 24.224376075780096 - acc: 0.7509243283214198 - val_acc: 0.6537179171608226 - bacc_train: 0.7509243283214198 - bacc_test: 0.8137876771891939 - true_sparsity: 0.4976303317535545 - effective_sparsity: 0.4976303317535545
Epoch 3 / 1000 
 - time: 0.06369447708129883 - sq_loss: 0.08182189554913562 - tot_loss: 23.894221266564035 - acc: 0.7593049051022923 - val_acc: 0.6690236971066705 - bacc_train: 0.7593049051022924 - bacc_test: 0.8216658647453432 - true_sparsity: 0.4976303317535545 - effective_sparsity: 0.4976303317535545
Epoch 4 / 1000 
 - t