In [18]:
import torch.nn.functional as F
import torch
from main_merck_all_real import get_dataset
from utils import set_seed, get_optimizer, InfIterator
from arguments import get_arguments
from main_origin import get_model
from setenc import get_mixer

import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
    

In [19]:
def calc_loss(y_hat, y, test=False):
    return F.mse_loss(y.cuda().squeeze(), y_hat.cuda().squeeze())

def test(args, dataloader, contextloader=None, model=None, mixer_phi=None, embed_type=None, n_t=10, n_c=5):
    model.eval()
    mixer_phi.eval()
    embedding_list = []
    label_list= []
    loss_list = []
    # print('model ', model)
    # print('mixer_phi ', mixer_phi)
    if embed_type == "train_none":
        with torch.no_grad():
            losses = []
            counts = 0
            for i, (x, y) in enumerate(dataloader):
                if i == n_t:
                    break
                y_hat, embedding_list, label_list = model(x=x.to(args.device), mixer_phi=mixer_phi, embedding_list=embedding_list, label_list=label_list, embed_type=embed_type, embed_test=args.embed_test)

                y = y.cuda().squeeze()
                y_hat = y_hat.cuda().squeeze()

                # y_hat = y_hat[:, 0]
                # print(f"in test: {y.size()=} {y_hat.size()=}")

                loss = calc_loss(y_hat, y, test=True)
                loss_scalar = loss.detach().item()
                loss_list.append(torch.full((x.shape[0],), loss_scalar))

                # print('loss_scalar ', loss_scalar)
                # print('x ', x.size(0))
                # print('y_hat ', y_hat)
                # print('embedding_list ', embedding_list)
                losses.append(loss_scalar * x.size(0))
                counts += x.size(0)
                # if i == 0:
                #     torch.save(x, 'tn_x.pt')
                #     torch.save(y, 'tn_y.pt')
        # self.model.eval()
    elif embed_type == "train_context":
        with torch.no_grad():
            losses = []
            counts = 0
            for i, (x, y) in enumerate(dataloader):
                if i == n_t:
                    break
                
                context_samples = []
                
                # if i == 0:
                #     torch.save(x, 'tc_x.pt')
                #     torch.save(y, 'tc_y.pt')

                for i_c, (x_c, y_c) in enumerate(contextloader):
                    if i_c == n_c:
                        break
                    x_c = x_c.reshape(args.batch_size, -1, x_c.size(-1))
                    if args.n_context > 1:
                        n = torch.randint(1, x_c.size(1), size=(1,)).item()
                        x_c = x_c[:, :n]
                    # print(x_c.shape)
                    # context_samples.append(x_c)
                
                # context_samples = torch.cat(context_samples, dim=0).to(args.device)
                # context_samples = context_samples.reshape(args.batch_size, -1, x_c.size(-1))
                
                # print('### context_samples ', context_samples)
                # torch.save(context_samples, f'context_samples_{i}.pth')
                    y_hat, embedding_list, label_list = model(x=x.to(args.device), context=x_c.to(args.device), mixer_phi=mixer_phi, embedding_list=embedding_list, label_list=label_list, embed_type=embed_type, embed_test=args.embed_test)

                    y = y.cuda().squeeze()
                    y_hat = y_hat.cuda().squeeze()

                    # y_hat = y_hat[:, 0]
                    # print(f"in test: {y.size()=} {y_hat.size()=}")

                    loss = calc_loss(y_hat, y, test=True)
                    loss_list.append(torch.full((x.shape[0],), loss.detach().item()))

                    losses.append(loss.item() * x.size(0))
                    counts += x.size(0)
    mse = sum(losses) / counts
    return mse, embedding_list, label_list, loss_list

In [20]:

def save_one_datapoint_features(args, model, mixer_phi, embed_type, n_t, n_c):
    set_seed(0)
    args.batch_size = 1
    args.tsne_plot = True
    args.embed_test = '2nd_last_ours_best'
    args.specify_ood_dataset = ['dpp4', 'nk1'] # TODO change
    trainloader_test, _, mvalidloader_test, _, contextloader_test, ood1_trainloader_test, ood2_trainloader_test = get_dataset(args=args, test=True)
    mse, embedding_list, label_list, loss_list = test(args=args, dataloader=trainloader_test, contextloader=contextloader_test, model=model, mixer_phi=mixer_phi, embed_type=embed_type, n_t=n_t, n_c=n_c)

    all_embeddings = torch.cat(embedding_list, dim=0)
    all_labels = np.concatenate(label_list, axis=0)
    all_losses = torch.cat(loss_list, dim=0)

    all_labels = torch.tensor(all_labels)
    all_losses = torch.tensor(all_losses)
                
    all_embeddings_np = all_embeddings.numpy()
    all_labels_np = all_labels.numpy()
    all_losses_np = all_losses.numpy()
    
    path = f"/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL/{args.embed_test}_"
    os.makedirs(path, exist_ok=True)
    f_path = f'/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL/{args.embed_test}_/{args.sencoder}_{args.dataset}_{args.vec_type}_{n_t}_{n_c}_{embed_type}.npz'
    np.savez(f_path, embeddings=all_embeddings_np, labels=all_labels_np, losses=all_losses_np)
    print(f'>>> saved {f_path}')
    
    trainloader_test._iterator._shutdown_workers()
    if 'context' in embed_type:
        contextloader_test._iterator._shutdown_workers()
    # ood1_trainloader_test._iterator._shutdown_workers()
    # ood2_trainloader_test._iterator._shutdown_workers()


In [21]:
def save_features(data):

    model = data['model']
    mixer_phi = data['mixer_phi']
    optimizer = data['optimizer']
    mixer_optimizer = data['mixer_optimizer']
    args_ = data['args']

    args = get_arguments()

    for k, v in args_.items():
        setattr(args, k, v)

    # print(e'n_context ', args.n_context )
    # exit()
    set_seed(0)

    model = get_model(args=args)
    mixer_phi = get_mixer(args=args)
    # optimizer = get_optimizer(optimizer=args.optimizer, model=model, lr=args.lr, wd=args.wd)
    # optimizermixer = None if mixer_phi is None else get_optimizer(optimizer=args.optimizer, model=mixer_phi, lr=args.clr, wd=args.cwd)

    model.load_state_dict(data['model'])
    mixer_phi.load_state_dict(data['mixer_phi'])
    # optimizer.load_state_dict(data['optimizer'])
    # mixer_optimizer.load_state_dict(data['mixer_optimizer'])

    model = model.to(args.device)
    mixer_phi = mixer_phi.to(args.device)
    
    for (n_t, n_c) in [(5, 10), (5, 100), (10, 5), (10, 100), (100, 5), (100, 10)]:
        save_one_datapoint_features(args, model, mixer_phi, "train_none", n_t, n_c)

        save_one_datapoint_features(args, model, mixer_phi, "train_context", n_t, n_c)

In [22]:
import os
import torch

# Directory containing .pth files
tsne_model_dir = '/c2/jinakim/Drug_Discovery_j/tsne_model/2nd_last_ours_best'
os.makedirs(tsne_model_dir, exist_ok=True)
# List all .pth files
pth_files = sorted([f for f in os.listdir(tsne_model_dir) if f.endswith('.pth')])

# Load each file
for i, f in enumerate(pth_files):
    if i <= 1:
        continue
    file_path = os.path.join(tsne_model_dir, f)
    print(f"🚀 Loading {file_path}")
    
    data = torch.load(file_path)
    save_features(data)
    # Now 'data' contains the loaded model or state dict or whatever was saved
    # You can process it here if needed
    # For example, just printing some keys if it's a checkpoint
    # if isinstance(data, dict):
    #     print(f"✅ Loaded {f}: keys = {list(data.keys())}")
    # else:
    #     print(f"✅ Loaded {f}: type = {type(data)}")

print("\n🏁 All models loaded.")





🏁 All models loaded.


In [23]:
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# Try to use GPU TSNE (cuML), fallback to CPU TSNE (openTSNE)
# try:
#     import cupy as cp
#     from cuml.manifold import TSNE as cuTSNE
#     gpu_available = True
#     print("✅ Using GPU cuML TSNE")
# except ImportError:
from openTSNE import TSNE as cpuTSNE
gpu_available = False
print("⚠️ cuML not available, falling back to CPU openTSNE")

# Set your directory
tsne_dir = '/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL/2nd_last_ours_best_/'

# List all .npz files
files = sorted([f for f in os.listdir(tsne_dir) if f.endswith('.npz')])

# --- Group files by their starting prefix (before 3rd underscore) ---
groups = defaultdict(list)
for f in files:
    parts = f.split('_')
    prefix = '_'.join(parts[:5])  # e.g., dsets_dpp4_count
    groups[prefix].append(f)

# --- Process each group ---
for prefix, group_files in groups.items():
    save_path = os.path.join(tsne_dir, f"{prefix}_tsne_combined.pdf")
    
    # if "dpp4_bit" in save_path:
    #     continue
    # --- Skip if already exists ---
    if os.path.exists(save_path):
        print(f"⏩ {save_path} already exists. Skipping...")
        continue

    print(f"\n🚀 Processing group {prefix}")

    all_embeddings = []
    all_labels = []

    for f in group_files:
        file_path = os.path.join(tsne_dir, f)
        data = np.load(file_path)
        embeddings = data['embeddings']
        labels = data['labels']

        print(f"✅ Loaded {f}: embeddings {embeddings.shape}, labels {labels.shape}")

        all_embeddings.append(embeddings)
        all_labels.append(labels)

    # Concatenate all
    all_embeddings = np.concatenate(all_embeddings, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    print(f"✅ Combined embeddings shape: {all_embeddings.shape}")

    # --- Run t-SNE ---
    # if gpu_available:
    #     embeddings_gpu = cp.asarray(all_embeddings)
    #     tsne = cuTSNE(n_components=2, random_state=42)
    #     embeddings_2d_gpu = tsne.fit_transform(embeddings_gpu)
    #     embeddings_2d = cp.asnumpy(embeddings_2d_gpu)
    # else:
    tsne = cpuTSNE(n_components=2, n_jobs=8, random_state=42)
    embeddings_2d = tsne.fit(all_embeddings)

    # --- Plot ---
    import matplotlib.patches as mpatches
    import matplotlib.pyplot as plt

    # Set global font
    plt.rcParams['font.family'] = 'DejaVu Sans'
    plt.rcParams['font.size'] = 10

    # --- t-SNE Scatter Plot with Nice Legend and Slight Grid ---

    plt.figure(figsize=(8, 6))

    # Define color mapping
    color_map = {0: '#ffb347', 1: '#0000CD'}  # orange and blue
    colors = [color_map[label] for label in all_labels]

    # Set point sizes
    sizes = [15 if label == 0 else 40 for label in all_labels]

    # Scatter plot
    scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=colors, s=sizes, alpha=0.8)

    # Slight soft grid
    plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.3)

    # Hide axis labels but keep grid
    plt.gca().set_xticklabels([])
    plt.gca().set_yticklabels([])
    plt.xlabel("")
    plt.ylabel("")
    plt.box(False)

    # --- Add better legend ---
    import matplotlib.lines as mlines

    # Define custom legend handles (use Line2D for circles)
    orange_circle = mlines.Line2D([], [], color='#ffb347', marker='o', linestyle='None', markersize=8, label='ours')
    blue_circle = mlines.Line2D([], [], color='#0000CD', marker='o', linestyle='None', markersize=8, label='ours (w/o context)')

    # Add legend inside plot (upper right)
    plt.legend(handles=[orange_circle, blue_circle],
            loc='upper right',  # inside the plot, top right
            framealpha=0.6,
            prop={'size': 12},
            handletextpad=0.4,
            borderpad=0.5)

    # Tight layout
    plt.tight_layout()

    # Save figure
    plt.savefig(save_path, dpi=300)
    print(f"✅ Saved t-SNE scatter plot with nice legend to {save_path}")
    plt.close()


⚠️ cuML not available, falling back to CPU openTSNE


FileNotFoundError: [Errno 2] No such file or directory: '/c2/jinakim/Drug_Discovery_j/analysis/tsne_last_REAL/2nd_last_ours_best_/'