In [None]:
import numpy as np
import pandas as pd
import multiprocessing as mp
from multiprocessing import Pool, freeze_support
from itertools import repeat
import time

def softT(m, t, zeta):
    """
    Soft-Threshold for solving univariate lasso.
    """
    # m: float; m_j in the thesis
    # t: float; truncation threshold
    # zeta: float; penalty factor, the same as the `lambda` in the thesis.
    if (zeta < abs(m)):
        z = (-m - zeta) / (2 * t) if (m < 0) else (-m + zeta) / (2 * t)
    else:
        z = 0
    return z

def unpack_args(X, Y, beta, beta0, zeta, chunks = 4):
    args_iterator = []
    n, _ = X.shape
    chunk_size = np.int(n/chunks) + 1
    for i in range(chunks):
        args_iterator.append((X[i*chunk_size:(i+1)*chunk_size], Y[i*chunk_size:(i+1)*chunk_size], beta, beta0, zeta))
    return args_iterator

def calculate_object_function(X, Y, beta, beta0, zeta):
    """
    Calculate objetc function value.
    @Input:
        X: dataframe; n * p
        Y: dataframe; n * 1
        beta: 2-d np.array; p*1
        beta0: float
        zeta: float
    @Output:
        fval: 2-d np.array; n*1
    """
    X_by_beta_plus_beta0 = np.matmul(X, beta) + beta0
    log_likelihood = np.sum(Y * X_by_beta_plus_beta0 - np.log(1 + np.exp(X_by_beta_plus_beta0)))
    fval = (-log_likelihood) / (2 * n) + zeta * np.sum(np.abs(beta)) 
    return fval


if __name__ == "__main__":
    data = pd.read_csv("./toyexample.csv")
    data = data.drop(["Unnamed: 0"], axis=1)
    X = data.drop('y', axis=1)
    Y = data.iloc[:,-1].values.reshape(-1,1)
    n, p = X.shape
    # rand_seed = np.random.RandomState(1234)
    # n, p = 1000, 100
    # X = rand_seed.beta(1,2,size=(n, p))
    # Y = rand_seed.binomial(n=1, p=0.2, size=n).reshape(-1,1)
    beta = np.ones(p).reshape(-1,1)
    beta0 = 1
    zeta = 0.1
    parallel = False
    print("Shape of X is n={}, p={}".format(n,p))
    serial_start = time.time()
    serila_result = calculate_object_function(X, Y, beta, beta0, zeta)
    serial_end = time.time()
    print("Serial took: [{}] s".format(serial_end - serial_start))
    if parallel:
        num_cores = mp.cpu_count()
        pool_start = time.time()
        args = unpack_args(X, Y, beta, beta0, zeta, chunks=10)
        t = time.time()
        with Pool(processes = num_cores) as pool:
            pool_result = pool.starmap(calculate_object_function, args)
        tt = time.time()
        pool_end = time.time()
        print("Pool took: [{}] s".format(pool_end - pool_start))
    
    try:
        print("Difference betwwen two results: {}".format((np.sum(pool_result) - np.sum(serila_result))/n))
        print(tt-t)
    except NameError:
        pass