In [23]:
import torch.nn.functional as F
import torch
from main_merck_all_real import get_dataset
from utils import set_seed, get_optimizer, InfIterator
from arguments import get_arguments
from main_origin import get_model
from setenc import get_mixer

import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
    

In [24]:
def calc_loss(y_hat, y, test=False):
    return F.mse_loss(y.cuda().squeeze(), y_hat.cuda().squeeze())

def test(args, dataloader, contextloader=None, model=None, mixer_phi=None, embed_type=None, n_t=10, n_c=5):
    model.eval()
    mixer_phi.eval()
    embedding_list = []
    label_list= []
    loss_list = []
    # print('model ', model)
    # print('mixer_phi ', mixer_phi)
    if embed_type == "train_none":
        with torch.no_grad():
            losses = []
            counts = 0
            for i, (x, y) in enumerate(dataloader):
                if i == n_t:
                    break
                
                if i == 0:
                    torch.save(x, 't_x.pt')
                    torch.save(y, 't_y.pt')
                #     return
                

                y_hat, embedding_list, label_list = model(x=x.to(args.device), context=None, mixer_phi=mixer_phi, embedding_list=embedding_list, label_list=label_list, embed_type=embed_type, embed_test=args.embed_test)

                y = y.cuda().squeeze()
                y_hat = y_hat.cuda().squeeze()

                # y_hat = y_hat[:, 0]
                # print(f"in test: {y.size()=} {y_hat.size()=}")

                loss = calc_loss(y_hat, y, test=True)
                loss_scalar = loss.detach().item()
                loss_list.append(torch.full((x.shape[0],), loss_scalar))

                # print('loss_scalar ', loss_scalar)
                # print('x ', x.size(0))
                # print('y_hat ', y_hat)
                # print('embedding_list ', embedding_list)
                losses.append(loss_scalar * x.size(0))
                counts += x.size(0)
                # if i == 0:
                #     torch.save(x, 'tn_x.pt')
                #     torch.save(y, 'tn_y.pt')
        # self.model.eval()
    elif embed_type == "ood1_none" or embed_type == "ood2_none":
        with torch.no_grad():
            losses = []
            counts = 0
            for i, (x, y) in enumerate(dataloader):
                if i == n_t:
                    break
                
                # if i == 0:
                #     torch.save(x, 't_x.pt')
                #     torch.save(y, 't_y.pt')
                #     return
                

                y_hat, embedding_list, label_list = model(x=x.to(args.device), context=None, mixer_phi=mixer_phi, embedding_list=embedding_list, label_list=label_list, embed_type=embed_type, embed_test=args.embed_test)

                y = y.cuda().squeeze()
                y_hat = y_hat.cuda().squeeze()

                # y_hat = y_hat[:, 0]
                # print(f"in test: {y.size()=} {y_hat.size()=}")

                loss = calc_loss(y_hat, y, test=True)
                loss_scalar = loss.detach().item()
                loss_list.append(torch.full((x.shape[0],), loss_scalar))

                # print('loss_scalar ', loss_scalar)
                # print('x ', x.size(0))
                # print('y_hat ', y_hat)
                # print('embedding_list ', embedding_list)
                losses.append(loss_scalar * x.size(0))
                counts += x.size(0)
                # if i == 0:
                #     torch.save(x, 'tn_x.pt')
                #     torch.save(y, 'tn_y.pt')
        # self.model.eval()
    elif embed_type == "train_context":
        with torch.no_grad():
            losses = []
            counts = 0
            for i, (x, y) in enumerate(dataloader):
                if i == n_t:
                    break
                
                context_samples = []
                
                if i == 0:
                    torch.save(x, 'tc_x.pt')
                    torch.save(y, 'tc_y.pt')

                for i_c, (x_c, y_c) in enumerate(contextloader):
                    if i_c == n_c:
                        break
                    x_c = x_c.reshape(args.batch_size, -1, x_c.size(-1))
                    if args.n_context > 1:
                        n = torch.randint(1, x_c.size(1), size=(1,)).item()
                        x_c = x_c[:, :n]
                        
                    if i == 0 and i_c == 0:
                        torch.save(x_c, 'tc_x_c.pt')
                        torch.save(y_c, 'tc_y_c.pt')
                        
                        
                    # print('$$$$$$')
                    # print('x_c ', x_c.shape)
                    # print('x ', x.shape)
                    
                    # return
                    # context_samples.append(x_c)
                
                # context_samples = torch.cat(context_samples, dim=0).to(args.device)
                # context_samples = context_samples.reshape(args.batch_size, -1, x_c.size(-1))
                
                # print('### context_samples ', context_samples)
                # torch.save(context_samples, f'context_samples_{i}.pth')
                    y_hat, embedding_list, label_list = model(x=x.to(args.device), context=x_c.to(args.device), mixer_phi=mixer_phi, embedding_list=embedding_list, label_list=label_list, embed_type=embed_type, embed_test=args.embed_test)

                    y = y.cuda().squeeze()
                    y_hat = y_hat.cuda().squeeze()

                    # y_hat = y_hat[:, 0]
                    # print(f"in test: {y.size()=} {y_hat.size()=}")

                    loss = calc_loss(y_hat, y, test=True)
                    loss_list.append(torch.full((x.shape[0],), loss.detach().item()))

                    losses.append(loss.item() * x.size(0))
                    counts += x.size(0)
    elif embed_type == "context_none":
        with torch.no_grad():
            losses = []
            counts = 0
            for i, (x, y) in enumerate(dataloader):
                if i == n_t:
                    break
                
                context_samples = []
                
                if i == 0:
                    torch.save(x, 'c_x.pt')
                    torch.save(y, 'c_y.pt')

                for i_c, (x_c, y_c) in enumerate(contextloader):
                    if i_c == n_c:
                        break
                    x_c = x_c.reshape(args.batch_size, -1, x_c.size(-1))
                    if args.n_context > 1:
                        n = torch.randint(1, x_c.size(1), size=(1,)).item()
                        x_c = x_c[:, :n]
                    
                    if i == 0 and i_c == 0:
                        torch.save(x_c, 'c_x_c.pt')
                        torch.save(y_c, 'c_y_c.pt')
                    # print('====')
                    # print('x_c ', x_c.shape)
                    # print('x ', x.shape)
                    
                    # return
                    
                    # context_samples.append(x_c)
                    
                    # B, S, H = x_c.size() <- did this in model
                    # x_c = x_c.view(B*S, H)
                
                # context_samples = torch.cat(context_samples, dim=0).to(args.device)
                # context_samples = context_samples.reshape(args.batch_size, -1, x_c.size(-1))
                
                # print('### context_samples ', context_samples)
                # torch.save(context_samples, f'context_samples_{i}.pth')
                    y_hat, embedding_list, label_list = model(x=x_c.to(args.device), context=x_c.to(args.device), mixer_phi=mixer_phi, embedding_list=embedding_list, label_list=label_list, embed_type=embed_type, embed_test=args.embed_test)

                    y = y.cuda().squeeze()
                    y_hat = y_hat.cuda().squeeze()

                    # y_hat = y_hat[:, 0]
                    # print(f"in test: {y.size()=} {y_hat.size()=}")

                    loss = calc_loss(y_hat, y, test=True)
                    loss_list.append(torch.full((x.shape[0],), loss.detach().item()))

                    losses.append(loss.item() * x.size(0))
                    counts += x.size(0)
    mse = sum(losses) / counts
    return mse, embedding_list, label_list, loss_list

In [25]:

def save_one_datapoint_features(args, model, mixer_phi, embed_type, n_t, n_c):
    assert args.model_no_context == False
    
    args.embed_test = 'lastlayer_ours_best'
    
    path = f"/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNct{args.model_no_context}_RYV1/{args.embed_test}_"
    os.makedirs(path, exist_ok=True)
    
    
    if args.seed == 42:
        f_path = f'/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNct{args.model_no_context}_RYV1/{args.embed_test}_/{args.sencoder}_{args.dataset}_{args.vec_type}_{n_t}_{n_c}_{embed_type}.npz'
    else:
        f_path = f'/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNct{args.model_no_context}_RYV1/{args.embed_test}_/{args.sencoder}_{args.dataset}_{args.vec_type}_{n_t}_{n_c}_{embed_type}_{args.seed}.npz'
    
    if os.path.exists(f_path):
        print(f"⏩ {f_path} already exists. Skipping...") 
        return
    
    if args.seed == 42:
        set_seed(0)
    else:
        set_seed(args.seed)
    args.batch_size = 1
    args.tsne_plot = True # because of get_dataset
    all_candidates = ['hivprot', 'dpp4', 'nk1']
    args.specify_ood_dataset = [d for d in all_candidates if d != args.dataset]
    trainloader_test, _, mvalidloader_test, _, contextloader_test, ood1_trainloader_test, ood2_trainloader_test = get_dataset(args=args, test=True)
    
    if "ood" not in embed_type:
        mse, embedding_list, label_list, loss_list = test(args=args, dataloader=trainloader_test, contextloader=contextloader_test, model=model, mixer_phi=mixer_phi, embed_type=embed_type, n_t=n_t, n_c=n_c)
    elif 'ood1' in embed_type:
        mse, embedding_list, label_list, loss_list = test(args=args, dataloader=ood1_trainloader_test, contextloader=contextloader_test, model=model, mixer_phi=mixer_phi, embed_type=embed_type, n_t=n_t, n_c=n_c)
    elif 'ood2' in embed_type:
        mse, embedding_list, label_list, loss_list = test(args=args, dataloader=ood2_trainloader_test, contextloader=contextloader_test, model=model, mixer_phi=mixer_phi, embed_type=embed_type, n_t=n_t, n_c=n_c)
    else:
        raise Exception()
    
    all_embeddings = torch.cat(embedding_list, dim=0)
    all_labels = np.concatenate(label_list, axis=0)
    all_losses = torch.cat(loss_list, dim=0)

    all_labels = torch.tensor(all_labels)
    all_losses = torch.tensor(all_losses)
                
    all_embeddings_np = all_embeddings.numpy()
    all_labels_np = all_labels.numpy()
    all_losses_np = all_losses.numpy()
    
    
    np.savez(f_path, embeddings=all_embeddings_np, labels=all_labels_np, losses=all_losses_np)
    print(f'>>> saved {f_path}')
    
    if 'ood1' in embed_type:
        ood1_trainloader_test._iterator._shutdown_workers()
    elif 'ood2' in embed_type:
        ood2_trainloader_test._iterator._shutdown_workers()
    else:
        trainloader_test._iterator._shutdown_workers()
        if 'context' in embed_type:
            contextloader_test._iterator._shutdown_workers()


In [26]:
def save_features(data):

    model = data['model']
    mixer_phi = data['mixer_phi']
    optimizer = data['optimizer']
    mixer_optimizer = data['mixer_optimizer']
    
    ltmse, lvmse, vmse, tmse = data['ltmse'], data['lvmse'], data['vmse'], data['tmse ']
    
    print('>> ltmse ', ltmse)
    print('>> tmse ', tmse)
    print('>> lvmse ', lvmse)
    print('>> vmse ', vmse)
    
    args_ = data['args']

    args = get_arguments()

    for k, v in args_.items():
        setattr(args, k, v)

    if args.seed == 42:
        set_seed(0)
    else:
        set_seed(args.seed)

    model = get_model(args=args)
    mixer_phi = get_mixer(args=args)
    # optimizer = get_optimizer(optimizer=args.optimizer, model=model, lr=args.lr, wd=args.wd)
    # optimizermixer = None if mixer_phi is None else get_optimizer(optimizer=args.optimizer, model=mixer_phi, lr=args.clr, wd=args.cwd)

    model.load_state_dict(data['model'])
    mixer_phi.load_state_dict(data['mixer_phi'])
    # optimizer.load_state_dict(data['optimizer'])
    # mixer_optimizer.load_state_dict(data['mixer_optimizer'])

    model = model.to(args.device)
    mixer_phi = mixer_phi.to(args.device)
    
    # for (n_t, n_c) in [(5, 10), (5, 100), (10, 5), (10, 100), (100, 5), (100, 10)]:
    # for (n_t, n_c) in [(5, 10), (5, 50), (5, 100), (10, 5), (10, 40), (10, 100), (40, 5), (40, 10), (100, 5),]:
    for (n_t, n_c) in [(10, 100),]:
        save_one_datapoint_features(args, model, mixer_phi, "train_none", n_t, n_c)
        save_one_datapoint_features(args, model, mixer_phi, "context_none", n_t, n_c)
        save_one_datapoint_features(args, model, mixer_phi, "train_context", n_t, n_c)
    save_one_datapoint_features(args, model, mixer_phi, "ood1_none", 500, 10)
    save_one_datapoint_features(args, model, mixer_phi, "ood2_none", 500, 10)

In [27]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [28]:
import os
import torch

# Directory containing .pth files
tsne_model_dir = '/c2/jinakim/Drug_Discovery_j/tsne_model2_mNctFalse_RYV1_mixTrue/ours_best/'
os.makedirs(tsne_model_dir, exist_ok=True)
# List all .pth files
# pth_files = ["Model_dsets_nk1_bit_['3a4', 'cb1'].pth"]
# pth_files = ["Model_dsets_nk1_bit_['hivint', 'tdi'].pth"]
pth_files = sorted([f for f in os.listdir(tsne_model_dir) if f.endswith('.pth')])

# Load each file
for i, f in enumerate(pth_files):
    # if i > 5:
    #     break
    file_path = os.path.join(tsne_model_dir, f)
    print(f"🚀 Loading {file_path}")
    
    data = torch.load(file_path)
    save_features(data)
    # Now 'data' contains the loaded model or state dict or whatever was saved
    # You can process it here if needed
    # For example, just printing some keys if it's a checkpoint
    # if isinstance(data, dict):
    #     print(f"✅ Loaded {f}: keys = {list(data.keys())}")
    # else:
    #     print(f"✅ Loaded {f}: type = {type(data)}")

print("\n🏁 All models loaded.")




🚀 Loading /c2/jinakim/Drug_Discovery_j/tsne_model2_mNctFalse_RYV1_mixTrue/ours_best/Model_dsets_nk1_bit_['3a4', 'cb1']_0.pth
>> ltmse  0.4183902328160928
>> tmse  0.4232667506366655
>> lvmse  0.2904639701048533
>> vmse  0.28735986749331155
loading deepsets
DSEncoder(
  (encoder): Sequential(
    (0): PermEquiMax(
      (Gamma): Linear(in_features=512, out_features=512, bias=True)
      (Lambda): Linear(in_features=512, out_features=512, bias=False)
    )
  )
)
Inner args.dataset='nk1' args.vec_type='bit'


  all_losses = torch.tensor(all_losses)


>>> saved /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_nk1_bit_10_100_train_none_0.npz
Inner args.dataset='nk1' args.vec_type='bit'
>>> saved /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_nk1_bit_10_100_context_none_0.npz
Inner args.dataset='nk1' args.vec_type='bit'
>>> saved /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_nk1_bit_10_100_train_context_0.npz
Inner args.dataset='nk1' args.vec_type='bit'
>>> saved /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_nk1_bit_500_10_ood1_none_0.npz
Inner args.dataset='nk1' args.vec_type='bit'
>>> saved /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_nk1_bit_500_10_ood2_none_0.npz
🚀 Loading /c2/jinakim/Drug_Discovery_j/tsne_model2_mNctFalse_RYV1_mixTrue/ours_best/Model_dsets_nk1_bit_['3a4', 'cb1']_1.pth
>>

In [84]:
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# Try to use GPU TSNE (cuML), fallback to CPU TSNE (openTSNE)
# try:
#     import cupy as cp
#     from cuml.manifold import TSNE as cuTSNE
#     gpu_available = True
#     print("✅ Using GPU cuML TSNE")
# except ImportError:
from openTSNE import TSNE as cpuTSNE
gpu_available = False
print("⚠️ cuML not available, falling back to CPU openTSNE")

# Set your directory
tsne_dir = '/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/'

# List all .npz files
files = sorted([f for f in os.listdir(tsne_dir) if f.endswith('.npz')])

# --- Group files by their starting prefix (before 3rd underscore) ---
groups = defaultdict(list)
for f in files:
    parts = f.split('_')
    prefix = '_'.join(parts[:5])  # e.g., dsets_dpp4_count
    groups[prefix].append(f)

# --- Process each group ---
for prefix, group_files in groups.items():
    save_path = os.path.join(tsne_dir, f"{prefix}_tsne_combined.pdf")
    
    # if "dpp4_bit" in save_path:
    #     continue
    # --- Skip if already exists ---
    if os.path.exists(save_path):
        print(f"⏩ {save_path} already exists. Skipping...")
        continue

    print(f"\n🚀 Processing group {prefix}")

    all_embeddings = []
    all_labels = []

    for f in group_files:
        file_path = os.path.join(tsne_dir, f)
        
        #####
        if 'ood' in file_path:
            continue
        
        # if 'train' not in file_path and 'context' not in file_path:
        #     continue
        #####
        
        data = np.load(file_path)
        
        embeddings = data['embeddings']
        labels = data['labels']

        print(f"✅ Loaded {f}: embeddings {embeddings.shape}, labels {labels.shape}")

        all_embeddings.append(embeddings)
        all_labels.append(labels)

    if len(all_embeddings) == 0 or len(all_labels) == 0:
        continue
    
    # Concatenate all
    all_embeddings = np.concatenate(all_embeddings, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    print(f"✅ Combined embeddings shape: {all_embeddings.shape}")

    # --- Run t-SNE ---
    # if gpu_available:
    #     embeddings_gpu = cp.asarray(all_embeddings)
    #     tsne = cuTSNE(n_components=2, random_state=42)
    #     embeddings_2d_gpu = tsne.fit_transform(embeddings_gpu)
    #     embeddings_2d = cp.asnumpy(embeddings_2d_gpu)
    # else:
    tsne = cpuTSNE(n_components=2, n_jobs=8, random_state=42)
    embeddings_2d = tsne.fit(all_embeddings)

    # --- Plot ---
    import matplotlib.patches as mpatches
    import matplotlib.pyplot as plt

    # Set global font
    plt.rcParams['font.family'] = 'DejaVu Sans'
    plt.rcParams['font.size'] = 10

    # --- t-SNE Scatter Plot with Nice Legend and Slight Grid ---

    plt.figure(figsize=(8, 6))

    # Define color mapping
    color_map = {0: '#ffb347', 1: '#0000CD', -1: '#DDA0DD'}  # orange and blue
    colors = [color_map[label] for label in all_labels]

    # Set point sizes
    sizes = [40 if label == 1 else 15 for label in all_labels]

    # Scatter plot
    scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=colors, s=sizes, alpha=0.8)

    # Slight soft grid
    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.3)

    # Hide axis labels but keep grid
    plt.gca().set_xticklabels([])
    plt.gca().set_yticklabels([])
    plt.xlabel("")
    plt.ylabel("")
    plt.box(False)

    # --- Add better legend ---
    import matplotlib.lines as mlines

    # Define custom legend handles (use Line2D for circles)
    orange_circle = mlines.Line2D([], [], color='#ffb347', marker='o', linestyle='None', markersize=8, label='ours')
    blue_circle = mlines.Line2D([], [], color='#0000CD', marker='o', linestyle='None', markersize=8, label='ours (w/o context)')
    red_circle = mlines.Line2D([], [], color='#DDA0DD', marker='o', linestyle='None', markersize=8, label='ours (context)')

    # Add legend inside plot (upper right)
    plt.legend(handles=[orange_circle, blue_circle, red_circle],
            loc='upper right',  # inside the plot, top right
            framealpha=0.6,
            prop={'size': 12},
            handletextpad=0.4,
            borderpad=0.5)

    # Tight layout
    plt.tight_layout()

    # Save figure
    plt.savefig(save_path, dpi=300)
    print(f"✅ Saved t-SNE scatter plot with nice legend to {save_path}")
    plt.close()


⚠️ cuML not available, falling back to CPU openTSNE
⏩ /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_dpp4_bit_100_5_tsne_combined.pdf already exists. Skipping...
⏩ /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_dpp4_bit_10_100_tsne_combined.pdf already exists. Skipping...
⏩ /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_dpp4_bit_10_40_tsne_combined.pdf already exists. Skipping...
⏩ /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_dpp4_bit_10_5_tsne_combined.pdf already exists. Skipping...
⏩ /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_dpp4_bit_40_10_tsne_combined.pdf already exists. Skipping...
⏩ /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_dpp4_bit_40_5_tsne_combined.pdf already exists. Skippin

In [85]:
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# Try to use GPU TSNE (cuML), fallback to CPU TSNE (openTSNE)
# try:
#     import cupy as cp
#     from cuml.manifold import TSNE as cuTSNE
#     gpu_available = True
#     print("✅ Using GPU cuML TSNE")
# except ImportError:
from openTSNE import TSNE as cpuTSNE
gpu_available = False
print("⚠️ cuML not available, falling back to CPU openTSNE")

# Set your directory
tsne_dir = '/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/'

# List all .npz files
files = sorted([f for f in os.listdir(tsne_dir) if f.endswith('.npz')])

# --- Group files by their starting prefix (before 3rd underscore) ---
groups = defaultdict(list)
for f in files:
    parts = f.split('_')
    
    prefix = '_'.join(parts[:5])  # e.g., dsets_dpp4_count
    
    if 'ood' not in f:
        groups[prefix].append(f)
    
    prefix_ = '_'.join(parts[:3])
    if prefix_ in prefix and 'ood' in f:
        groups[prefix].append(f)

# --- Process each group ---
for prefix, group_files in groups.items():
    save_path = os.path.join(tsne_dir, f"{prefix}_tsne_combined_all.pdf")
    
    # if "dpp4_bit" in save_path:
    #     continue
    # --- Skip if already exists ---
    if os.path.exists(save_path):
        print(f"⏩ {save_path} already exists. Skipping...")
        continue

    print(f"\n🚀 Processing group {prefix}")

    all_embeddings = []
    all_labels = []

    for f in group_files:
        file_path = os.path.join(tsne_dir, f)
        
        if len(group_files) != 5:
            print('group_files ', len(group_files))
            continue
        #####
        # if 'ood' in file_path:
        #     continue
        
        # if 'train' not in file_path and 'context' not in file_path:
        #     continue
        #####
        
        data = np.load(file_path)
        
        embeddings = data['embeddings']
        labels = data['labels']

        print(f"✅ Loaded {f}: embeddings {embeddings.shape}, labels {labels.shape}")

        all_embeddings.append(embeddings)
        all_labels.append(labels)

    if len(all_embeddings) == 0 or len(all_labels) == 0:
        continue
    
    # Concatenate all
    all_embeddings = np.concatenate(all_embeddings, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    print(f"✅ Combined embeddings shape: {all_embeddings.shape}")

    # --- Run t-SNE ---
    # if gpu_available:
    #     embeddings_gpu = cp.asarray(all_embeddings)
    #     tsne = cuTSNE(n_components=2, random_state=42)
    #     embeddings_2d_gpu = tsne.fit_transform(embeddings_gpu)
    #     embeddings_2d = cp.asnumpy(embeddings_2d_gpu)
    # else:
    tsne = cpuTSNE(n_components=2, n_jobs=8, random_state=42)
    embeddings_2d = tsne.fit(all_embeddings)

    # --- Plot ---
    import matplotlib.patches as mpatches
    import matplotlib.pyplot as plt

    # Set global font
    plt.rcParams['font.family'] = 'DejaVu Sans'
    plt.rcParams['font.size'] = 10

    # --- t-SNE Scatter Plot with Nice Legend and Slight Grid ---

    plt.figure(figsize=(8, 6))

    # Define color mapping
    color_map = {0: '#ffb347', 1: '#0000CD', -1: '#DDA0DD', 3:'#48b33c', 4:'#3cadb3'}  # orange and blue
    colors = [color_map[label] for label in all_labels]

    # Set point sizes
    sizes = [40 if label == 1 else 15 for label in all_labels]

    # Scatter plot
    scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=colors, s=sizes, alpha=0.8)

    # Slight soft grid
    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.3)

    # Hide axis labels but keep grid
    plt.gca().set_xticklabels([])
    plt.gca().set_yticklabels([])
    plt.xlabel("")
    plt.ylabel("")
    plt.box(False)

    # --- Add better legend ---
    import matplotlib.lines as mlines

    # Define custom legend handles (use Line2D for circles)
    orange_circle = mlines.Line2D([], [], color='#ffb347', marker='o', linestyle='None', markersize=8, label='ours')
    blue_circle = mlines.Line2D([], [], color='#0000CD', marker='o', linestyle='None', markersize=8, label='ours (w/o context)')
    red_circle = mlines.Line2D([], [], color='#DDA0DD', marker='o', linestyle='None', markersize=8, label='ours (context)')
    green_circle = mlines.Line2D([], [], color='#48b33c', marker='o', linestyle='None', markersize=8, label='OOD1')
    bluegreen_circle = mlines.Line2D([], [], color='#3cadb3', marker='o', linestyle='None', markersize=8, label='OOD2')

    # Add legend inside plot (upper right)
    plt.legend(handles=[orange_circle, blue_circle, red_circle, green_circle, bluegreen_circle],
            loc='upper right',  # inside the plot, top right
            framealpha=0.6,
            prop={'size': 12},
            handletextpad=0.4,
            borderpad=0.5)

    # Tight layout
    plt.tight_layout()

    # Save figure
    plt.savefig(save_path, dpi=300)
    print(f"✅ Saved t-SNE scatter plot with nice legend to {save_path}")
    plt.close()


In [None]:
ood1_baseline_path = '/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1_MIXUP_BILEVEL/lastlayer_ours_best_/strans_nk1_bit_500_10_ood1_none.npz'
ood1_ours_path = '/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_nk1_bit_500_10_ood1_none.npz'
ood2_baseline_path = '/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1_MIXUP_BILEVEL/lastlayer_ours_best_/strans_nk1_bit_500_10_ood2_none.npz'
ood2_ours_path = '/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_nk1_bit_500_10_ood2_none.npz'

compare_histogram(ood1_baseline_path, ood1_ours_path, 1)
compare_histogram(ood2_baseline_path, ood2_ours_path, 2)


### Final Figure

In [14]:
# import os
# import numpy as np
# import matplotlib.pyplot as plt
# from collections import defaultdict

# # Try to use GPU TSNE (cuML), fallback to CPU TSNE (openTSNE)
# # try:
# #     import cupy as cp
# #     from cuml.manifold import TSNE as cuTSNE
# #     gpu_available = True
# #     print("✅ Using GPU cuML TSNE")
# # except ImportError:
# from openTSNE import TSNE as cpuTSNE
# gpu_available = False
# print("⚠️ cuML not available, falling back to CPU openTSNE")

# # Set your directory
# tsne_dir = '/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/'

# # List all .npz files
# files = sorted([f for f in os.listdir(tsne_dir) if f.endswith('.npz')])

# # --- Group files by their starting prefix (before 3rd underscore) ---
# groups = defaultdict(list)
# for f in files:
#     parts = f.split('_')
    
#     prefix = '_'.join(parts[:5])  # e.g., dsets_dpp4_count
    
#     if 'dsets_nk1_bit_10_100' in prefix:
#         prefix_ = '_'.join(parts[:3])
        
#         if 'ood' not in f:
#             groups[prefix].append(f)
        
#         # print('prefix ', prefix)
#         # print('prefix_ ', prefix_)
#         # print('f')
#     # if prefix_ in prefix and 'ood' in f:
#     #     groups[prefix].append(f)
    
#     if 'dsets_nk1_bit_500_10_ood1_none.npz' in f:
#         groups['dsets_nk1_bit_10_100'].append(f)
#     if 'dsets_nk1_bit_500_10_ood2_none.npz' in f:
#         groups['dsets_nk1_bit_10_100'].append(f)

# # --- Process each group ---
# for prefix, group_files in groups.items():
#     save_path_pdf = os.path.join(tsne_dir, f"{prefix}_tsne_combined_all_final.pdf")
#     save_path_png = os.path.join(tsne_dir, f"{prefix}_tsne_combined_all_final.png")
    
#     # if "dpp4_bit" in save_path:
#     #     continue
#     # --- Skip if already exists ---
#     if os.path.exists(save_path_pdf):
#         print(f"⏩ {save_path_pdf} already exists. Skipping...")
#         continue
    
#     if os.path.exists(save_path_png):
#         print(f"⏩ {save_path_png} already exists. Skipping...")
#         continue

#     print(f"\n🚀 Processing group {prefix}")

#     all_embeddings = []
#     all_labels = []

#     for f in group_files:
#         file_path = os.path.join(tsne_dir, f)
        
#         if len(group_files) != 5:
#             print('group_files ', len(group_files))
#             continue
#         #####
#         # if 'ood' in file_path:
#         #     continue
        
#         # if 'train' not in file_path and 'context' not in file_path:
#         #     continue
#         #####
        
#         data = np.load(file_path)
        
#         embeddings = data['embeddings']
#         labels = data['labels']

#         print(f"✅ Loaded {f}: embeddings {embeddings.shape}, labels {labels.shape}")

#         all_embeddings.append(embeddings)
#         all_labels.append(labels)

#     if len(all_embeddings) == 0 or len(all_labels) == 0:
#         continue
    
#     # Concatenate all
#     all_embeddings = np.concatenate(all_embeddings, axis=0)
#     all_labels = np.concatenate(all_labels, axis=0)

#     print(f"✅ Combined embeddings shape: {all_embeddings.shape}")

#     # --- Run t-SNE ---
#     # if gpu_available:
#     #     embeddings_gpu = cp.asarray(all_embeddings)
#     #     tsne = cuTSNE(n_components=2, random_state=42)
#     #     embeddings_2d_gpu = tsne.fit_transform(embeddings_gpu)
#     #     embeddings_2d = cp.asnumpy(embeddings_2d_gpu)
#     # else:
#     tsne = cpuTSNE(n_components=2, n_jobs=8, random_state=42)
#     embeddings_2d = tsne.fit(all_embeddings)

#     # --- Plot ---
#     import matplotlib.patches as mpatches
#     import matplotlib.pyplot as plt

#     # Set global font
#     plt.rcParams['font.family'] = 'DejaVu Sans'
#     plt.rcParams['font.size'] = 10

#     # --- t-SNE Scatter Plot with Nice Legend and Slight Grid ---

#     plt.figure(figsize=(8, 6))

#     # Define color mapping
#     # color_map = {0: '#ffb347', 1: '#0000CD', -1: '#DDA0DD', 3:'#48b33c', 4:'#3cadb3'}  # orange and blue
#     color_map = {0: '#FFA500', 1: '#2E8B57', -1: '#9B30FF', 3:'#D62728', 4:'#808080'}  # orange and blue
#     colors = [color_map[label] for label in all_labels]

#     # Set point sizes
#     # sizes = [40 if label == 1 else 15 for label in all_labels]
#     sizes = []
#     for label in all_labels:
#         if label == 1:         # mixup (w/o context)
#             sizes.append(60)   # larger to emphasize
#         elif label == -1:      # mixup (context)
#             sizes.append(30)
#         elif label == 0:       # mixup
#             sizes.append(30)
#         elif label in [3, 4]:  # OOD
#             sizes.append(40)   # fairly large to stand out
#         else:
#             sizes.append(15)   # default fallback


#     # Scatter plot
#     scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=colors, s=sizes, alpha=0.8)

#     # Slight soft grid
#     plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.3)

#     # Hide axis labels but keep grid
#     plt.gca().set_xticklabels([])
#     plt.gca().set_yticklabels([])
#     plt.xlabel("")
#     plt.ylabel("")
#     plt.box(False)

#     # --- Add better legend ---
#     import matplotlib.lines as mlines

#     # Define custom legend handles (use Line2D for circles)
#     orange_circle = mlines.Line2D([], [], color='#FFA500', marker='o', linestyle='None', markersize=8, label='ours')
#     blue_circle = mlines.Line2D([], [], color='#2E8B57', marker='o', linestyle='None', markersize=8, label='ours (w/o context)')
#     red_circle = mlines.Line2D([], [], color='#9B30FF', marker='o', linestyle='None', markersize=8, label='ours (context)')
#     green_circle = mlines.Line2D([], [], color='#D62728', marker='o', linestyle='None', markersize=8, label='OOD1')
#     bluegreen_circle = mlines.Line2D([], [], color='#808080', marker='o', linestyle='None', markersize=8, label='OOD2')

#     # Add legend inside plot (upper right)
#     plt.legend(handles=[orange_circle, blue_circle, red_circle, green_circle, bluegreen_circle],
#             loc='upper right',  # inside the plot, top right
#             framealpha=0.6,
#             prop={'size': 12},
#             handletextpad=0.4,
#             borderpad=0.5)

#     # Tight layout
#     plt.tight_layout()

#     # Save figure
#     plt.savefig(save_path_pdf, dpi=300)
#     print(f"✅ Saved t-SNE scatter plot with nice legend to {save_path_pdf}")
#     plt.savefig(save_path_png, dpi=300)
#     print(f"✅ Saved t-SNE scatter plot with nice legend to {save_path_png}")
#     plt.close()


⚠️ cuML not available, falling back to CPU openTSNE

🚀 Processing group dsets_nk1_bit_10_100
✅ Loaded dsets_nk1_bit_10_100_context_none.npz: embeddings (1000, 64), labels (1000,)
✅ Loaded dsets_nk1_bit_10_100_train_context.npz: embeddings (1000, 64), labels (1000,)
✅ Loaded dsets_nk1_bit_10_100_train_none.npz: embeddings (10, 64), labels (10,)
✅ Loaded dsets_nk1_bit_500_10_ood1_none.npz: embeddings (500, 64), labels (500,)
✅ Loaded dsets_nk1_bit_500_10_ood2_none.npz: embeddings (500, 64), labels (500,)
✅ Combined embeddings shape: (3010, 64)
✅ Saved t-SNE scatter plot with nice legend to /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_nk1_bit_10_100_tsne_combined_all_final.pdf
✅ Saved t-SNE scatter plot with nice legend to /c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/dsets_nk1_bit_10_100_tsne_combined_all_final.png


In [13]:
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.manifold import TSNE

# Set your directory
tsne_dir = '/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL2_mNctFalse_RYV1/lastlayer_ours_best_/'

# List all .npz files
files = sorted([f for f in os.listdir(tsne_dir) if f.endswith('.npz')])

# --- Group files by prefix ---
groups = defaultdict(list)
for f in files:
    if 'dsets_nk1_bit_10_100' in f or 'dsets_nk1_bit_500_10_ood1_none.npz' in f or 'dsets_nk1_bit_500_10_ood2_none.npz' in f:
        groups['dsets_nk1_bit_10_100'].append(f)

# --- Function to save plot with label exclusions ---
def plot_tsne(embeddings_2d, all_labels, excluded_labels, save_prefix):
    filtered_indices = [i for i, label in enumerate(all_labels) if label not in excluded_labels]
    filtered_embeddings = embeddings_2d[filtered_indices]
    filtered_labels = all_labels[filtered_indices]

    color_map = {0: '#ffb347', 1: '#0000CD', -1: '#228B22', 3: '#8B008B', 4: '#808080'}
    sizes = []
    colors = []
    for label in filtered_labels:
        sizes.append(70 if label == 1 else 18 if label in [3, 4] else 13)
        colors.append(color_map[label])

    plt.figure(figsize=(8, 6))
    plt.scatter(filtered_embeddings[:, 0], filtered_embeddings[:, 1], c=colors, s=sizes, alpha=0.8)
    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.3)
    plt.xticks([])
    plt.yticks([])
    plt.box(False)

    import matplotlib.lines as mlines
    legend_handles = [
        mlines.Line2D([], [], color='#ffb347', marker='o', linestyle='None', markersize=8, label='ours'),
        mlines.Line2D([], [], color='#0000CD', marker='o', linestyle='None', markersize=8, label='ours (w/o context)'),
        mlines.Line2D([], [], color='#228B22', marker='o', linestyle='None', markersize=8, label='ours (context)')
    ]
    if 3 not in excluded_labels:
        legend_handles.append(mlines.Line2D([], [], color='#8B008B', marker='o', linestyle='None', markersize=8, label='OOD'))
    if 4 not in excluded_labels:
        legend_handles.append(mlines.Line2D([], [], color='#808080', marker='o', linestyle='None', markersize=8, label='OOD'))

    plt.legend(handles=legend_handles, loc='upper right', framealpha=0.6, prop={'size': 12})
    plt.tight_layout()
    plt.savefig(f"{save_prefix}.pdf", dpi=300)
    plt.savefig(f"{save_prefix}.png", dpi=300)
    plt.close()

# --- Process each group ---
for prefix, group_files in groups.items():
    print(f"\n🚀 Processing group {prefix}")
    all_embeddings, all_labels = [], []

    for f in group_files:
        file_path = os.path.join(tsne_dir, f)
        data = np.load(file_path)
        all_embeddings.append(data['embeddings'])
        all_labels.append(data['labels'])
        print(f"✅ Loaded {f}: embeddings {data['embeddings'].shape}, labels {data['labels'].shape}")

    if not all_embeddings:
        continue

    all_embeddings = np.concatenate(all_embeddings, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    print(f"✅ Combined embeddings shape: {all_embeddings.shape}")

    # Run t-SNE using scikit-learn
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(all_embeddings)

    # Save three filtered plots
    base_path = os.path.join(tsne_dir, f"{prefix}_tsne")
    plot_tsne(embeddings_2d, all_labels, excluded_labels=[3, 4], save_prefix=f"{base_path}_no_ood")
    plot_tsne(embeddings_2d, all_labels, excluded_labels=[4], save_prefix=f"{base_path}_no_ood2")
    plot_tsne(embeddings_2d, all_labels, excluded_labels=[3], save_prefix=f"{base_path}_no_ood1")
    print(f"✅ Saved all t-SNE variations for {prefix}")



🚀 Processing group dsets_nk1_bit_10_100
✅ Loaded dsets_nk1_bit_10_100_context_none.npz: embeddings (1000, 64), labels (1000,)
✅ Loaded dsets_nk1_bit_10_100_train_context.npz: embeddings (1000, 64), labels (1000,)
✅ Loaded dsets_nk1_bit_10_100_train_none.npz: embeddings (10, 64), labels (10,)
✅ Loaded dsets_nk1_bit_500_10_ood1_none.npz: embeddings (500, 64), labels (500,)
✅ Loaded dsets_nk1_bit_500_10_ood2_none.npz: embeddings (500, 64), labels (500,)
✅ Combined embeddings shape: (3010, 64)
✅ Saved all t-SNE variations for dsets_nk1_bit_10_100
