In [48]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from utils import get_data, get_probabilities, estimate_q_Z_given_A, get_probabilities_one_hot
from sklearn.decomposition import NMF  # Placeholder for volmin factorization
from volmin_nmf import *
from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.neural_network import MLPClassifier
import matplotlib.pyplot as plt
import itertools
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

In [49]:
import torch
import numpy as np
from data import get_dataset, get_dataset_new
from scipy.linalg import sqrtm, det, pinv
from sklearn.decomposition import NMF
import matplotlib.pyplot as plt

def get_data(theta_a_z, theta_y_a, theta_y_w, p_source, p_target, total):
    """
    Generates datasets for source and target distributions.
    
    Args:
        theta_xz (torch.Tensor): Transformation matrix for X given Z.
        theta_yx (torch.Tensor): Transformation matrix for Y given X.
        theta_yw (torch.Tensor): Transformation matrix for Y given W.
        p_source (float): Probability parameter for the source distribution.
        p_target (float): Probability parameter for the target distribution.
        total (int): Total number of samples to generate.

    Returns:
        tuple: Two tuples containing the source and target datasets respectively.
    """
    # Source distribution data
    U_source, Z_source, W_source, X_source, Y_source = get_dataset(theta_a_z,theta_y_a,theta_y_w, p_source, total)
    
    # Target distribution data
    U_target, Z_target, W_target, X_target, Y_target = get_dataset(theta_a_z,theta_y_a,theta_y_w, p_target, total)
    
    return (Z_source.numpy(), U_source.numpy(), W_source.numpy(), X_source.numpy(), Y_source.numpy()), \
           (Z_target.numpy(), U_target.numpy(), W_target.numpy(), X_target.numpy(), Y_target.numpy())

def get_data_new(theta_a_z, theta_y_a, theta_y_w, theta_y_epsilon, theta_a_epsilon, p_source, p_target, total):
    """
    Generates datasets for source and target distributions.
    
    Args:
        theta_xz (torch.Tensor): Transformation matrix for X given Z.
        theta_yx (torch.Tensor): Transformation matrix for Y given X.
        theta_yw (torch.Tensor): Transformation matrix for Y given W.
        p_source (float): Probability parameter for the source distribution.
        p_target (float): Probability parameter for the target distribution.
        total (int): Total number of samples to generate.

    Returns:
        tuple: Two tuples containing the source and target datasets respectively.
    """
    # Source distribution data
    U_source, Z_source, W_source, X_source, Y_source = get_dataset_new(theta_a_z, theta_y_a, theta_y_w, theta_y_epsilon, theta_a_epsilon, p_source, total)
    
    # Target distribution data
    U_target, Z_target, W_target, X_target, Y_target = get_dataset_new(theta_a_z, theta_y_a, theta_y_w, theta_y_epsilon, theta_a_epsilon, p_target, total)
    
    return (Z_source.numpy(), U_source.numpy(), W_source.numpy(), X_source.numpy(), Y_source.numpy()), \
           (Z_target.numpy(), U_target.numpy(), W_target.numpy(), X_target.numpy(), Y_target.numpy())


def get_probabilities(model, Z, A):
    """
    Computes the softmax probabilities from the trained model.
    
    Args:
        model (model): Trained model.
        Z (numpy.ndarray): Feature matrix Z.
        A (numpy.ndarray): Feature matrix A.

    Returns:
        numpy.ndarray: Probability matrix reshaped to (|Z|, |A|, |Y|) or (|Z|, |A|, |W|).
    """
    num_Z = Z.shape[1]
    num_A = A.shape[1]
    num_classes = model.linear.out_features

    # Generate all possible one-hot vectors for Z
    possible_Z = np.eye(num_Z)
    possible_A = np.eye(num_A)
    
    probabilities = []
    
    for z in possible_Z:
        for a in possible_A:
            ZA = np.hstack((z.reshape(1, -1), a.reshape(1, -1)))
            ZA_tensor = torch.tensor(ZA, dtype=torch.float32)
            print(ZA_tensor)
            with torch.no_grad():
                logits = model(ZA_tensor)
                print(logits)
                prob_1 = torch.sigmoid(logits).numpy()[0][0]  # Probability of class 1
                prob_0 = 1 - prob_1  # Probability of class 0
                print([prob_0, prob_1])
                probabilities.append([prob_0, prob_1])
    
    probabilities = np.array(probabilities).reshape((num_Z, num_A, 2))
    
    return probabilities

def get_probabilities_one_hot(model, Z, A):
    """
    Computes the softmax probabilities from the trained model.
    
    Args:
        model (model): Trained model.
        Z (numpy.ndarray): Feature matrix Z.
        A (numpy.ndarray): Feature matrix A.

    Returns:
        numpy.ndarray: Probability matrix reshaped to (|Z|, |A|, |Y|) or (|Z|, |A|, |W|).
    """
    num_Z = Z.shape[1]
    num_A = A.shape[1]
    num_classes = model.linear.out_features

    # Generate all possible one-hot vectors for Z
    possible_Z = np.eye(num_Z)
    possible_A = np.eye(num_A)
    
    probabilities = []
    
    for z in possible_Z:
        for a in possible_A:
            ZA = np.hstack((z.reshape(1, -1), a.reshape(1, -1)))
            ZA_tensor = torch.tensor(ZA, dtype=torch.float32)
            with torch.no_grad():
                probs = model(ZA_tensor)
                probs = torch.softmax(probs, dim=1).numpy()
                probabilities.append(probs[0])
    
    probabilities = np.array(probabilities).reshape((num_Z, num_A, num_classes))
    
    return probabilities

def get_probabilities_one_hot_nn(model, Z, A):
    """
    Computes the softmax probabilities from the trained model.
    
    Args:
        model (model): Trained model.
        Z (numpy.ndarray): Feature matrix Z.
        A (numpy.ndarray): Feature matrix A.

    Returns:
        numpy.ndarray: Probability matrix reshaped to (|Z|, |A|, |Y|) or (|Z|, |A|, |W|).
    """
    num_Z = Z.shape[1]
    num_A = A.shape[1]
    num_classes = model.fc3.out_features

    # Generate all possible one-hot vectors for Z
    possible_Z = np.eye(num_Z)
    possible_A = np.eye(num_A)
    
    probabilities = []
    
    for z in possible_Z:
        for a in possible_A:
            ZA = np.hstack((z.reshape(1, -1), a.reshape(1, -1)))
            ZA_tensor = torch.tensor(ZA, dtype=torch.float32)
            with torch.no_grad():
                probs = model(ZA_tensor)
                probs = torch.softmax(probs, dim=1).numpy()
                probabilities.append(probs[0])
    
    probabilities = np.array(probabilities).reshape((num_Z, num_A, num_classes))
    
    return probabilities

def estimate_q_Z_given_A(model, A, num_classes_Z, num_features_A):
    """
    Estimate q(Z|a) using the given model.
    
    Args:
        model (model): Trained model.
        A (numpy.ndarray): Feature matrix A.
        num_classes_Z (int): Number of classes for Z.
        num_features_A (int): Number of features for A.

    Returns:
        numpy.ndarray: Estimated q(Z|a) matrix.
    """
    num_A = A.shape[1]
    num_classes = model.linear.out_features

    # Generate all possible one-hot vectors for A
    possible_A = np.eye(num_A)
    
    probabilities = []
    
    for a in possible_A:
        A_sample = a.reshape(1, -1)
        A_tensor = torch.tensor(A_sample, dtype=torch.float32)
        with torch.no_grad():
            probs = model(A_tensor)
            probs = torch.softmax(probs, dim=1).numpy()
            probabilities.append(probs)
    
    probabilities = np.array(probabilities).reshape((num_features_A, num_classes_Z))
    return probabilities

# linear solve for Q(Epsilon | Z, A) using linalg.lstsq
def solve_for_q_epsilon_given_ZA(p_W_given_epsilon, q_W_given_ZA):
    print("p_W_given_epsilon shape before transpose:", p_W_given_epsilon.shape)  # Debug
    print("q_W_given_ZA shape before transpose:", q_W_given_ZA.shape)  # Debug
    p_W_given_epsilon = p_W_given_epsilon.T
    q_W_given_ZA = q_W_given_ZA.T
    print("p_W_given_epsilon shape after transpose:", p_W_given_epsilon.shape)  # Debug
    print("q_W_given_ZA shape after transpose:", q_W_given_ZA.shape)  # Debug
    
    # Check dimensions and adjust if necessary
    num_ZA = p_W_given_epsilon.shape[1]
    num_epsilon = p_W_given_epsilon.shape[0]
    num_W = q_W_given_ZA.shape[1]
    
    if q_W_given_ZA.shape[0] != num_ZA:
        raise ValueError("Dimensions of q_W_given_ZA and p_W_given_epsilon do not match after transpose.")
    
    # Solve by least squares
    q_epsilon_given_Z_and_A, _, _, _ = np.linalg.lstsq(p_W_given_epsilon, q_W_given_ZA, rcond=None)
    print("q_epsilon_given_Z_and_A shape after lstsq:", q_epsilon_given_Z_and_A.shape)  # Debug
    return q_epsilon_given_Z_and_A.T


def volume_regularized_nmf(B, n_components, w_vol=0.1, delta=1e-8, n_iter=1000, err_cut=1e-8):
    """
    Volume-regularized NMF implementation in Python.
    
    Args:
        B (np.ndarray): Input matrix to factorize.
        n_components (int): Number of components for factorization.
        w_vol (float): Weight for volume regularization.
        delta (float): Regularization term for logdet.
        n_iter (int): Number of iterations.
        err_cut (float): Convergence criterion.
        
    Returns:
        tuple: Factorized matrices C, R.
    """
    #np.random.seed(0)  # For reproducibility

    # Initialize C and R using standard NMF
    nmf = NMF(n_components=n_components, init='random', random_state=0, max_iter=500)
    C = nmf.fit_transform(B)
    R = nmf.components_

    for iteration in range(n_iter):
        # Update R using volume regularization
        err_prev = np.linalg.norm(B - np.dot(C, R), 'fro')**2

        # Update R with regularization
        R = np.maximum(np.dot(pinv(C), B), 0)  # Standard NMF update
        R = R - w_vol * (np.dot(R, R.T) + delta * np.eye(R.shape[0]))  # Volume regularization update
        R = np.maximum(R, 0)  # Ensure non-negativity

        # Update C to satisfy the simplex constraint
        C = np.maximum(np.dot(B, pinv(R)), 0)
        C = C / C.sum(axis=0)

        err_post = np.linalg.norm(B - np.dot(C, R), 'fro')**2

        # Check for convergence
        if np.abs(err_prev - err_post) / err_prev < err_cut:
            break

    return C, R

def volnmf_logdet(C, X, R, R_constraint="pos", majorate=False, extrapolate=True, qmax=100,
                  w_vol=1e-1, delta=1, err_cut=1e-3, n_iter=1000):
    # Initial assignments
    W = np.transpose(R)
    W_update = W.copy()
    H = np.transpose(C)
    FM = np.linalg.inv(np.transpose(W) @ W + delta * np.eye(W.shape[0]))

    # Iteration variables
    iter = 1
    err = 1e5
    q = [1, (1 + np.sqrt(5)) / 5]

    # Main iteration loop
    while err > err_cut and iter < n_iter:
        W_prev = W.copy()
        
        if majorate:
            Y = W_prev
            FM = np.linalg.inv(np.transpose(Y) @ Y + delta * np.eye(Y.shape[0]))
        
        if R_constraint == "pos":
            Lip = np.sqrt(np.sum((H @ H.T + w_vol * FM) ** 2))
            gradF = W_update @ (H @ H.T + w_vol * FM) - X.T @ H
            W = W_update - gradF / Lip
            W[W < 0] = 0
        else:
            W = X.T @ H.T @ np.linalg.inv(H @ H.T + w_vol * FM)

        if extrapolate:
            extr = (q[iter - 1] - 1) / q[iter]
            W_update = W + extr * (W - W_prev)
        else:
            W_update = W

        # Error calculation and iteration increment
        err = np.sum((W - W_prev) ** 2) / np.sum(W ** 2)
        iter += 1
        q.append(min(qmax, (1 + np.sqrt(1 + 4 * q[iter - 1] ** 2)) / 2))

    return W.T

def volnmf_estimate(B, C, R, Q, domain="covariance", volf='logdet', R_majorate=False,
                    wvol=None, delta=1e-8, n_iter=1000, err_cut=1e-8,
                    vol_iter=100, c_iter=100,
                    extrapolate=True, accelerate=True, acc_C=4/5, acc_R=3/4,
                    C_constraint="col", C_bound=1, R_constraint="pos",
                    verbose=True, record=100, Ctrue=None, mutation_run=False):
    
    # Initialization
    iter = 1
    err = 1e5
    rvol = []
    aff_mean = []
    info_record = []
    eigens = 1
    R_update = R.copy()
    C_update = C.copy()
    tot_update_prev = 0
    tot_update = 0

    while iter < n_iter and err > err_cut:
        # Domain check
        if domain == "covariance":
            X = np.dot(B, Q)
        else:
            X = B

        # Update R
        err_prev = np.sum((X - np.dot(C_update, R))**2)
        if volf == "logdet":
            vol_prev = np.log(np.linalg.det(np.dot(R, R.T) + delta * np.eye(R.shape[0])))
        elif volf == "det":
            vol_prev = np.linalg.det(np.dot(R, R.T))
        R_prev = R.copy()

        # Update R based on the volume function
        if volf == "logdet":
            R = volnmf_logdet(C_update, X, R_update, R_constraint=R_constraint, extrapolate=extrapolate, majorate=R_majorate,
                              w_vol=wvol, delta=delta, err_cut=1e-100, n_iter=vol_iter)
        elif volf == "det":
            R = volnmf_det(C_update, X, R_update, posit=False, w_vol=wvol, eigen_cut=1e-20, err_cut=1e-100, n_iter=vol_iter)

        ### Post-update calculations
        err_post = np.sum((X - np.dot(C_update, R))**2)
        if volf == "logdet":
            vol_post = np.log(np.linalg.det(np.dot(R, R.T) + delta * np.eye(R.shape[0])))
        elif volf == "det":
            vol_post = np.linalg.det(np.dot(R, R.T))
        rvol.append(vol_post)

        ### Update C
        C_prev = C.copy()
        if C_constraint == "col":
            C = volnmf_simplex_col(X, R, C_prev=C_update, bound=C_bound, extrapolate=extrapolate, err_cut=1e-100, n_iter=c_iter)
        else:
            C = volnmf_simplex_row(X, R, C_prev=C_update, meq=1)
        err_post_C = np.sum((X - np.dot(C, R_update))**2)

        # Accelerate C if possible
        if accelerate:
            C_update = C + acc_C * (C - C_prev)
            R_update = R + acc_R * (R - R_prev)

            # Ensure non-negativity
            C_update = np.maximum(C_update, 0)
            R_update = np.maximum(R_update, 0)

            err_update = np.sum((X - np.dot(C, R_update))**2)
            vol_update = np.log(np.linalg.det(np.dot(R_update, R_update.T) + delta * np.eye(R_update.shape[0])))
            tot_update = err_update + wvol * vol_update

            if tot_update > tot_update_prev:
                C_update = C.copy()
                R_update = R.copy()
        else:
            C_update = C.copy()
            R_update = R.copy()

        tot_update_prev = tot_update

        ### optimize Q
        if domain == "covariance":
            Q_init = volnmf_procrustes(np.dot(C_update, R_update), B)

        err = np.sum((C_update - C_prev)**2) / np.sum(C_update**2)
        eigens = np.linalg.eigvals(np.dot(R_update, R_update.T))
        aff = 1

        if Ctrue is not None:
            if C_bound is None:
                aff = np.apply_along_axis(lambda x: np.max(np.abs(np.corrcoef(x, Ctrue))), axis=1, arr=C_update)
            else:
                aff = np.apply_along_axis(lambda x: np.max(np.abs(np.corrcoef(x * C_bound, Ctrue))), axis=1,
                                          arr=C_update)

        aff_mean.append(np.mean(aff))

        iter += 1

    return {'C': C_update, 'R': R_update, 'Q': Q_init, 'iter': iter, 'err': err, 'info_record': info_record}

def volnmf_main(vol, B, volnmf=None, n_comp=3, n_reduce=None,
                do_nmf=True, iter_nmf=100, seed=None,
                domain="covariance", volf='logdet',
                wvol=None, delta=1e-8, n_iter=500, err_cut=1e-16,
                vol_iter=20, c_iter=20,
                extrapolate=True, accelerate=False, acc_C=4/5, acc_R=3/4,
                C_constraint="col", C_bound=1, R_constraint="pos", R_majorate=False,
                C_init=None, R_init=None, Q_init=None, anchor=None, Ctrue=None,
                verbose=True, record=100, verbose_nmf=False, record_nmf=None, mutation_run=False):

    # Initialize Q_init if None
    if Q_init is None:
        Q_init = np.eye(n_reduce, n_comp)  # Identity matrix with n_reduce rows and n_comp columns

    # Set random seed if provided
    if seed is not None:
        np.random.seed(seed)

    # Use NMF to initialize C_init and R_init if they are not provided
    if C_init is None or R_init is None:
        nmf_model = NMF(n_components=n_comp, init='random', random_state=seed)
        C_init = nmf_model.fit_transform(B)
        R_init = nmf_model.components_

    C_rand, R_rand, Q_rand = C_init.copy(), R_init.copy(), Q_init.copy()

    if wvol is None:
        wvol = 0

    # Print message indicating the start of volume-regularized NMF
    print('Run volume-regularized NMF...')

    # Run volume-regularized NMF
    vol_solution = volnmf_estimate(B, C_init, R_init, Q_init,
                                   domain=domain, volf=volf, R_majorate=R_majorate,
                                   wvol=wvol, delta=delta, n_iter=n_iter, err_cut=err_cut,
                                   vol_iter=vol_iter, c_iter=c_iter,
                                   extrapolate=extrapolate, accelerate=accelerate, acc_C=acc_C, acc_R=acc_R,
                                   C_constraint=C_constraint, C_bound=C_bound, R_constraint=R_constraint,
                                   verbose=verbose, record=record, Ctrue=Ctrue, mutation_run=mutation_run)
    
    # Print done message
    print('Done')

    # Return the results
    return {
        'C': vol_solution['C'], 'R': vol_solution['R'], 'Q': vol_solution['Q'],
        'C_init': C_init, 'R_init': R_init, 'Q_init': Q_init,
        'C_rand': C_rand, 'R_rand': R_rand, 'Q_rand': Q_rand,
        'rec': vol_solution['info_record']
    }

def plot_histograms(source_data):
    Z_source, epsilon_source, W_source, A_source, Y_source = source_data

    # Convert tensors to numpy arrays
    Z_source_np = Z_source.numpy()
    epsilon_source_np = epsilon_source.numpy()
    W_source_np = W_source.numpy()
    A_source_np = A_source.numpy()
    Y_source_np = Y_source.numpy()

    # Flatten the arrays for histogram plotting
    Z_source_flat = Z_source_np.flatten()
    epsilon_source_flat = epsilon_source_np.flatten()
    W_source_flat = W_source_np.flatten()
    A_source_flat = A_source_np.flatten()
    Y_source_flat = Y_source_np.flatten()

    # Plot histograms
    plt.figure(figsize=(15, 10))

    plt.subplot(2, 3, 1)
    plt.hist(Z_source_flat, bins=20, color='blue', alpha=0.7)
    plt.title('Histogram of Z_source')

    plt.subplot(2, 3, 2)
    plt.hist(epsilon_source_flat, bins=20, color='green', alpha=0.7)
    plt.title('Histogram of epsilon_source')

    plt.subplot(2, 3, 3)
    plt.hist(W_source_flat, bins=20, color='red', alpha=0.7)
    plt.title('Histogram of W_source')

    plt.subplot(2, 3, 4)
    plt.hist(A_source_flat, bins=20, color='purple', alpha=0.7)
    plt.title('Histogram of A_source')

    plt.subplot(2, 3, 5)
    plt.hist(Y_source_flat, bins=20, color='orange', alpha=0.7)
    plt.title('Histogram of Y_source')

    plt.tight_layout()
    plt.show()

In [50]:
import torch

def get_tuple(theta_xz,theta_yx,theta_yw,p=0.2):
    vec1 = torch.tensor([0.1,0.1,0.4,0.4])
    vec2 = torch.tensor([0.4,0.4,0.1,0.1])

    U = torch.bernoulli(torch.tensor([p])).item()

    cat0 = torch.distributions.categorical.Categorical(vec1) 
    cat1 = torch.distributions.categorical.Categorical(vec2) 

    z = torch.nn.functional.one_hot((cat0.sample()*U + cat1.sample()*(1-U)).long(), num_classes = 4).unsqueeze(0).float()
    w = torch.nn.functional.one_hot((cat0.sample()*U + cat1.sample()*(1-U)).long(), num_classes = 4).unsqueeze(0).float()

    probx = (vec1 * U + vec2 * (1-U)).unsqueeze(0)
    proby = (vec1 * U + vec2 * (1-U)).unsqueeze(0)

    probx = probx * (theta_xz @ z.T).T
    #print(probx)
    x = torch.nn.functional.one_hot(torch.distributions.categorical.Categorical(probx).sample().long(), num_classes = 4).float()
    #print(x)
    # x = torch.nn.functional.one_hot(torch.argmax(probx, dim=1).long(), num_classes = 4).float()

    proby = proby * (theta_yx @ x.T + theta_yw @ w.T).T
    y_prob = proby.sum(dim=1)
    y = (y_prob > 0.5).float().unsqueeze(0)

    return U,z,w,x,y

def get_dataset(theta_a_z,theta_y_a,theta_y_w, p, total):
    U,Z,W,X,Y = [],[],[],[],[]
    for _ in range(total):
        u,z,w,x,y = get_tuple(theta_a_z,theta_y_a,theta_y_w,p)
        U.append(torch.tensor([[u]]))
        Z.append(z)
        W.append(w)
        X.append(x)
        Y.append(y)
    # print(u.shape,z.shape,w.shape,x.shape,y.shape)
    U = torch.cat(U,0)
    Z = torch.cat(Z,0)
    W = torch.cat(W,0)
    X = torch.cat(X,0)
    Y = torch.cat(Y,0)

    return U,Z,W,X,Y

if __name__=="__main__":
    get_tuple(torch.rand(4,4),torch.rand(4,4),torch.rand(4,4),0.2)
    #print(get_dataset(torch.rand(4,4),torch.rand(4,4),torch.rand(4,4),0.2,10))


def get_tuple_new(theta_a_z,theta_y_a,theta_y_w,theta_y_epsilon, theta_a_epsilon, p=0.2):
    vec1 = torch.tensor([0.1,0.1,0.4,0.4])
    vec2 = torch.tensor([0.4,0.4,0.1,0.1])

    epsilon = torch.bernoulli(torch.tensor([p])).float()

    cat0 = torch.distributions.categorical.Categorical(vec1) 
    cat1 = torch.distributions.categorical.Categorical(vec2) 

    z = torch.nn.functional.one_hot((cat0.sample()*epsilon + cat1.sample()*(1-epsilon)).long(), num_classes = 4).unsqueeze(0).float() #discrete so categorical depending on epsilon (paper has Normal as cont)
    w = torch.nn.functional.one_hot((cat0.sample()*epsilon + cat1.sample()*(1-epsilon)).long(), num_classes = 4).unsqueeze(0).float() #discrete so categorical depending on epsilon (paper has Normal as cont)

    proba = (vec1 * epsilon + vec2 * (1-epsilon)).unsqueeze(0)  
    proby = (vec1 * epsilon + vec2 * (1-epsilon)).unsqueeze(0)
    
    print("proba shape",proba.shape)
    print("theta_a_z shape",theta_a_z.shape)
    print("z.T shape",z.T.shape)
    print("theta_a_epsilon shape",theta_a_epsilon.shape)
    print("epsilon.T shape",epsilon.T.shape)


    proba = proba * (theta_a_z @ z.T + theta_a_epsilon * epsilon.T).T # a is a function of z and epsilon from the graph
    a = torch.bernoulli(torch.sigmoid(proba).sample().long(), num_classes = 4).float()

    proby = proby * (theta_y_a @ a.T + theta_y_w @ w.T + theta_y_epsilon * epsilon.T).T # y is a function of a, w and epsilon from the graph
    #y = torch.bernoulli(torch.nn.Sigmoid(proby).sample().long(), num_classes = 4).float()
    y = torch.bernoulli(torch.sigmoid(proby).squeeze()).float()

    return epsilon,z,w,a,y

def get_dataset_new(theta_a_z,theta_y_a,theta_y_w,theta_y_epsilon, theta_a_epsilon, p, total):
    U,Z,W,X,Y = [],[],[],[],[]
    for _ in range(total):
        u,z,w,x,y = get_tuple_new(theta_a_z,theta_y_a,theta_y_w,theta_y_epsilon, theta_a_epsilon, p)
        U.append(torch.tensor([[u]]))
        Z.append(z)
        W.append(w)
        X.append(x)
        Y.append(y)
    # print(u.shape,z.shape,w.shape,x.shape,y.shape)
    U = torch.cat(U,0)
    Z = torch.cat(Z,0)
    W = torch.cat(W,0)
    X = torch.cat(X,0)
    Y = torch.cat(Y,0)

    return U,Z,W,X,Y

if __name__=="__main__":
    get_tuple(torch.rand(4,4),torch.rand(4,4),torch.rand(4,4),0.2)
    #print(get_dataset(torch.rand(4,4),torch.rand(4,4),torch.rand(4,4),0.2,10))

In [51]:
parameters_debug = True
step1_debug = True
step2_debug = True
step3_debug = True
step4_debug = False
step5_debug = False
step6_debug = False
testing = False
parameter_tuning = False
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7fc15d203d50>

In [52]:
p_source = 0.8
p_target = 0.2
total = 10000
factorisation_atol = 1e-1

# Step 2 parameters
specific_a_index = 0  # First value of A

# "sklearn" for sklearn's NMF
# "volmin_1" for volmin NMF, adapted from https://github.com/kharchenkolab/vrnmf
# "volmin_2" for volmin NMF, adapted from https://github.com/bm424/mvcnmf/blob/master/mvcnmf.py
nmf_method = "volmin_2" 
# parameters for volmin NMF
w_vol = 5
delta = 1e-8
n_iter = 100000
err_cut = 1e-8

In [174]:
theta_w_epsilon = torch.tensor([
    [-1],
    [1]
])

theta_y_a = torch.tensor([
    [-1, 1, 1, 1],
    [1, -1, 1, 1],
    [1, 1, -1, 1],
    [1, 1, 1, -1]
])

theta_y_w = torch.tensor([
    [-1, 1, 1, 1],
    [1, -1, 1, 1],
    [1, 1, -1, 1],
    [1, 1, 1, -1]
])

theta_y_epsilon = torch.tensor([
    [-1, 1, 1, 1]
])

theta_a_epsilon = torch.tensor([
    [-1, 1, 1, 1]
])

theta_a_z = torch.tensor([
    [-1, 1, 1, 1],
    [1, -1, 1, 1],
    [1, 1, -1, 1],
    [1, 1, 1, -1]
])

In [178]:
def get_tuple_new(theta_a_z,theta_y_a,theta_y_w,theta_y_epsilon, theta_a_epsilon, p=0.2):
    vec1 = torch.tensor([0.1,0.1,0.4,0.4]) # (4,)
    vec2 = torch.tensor([0.4,0.4,0.1,0.1]) # (4,)

    epsilon = torch.bernoulli(torch.tensor([p])).long() # (1,)
    epsilon_one_hot = torch.nn.functional.one_hot(torch.tensor(epsilon), num_classes=2).squeeze() # (1,2)

    cat0 = torch.distributions.categorical.Categorical(vec1) 
    cat1 = torch.distributions.categorical.Categorical(vec2) 

    z = torch.nn.functional.one_hot((cat0.sample()*epsilon + cat1.sample()*(1-epsilon)).long(), num_classes = 4).squeeze() #(1,4) #discrete so categorical depending on epsilon (paper has Normal as cont)
    w = torch.nn.functional.one_hot((cat0.sample()*epsilon + cat1.sample()*(1-epsilon)).long(), num_classes = 4).squeeze() #(1,4) #discrete so categorical depending on epsilon (paper has Normal as cont)

    ###### FIX FROM HERE ######
    # z is (1,4), theta_a_z is (4,4), theta_a_epsilon is (4,1)
    a_logits = (z.float() @ theta_a_z.float()  + theta_a_epsilon * epsilon.T) # a is a function of z and epsilon from the graph 
    a = torch.bernoulli(torch.sigmoid(a_logits)).float().squeeze() # (1,4)

    # a is (1,4), theta_y_a is (4,4), w is (1,4), theta_y_w is (4,4), theta_y_epsilon is (1,4), epsilon is (1,)
    y_logits = (a.float() @ theta_y_a.float()   + w.float() @ theta_y_w.float() + theta_y_epsilon.float() * epsilon) # y is a function of a, w and epsilon from the graph
    y = torch.bernoulli(torch.sigmoid(y_logits).squeeze()).float()
    print(y)

    return epsilon,z,w,a,y

In [179]:
source_data, target_data = get_data_new(theta_a_z, theta_y_a, theta_y_w, theta_y_epsilon, theta_a_epsilon, p_source, p_target, total)
Z_source, epsilon_source, W_source, A_source, Y_source = source_data
Z_target, epsilon_target, W_target, A_target, Y_target = target_data

  epsilon_one_hot = torch.nn.functional.one_hot(torch.tensor(epsilon), num_classes=2).squeeze() # (1,2)


tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([0., 1., 1., 1.])
tensor([0., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 0.])
tensor([1., 0., 1., 1.])
tensor([1., 0., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([0., 1., 1., 1.])
tensor([1., 0., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([0., 1., 1., 1.])
tensor([1., 1., 0., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 0.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([0., 1., 1., 1.])
tensor([1., 0., 0., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([0., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])
tensor([0., 1., 0., 1.])
tensor([1., 1., 1., 1.])
tensor([1., 1., 1., 1.])


KeyboardInterrupt: 

In [177]:
print("Y_source.shape", Y_source.shape)
print("W_source.shape", W_source.shape)
print("A_source.shape", A_source.shape)
print("Z_source.shape", Z_source.shape)

print("Y_source", Y_source)
print("W_source", W_source)
print("A_source", A_source)
print("Z_source", Z_source)

Y_source.shape (40000,)
W_source.shape (40000,)
A_source.shape (40000,)
Z_source.shape (40000,)
Y_source [1. 1. 0. ... 0. 0. 1.]
W_source [0 0 1 ... 0 0 1]
A_source [0. 1. 1. ... 0. 1. 0.]
Z_source [0 0 0 ... 1 0 0]


In [None]:
def plot_histograms(source_data):
    Z_source, epsilon_source, W_source, A_source, Y_source = source_data

    # Convert tensors to numpy arrays
    Z_source_np = Z_source
    epsilon_source_np = epsilon_source
    W_source_np = W_source
    A_source_np = A_source
    Y_source_np = Y_source

    # Flatten the arrays for histogram plotting
    Z_source_flat = Z_source_np.flatten()
    epsilon_source_flat = epsilon_source_np.flatten()
    W_source_flat = W_source_np.flatten()
    A_source_flat = A_source_np.flatten()
    Y_source_flat = Y_source_np.flatten()

    # Plot histograms
    plt.figure(figsize=(10, 5))

    plt.subplot(2, 3, 1)
    plt.hist(Z_source_flat, bins=20, color='blue', alpha=0.7)
    plt.title('Histogram of Z_source')

    plt.subplot(2, 3, 2)
    plt.hist(epsilon_source_flat, bins=20, color='green', alpha=0.7)
    plt.title('Histogram of epsilon_source')

    plt.subplot(2, 3, 3)
    plt.hist(W_source_flat, bins=20, color='red', alpha=0.7)
    plt.title('Histogram of W_source')

    plt.subplot(2, 3, 4)
    plt.hist(A_source_flat, bins=20, color='purple', alpha=0.7)
    plt.title('Histogram of A_source')

    plt.subplot(2, 3, 5)
    plt.hist(Y_source_flat, bins=20, color='orange', alpha=0.7)
    plt.title('Histogram of Y_source')

    plt.tight_layout()
    plt.show()

In [None]:
plot_histograms(source_data)

In [None]:
plot_histograms(target_data)

In [None]:
print(Y_source.shape)
print(W_source.shape)
print(A_source.shape)
print(Z_source.shape)

print(A_source)

In [None]:
def plot_distributions(Z_source, W_source, A_source, Y_source):
    combinations = list(itertools.product(range(4), range(4)))
    
    plt.figure(figsize=(20, 20))
    
    # Convert one-hot encoded tensors to indices directly
    Z_indices = np.argmax(Z_source, axis=1).flatten()
    A_indices = np.argmax(A_source, axis=1).flatten()
    Y_indices = Y_source.flatten()  # Y_source is already binary, just flatten it
    W_indices = np.argmax(W_source, axis=1).flatten()

    print(Y_indices.shape)
    
    # Print the first few indices for debugging
    print("Z_indices (first 10):", Z_indices[:10])
    print("A_indices (first 10):", A_indices[:10])
    print("Y_indices (first 10):", Y_indices[:10])
    print("W_indices (first 10):", W_indices[:10])
    
    for i, (z_val, a_val) in enumerate(combinations):
        indices = (Z_indices == z_val) & (A_indices == a_val)
        
        Y_subsample = Y_indices[indices]
        W_subsample = W_indices[indices]
        
        plt.subplot(8, 4, 2 * i + 1)
        plt.hist(Y_subsample, bins=np.arange(3) - 0.5, color='orange', alpha=0.7)
        plt.title(f'Y | Z={z_val}, A={a_val}')
        plt.xticks([0, 1])
        
        plt.subplot(8, 4, 2 * i + 2)
        plt.hist(W_subsample, bins=np.arange(5) - 0.5, color='blue', alpha=0.7)
        plt.title(f'W | Z={z_val}, A={a_val}')
        plt.xticks([0, 1, 2, 3])
    
    plt.tight_layout()
    plt.show()

plot_distributions(Z_source, W_source, A_source, Y_source)

In [None]:
print("Z_source shape:", Z_source.shape)
print("A_source shape:", A_source.shape)
print("W_source shape:", W_source.shape)
print("epsilon_source shape:", epsilon_source.shape)
print("Y_source shape:", Y_source.shape)
print("Z_source:", Z_source)
print("A_source:", A_source)
print("W_source:", W_source)
print("epsilon_source:", epsilon_source)
print("Y_source:", Y_source)


In [None]:
num_classes_Y = 2 #AMEND THIS???????????????????????????????
num_classes_W = W_source.shape[1]
num_features_Z = Z_source.shape[1]
num_features_A = A_source.shape[1]

sum_epsilon = np.sum(epsilon_source)
print("Sum of epsilon_source:", sum_epsilon.item())

sum_Y = np.sum(Y_source)
print("Sum of Y_source:", sum_Y.item())

In [None]:
# =============================================================================
# Step 1: Estimate p(Y|Z,a) and p(W|Z,a)
# =============================================================================

# Train model to estimate p(Y|Z,a)
# By stacking with A, we condition on A by including all values of A in the input
ZA_source = np.hstack((Z_source, A_source))  # We go from 4 features to 4 + 4 = 8 features
if step1_debug:
    print("ZA_source.shape", ZA_source.shape)  # Debug print statement
    print("ZA_source", ZA_source)  # Debug print statement

############### LOGISTIC REGRESSION VERSION ###############

#model_Y = LogisticRegression(input_dim=ZA_source.shape[1], num_classes=2)
#model_Y.train(torch.tensor(ZA_source, dtype=torch.float32), torch.tensor(Y_source, dtype=torch.float32))

# model_Y = LogisticRegressionGD(input_dim=ZA_source.shape[1], num_classes=Y_source.shape[1])
# model_Y.train_model(torch.tensor(ZA_source, dtype=torch.float32), torch.tensor(Y_source, dtype=torch.float32), learning_rate=0.01, epochs=100, verbose=True)
# p_Y_given_ZA = get_probabilities(model_Y, Z_source, A_source)

In [None]:
class SimpleNN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 16)
        self.fc5 = nn.Linear(16, output_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)
        x = torch.relu(self.fc3(x))
        x = self.dropout(x)
        x = torch.relu(self.fc4(x))
        x = self.dropout(x)
        x = self.fc5(x)
        return x
    
def train_nn(model, X, Y, learning_rate=0.001, epochs=50, batch_size=32):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
    dataset = torch.utils.data.TensorDataset(X, Y)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    for epoch in range(epochs):
        for batch_X, batch_Y in dataloader:
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_Y)
            loss.backward()
            optimizer.step()
        if (epoch+1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

def eval_nn(model, X, Y):
    with torch.no_grad():
        outputs = model(X)
        _, predicted = torch.max(outputs, 1)
        correct = (predicted == Y).sum().item()
        accuracy = correct / Y.size(0)
    return accuracy

In [None]:
# Define and train the neural network
input_dim = ZA_source.shape[1]
model_Y = SimpleNN(input_dim, num_classes_Y)
train_nn(model_Y, torch.tensor(ZA_source, dtype=torch.float32), torch.tensor(Y_source.flatten(), dtype=torch.long), learning_rate=0.001, epochs=500, batch_size=16)

# Evaluate the model
accuracy_nn = eval_nn(model_Y, torch.tensor(ZA_source, dtype=torch.float32), torch.tensor(Y_source.flatten(), dtype=torch.long))
print(f'Accuracy of neural network model on training set: {accuracy_nn:.4f}')

In [None]:
print(ZA_source.shape)
print(Y_source.shape)

In [None]:
Y_source.sum()

In [None]:
############### SKLEARN VERSION ###############

print("Y_source flatten", Y_source.flatten())

##### LOGISTIC REGRESSION #####
#model_Y = SklearnLogisticRegression(max_iter=1000)
#model_Y.fit(ZA_source, Y_source.flatten())

##### GRAIDENT BOOSTING #####
#model_Y = GradientBoostingClassifier(n_estimators=1000, learning_rate=0.01, max_depth=5, random_state=0)
#model_Y.fit(ZA_source, Y_source.flatten())

##### RANDOM FOREST #####
model_Y = RandomForestClassifier(n_estimators=100, max_depth=3, random_state=0)
model_Y.fit(ZA_source, Y_source.flatten())

##### NEURAL NETWORK #####
#model_Y = MLPClassifier(hidden_layer_sizes=(8, 4), max_iter=1000)
#model_Y.fit(ZA_source, Y_source.flatten())


print(ZA_source)
print(Y_source.flatten())
Y_train_pred = model_Y.predict(ZA_source)
print("Y_train_pred",Y_train_pred)

Y_train_true = Y_source
accuracy_Y_train = np.mean(Y_train_pred == Y_source)
print(f"Accuracy of model_Y on training set: {accuracy_Y_train:.4f}")


In [None]:
cm = confusion_matrix(Y_source.flatten(), model_Y.predict(ZA_source))
print("Classification Report:")
print(classification_report(Y_source.flatten(), model_Y.predict(ZA_source)))

In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Predicted 0', 'Predicted 1'], yticklabels=['Actual 0', 'Actual 1'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix for Y')
plt.show()

In [None]:
def get_probabilities_one_hot_nn(model, Z, A):
    """
    Computes the softmax probabilities from the trained model.
    
    Args:
        model (model): Trained model.
        Z (numpy.ndarray): Feature matrix Z.
        A (numpy.ndarray): Feature matrix A.

    Returns:
        numpy.ndarray: Probability matrix reshaped to (|Z|, |A|, |Y|) or (|Z|, |A|, |W|).
    """
    num_Z = Z.shape[1]
    num_A = A.shape[1]
    num_classes = model.fc5.out_features #number of layers here

    # Generate all possible one-hot vectors for Z
    possible_Z = np.eye(num_Z)
    possible_A = np.eye(num_A)
    
    probabilities = []
    
    for z in possible_Z:
        for a in possible_A:
            ZA = np.hstack((z.reshape(1, -1), a.reshape(1, -1)))
            ZA_tensor = torch.tensor(ZA, dtype=torch.float32)
            with torch.no_grad():
                probs = model(ZA_tensor)
                probs = torch.softmax(probs, dim=1).numpy()
                probabilities.append(probs[0])
    
    probabilities = np.array(probabilities).reshape((num_Z, num_A, num_classes))
    
    return probabilities

In [None]:

p_Y_given_ZA = get_probabilities_one_hot_nn(model_Y, Z_source, A_source)
print("p_Y_given_ZA", p_Y_given_ZA)
print("p_Y_given_ZA shape:", p_Y_given_ZA.shape)  # Debug print statement


In [None]:

# Verify the shape of p_Y_given_ZA
assert p_Y_given_ZA.shape == (num_features_Z, num_features_A, num_classes_Y), f"p_Y_given_ZA shape mismatch: {p_Y_given_ZA.shape}"
assert np.allclose(p_Y_given_ZA.sum(axis=2), 1.0), "p_Y_given_ZA rows do not sum to 1"
print("Step 1: p_Y_given_ZA shape and sum are correct.")



In [None]:
W_source_indices = torch.argmax(torch.tensor(W_source), dim=1)
print(W_source)
print(W_source_indices)

In [None]:
# print(W_source)
# print(np.argmax(W_source, axis=1))

In [None]:
############### LOGISTIC REGRESSION VERSION ###############
#Train model to estimate p(W|Z,a)
# model_W = LogisticRegressionGD(input_dim=ZA_source.shape[1], num_classes=W_source.shape[1])
# model_W.train_model(torch.tensor(ZA_source, dtype=torch.float32), torch.tensor(Y_source, dtype=torch.float32), learning_rate=0.01, epochs=100, verbose=True)

# p_W_given_ZA = get_probabilities_one_hot(model_W, Z_source, A_source)



############### SKLEARN VERSION ###############

model_W = SklearnLogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=200)
model_W.fit(ZA_source, np.argmax(W_source, axis=1))
W_train_pred = model_W.predict(ZA_source)
W_train_true = np.argmax(W_source, axis=1)
accuracy_W_train = np.mean(W_train_pred == W_train_true)
print(f"Accuracy of model_W on training set: {accuracy_W_train:.4f}")
#p_W_given_ZA = model_W.predict_proba(ZA_source).reshape(num_features_Z, num_features_A, num_classes_W)



############### NEURAL NETWORK VERSION ###############

# Define and train the neural network
# input_dim = ZA_source.shape[1]
# model_W = SimpleNN(input_dim, num_classes_W)
# train_nn(model_W, torch.tensor(ZA_source, dtype=torch.float32), W_source_indices, learning_rate=0.001, epochs=500, batch_size=16)

# # Evaluate the model
# accuracy_nn = eval_nn(model_W, torch.tensor(ZA_source, dtype=torch.float32), W_source_indices)
# print(f'Accuracy of neural network model on training set: {accuracy_nn:.4f}')


In [None]:
cm_w = confusion_matrix(W_source_indices.numpy(), model_W.predict(ZA_source))
print("Classification Report for W:")
print(classification_report(W_source_indices.numpy(), model_W.predict(ZA_source)))

In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(cm_w, annot=True, fmt='d', cmap='Blues', xticklabels=['Class 0', 'Class 1', 'Class 2', 'Class 3'], yticklabels=['Class 0', 'Class 1', 'Class 2', 'Class 3'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix for W')
plt.show()

In [None]:
p_W_given_ZA = get_probabilities(model_W, Z_source, A_source)

if step1_debug:
    print("p_W_given_ZA shape:", p_W_given_ZA.shape)  # Debug print statement
    print("p_W_given_ZA", p_W_given_ZA)  # Debug print statement

# Verify the shape of p_W_given_ZA
assert p_W_given_ZA.shape == (num_features_Z, num_features_A, num_classes_W), f"p_W_given_ZA shape mismatch: {p_W_given_ZA.shape}"
assert np.allclose(p_W_given_ZA.sum(axis=2), 1.0), "p_W_given_ZA rows do not sum to 1"
if step1_debug:
    print("Step 1: p_W_given_ZA shape and sum are correct.")

print("STEP 1 DONE")