In [None]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
from timebudget import timebudget

from cplAE_MET.utils.load_config import load_config
from cplAE_MET.utils.dataset import load_MET_dataset 
from cplAE_MET.utils.dataclass import met_dataclass

from cplAE_MET.models.subnetworks_M import AE_M
from cplAE_MET.models.subnetworks_E import AE_E
from cplAE_MET.models.subnetworks_ME import AE_ME_int
from cplAE_MET.models.subnetworks_T import AE_T
from cplAE_MET.models.torch_utils import min_var_loss

dir_pth = load_config('config.toml')
D = load_MET_dataset(dir_pth['MET_data'])
dat = met_dataclass(D)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# E data -----------
xe = torch.as_tensor(dat.XE).float().to(device)

# get indices of valid data points before nan -> 0.0
valid_xe = ~torch.isnan(xe)
xe = torch.nan_to_num(xe)
gnoise_std = torch.var(xe, dim=0, keepdim=True).sqrt()

# M data -----------
xm = np.expand_dims(dat.XM_centered,axis=1)
xm = torch.as_tensor(xm).float().to(device)
xsd = torch.as_tensor(dat.Xsd).float().to(device)

# get indices of valid data points before nan -> 0.0
valid_xm = ~torch.isnan(xm)
valid_xsd = ~torch.isnan(xsd)

# xm and xsd nan -> 0.0
xm = torch.nan_to_num(xm)
xsd = torch.nan_to_num(xsd)

# T data -----------
xt = torch.as_tensor(dat.XT).float().to(device)
valid_xt = ~torch.isnan(xt)
xt = torch.nan_to_num(xt)


In [None]:
from itertools import chain

# Model init -----------
model_config = dict(latent_dim=2,
                    T=dict(dropout_p=0.2),
                    E=dict(gnoise_std_frac=0.05,
                           dropout_p=0.2))
ae_t = AE_T(config=model_config)
ae_e = AE_E(config=model_config, gnoise_std=gnoise_std)
ae_m = AE_M(config=model_config)
ae_me = AE_ME_int(config=model_config)

ae_t.to(device)
ae_m.to(device)
ae_e.to(device)
ae_me.to(device)

optimizer = torch.optim.Adam(chain(ae_e.parameters(),ae_m.parameters(), ae_me.parameters(), ae_t.parameters()), lr=0.001)

In [None]:
# do this for every batch (if using batching)
valid_m_cells = ~torch.all(~valid_xm.view(valid_xm.shape[0], -1), dim=1)
valid_e_cells = ~torch.all(~valid_xe.view(valid_xe.shape[0], -1), dim=1)
valid_t_cells = ~torch.all(~valid_xt.view(valid_xt.shape[0], -1), dim=1)
paired_me = torch.logical_and(valid_m_cells, valid_e_cells)
paired_met = torch.logical_and(paired_me, valid_t_cells)
paired_me_ind = torch.squeeze(torch.nonzero(paired_met[paired_me]))

print(f'valid_m_cells cells {torch.sum(valid_m_cells.int())}')
print(f'valid_e_cells cells {torch.sum(valid_e_cells.int())}')
print(f'valid_t_cells cells {torch.sum(valid_t_cells.int())}')
print(f'paired_me cells {torch.sum(paired_me.int())}')
print(f'paired_met cells {torch.sum(paired_met.int())}')

In [None]:
# Train loop -----------
for step in range(20000):
    optimizer.zero_grad()

    # t arm
    zt = ae_t.enc_xt_to_zt(xt)
    xrt = ae_t.dec_zt_to_xt(zt)

    # e arm
    ze_int_enc = ae_e.enc_xe_to_ze_int(xe)
    ze = ae_e.enc_ze_int_to_ze(ze_int_enc.detach())
    ze_int_dec = ae_e.dec_ze_to_ze_int(ze)
    with torch.no_grad():
        xre = ae_e.dec_ze_int_to_xe(ze_int_dec)

    # m arm
    zm_int_enc = ae_m.enc_xm_to_zm_int(xm, xsd)
    zm = ae_m.enc_zm_int_to_zm(zm_int_enc.detach())
    zm_int_dec = ae_m.dec_zm_to_zm_int(zm)
    with torch.no_grad():
        xrm, xrsd = ae_m.dec_zm_int_to_xm(zm_int_dec,
                                        ae_m.enc_xm_to_zm_int.pool_0_ind,
                                        ae_m.enc_xm_to_zm_int.pool_1_ind)

    # me arm
    zme_paired = ae_me.enc_zme_int_to_zme(zm_int_enc[paired_me,...], ze_int_enc[paired_me,...])
    zm_int_dec_paired, ze_int_dec_paired = ae_me.dec_zme_to_zme_int(zme_paired)

    # connect back to m and e
    xre_me_paired = ae_e.dec_ze_int_to_xe(ze_int_dec_paired)
    xrm_me_paired, xrsd_me_paired = ae_m.dec_zm_int_to_xm(zm_int_dec_paired,
                                        ae_m.enc_xm_to_zm_int.pool_0_ind[paired_me,...],
                                        ae_m.enc_xm_to_zm_int.pool_1_ind[paired_me,...])

    loss_rec_t = torch.mean(torch.masked_select(torch.square(xt-xrt), valid_xt))
    loss_rec_e = torch.mean(torch.masked_select(torch.square(xe-xre), valid_xe))
    loss_rec_m = torch.mean(torch.masked_select(torch.square(xm-xrm), valid_xm))
    loss_rec_sd = torch.mean(torch.masked_select(torch.square(xsd-xrsd), valid_xsd))

    loss_rec_m_me = torch.mean(torch.masked_select(torch.square(xm[paired_me,...]-xrm_me_paired), valid_xm[paired_me,...]))
    loss_rec_e_me = torch.mean(torch.masked_select(torch.square(xe[paired_me,...]-xre_me_paired), valid_xe[paired_me,...]))
    loss_rec_sd_me = torch.mean(torch.masked_select(torch.square(xsd[paired_me,...]-xrsd_me_paired), valid_xsd[paired_me,...]))

    # calculate on me cells
    loss_cpl_me_m = torch.mean(torch.square(zme_paired.detach() - ze[paired_me,...]))
    loss_cpl_me_e = torch.mean(torch.square(zme_paired.detach() - zm[paired_me,...]))

    # calculate only on met cells
    zt_paired_met = zt[paired_me,...]
    zt_paired_met = zt_paired_met[paired_me_ind,...]
    zme_paired_met = zme_paired[paired_me_ind,...]
    loss_cpl_me_t = min_var_loss(zme_paired_met, zt_paired_met)

    loss = loss_rec_e + loss_rec_m + 0*loss_rec_sd \
        + loss_rec_e_me + loss_rec_m_me + 0*loss_rec_sd_me \
        + loss_cpl_me_m + loss_cpl_me_e + loss_cpl_me_t

    loss.backward()
    optimizer.step()
    if (step+1) % 500 == 0:
        print(f'step: {step} loss: {loss.to("cpu").detach().numpy()}')
        zt = zt.detach().to('cpu').numpy()
        ze = ze.detach().to('cpu').numpy()
        zm = zm.detach().to('cpu').numpy()
        zme_paired = zme_paired.detach().to('cpu').numpy()
        f,ax = plt.subplots(1,4, figsize=(8,2))
        ax[0].scatter(zm[:,0],zm[:,1],c=dat.cluster_color,s=1)
        ax[0].set(title='M')
        ax[1].scatter(ze[:,0],ze[:,1],c=dat.cluster_color,s=1)
        ax[1].set(title='E')
        ax[2].scatter(zme_paired[:,0],zme_paired[:,1],c=dat.cluster_color[paired_me.detach().to('cpu').numpy()],s=1)
        ax[2].set(title='ME')
        ax[3].scatter(zt[:,0],zt[:,1],c=dat.cluster_color,s=1)
        ax[3].set(title='T')
        plt.tight_layout()
        plt.show()

In [None]:
df = pd.DataFrame(dict(ze0=ze[:, 0], ze1=ze[:, 1],
                       zm0=zm[:, 0], zm1=zm[:, 1],
                       paired=paired_me.detach().to('cpu').numpy().astype(str),
                       color=dat.cluster_color,
                       id=dat.specimen_id,
                       cluster=dat.cluster_label))


In [None]:

fig = px.scatter(df,
                 x="zm0", y="zm1", color="paired",
                 hover_data=["id", "cluster"],
                 range_x=(-5,5),
                 range_y=(-5,5),
                 width=400,height=300)
fig.show()

fig = px.scatter(df,
                 x="ze0", y="ze1", color="paired",
                 hover_data=["id", "cluster"],
                 range_x=(-5,5),
                 range_y=(-5,5),
                 width=400,height=300)
fig.show()

In [None]:

fig = px.scatter(df,
                 x="zm0", y="zm1", color="cluster",
                 hover_data=["id", "cluster"],
                 color_discrete_map=dict(zip(df.cluster, df.color)),
                 range_x=(-5,5),
                 range_y=(-5,5),
                 width=800,height=600)
fig.show()

fig = px.scatter(df,
                 x="ze0", y="ze1", color="cluster",
                 hover_data=["id", "cluster"],
                 color_discrete_map=dict(zip(df.cluster, df.color)),
                 range_x=(-5,5),
                 range_y=(-5,5),
                 width=800,height=600)
fig.show()

In [None]:
df = pd.DataFrame(dict(zme0=zme_paired[:, 0], zme1=zme_paired[:, 1],
                       color=dat.cluster_color[paired_me.detach().to('cpu').numpy()],
                       id=dat.specimen_id[paired_me.detach().to('cpu').numpy()],
                       cluster=dat.cluster_label[paired_me.detach().to('cpu').numpy()]))

fig = px.scatter(df,
                 x="zme0", y="zme1", color="cluster",
                 hover_data=["id", "cluster"],
                 color_discrete_map=dict(zip(df.cluster, df.color)),
                 range_x=(-5,5),
                 range_y=(-5,5),
                 width=800,height=600)
fig.show()