In [1]:
import numpy as np
from scipy.special import expit

np.random.seed(123)
n = 100000
p = 100
x = np.random.normal(size=(n, p))
beta = np.random.normal(size=p)
prob = expit(-0.5 + x.dot(beta))  # p = 1 / (1 + exp(-x * beta))
y = np.random.binomial(1, prob, size=n)
one = np.ones(n).reshape(n,1)
dat = np.hstack((y.reshape(n, 1),one, x))
np.savetxt("logistic_data.txt", dat, fmt="%.8f", delimiter="\t")

In [2]:
import findspark
findspark.init("/Users/xinby/Library/Spark")
from pyspark.sql import SparkSession
# 本地模式
spark = SparkSession.builder.\
    master("local[4]").\
    appName("PySpark RDD").\
    getOrCreate()
sc = spark.sparkContext
sc.setLogLevel("ERROR")
print(spark)
print(sc)

23/05/03 19:03:44 WARN Utils: Your hostname, XinBys-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 10.65.116.79 instead (on interface en0)
23/05/03 19:03:44 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/05/03 19:03:44 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
<pyspark.sql.session.SparkSession object at 0x7fc681fcd1c0>
<SparkContext master=local[4] appName=PySpark RDD>


In [3]:
# load softplus
test = False
def softplus(x):
    # 此处插入代码
    if (test): 
        print(f"call func: softplus")

    ans = np.where(x>=0,np.log(1+np.exp(-x))+x,np.log(1+np.exp(x)))
    
    if (test): 
        print(f"return.shape:{ans.shape}")
        print(f"end func: softplus \n")
        
    return ans

# load data to rdd
def string_to_vector(line):
    vector = line.split("\t")
    return np.array(vector, dtype=float)

def partition_to_matrix (iterator):
    iterator_vec = map(string_to_vector, iterator)
    data = list(iterator_vec)
    if len(data) < 1:
        matrix = np.array([])
    else:
        matrix = np.vstack(data)
    yield matrix

file = sc.textFile("logistic_data.txt")

data_partition = file.mapPartitions(partition_to_matrix)
data_partition_nonempty = data_partition.filter(lambda x: x.shape[0] > 0)
data_partition_nonempty.cache()
data_partition_nonempty.count()


                                                                                

4

In [15]:
def compute_stats(part_mat, beta_old):
    # 提取 X 和 y
    y = part_mat[:, 0]
    x = part_mat[:,1:]

    # X * beta
    xb = x.dot(beta_old)

    # rho(X * beta)
    prob = expit(xb)

    # W 的对角线元素
    w = prob * (1.0 - prob) + 1e-6

    # X'W，数组广播操作，避免生成完整的 W
    xtw = x.transpose() * w

    # X'WX
    xtwx = xtw.dot(x)

    # X'Wz
    z = xb + (y - prob) / w

    xtwz = xtw.dot(z)

    # 目标函数：sum(y * log(prob) + (1 - y) * log(1 - prob))
    objfn = -np.sum(y * np.log(prob + 1e-8) + (1.0 - y) * np.log(1.0 - prob + 1e-8))
    return xtwx, xtwz, objfn

In [16]:


# iter computation

p = data_partition_nonempty.first().shape[1]-1 #subtract y
beta_hat = np.zeros(p)#initialization
object_values = [] #init

MaxIteration = 100 #iter settings
epsilon = 1e-6 #iter settings

for i in range(MaxIteration):
    if (test):
        print(f"start iter:{i}")
    # 完整数据的 X'WX 和 X'Wz 是各分区的加和
    xtwx, xtwz, objfn = data_partition_nonempty.\
        map(lambda part: compute_stats(part, beta_hat)).\
        reduce(lambda x, y: (x[0] + y[0], x[1] + y[1], x[2] + y[2]))
    # 计算新 beta
    beta_new = np.linalg.solve(xtwx, xtwz)
    if (test):
        print(f"bn{beta_new.shape}")
    # 计算 beta 的变化
    resid = np.linalg.norm(beta_new - beta_hat)
    if (np.mod(i,5)==0 or i==MaxIteration-1):
        print(f"Iteration {i}, objfn = {objfn}, resid = {resid}")
    object_values.append(objfn)
    # 如果 beta 几乎不再变化，退出循环
    if resid < epsilon:
        print(f"Iteration {i}, objfn = {objfn}, resid = {resid}")
        break
    # 更新 beta
    beta_hat = beta_new


Iteration 0, objfn = 69314.71605599453, resid = 1.569852188525537
Iteration 5, objfn = 12424.603668247559, resid = 1.3106763115298585
Iteration 9, objfn = 12248.19519231431, resid = 6.056474674176223e-08


In [18]:
x = dat[:,1:]
print(beta_hat)
xbetahat = x.dot(beta_hat)
probhat = expit(xbetahat)
print(probhat)

[-0.52097596 -0.59176196 -1.10420803  1.1546564   0.67280846  0.63931464
 -1.68220071  0.86041452 -0.69834251  1.22446393 -0.21062583 -0.60143279
  0.44213217  1.57506739  0.93504494  0.34283382 -0.63115954 -0.16737269
  1.03564847  0.9885046  -0.21736314  0.26608044 -1.9546613   0.93399147
 -0.44097986 -1.32382408 -1.06955406 -0.93365571 -0.47879284 -0.40975603
  0.13045673  0.72407648  0.43211279  0.78064486  0.12355277 -0.20116152
  1.34425232 -0.8467126  -1.57109621 -0.02174217  0.04202859  0.01757209
 -0.33735364 -1.74371023 -1.32742343 -1.60007017 -1.28377679  0.93921256
  0.93254572 -0.84857908 -1.08700601 -0.65543369 -1.52634259 -1.46037052
 -1.41541159  0.06736292 -2.06484347  0.25380448 -1.44377969 -0.45925857
 -1.12439824  1.24274755  0.72115264  0.46168381 -0.20588244  1.19789065
 -0.17370143  0.42621562  0.49622293 -0.29831759 -0.93076696 -2.52159476
  1.21260763 -0.40380623  0.41771335  0.75208194  1.5969521  -0.36537673
  0.40531527 -1.43161884 -0.46412512 -0.29281007 -1