In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import time
import random
import argparse
from scipy.sparse import issparse
from collections import Counter
import torch
from torch.utils.data import DataLoader

from merge_Dataset import PretrainDataset
from main_tcr_train import *
from utils import *




import pickle
import umap
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score


In [None]:
# Set display options
pd.set_option('display.max_columns', None)

# --------------------------- Parse config --------------------------- #
argparser = argparse.ArgumentParser()
argparser.add_argument('--config', type=str, default='./config.yaml')
args = argparser.parse_args()

config = load_config(args.config)

# Set random seed
set_random_seed(config['Train']['Trainer_parameter']['random_seed'])
# device = config['Train']['Model_Parameter']['device']
device = torch.device("cuda:1")


In [None]:
# --------------------------- Load and preprocess data --------------------------- #
adata = sc.read_h5ad("../Data/pretrain_raw_data/merged_8datasets_after_combat_hvg_celltypist.h5ad")
# merged_8datasets_after_combat_hvg_celltypist.h5ad # merged_8datasets_celltype_subset.h5ad
if issparse(adata.X):
    adata.X = np.array(adata.X.toarray(), dtype=np.float32)
else:
    adata.X = np.array(adata.X, dtype=np.float32)
dataset = PretrainDataset(adata)
smile_seqs, vocab_dict = dataset.get_smile_seqs()

# print(dataset.type.value_counts())
print("dataset over")


In [None]:
# --------------------------- Initialize and optionally load pretrained model --------------------------- #
model = init_model(config, device)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    model.optimizer,
    mode='min',
    factor=0.5,
    patience=10
)
# === DataLoader ===
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config['Train']['Sampling']['batch_size'],
    shuffle=config['Train']['Sampling']['sample_shuffle']
)
# === Logger === #
output_dir = config['Train']['output_dir']
logger = get_logger(f"{config['Train']['output_dir']}/training_8_datasets_AE.log")
# === Training Params ===
epochs = config['Train']['Trainer_parameter']['epoch']
patience = config['Train']['Trainer_parameter']['patience']
best_loss = float('inf')

# === Metrics Storage ===
all_loss = []
number_of_GEO = len(Counter(dataset.GEOlabel))
number_of_beta = len(Counter(dataset.cdr3))
number_of_type = len(Counter(dataset.type))


In [None]:
for epoch in range(1, epochs + 1):
   total_contrastive, total_recon, total_fusion, epoch_loss = 0,0,0,0
   start_time = time.time()

   avg_loss, avg_components, all_embeddings, all_rna_embeddings, all_tcr_embeddings, all_GEO_labels, all_beta_labels, all_type_labels = train_one_epoch(model, dataloader, device)
   all_loss.append(avg_loss)
   logger.info(
       f"Epoch [{epoch}/{epochs}] | Loss : {avg_loss:.5f} | "
       f"Contrastive: {avg_components['contrastive']:.5f} | "
       f"Recon: {avg_components['recon']:.5f} | "
       f"Fusion: {avg_components['fusion']:.5f} | "
       f"Time: {time.time() - start_time:.2f}s"
   )

   ######## Calculzate metrics every 5 epochs ###########
   if epoch > 9 and epoch % 5 == 0:
       emb = np.vstack(all_embeddings)
       enc = LabelEncoder()
       nmi_geo = normalized_mutual_info_score(enc.fit_transform(all_GEO_labels), KMeans(n_clusters=number_of_GEO).fit(emb).labels_)
       nmi_beta = normalized_mutual_info_score(enc.fit_transform(all_beta_labels), KMeans(n_clusters=number_of_beta).fit(emb).labels_)
       nmi_cell = normalized_mutual_info_score(enc.fit_transform(all_type_labels), KMeans(n_clusters=number_of_type).fit(emb).labels_)

       print(f"\n[Epoch {epoch}] NMI Scores:")
       print(f"  GEO      NMI: {nmi_geo:.4f}")
       print(f"  beta   NMI: {nmi_beta:.4f}")
       print(f"  type NMI: {nmi_cell:.4f}")
   ######## save best model ###########
   if avg_loss < best_loss:
       best_loss = avg_loss
       early_stop_counter = 0
       best_epoch_model = model.state_dict()
       save_best_model(model, epoch, output_dir, stage="pretrain_after_batchremove")
   else:
       early_stop_counter += 1

   if early_stop_counter >= patience:
       logger.info(f"Early stopping at epoch {epoch}. No improvement for {patience} epochs.")
       break

print("train over")


In [None]:
# 聚类可视化
final_embeddings_np = np.vstack(all_embeddings)
final_GEO_cluster_labels = np.array(all_GEO_labels)
final_beta_cluster_labels = np.array(all_beta_labels)
final_type_cluster_labels = np.array(all_type_labels)
# save
import pickle
with open("embedding_and_labels.pkl", "wb") as f:
    pickle.dump({
        "embeddings": final_embeddings_np,
        "GEO_labels": final_GEO_cluster_labels,
        "beta_labels": final_beta_cluster_labels,
        "type_labels": final_type_cluster_labels,
        "all_rna_embeddings": all_rna_embeddings,
        "all_tcr_embeddings": all_tcr_embeddings,
    }, f)


In [None]:

import matplotlib.pyplot as plt
import numpy as np
import umap
from matplotlib.lines import Line2D
# 使用 UMAP 降维
reducer_all = umap.UMAP(
    n_components=2,
    n_neighbors=100,
    min_dist=0.1,
    metric='cosine',
    n_jobs=-1
)
reduced_embeddings_umap = reducer_all.fit_transform(final_embeddings_np)
print("umap all embedding")

# rna
final_rna_embeddings_np = np.vstack(all_rna_embeddings)
reducer_rna = umap.UMAP(
    n_components=2,
    n_neighbors=100,
    min_dist=0.1,
    metric='cosine',
    n_jobs=-1
)
rna_reduced_embeddings_umap = reducer_rna.fit_transform(final_rna_embeddings_np)
print("rna_embeddings_umap over")
# tcr
final_tcr_embeddings_np = np.vstack(all_tcr_embeddings)
reducer_tcr = umap.UMAP(
    n_components=2,
    n_neighbors=100,
    min_dist=0.1,
    metric='cosine',
    n_jobs=-1
)
tcr_reduced_embeddings_umap = reducer_tcr.fit_transform(final_tcr_embeddings_np)
print("tcr_embeddings_umap over")



draw_reduced_embeddings_umap = reduced_embeddings_umap
visiual_label = final_type_cluster_labels # visiual_label = np.array(adata.obs['cdr3'])

unique_labels = np.unique(visiual_label)
label_counts = Counter(visiual_label)
top_labels = [label for label, _ in label_counts.most_common(5)]
top_colors = [plt.cm.tab10(i) for i in range(len(top_labels))]
label2color = {label: top_colors[i] for i, label in enumerate(top_labels)}
plt.figure(figsize=(8, 6))
for label in unique_labels:
    mask = visiual_label == label
    color = label2color.get(label, 'lightgray')  # 非top类默认灰色
    plt.scatter(
        draw_reduced_embeddings_umap[mask, 0],
        draw_reduced_embeddings_umap[mask, 1],
        color=color,
        s=0.1,
        label=str(label) if label in top_labels else None  # 避免重复图例
    )
legend_handles = [
    Line2D([0], [0], marker='o', color='w',
           markerfacecolor=label2color[label],
           label=label, markersize=4)
    for label in top_labels
]
plt.legend(
    handles=legend_handles,
    title="Dataset Labels",
    loc='center left',
    bbox_to_anchor=(1.02, 0.5),
    markerscale=1,
    fontsize=10,
    title_fontsize=12,
    borderaxespad=0,
    frameon=False
)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# plt.title(f'UMAP Visualization of top 5 beta {nmi_cell}', fontsize=14)
plt.xlabel('UMAP_1', fontsize=12)
plt.ylabel('UMAP_2', fontsize=12)
plt.gca().set_aspect('equal', adjustable='box')
plt.tight_layout()
plt.show()
plt.savefig("umap_final_beta_cluster_labels.png", dpi=300, bbox_inches='tight')
