In [1]:
import numpy as np
import torch
from geomloss import SamplesLoss
from Utils import nanmean, MAE, RMSE
from CMI import compute_all_cmi_methods
import logging

In [2]:
class SinkhornImputation_CMI():
    def __init__(self,
                 eps=0.01,
                 lr=1e-2,
                 opt=torch.optim.RMSprop,
                 niter=2000,
                 highest_lamda_cmi=100,
                 batchsize=128,
                 n_pairs=1,
                 noise=0.1,
                 scaling=.9,
                 lambda_cmi=0.1):
        self.eps = eps
        self.lr = lr
        self.opt = opt
        self.niter = niter
        self.batchsize = batchsize
        self.n_pairs = n_pairs
        self.noise = noise
        self.sk = SamplesLoss("sinkhorn", p=2, blur=eps, scaling=scaling, backend="tensorized")
        self.lambda_cmi = lambda_cmi
        self.highest_lamda_cmi = highest_lamda_cmi

    def fit_transform(self, X, verbose=True, report_interval=500, X_true=None, X_cols=None, Y_cols=None, Z_cols=None, bucket_specs=None,
                      encoder=None, discrete_columns=None, continuous_columns=None):
        torch.manual_seed(42)
        np.random.seed(42)

        X = X.clone()
        n, d = X.shape

        history = {
            "losses": [[], [], []],
            "cmi_values": [[], [], []],
            "lambda_cmi": [],
            "sinkhorn": [],
        }

        mask = torch.isnan(X).double()
        imps = (self.noise * torch.randn(mask.shape).double() + nanmean(X, 0))[mask.bool()]
        imps.requires_grad = True

        optimizer = self.opt([imps], lr=self.lr)

        if X_true is not None:
            maes = [np.zeros(self.niter) for _ in range(3)]
            rmses = [np.zeros(self.niter) for _ in range(3)]

        for i in range(self.niter):
            X_filled = X.detach().clone()
            X_filled[mask.bool()] = imps
            sk_loss = 0

            for _ in range(self.n_pairs):
                idx1 = np.random.choice(n, self.batchsize, replace=False)
                idx2 = np.random.choice(n, self.batchsize, replace=False)
                X1 = X_filled[idx1]
                X2 = X_filled[idx2]
                sk_loss += self.sk(X1, X2)

            history["sinkhorn"].append(sk_loss.item())
            self.lambda_cmi = min(self.highest_lamda_cmi, i / 100.0)
            losses = []

            # compute 3 CMI values
            X_np = X_filled.detach().cpu().numpy()
            cmi1, cmi2, cmi3 = compute_all_cmi_methods(X_np, (X_cols[0], Y_cols[0], Z_cols[0]), encoder, discrete_columns, continuous_columns)

            for idx, cmi_val in enumerate([cmi1, cmi2, cmi3]):
                total_loss = sk_loss + self.lambda_cmi * torch.tensor(cmi_val, dtype=torch.float64, requires_grad=True)
                history["losses"][idx].append(total_loss.item())
                history["cmi_values"][idx].append(cmi_val)

                if idx == 0:
                    optimizer.zero_grad()
                    total_loss.backward()
                    optimizer.step()

                if X_true is not None:
                    maes[idx][i] = MAE(X_filled, X_true, mask).item()
                    rmses[idx][i] = RMSE(X_filled, X_true, mask).item()

            history["lambda_cmi"].append(self.lambda_cmi)

            if verbose and (i % report_interval == 0):
                logging.info(f"Iteration {i}: Sinkhorn={sk_loss.item():.4f}, CMI1={cmi1:.4f}, CMI2={cmi2:.4f}, CMI3={cmi3:.4f}")

        if X_true is not None:
            return X_filled, maes, rmses, history
        else:
            return X_filled, history