### <span style="color:blue"> *---------------------------------------*</span>
# <span style="color:blue"> *UDV project - SVD based pruning*</span>
### <span style="color:blue"> *---------------------------------------*</span>

In [None]:
# Imports libraries

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd

from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from sklearn.impute import SimpleImputer

import pickle
import time
import matplotlib.pyplot as plt

In [None]:
# Set the global seed for reproducibility

public_seed = 529
torch.manual_seed(public_seed)
print("Current Seed is {0}".format(public_seed))

In [None]:
# Obtain device

from udvFunctions.udvDevice import get_device
device = get_device()

In [None]:
# Pre-process dataset

from udvFunctions.udvDatasetPreprocessing import HousingPriceDataset
from udvFunctions.udvDatasetPreprocessing import NYCDataset

In [None]:
# Prepare the single-connection layer that carry weight matrix 'w'

from udvFunctions.udvDiagonalLayer import D_singleConnection   

In [None]:
# Prepare networks

from udvFunctions.udvRegNetworks import UDV_net_1
from udvFunctions.udvRegNetworks import relu_net_1, fc_net_1

In [None]:
# Prepare loss function (optional)

from udvFunctions.udvLoss import UDV_Loss

In [None]:
# Set the validation loop

from udvFunctions.udvRegVal import reg_valLoop

In [None]:
# Other customised functions

from udvFunctions.udvOtherFunctions import store_metrics, take_avg, check_shapes

### <span style="color:blue"> *---------------------------------------*</span>
# <span style="color:blue"> *Set SVD-based pruning test (Without re-train):*</span>
### <span style="color:blue"> *---------------------------------------*</span>

In [None]:
# Dataset selection

dataset_choice = 1 # 0 indicates HousingPriceDataset; 1 points NYC taxi duration dataset
is_full = 0        # 0 indicates mini-batch; 1 points full batch size

# HP
if dataset_choice == 0:
    data_path = "./HP_Orig.csv"
    dataset = HousingPriceDataset(data_path)
    print("load the housing price dataset")
    pre_path = "./01_Orig/"
    file_list = ["Adam_0.001_H1_26_H2_5_BS128_E200_S1000",
                 "NAdam_0.001_H1_26_H2_5_BS128_E200_S1000",
                 "SGD_0.1_H1_26_H2_5_BS128_E200_S1000",
                 "SGDM_0.1_H1_26_H2_5_BS128_E200_S1000"]
    if is_full == 1:
        pre_path = pre_path = "./02_FullBatch/"
        file_list = ["Adam_0.001_H1_26_H2_5_BSfull_E200_S1000",
                     "NAdam_0.001_H1_26_H2_5_BSfull_E200_S1000",
                     "SGD_0.1_H1_26_H2_5_BSfull_E200_S1000",
                     "SGDM_0.1_H1_26_H2_5_BSfull_E200_S1000"]
        
# NYC
elif dataset_choice == 1:
    data_path = "./NYC_Orig.csv"
    dataset = NYCDataset(data_path)
    print("load the NYC taxi duration dataset")
    pre_path = "./01_Orig/"
    file_list = ["Adam_0.0001_H1_10_H2_2_BS128_E50_S100",
                 "NAdam_0.0001_H1_10_H2_2_BS128_E50_S100",
                 "SGD_1_H1_10_H2_2_BS128_E50_S100",
                 "SGDM_3_H1_10_H2_2_BS128_E50_S100"]  
    if is_full == 1:
        pre_path = pre_path = "./02_FullBatch/"
        file_list = ["Adam_0.0001_H1_10_H2_2_BSfull_E50_S100",
                     "NAdam_0.0001_H1_10_H2_2_BSfull_E50_S100",
                     "SGD_1_H1_10_H2_2_BSfull_E50_S100",
                     "SGDM_3_H1_10_H2_2_BSfull_E50_S100"] 
else:
    print("Wrong selection on dataset")


In [None]:
# Warning: raise Error if file_list is not match
# Revise the file_list according to the baseline results
from udvFunctions.udvOtherFunctions import CPU_Unpickler

for drop_file_index in range(len(file_list)): 
    # Set path and open file
    store_file_path = pre_path + file_list[drop_file_index] + '/SingleLayer/results.pkl'
    print("current file is: ", store_file_path)

    with open(store_file_path, 'rb') as file:
        variables = CPU_Unpickler(file).load()

    # Read parameters
    num_epochs = variables['num_epochs']
    num_seeds = variables['num_seeds']
    batch_size = variables['batch_size']

    optimiser_name = variables['optimiser_name']
    learning_rate = variables['learning_rate']

    num_input = variables['num_input']
    num_hidden_1 = variables['num_hidden_1']
    num_output = variables['num_output']   

    u_1_V_model_0 = variables['u_1_V_model_0']
    w_1_V_model_0 = variables['w_1_V_model_0']
    v_1_V_model_0 = variables['v_1_V_model_0']

    u_1_V_model_1 = variables['u_1_V_model_1']
    w_1_V_model_1 = variables['w_1_V_model_1']
    v_1_V_model_1 = variables['v_1_V_model_1']

    u_1_M_model_2 = variables['u_1_M_model_2']
    w_1_M_model_2 = variables['w_1_M_model_2']
    u_2_M_model_2 = variables['u_2_M_model_2']

    u_1_M_model_3 = variables['u_1_M_model_3']
    w_1_M_model_3 = variables['w_1_M_model_3']
    u_2_M_model_3 = variables['u_2_M_model_3']

    l_1_model_4 = variables['l_1_model_4']
    l_2_model_4 = variables['l_2_model_4']

    l_1_model_5 = variables['l_1_model_5']
    l_2_model_5 = variables['l_2_model_5']

    seed_index_list = list(range(num_seeds))

    # Use the same way as the baseline code (make sure the same validation dataset)
    torch.manual_seed(public_seed)
    training_ratio = 0.8
    train_size = int(training_ratio * len(dataset))
    validation_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, validation_size])
    
    # If the result comes from full batch size experiment
    if is_full == 0:
        train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
        val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False)
    else: 
        train_dataloader = DataLoader(train_dataset, batch_size = len(train_dataset), shuffle = True)
        val_dataloader = DataLoader(val_dataset, batch_size = len(val_dataset), shuffle = False)
    
    # loss_fn = UDV_Loss()    # MSE/2
    loss_fn = nn.MSELoss()  # MSE

    # Re-producible setting
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Create blank space for saving validation loss
    # This code gradually decrease number of hidden neurons and obtain validation loss
    val_losses_m_0 = [0] * num_hidden_1
    val_losses_m_1 = [0] * num_hidden_1
    val_losses_m_2 = [0] * num_hidden_1
    val_losses_m_3 = [0] * num_hidden_1

    # The SVD-based pruning results will be averaged by the same number of seeds as baseline code
    for seed_index in seed_index_list:
    
    # =======================Model_0======DONE=======Model_0====================================
        
        # SVD of model_0: Vector_uwv (UDV-v1)
        u1w1_m_0 = (torch.mul(w_1_V_model_0[seed_index].t(), u_1_V_model_0[seed_index])).t()
        U_m_0, S_m_0, Vh_m_0 = torch.linalg.svd(u1w1_m_0, full_matrices = False)

        # Obtain baseline with original structure and orignal weights 
        model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)
        with torch.no_grad():
            check_shapes(model_weight = model.fc1.weight, list_sample = u_1_V_model_0[seed_index])
            check_shapes(model_weight = model.diag1.weight, list_sample = w_1_V_model_0[seed_index])
            check_shapes(model_weight = model.fc2.weight, list_sample = v_1_V_model_0[seed_index])
            model.fc1.weight = nn.Parameter(u_1_V_model_0[seed_index].clone().detach())
            model.diag1.weight = nn.Parameter(w_1_V_model_0[seed_index].clone().detach())
            model.fc2.weight = nn.Parameter(v_1_V_model_0[seed_index].clone().detach())

        val_losses_m_0[0] += reg_valLoop(model = model,
                                         loss_fn = loss_fn,
                                         val_loader = val_dataloader,
                                         device = device)
        del model

        # Gradually decrease number of hidden neurons 
        # Accumulate validation loss after SVD-based pruning
        for num_hidden_new in range (num_hidden_1 - 1, 0, -1):
            U_m_0_drop = U_m_0.clone().detach()[:,:num_hidden_new]
            S_m_0_drop = S_m_0.clone().detach()[:num_hidden_new]
            Vh_m_0_drop = Vh_m_0.clone().detach()[:num_hidden_new,:]
            
            # load constrained model
            model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_new, num_output = num_output)
            # send re-constructed weight matrix to model (re-construct the weight matrix by truncated singular value decomposition)
            with torch.no_grad():
                check_shapes(model_weight = model.fc1.weight, list_sample = U_m_0_drop.t())
                check_shapes(model_weight = model.diag1.weight, list_sample = S_m_0_drop.unsqueeze(0))
                check_shapes(model_weight = model.fc2.weight, list_sample = (v_1_V_model_0[seed_index].clone().detach())@(Vh_m_0_drop.t()))
                model.fc1.weight = nn.Parameter(U_m_0_drop.t())
                model.diag1.weight = nn.Parameter(S_m_0_drop.unsqueeze(0))
                model.fc2.weight = nn.Parameter((v_1_V_model_0[seed_index].clone().detach())@(Vh_m_0_drop.t())) 

            val_losses_m_0[num_hidden_1-num_hidden_new] += reg_valLoop(model = model,
                                                                       loss_fn = loss_fn,
                                                                       val_loader = val_dataloader,
                                                                       device = device)
            del model, U_m_0_drop, S_m_0_drop, Vh_m_0_drop
        del u1w1_m_0, U_m_0, S_m_0, Vh_m_0
    # =======================Model_0======DONE=======Model_0====================================
    
    # =======================Model_1=================Model_1====================================    
        
        # UDV-v2
        u1w1_m_1 = (torch.mul(w_1_V_model_1[seed_index].t(), u_1_V_model_1[seed_index])).t()
        U_m_1, S_m_1, Vh_m_1 = torch.linalg.svd(u1w1_m_1, full_matrices = False)

        model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)
        with torch.no_grad():
            check_shapes(model_weight = model.fc1.weight, list_sample = u_1_V_model_1[seed_index])
            check_shapes(model_weight = model.diag1.weight, list_sample = w_1_V_model_1[seed_index])
            check_shapes(model_weight = model.fc2.weight, list_sample = v_1_V_model_1[seed_index])
            model.fc1.weight = nn.Parameter(u_1_V_model_1[seed_index].clone().detach())
            model.diag1.weight = nn.Parameter(w_1_V_model_1[seed_index].clone().detach())
            model.fc2.weight = nn.Parameter(v_1_V_model_1[seed_index].clone().detach())  
        val_losses_m_1[0] += reg_valLoop(model = model,
                                         loss_fn = loss_fn,
                                         val_loader = val_dataloader,
                                         device = device)
        del model
        
        for num_hidden_new in range (num_hidden_1 - 1, 0, -1):
            U_m_1_drop = U_m_1.clone().detach()[:,:num_hidden_new]
            S_m_1_drop = S_m_1.clone().detach()[:num_hidden_new]
            Vh_m_1_drop = Vh_m_1.clone().detach()[:num_hidden_new,:]

            model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_new, num_output = num_output)

            with torch.no_grad():
                check_shapes(model_weight = model.fc1.weight, list_sample = U_m_1_drop.t())
                check_shapes(model_weight = model.diag1.weight, list_sample = S_m_1_drop.unsqueeze(0))
                check_shapes(model_weight = model.fc2.weight, list_sample = (v_1_V_model_1[seed_index].clone().detach())@(Vh_m_1_drop.t()))
                model.fc1.weight = nn.Parameter(U_m_1_drop.t())
                model.diag1.weight = nn.Parameter(S_m_1_drop.unsqueeze(0))
                model.fc2.weight = nn.Parameter((v_1_V_model_1[seed_index].clone().detach())@(Vh_m_1_drop.t()))

            val_losses_m_1[num_hidden_1-num_hidden_new] += reg_valLoop(model = model,
                                                                       loss_fn = loss_fn,
                                                                       val_loader = val_dataloader,
                                                                       device = device)
            del model, U_m_1_drop, S_m_1_drop, Vh_m_1_drop
        del u1w1_m_1, U_m_1, S_m_1, Vh_m_1
    # =======================Model_1=======DONE======Model_1====================================

    # =======================Model_2=================Model_2====================================
        
        # UDV
        u1w1_m_2 = (torch.mul(w_1_M_model_2[seed_index].t(), u_1_M_model_2[seed_index])).t()
        U_m_2, S_m_2, Vh_m_2 = torch.linalg.svd(u1w1_m_2, full_matrices = False)

        model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)
        with torch.no_grad():
            check_shapes(model_weight = model.fc1.weight, list_sample = u_1_M_model_2[seed_index])
            check_shapes(model_weight = model.diag1.weight, list_sample = w_1_M_model_2[seed_index])
            check_shapes(model_weight = model.fc2.weight, list_sample = u_2_M_model_2[seed_index])
            model.fc1.weight = nn.Parameter(u_1_M_model_2[seed_index].clone().detach())
            model.diag1.weight = nn.Parameter(w_1_M_model_2[seed_index].clone().detach())
            model.fc2.weight = nn.Parameter(u_2_M_model_2[seed_index].clone().detach())        
        val_losses_m_2[0] += reg_valLoop(model = model,
                                          loss_fn = loss_fn,
                                          val_loader = val_dataloader,
                                          device = device)
        del model
        
        for num_hidden_new in range (num_hidden_1 - 1, 0, -1):
            U_m_2_drop = U_m_2.clone().detach()[:,:num_hidden_new]
            S_m_2_drop = S_m_2.clone().detach()[:num_hidden_new]
            Vh_m_2_drop = Vh_m_2.clone().detach()[:num_hidden_new,:]

            model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_new, num_output = num_output)

            with torch.no_grad():
                check_shapes(model_weight = model.fc1.weight, list_sample = U_m_2_drop.t())
                check_shapes(model_weight = model.diag1.weight, list_sample = S_m_2_drop.unsqueeze(0))
                check_shapes(model_weight = model.fc2.weight, list_sample = (u_2_M_model_2[seed_index].clone().detach())@(Vh_m_2_drop.t()))
                model.fc1.weight = nn.Parameter(U_m_2_drop.t())
                model.diag1.weight = nn.Parameter(S_m_2_drop.unsqueeze(0))
                model.fc2.weight = nn.Parameter((u_2_M_model_2[seed_index].clone().detach())@(Vh_m_2_drop.t()))

            val_losses_m_2[num_hidden_1-num_hidden_new] += reg_valLoop(model = model,
                                                                       loss_fn = loss_fn,
                                                                       val_loader = val_dataloader,
                                                                       device = device)
            del model, U_m_2_drop, S_m_2_drop, Vh_m_2_drop
        del u1w1_m_2, U_m_2, S_m_2, Vh_m_2
    # =======================Model_2======DONE=======Model_2====================================   

    # =======================Model_3=================Model_3====================================
        
        # UDV-s
        u1w1_m_3 = (torch.mul(w_1_M_model_3[seed_index].t(), u_1_M_model_3[seed_index])).t()
        U_m_3, S_m_3, Vh_m_3 = torch.linalg.svd(u1w1_m_3, full_matrices = False)

        model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_1, num_output = num_output)
        with torch.no_grad():
            check_shapes(model_weight = model.fc1.weight, list_sample = u_1_M_model_3[seed_index])
            check_shapes(model_weight = model.diag1.weight, list_sample = w_1_M_model_3[seed_index])
            check_shapes(model_weight = model.fc2.weight, list_sample = u_2_M_model_3[seed_index])
            model.fc1.weight = nn.Parameter(u_1_M_model_3[seed_index].clone().detach())
            model.diag1.weight = nn.Parameter(w_1_M_model_3[seed_index].clone().detach()) 
            model.fc2.weight = nn.Parameter(u_2_M_model_3[seed_index].clone().detach()) 
        val_losses_m_3[0] += reg_valLoop(model = model,
                                         loss_fn = loss_fn,
                                         val_loader = val_dataloader,
                                         device = device)
        del model

        for num_hidden_new in range (num_hidden_1 - 1, 0, -1):
            U_m_3_drop = U_m_3.clone().detach()[:,:num_hidden_new]
            S_m_3_drop = S_m_3.clone().detach()[:num_hidden_new]
            Vh_m_3_drop = Vh_m_3.clone().detach()[:num_hidden_new,:]

            model = UDV_net_1(num_input = num_input, num_hidden_1 = num_hidden_new, num_output = num_output)

            with torch.no_grad():
                check_shapes(model_weight = model.fc1.weight, list_sample = U_m_3_drop.t())
                check_shapes(model_weight = model.diag1.weight, list_sample = S_m_3_drop.unsqueeze(0))
                check_shapes(model_weight = model.fc2.weight, list_sample = (u_2_M_model_3[seed_index].clone().detach())@(Vh_m_3_drop.t()))
                model.fc1.weight = nn.Parameter(U_m_3_drop.t())
                model.diag1.weight = nn.Parameter(S_m_3_drop.unsqueeze(0))
                model.fc2.weight = nn.Parameter((u_2_M_model_3[seed_index].clone().detach())@(Vh_m_3_drop.t()))

            val_losses_m_3[num_hidden_1-num_hidden_new] += reg_valLoop(model = model,
                                                                        loss_fn = loss_fn,
                                                                        val_loader = val_dataloader,
                                                                        device = device)
            del model, U_m_3_drop, S_m_3_drop, Vh_m_3_drop
        del u1w1_m_3, U_m_3, S_m_3, Vh_m_3
    # =======================Model_3=======DONE======Model_3====================================

    # Avergae the accumulate validation loss
    for avg_index in range(len(val_losses_m_0)):
        val_losses_m_0[avg_index] /= num_seeds
        val_losses_m_1[avg_index] /= num_seeds
        val_losses_m_2[avg_index] /= num_seeds
        val_losses_m_3[avg_index] /= num_seeds

    # Convert data to percentage compared to the baseline
    val_losses_m_0_plot = [(x - val_losses_m_0[0]) / val_losses_m_0[0] for x in val_losses_m_0]
    val_losses_m_1_plot = [(x - val_losses_m_1[0]) / val_losses_m_1[0] for x in val_losses_m_1]
    val_losses_m_2_plot = [(x - val_losses_m_2[0]) / val_losses_m_2[0] for x in val_losses_m_2]
    val_losses_m_3_plot = [(x - val_losses_m_3[0]) / val_losses_m_3[0] for x in val_losses_m_3]
    
    print("The file has been run: ", store_file_path)

    with open('XXXXXX.txt', 'a+') as file: # Create or add text to the file
        file.seek(0, 2) # Find the end of the file
        for write_index in range(0, num_hidden_1):
            file.write("{0}\t".format(val_losses_m_0[write_index]))
            file.write("{0}\t".format(val_losses_m_1[write_index]))
            file.write("{0}\t".format(val_losses_m_2[write_index]))
            file.write("{0}\t".format(val_losses_m_3[write_index]))
            file.write("{0}\t".format(num_hidden_1 - write_index))
            file.write("{0}\t".format(val_losses_m_0_plot[write_index]))
            file.write("{0}\t".format(val_losses_m_1_plot[write_index]))
            file.write("{0}\t".format(val_losses_m_2_plot[write_index]))
            file.write("{0}\t".format(val_losses_m_3_plot[write_index]))
            file.write("\n")
        file.write("\n\n\n\n\n")

print("all done")