### Using impyute
python package: https://pypi.org/project/impyute/
doc: https://impyute.readthedocs.io/en/master/

In [2]:
!pip install impyute

You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [4]:
import impyute as impy

In [5]:
data = np.array([[np.nan, 6, 3, 4, np.nan], [1, np.nan, 4, np.nan, np.nan], 
          [3, 7, 5, 6, 7], [2, 8, 2, np.nan, np.nan]])
data

array([[nan,  6.,  3.,  4., nan],
       [ 1., nan,  4., nan, nan],
       [ 3.,  7.,  5.,  6.,  7.],
       [ 2.,  8.,  2., nan, nan]])

#### 找出na的位置

In [7]:
def nan_indices(data):
    """ Finds the indices of all missing values.
    Parameters
    ----------
    data: numpy.ndarray
    Returns
    -------
    List of tuples
        Indices of all missing values in tuple format; (i, j)
    """
    return np.argwhere(np.isnan(data))

In [8]:
np.argwhere(np.isnan(data))

array([[0, 0],
       [0, 4],
       [1, 1],
       [1, 3],
       [1, 4],
       [3, 3],
       [3, 4]])

In [10]:
nan_xy

array([[0, 0],
       [0, 4],
       [1, 1],
       [1, 3],
       [1, 4],
       [3, 3],
       [3, 4]])

In [38]:
def em(data, loops=50):
    """ Imputes given data using expectation maximization.
    E-step: Calculates the expected complete data log likelihood ratio.
    M-step: Finds the parameters that maximize the log likelihood of the
    complete data.
    Parameters
    ----------
    data: numpy.nd.array
        Data to impute.
    loops: int
        Number of em iterations to run before breaking.
    inplace: boolean
        If True, operate on the numpy array reference
    Returns
    -------
    numpy.nd.array
        Imputed data.
    """
    nan_xy = nan_indices(data)
    for x_i, y_i in nan_xy:
        col = data[:, int(y_i)]
        mu = col[~np.isnan(col)].mean()
        std = col[~np.isnan(col)].std()
        col[x_i] = np.random.normal(loc=mu, scale=std)
        previous, i = 1, 1
        for i in range(loops):
            # Expectation
            mu = col[~np.isnan(col)].mean()
            std = col[~np.isnan(col)].std()
            # Maximization
            col[x_i] = np.random.normal(loc=mu, scale=std)
            # Break out of loop if likelihood doesn't change at least 10%
            # and has run at least 5 times
            delta = (col[x_i]-previous)/previous
            if i > 5 and delta < 0.1:
                data[x_i][y_i] = col[x_i]
                break
            data[x_i][y_i] = col[x_i]
            previous = col[x_i]
    return data

In [5]:
import numpy as np

data = np.array([[np.nan, 6, 3, 4, np.nan], [1, np.nan, 4, np.nan, np.nan], 
          [3, 7, 5, 6, 7], [2, 8, 2, np.nan, np.nan]])

# get NaN position
na_position = np.argwhere(np.isnan(data))

# get dims
nrow = data.shape[0]
ncol = data.shape[1]

# get whether na matrix
C = np.isnan(data).astype(int)
O = (np.isnan(data) == False).astype(int)

# get row mean
row_mean = np.nanmean(data, axis=1)

# get col mean after filling NaN with row mean
data2 = data.copy()
data2[np.isnan(data2)] = 0
energy_table = C * np.expand_dims(row_mean, axis=1) + O * data2
col_mean = np.nanmean(energy_table, axis=0)

# get global mean
G = np.sum(energy_table) / np.multiply(*energy_table.shape)

# update function
def get_new_value(pos, energy_table):
    row_part = ncol * (nrow*col_mean[pos[1]]-energy_table[pos[0], pos[1]])
    col_part = nrow * (ncol*row_mean[pos[0]]-energy_table[pos[0], pos[1]])
    all_part = nrow * ncol * G - energy_table[pos[0], pos[1]]
    V = ((row_part + col_part) - all_part) / ((nrow-1) * (ncol-1))
    return V

# calculate new NA element
converge_num = 0
energy_table_iter = np.copy(energy_table)
for i in range(500):
    error = 0
    # 每一次更新完table算好一版给下一代用
    row_mean = np.nanmean(data, axis=1)
    col_mean = np.nanmean(energy_table, axis=0)
    G = np.sum(energy_table) / np.multiply(*energy_table.shape)
    
    for pos in na_position:
        V = get_new_value(pos, energy_table_iter)
        error += np.abs(V - energy_table_iter[pos[0], pos[1]])
        energy_table[pos[0], pos[1]] = V
    energy_table_iter = np.copy(energy_table)
    break
    if error < 0.000001:
        converge_num += 1
        print("Error didn't change for", converge_num, "time.")
    if converge_num >= 5:
        print("Error don't change anymore, converge condition met.")
        break

In [42]:
energy_table

array([[2.80833336, 6.        , 3.        , 4.        , 4.68333333],
       [1.        , 4.26666664, 4.        , 2.51666667, 2.84999999],
       [3.        , 7.        , 5.        , 6.        , 7.        ],
       [2.        , 8.        , 2.        , 4.01666667, 4.34999999]])

In [4]:
energy_table

array([[1.79166667, 6.        , 3.        , 4.        , 4.91666667],
       [1.        , 5.44444444, 4.        , 2.52777778, 3.08333333],
       [3.        , 7.        , 5.        , 6.        , 7.        ],
       [2.        , 8.        , 2.        , 4.02777778, 4.58333333]])