In [None]:
from collections import defaultdict
from itertools import product
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import os
import pathlib
import seaborn as sns
import sklearn.metrics as skm
import sklearn.utils as skut
import sklearn.preprocessing as skpp
import time
import warnings

from tabpfn import TabPFNClassifier
from sklearn.metrics import accuracy_score, f1_score
import condo

In [None]:
datapath = f"../data"
Xs = np.load(f"{datapath}/scatac_feat.npy")
Xt = np.load(f"{datapath}/scrna_feat.npy")
Zs = np.loadtxt(f"{datapath}/SNAREseq_atac_types.txt").astype(int).astype(str).reshape(-1, 1)
Zt = np.loadtxt(f"{datapath}/SNAREseq_rna_types.txt").astype(int).astype(str).reshape(-1, 1)

Xs = skpp.Normalizer().fit_transform(Xs)
Xt = skpp.Normalizer().fit_transform(Xt)

# paired-sample info
Ys = np.arange(Xs.shape[0]).astype(str).reshape(-1, 1)
Yt = np.arange(Xt.shape[0]).astype(str).reshape(-1, 1)

transform_type = 'affine'

minfos = [
    ('MMD', partial(condo.AdapterMMD, transform_type=transform_type, verbose=0)),
    ('ConDo MMD - cell type', partial(condo.ConDoAdapterMMD, transform_type=transform_type, verbose=0)),
    ('ConDo MMD - cell identity', partial(condo.ConDoAdapterMMD, transform_type=transform_type, verbose=0)),
]
num_random = 10
n_pairs_list = [5, 10, 20, 50, 100]
clf_train = 500

In [None]:
for mname, mfunc in minfos:
    if mname == 'No Adaptation':
        adaptXs = Xs.copy()
    else:
        adapter = mfunc()
        if mname == 'ConDo MMD - cell type':
            adapter.fit(Xs, Xt, Zs, Zt)
        elif mname == 'ConDo MMD - cell identity':
            adapter.fit(Xs, Xt, Ys, Yt)
        elif mname == 'MMD':
            adapter.fit(Xs, Xt)
        adaptXs = adapter.transform(Xs)
    ch_pair = skm.calinski_harabasz_score(
        np.concatenate([adaptXs, Xt], axis=0),
        labels=np.concatenate([Ys.flatten(), Yt.flatten()], axis=0),
    )
    ch_celltype = skm.calinski_harabasz_score(
        np.concatenate([adaptXs, Xt], axis=0),
        labels=np.concatenate([Zs.flatten(), Zt.flatten()], axis=0),
    )
    ch_assay = skm.calinski_harabasz_score(
        np.concatenate([adaptXs, Xt], axis=0),
        labels=np.concatenate([['src']*Zs.shape[0], ['tgt']*Zt.shape[0]], axis=0),
    )
    sil_pair = skm.silhouette_score(
        np.concatenate([adaptXs, Xt], axis=0),
        labels=np.concatenate([Ys.flatten(), Yt.flatten()], axis=0),
    )
    sil_celltype = skm.silhouette_score(
        np.concatenate([adaptXs, Xt], axis=0),
        labels=np.concatenate([Zs.flatten(), Zt.flatten()], axis=0),
    )
    sil_assay = skm.silhouette_score(
        np.concatenate([adaptXs, Xt], axis=0),
        labels=np.concatenate([['src']*Zs.shape[0], ['tgt']*Zt.shape[0]], axis=0),
    )
    print(
        f"{mname}\t "
        f"CH-pair:{ch_pair:.4f} Sil-pair:{sil_pair:.4f} "
        f"CH-cell:{ch_celltype:.4f} Sil-cell:{sil_celltype:.4f} "
        f"CH-assay:{ch_assay:.4f} Sil-assay:{sil_assay:.4f} "
)

In [None]:
chs_pair = {mname: [] for mname, _ in minfos}
chs_celltype = {mname: [] for mname, _ in minfos}
chs_assay = {mname: [] for mname, _ in minfos}
sils_pair = {mname: [] for mname, _ in minfos}
sils_celltype = {mname: [] for mname, _ in minfos}
sils_assay = {mname: [] for mname, _ in minfos}

for rix in range(num_random):
    rng = skut.check_random_state(rix)
    ixs = rng.choice(Xs.shape[0], size=1000, replace=False)
    ixt = rng.choice(Xt.shape[0], size=1000, replace=False)

    for mname, mfunc in minfos:
        adapter = mfunc()
        if mname == 'ConDo MMD - cell type':
            adapter.fit(Xs[ixs, :], Xt[ixt, :], Zs[ixs, :], Zt[ixt, :])
        elif mname == 'ConDo MMD - cell identity':
            adapter.fit(Xs[ixs, :], Xt[ixt, :], Ys[ixs, :], Yt[ixt, :])
        elif mname == 'MMD':
            adapter.fit(Xs[ixs, :], Xt[ixt, :])
        adaptXs = adapter.transform(Xs)
    
        chs_pair[mname].append(skm.calinski_harabasz_score(
            np.concatenate([adaptXs, Xt], axis=0),
            labels=np.concatenate([Ys.flatten(), Yt.flatten()], axis=0),
        ))
        chs_celltype[mname].append(skm.calinski_harabasz_score(
            np.concatenate([adaptXs, Xt], axis=0),
            labels=np.concatenate([Zs.flatten(), Zt.flatten()], axis=0),
        ))
        chs_assay[mname].append(skm.calinski_harabasz_score(
            np.concatenate([adaptXs, Xt], axis=0),
            labels=np.concatenate([['src']*Zs.shape[0], ['tgt']*Zt.shape[0]], axis=0),
        ))
        sils_pair[mname].append(skm.silhouette_score(
            np.concatenate([adaptXs, Xt], axis=0),
            labels=np.concatenate([Ys.flatten(), Yt.flatten()], axis=0),
        ))
        sils_celltype[mname].append(skm.silhouette_score(
            np.concatenate([adaptXs, Xt], axis=0),
            labels=np.concatenate([Zs.flatten(), Zt.flatten()], axis=0),
        ))
        sils_assay[mname].append(skm.silhouette_score(
            np.concatenate([adaptXs, Xt], axis=0),
            labels=np.concatenate([['src']*Zs.shape[0], ['tgt']*Zt.shape[0]], axis=0),
        ))

In [None]:
for mname, mfunc in minfos:
    #print(f"{mname} CH-pair:{np.mean(chs_pair[mname])} ({np.std(chs_pair[mname])})")
    #print(f"{mname} CH-cell:{np.mean(chs_celltype[mname])} ({np.std(chs_celltype[mname])})")
    print(f"{mname} Sil-pair:{np.mean(sils_pair[mname])} ({np.std(sils_pair[mname])})")
    print(f"{mname} Sil-cell:{np.mean(sils_celltype[mname])} ({np.std(sils_celltype[mname])})")
    print(f"{mname} Sil-assay:{np.mean(sils_assay[mname])} ({np.std(sils_assay[mname])})")

In [None]:
pair_labels = np.unique(np.concatenate([Ys, Yt]))
n_total_pairs = pair_labels.size

accs = {n_pairs: {mname: [] for mname, _ in minfos} for n_pairs in n_pairs_list}
f1s = {n_pairs: {mname: [] for mname, _ in minfos} for n_pairs in n_pairs_list}
saccs = {n_pairs: {mname: [] for mname, _ in minfos} for n_pairs in n_pairs_list}
sf1s = {n_pairs: {mname: [] for mname, _ in minfos} for n_pairs in n_pairs_list}

accs_unseen = {n_pairs: {mname: [] for mname, _ in minfos} for n_pairs in n_pairs_list}
f1s_unseen = {n_pairs: {mname: [] for mname, _ in minfos} for n_pairs in n_pairs_list}
saccs_unseen = {n_pairs: {mname: [] for mname, _ in minfos} for n_pairs in n_pairs_list}
sf1s_unseen = {n_pairs: {mname: [] for mname, _ in minfos} for n_pairs in n_pairs_list}

clf = TabPFNClassifier()
for rix in range(num_random):
    rng = skut.check_random_state(rix)
    clf_train_pairlabels = rng.choice(pair_labels, size=clf_train, replace=False)
    clf_train_sixs, = (Ys == clf_train_pairlabels).any(axis=1).nonzero()
    remaining_pairlabels = np.setdiff1d(pair_labels, clf_train_pairlabels)
    for n_pairs in n_pairs_list:
        cur_pairlabels = rng.choice(remaining_pairlabels, size=n_pairs, replace=False)
        cur_sixs, = (Ys == cur_pairlabels).any(axis=1).nonzero()
        cur_tixs, = (Yt == cur_pairlabels).any(axis=1).nonzero()

        unseen_pairlabels = np.setdiff1d(remaining_pairlabels, cur_pairlabels)
        unseen_tixs = (Yt == unseen_pairlabels).any(axis=1).nonzero()
        for mname, mfunc in minfos:
            s2tadapter = mfunc()
            t2sadapter = mfunc()
            if mname == 'ConDo MMD - cell type':
                s2tadapter.fit(Xs[cur_sixs, :], Xt[cur_tixs, :], Zs[cur_sixs, :], Zt[cur_tixs, :])
                t2sadapter.fit(Xt[cur_tixs, :], Xs[cur_sixs, :], Zt[cur_tixs, :], Zs[cur_sixs, :])
            elif mname == 'ConDo MMD - cell identity':
                s2tadapter.fit(Xs[cur_sixs, :], Xt[cur_tixs, :], Ys[cur_sixs, :], Yt[cur_tixs, :])
                t2sadapter.fit(Xt[cur_tixs, :], Xs[cur_sixs, :], Yt[cur_tixs, :], Ys[cur_sixs, :])                
            elif mname == 'MMD':
                s2tadapter.fit(Xs[cur_sixs, :], Xt[cur_tixs, :])
                t2sadapter.fit(Xt[cur_tixs, :], Xs[cur_sixs, :])
            adaptXs = s2tadapter.transform(Xs)
            adaptXt = t2sadapter.transform(Xt)
            
            clf.fit(adaptXs[clf_train_sixs,:], Zs[clf_train_sixs,:].ravel())
            pred_tgt = clf.predict(Xt)
            accs[n_pairs][mname].append(accuracy_score(pred_tgt, Zt.ravel()))
            f1s[n_pairs][mname].append(f1_score(pred_tgt, Zt.ravel(), average='macro'))
            accs_unseen[n_pairs][mname].append(
                accuracy_score(pred_tgt[unseen_tixs], Zt.ravel()[unseen_tixs]))
            f1s_unseen[n_pairs][mname].append(
                f1_score(pred_tgt[unseen_tixs], Zt.ravel()[unseen_tixs], average='macro'))
            
            clf.fit(Xs[clf_train_sixs,:], Zs[clf_train_sixs,:].ravel())
            pred_tgt = clf.predict(adaptXt)
            saccs[n_pairs][mname].append(accuracy_score(pred_tgt, Zt.ravel()))
            sf1s[n_pairs][mname].append(f1_score(pred_tgt, Zt.ravel(), average='macro'))
            saccs_unseen[n_pairs][mname].append(
                accuracy_score(pred_tgt[unseen_tixs], Zt.ravel()[unseen_tixs]))
            sf1s_unseen[n_pairs][mname].append(
                f1_score(pred_tgt[unseen_tixs], Zt.ravel()[unseen_tixs], average='macro'))

In [None]:
for metricname, metric in [('Accuracy', accs), ('F1-score', f1s)]: 
    plt.figure(figsize=(3,2));
    color_dict = {'MMD': 'red', 'ConDo MMD': 'magenta'}
    color_dict = {'MMD': 'red', 'ConDo MMD - cell type': 'magenta', 'ConDo MMD - cell identity': 'cyan'}
    for mname, _ in minfos:
        cur_vals = [metric[n_pairs][mname] for n_pairs in n_pairs_list]
        plt.errorbar(
            n_pairs_list,
            [np.mean(cur) for cur in cur_vals],
            [np.std(cur)/np.sqrt(num_random) for cur in cur_vals],
            label=mname, color=color_dict[mname], alpha=0.5,
        );
    plt.xscale('log');
    plt.xticks(n_pairs_list, n_pairs_list);
    plt.minorticks_off();
    plt.ylabel(metricname);
    plt.xlabel('Number of pairs');
    plt.legend(loc='lower right');
    plt.tight_layout();
    plt.savefig(f'SNAREseq-{metricname}-tabpfnOnAdapted.pdf');

In [None]:
for metricname, metric in [('Accuracy', saccs), ('F1-score', sf1s)]: 
    plt.figure(figsize=(3,2));
    color_dict = {'MMD': 'red', 'ConDo MMD': 'magenta'}
    color_dict = {'MMD': 'red', 'ConDo MMD - cell type': 'magenta', 'ConDo MMD - cell identity': 'cyan'}
    for mname, _ in minfos:
        cur_vals = [metric[n_pairs][mname] for n_pairs in n_pairs_list]
        plt.errorbar(
            n_pairs_list,
            [np.mean(cur) for cur in cur_vals],
            [np.std(cur)/np.sqrt(num_random) for cur in cur_vals],
            label=mname, color=color_dict[mname], alpha=0.5,
        );
    plt.xscale('log');
    plt.xticks(n_pairs_list, n_pairs_list);
    plt.minorticks_off();
    plt.ylabel(metricname);
    plt.xlabel('Number of pairs');
    plt.legend(loc='lower right');
    plt.tight_layout();
    plt.savefig(f'SNAREseq-{metricname}-tabpfn.pdf');

In [None]:
fig, axes = plt.subplots(figsize=(8,2.5), ncols=2);
plt.locator_params(nbins=4)
for mix, (metricname, metric) in enumerate([('Accuracy', accs_unseen), ('F1-score', f1s_unseen)]):
    for mname, _ in minfos:
        cur_vals = [metric[n_pairs][mname] for n_pairs in n_pairs_list]
        axes[mix].errorbar(
            n_pairs_list,
            [np.mean(cur) for cur in cur_vals],
            [np.std(cur)/np.sqrt(num_random) for cur in cur_vals],
            label=mname, alpha=0.8, capsize=2,
        );
    axes[mix].set_xscale('log');
    axes[mix].set_xticks(n_pairs_list, n_pairs_list);
    axes[mix].minorticks_off();
    axes[mix].set_ylabel(metricname);
    axes[mix].set_xlabel('Number of pairs');
plt.legend(bbox_to_anchor=(1., 0.7));
plt.tight_layout();
plt.savefig(f'SNAREseq-tabpfn-unseen-OnAdapted-wcellidentity.pdf');

In [None]:
fig, axes = plt.subplots(figsize=(8,2.5), ncols=2);
plt.locator_params(nbins=4)
for mix, (metricname, metric) in enumerate([('Accuracy', saccs_unseen), ('F1-score', sf1s_unseen)]):
    for mname, _ in minfos:
        cur_vals = [metric[n_pairs][mname] for n_pairs in n_pairs_list]
        axes[mix].errorbar(
            n_pairs_list,
            [np.mean(cur) for cur in cur_vals],
            [np.std(cur)/np.sqrt(num_random) for cur in cur_vals],
            label=mname, alpha=0.8, capsize=2,
        );
    axes[mix].set_xscale('log');
    axes[mix].set_xticks(n_pairs_list, n_pairs_list);
    axes[mix].minorticks_off();
    axes[mix].set_ylabel(metricname);
    axes[mix].set_xlabel('Number of pairs');
plt.legend(bbox_to_anchor=(1., 0.7));
plt.tight_layout();
plt.savefig(f'SNAREseq-tabpfn-unseen-wcellidentity.pdf');

In [None]:
fig, axes = plt.subplots(figsize=(8,2.5), ncols=2);
plt.locator_params(nbins=4)
for mix, (metricname, metric) in enumerate([('Accuracy', accs_unseen), ('F1-score', f1s_unseen)]):
    for mname, _ in minfos[0:2]:
        cur_vals = [metric[n_pairs][mname] for n_pairs in n_pairs_list]
        axes[mix].errorbar(
            n_pairs_list,
            [np.mean(cur) for cur in cur_vals],
            [np.std(cur)/np.sqrt(num_random) for cur in cur_vals],
            label=mname[0:9], alpha=0.8, capsize=2,
        );
    axes[mix].set_xscale('log');
    axes[mix].set_xticks(n_pairs_list, n_pairs_list);
    axes[mix].minorticks_off();
    axes[mix].set_ylabel(metricname);
    axes[mix].set_xlabel('Number of pairs');
plt.legend(bbox_to_anchor=(1., 0.7));
plt.tight_layout();
plt.savefig(f'SNAREseq-tabpfn-unseen-OnAdapted.pdf');

In [None]:
fig, axes = plt.subplots(figsize=(8,2.5), ncols=2);
plt.locator_params(nbins=4)
for mix, (metricname, metric) in enumerate([('Accuracy', saccs_unseen), ('F1-score', sf1s_unseen)]):
    for mname, _ in minfos[0:2]:
        cur_vals = [metric[n_pairs][mname] for n_pairs in n_pairs_list]
        axes[mix].errorbar(
            n_pairs_list,
            [np.mean(cur) for cur in cur_vals],
            [np.std(cur)/np.sqrt(num_random) for cur in cur_vals],
            label=mname[0:9], alpha=0.8, capsize=2,
        );
    axes[mix].set_xscale('log');
    axes[mix].set_xticks(n_pairs_list, n_pairs_list);
    axes[mix].minorticks_off();
    axes[mix].set_ylabel(metricname);
    axes[mix].set_xlabel('Number of pairs');
plt.legend(bbox_to_anchor=(1., 0.7));
plt.tight_layout();
plt.savefig(f'SNAREseq-tabpfn-unseen.pdf');

In [None]:
Xs.shape[0]