In [86]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
import pandas as pd

### Global Variables

In [252]:
batch_size = 128
alpha = 100
epoch = 10000
p_hint = 0.9
miss_prob = 0.2

### Data Loading

In [269]:
def load_data(filename):
    return pd.read_csv(filename, sep=",").values

def mask_data(missed_data):
    mask = np.random.binomial(n=1, p=1-miss_prob, size = data.shape)
    """
    n_rows, n_cols = data.shape
    v = np.random.uniform(size=(n_rows, n_cols))
    missed_cols = np.random.choice(n_cols, int(n_cols / 2), replace=False)
    c = np.zeros(n_cols, dtype=bool)
    c[missed_cols] = True
    mask = (v <= 0.2) * c
    mask = 1-mask
    """
    return mask

def normalize(data):
    params = [np.nanmin(data, axis=0), np.nanmax(data, axis=0)]
    data = (data - params[0] + 1e-6)/(params[1] - params[0] + 1e-6)
    return data, params

def denormalize(data, params):
    data_back = data*(params[1] - params[0]) + params[0]
    return data

data = load_data("./Dataset/magic.csv")
n,d = data.shape
data = np.array(data, dtype=np.float64)
n,d = data.shape
mask = mask_data(data)
data_masked = np.copy(data)
data_masked[mask == 0] = np.nan
data_masked, params = normalize(data_masked)

### Build Model

In [270]:
def init_weights(shape):
    value = np.random.normal(scale=1/np.sqrt((shape[0]/2)), size=shape)
    return torch.tensor(value, requires_grad=True)

def generator(X_tilde,M):
    inputs = torch.cat([X_tilde,M], dim=1)
    G_h1 = F.relu(torch.matmul(inputs, G_W1) + G_b1)
    G_h2 = F.relu(torch.matmul(G_h1, G_W2) + G_b2)
    G_prob = torch.sigmoid(torch.matmul(G_h2, G_W3) + G_b3)
    return G_prob

def discriminator(X_hat,H):
    inputs = torch.cat([X_hat,H], dim=1)
    D_h1 = F.relu(torch.matmul(inputs, D_W1) + D_b1)
    D_h2 = F.relu(torch.matmul(D_h1, D_W2) + D_b2)
    D_prob = torch.sigmoid(torch.matmul(D_h2, D_W3) + D_b3)
    return D_prob

def discriminator_loss(X_tilde, M, H):
    X_bar = generator(X_tilde, M)
    X_hat = X_tilde * M + X_bar * (1-M)
    M_hat = discriminator(X_hat, H)
    loss_D = -torch.mean(M * torch.log(M_hat+1e-8) + (1-M) * torch.log(1-M_hat+1e-8))
    return loss_D

def generator_loss(X_tilde, M, H):
    X_bar = generator(X_tilde, M)
    X_hat = X_tilde * M + X_bar * (1-M)
    M_hat = discriminator(X_hat, H)
    loss_G_first = -torch.mean((1-M) * torch.log(M_hat+1e-8))
    loss_G_second = torch.mean((M * X_tilde - M * X_bar)**2)/torch.mean(M)
    loss_G = loss_G_first + alpha*loss_G_second
    return loss_G, loss_G_second

hidden_dim1 = d
hidden_dim2 = d
G_W1 = init_weights([d*2, hidden_dim1])
G_b1 = torch.zeros([hidden_dim1,],requires_grad=True)
G_W2 = init_weights([hidden_dim1, hidden_dim2])
G_b2 = torch.zeros([hidden_dim2,],requires_grad=True)
G_W3 = init_weights([hidden_dim2, d])
G_b3 = torch.zeros([d,],requires_grad=True)
optimizer_G = torch.optim.Adam([G_W1,G_b1,G_W2,G_b2,G_W3,G_b3])

D_W1 = init_weights([d*2, hidden_dim1])
D_b1 = torch.zeros([hidden_dim1,],requires_grad=True)
D_W2 = init_weights([hidden_dim1, hidden_dim2])
D_b2 = torch.zeros([hidden_dim2,],requires_grad=True)
D_W3 = init_weights([hidden_dim2, d])
D_b3 = torch.zeros([d,],requires_grad=True)
optimizer_D = torch.optim.Adam([D_W1,D_b1,D_W2,D_b2,D_W3,D_b3])

### Train and Test Phase

In [271]:
mask = 1-np.isnan(data_masked)

data_masked[np.isnan(data_masked)] = 0.0

for iteration in range(1,epoch+1):
    batch_idx = np.random.choice(n,batch_size)
    X_tilde = data_masked[batch_idx, :]
    M = mask[batch_idx, :]
    Z = np.random.uniform(0, 0.01, size = (batch_size,d))
    B = np.random.binomial(n=1, p=p_hint, size = (batch_size,d))
    H = M*B
    X_tilde = M * X_tilde + (1-M) * Z
    
    X_tilde = torch.tensor(X_tilde, dtype=torch.float64)
    M = torch.tensor(M, dtype=torch.float64)
    H = torch.tensor(H, dtype=torch.float64)
    
    optimizer_D.zero_grad()
    loss_D = discriminator_loss(X_tilde, M, H)
    loss_D.backward()
    optimizer_D.step()
    
    optimizer_G.zero_grad()
    loss_G, loss_G_second = generator_loss(X_tilde, M, H)
    loss_G.backward()
    optimizer_G.step()
    
    if iteration % 500 == 0:
        print("iteration: {}".format(iteration),end='\t')
        print("generator_loss: {:.5}".format(loss_G.item()),end='\n')
        
        
Z = np.random.uniform(0, 0.01, size = (n,d))
X_tilde = mask*data_masked + (1-mask)*Z
imputed = generator(torch.tensor(X_tilde), torch.tensor(mask))
imputed = torch.tensor(data_masked) * torch.tensor(mask) + imputed * (1-torch.tensor(mask))
imputed = imputed.detach().numpy()
rmse = np.sqrt(np.sum(((1-mask)*imputed - (1-mask)*normalize(data)[0])**2)/np.sum(1-mask))
print("Test RMSE: {:.5}".format(rmse))

iteration: 500	generator_loss: 1.5257
iteration: 1000	generator_loss: 0.47477
iteration: 1500	generator_loss: 0.38571
iteration: 2000	generator_loss: 0.29332
iteration: 2500	generator_loss: 0.30029
iteration: 3000	generator_loss: 0.26186
iteration: 3500	generator_loss: 0.22572
iteration: 4000	generator_loss: 0.22413
iteration: 4500	generator_loss: 0.27507
iteration: 5000	generator_loss: 0.21994
iteration: 5500	generator_loss: 0.26575
iteration: 6000	generator_loss: 0.23273
iteration: 6500	generator_loss: 0.22404
iteration: 7000	generator_loss: 0.24043
iteration: 7500	generator_loss: 0.26228
iteration: 8000	generator_loss: 0.21805
iteration: 8500	generator_loss: 0.23087
iteration: 9000	generator_loss: 0.20654
iteration: 9500	generator_loss: 0.23887
iteration: 10000	generator_loss: 0.2106
Test RMSE: 0.22932
