In [1]:
import sys
import os 
import torch
import torch_cluster

from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import argparse
import json
import random
from utils import *
from dataset import *
from sc_model import *
from datetime import datetime
from tqdm import tqdm
import pandas as pd
from termcolor import colored
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import roc_auc_score, accuracy_score, precision_recall_curve, auc


2024-04-30 01:40:51.655640: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
base_path = f'data_su2020/'
target_dir = f'{base_path}/VAE/'
data_dir = f'{base_path}/data/'

# dir_path="covid"
# base_path = f"../ProtoCell4P-main/data/{dir_path}/"
# target_dir = f'{base_path}/VAE/'
# dir create
# os.makedirs(target_dir, exist_ok=False)


In [3]:

def load_dataset_and_preprocessors(base_path, exp, device):
    train_dataset = torch.load(f"{base_path}/train_dataset_exp{exp}_HVG_count_noflt_only2.pt", map_location= device)
    val_dataset = torch.load(f"{base_path}/val_dataset_exp{exp}_HVG_count_noflt_only2.pt", map_location= device)
    test_dataset = torch.load(f"{base_path}/test_dataset_exp{exp}_HVG_count_noflt_only2.pt",map_location = device)
    
    with open(f"{base_path}/label_encoder_exp{exp}_HVG_count_noflt_only2.pkl", 'rb') as f:
        label_encoder = pickle.load(f)
    with open(f"{base_path}/scaler_exp{exp}_HVG_count_noflt_only2.pkl", 'rb') as f:
        scaler = pickle.load(f)

    return train_dataset, val_dataset, test_dataset, label_encoder, scaler



In [4]:
device_num = 4
device = torch.device(f'cuda:{device_num}' if torch.cuda.is_available() else 'cpu')
print("INFO: Using device: {}".format(device))


INFO: Using device: cuda:4


In [5]:
class InstanceDataset2(Dataset):
    '''
    인스턴스 단위로 데이터를 반환하는 Dataset 클래스.
    MilDataset과 유사하지만, 각 인스턴스에 대한 데이터와 레이블을 반환합니다.

    Args:
        data (Tensor): 특성 데이터
        ids (Tensor): 각 인스턴스에 대응하는 백의 ID
        labels (Tensor): 각 백에 대한 레이블
        instance_labels (Tensor): 각 인스턴스에 대한 레이블
        normalize (bool): 데이터 정규화 여부
    '''
    def __init__(self, data, ids, labels, instance_labels, bag_labels):
        self.data = data
        self.labels = labels
        self.ids = ids
        self.mil_ids = ids.clone()
        self.instance_labels = instance_labels
        self.bag_labels = bag_labels
        if (len(self.mil_ids.shape) == 1):
            self.mil_ids.resize_(1, len(self.mil_ids))
        self.bags = torch.unique(self.mil_ids[0])
    def __len__(self):
        return self.data.size(0)
    def __getitem__(self, index):
        # 각 인스턴스에 대한 데이터와 레이블을 반환
        data = self.data[index]
        bag_id = self.ids[index] 
        
        instance_label = self.instance_labels[index]
        bag_label = self.bag_labels[index]
        return data, bag_id, instance_label, bag_label
    

def update_instance_labels_with_bag_labels(instance_dataset):
    """
    Updates the instance labels in the InstanceDataset with the corresponding bag labels.
    
    Args:
    instance_dataset (InstanceDataset): The dataset whose instance labels are to be updated.
    
    Note: This function modifies the instance_dataset in-place.
    """
    combined_labels = torch.empty(len(instance_dataset), dtype=torch.long, device=device)
    for i in range(len(instance_dataset)):
        _, bag_id, instance_label = instance_dataset[i]
        bag_index = (instance_dataset.bags == bag_id).nonzero(as_tuple=True)[0][0]

        bag_label = instance_dataset.labels[bag_index]

        combined_label = instance_label * 0 + bag_label
        combined_labels[i] = combined_label 
        
    return InstanceDataset2(instance_dataset.data, instance_dataset.ids, instance_dataset.labels, instance_dataset.instance_labels, combined_labels)

# Optimizer

In [6]:

class AttentionModule(nn.Module):
    def __init__(self, L, D, K):
        super(AttentionModule, self).__init__()
        self.L = L
        self.D = D
        self.K = K

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )


    def forward(self, H):
        A = self.attention(H)  # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        # A = F.softmax(A, dim=1)  # softmax over N
        return A
class GatedAttentionModule(nn.Module):
    def __init__(self, L, D, K):
        super(GatedAttentionModule, self).__init__()
        self.L = L
        self.D = D
        self.K = K

        self.attention_V = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh()
        )

        self.attention_U = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Sigmoid()
        )

        self.attention_weights = nn.Linear(self.D, self.K)

    def forward(self, H):
        A_V = self.attention_V(H)  # NxD
        A_U = self.attention_U(H)  # NxD
        A = self.attention_weights(A_V * A_U)  # element wise multiplication # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        # A = F.softmax(A, dim=1)  # softmax over N
        return A


class TeacherBranch(nn.Module):
  def __init__(self, input_dims, latent_dims, attention_module, 
               num_classes=2, 
               activation_function=nn.Sigmoid, 
               dropout_rate=0.1):
    super().__init__()
    self.input_dims = input_dims
    self.L = latent_dims
    self.K = 1
    self.D = latent_dims
    self.attention_module = attention_module
    self.num_classes = num_classes
    
    self.bagNN = nn.Sequential(
        nn.Linear(self.input_dims, self.L),
        
        activation_function(),
        nn.Linear(self.L, self.L),
        activation_function(),
        # mode 1
        # nn.Linear(self.L, self.L//4),
        # activation_function(),
        
        # nn.Linear(self.L//4, self.L//4),
        # activation_function(),
        
        # nn.Linear(self.L//4, self.num_classes ),
        
        # mode 2
        nn.Linear(self.L, self.num_classes ),
    )
    self.initialize_weights()
      
  def forward(self, input, replaceAS=None):  
    if replaceAS is not None:
      attention_weights = F.softmax(replaceAS,dim=1)
    else:
      attention_weights = self.attention_module(input)
      attention_weights = F.softmax(attention_weights,dim=1)
    
    aggregated_instance = torch.mm(attention_weights, input)
    output = aggregated_instance.squeeze()
    output = self.bagNN(output)
    return output
  
  def initialize_weights(self):
      for m in self.modules():
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                nn.init.zeros_(m.bias.data)

class StudentBranch(nn.Module):
  def __init__(self, input_dims, latent_dims, 
               num_classes=2, 
               activation_function=nn.ReLU):
    super().__init__()
    self.input_dims = input_dims
    self.L = latent_dims
    self.K = 1
    self.D = latent_dims
    self.num_classes = num_classes 
    
    self.instanceNN = nn.Sequential(
        nn.Linear(self.input_dims, self.L),
        activation_function(),
        # mode 1
        # nn.Linear(self.L, self.L//4),
        # activation_function(),
        
        # nn.Linear(self.L//4, self.L//4),
        # activation_function(),
        
        # nn.Linear(self.L//4, self.num_classes ),
        # mode 2
        nn.Linear(self.L, self.L),
        activation_function(),
        nn.Linear(self.L, self.num_classes )
      )
    self.initialize_weights()
  def forward(self, input):  
    NN_out = input
    output = self.instanceNN(NN_out)
    
    return output #, norm_attention_score
  
  def initialize_weights(self):
      for m in self.modules():
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                nn.init.zeros_(m.bias.data)

class EncoderBranch(nn.Module):
  def __init__(self, proto_vae, output_dims, activation_function = nn.Tanh):
    super().__init__()
    self.proto_vae = proto_vae
    self.activation_function = activation_function
    self.output_dims = output_dims
    self.encoder_layer = nn.Sequential(
      nn.Linear(self.proto_vae.latent_dim, self.output_dims),
      activation_function(),
      nn.Linear(self.output_dims, self.output_dims),
      activation_function(),
      nn.Linear(self.output_dims, self.output_dims)
    )
    self.initialize_weights()
  def forward(self, input):
    with torch.no_grad():
      vae_latent = self.proto_vae.features(input)
      mu = vae_latent[:,:self.proto_vae.latent_dim]
      logVar = vae_latent[:,self.proto_vae.latent_dim:].clamp(np.log(1e-8), - np.log(1e-8))
      z = self.proto_vae.reparameterize(mu, logVar)
    
    encoded_vector = self.encoder_layer(z)
    return encoded_vector
  def initialize_weights(self):
    for m in self.encoder_layer.modules():
      if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight.data)
        if m.bias is not None:
          nn.init.zeros_(m.bias.data)
### AENB
class AENB(nn.Module):
    def __init__(self, input_dim, latent_dim, device, hidden_layers, activation_function=nn.ReLU):
        super(AENB, self).__init__()
        self.device= device
        self.latent_dim = latent_dim
        self.input_dim = input_dim
        self.epsilon = 1e-4
        self.hidden_layers = hidden_layers
        self.activation_function = activation_function

        
        feature_layers = []
        previous_dim = input_dim 
        for layer_dim in self.hidden_layers:
            feature_layers.append(nn.Linear(previous_dim, layer_dim))
            feature_layers.append(self.activation_function())
            # feature_layers.append(nn.BatchNorm1d(layer_dim))
            previous_dim = layer_dim
        feature_layers.append(nn.Linear(previous_dim, latent_dim))
        self.features = nn.Sequential(*feature_layers)
        
        decoder_layers = []
        for layer_dim in reversed(self.hidden_layers):
            decoder_layers.append(nn.Linear(previous_dim, layer_dim))
            decoder_layers.append(self.activation_function())
            # decoder_layers.append(nn.BatchNorm1d(layer_dim))
            previous_dim = layer_dim
        decoder_layers.append(nn.Linear(previous_dim, input_dim * 2))
        self.decoder_layers = nn.Sequential(*decoder_layers)

        self._initialize_weights()

    def decoder(self, z):
        decoded = self.decoder_layers(z)
        mu_recon = torch.exp(decoded[:, :self.input_dim]).clamp(1e-6, 1e6) 
        theta_recon = F.softplus(decoded[:, self.input_dim:]).clamp(1e-4, 1e4)  
        return mu_recon, theta_recon

    def forward(self, x, y=None, is_train=True):
        encoded_features = self.features(x)
        z = encoded_features
        mu_recon, theta_recon = self.decoder(z)


        return mu_recon, theta_recon
    
    def _initialize_weights(self):
        for m in self.features.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


        for m in self.decoder_layers.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

In [7]:
import scanpy as sc
from sklearn.model_selection import train_test_split
from scipy import sparse
import modin.pandas as pd
import ray
ray.init()

dat = sc.read_h5ad('/home/local/kyeonghunjeong_920205/nipa_bu/COVID19/3.analysis/9.MIL/covid19_sc/su_2020_processed.h5ad')
print(dat.shape)
sc.pp.filter_genes(dat, min_cells=5)
adata_raw = dat.copy()
sc.pp.normalize_total(dat, target_sum=1e4)

print("Preprocessing Complete!")
print(dat.shape)

sc.pp.log1p(dat)
sc.pp.highly_variable_genes(dat, n_top_genes=2000)
adata = adata_raw[:, dat.var.highly_variable]




print(adata.shape)
adata = adata[adata.obs['disease_severity_standard'].isin(['mild','moderate', 'severe'])]
print(adata.shape)
mapping = {'mild': 0, 'moderate': 1, 'severe': 1}

adata.obs['disease_numeric'] = adata.obs['disease_severity_standard'].map(mapping)
adata.obs['sample_id_numeric'], _ = pd.factorize(adata.obs['sample'])

sample_labels = adata.obs[['disease_numeric', 'sample_id_numeric']].drop_duplicates()

2024-04-30 01:41:28,572	INFO worker.py:1715 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


(559517, 33538)
Preprocessing Complete!
(559517, 23989)
(559517, 2000)
(515141, 2000)




In [9]:
saved_model_path = '/home/local/kyeonghunjeong_920205/nipa_bu/COVID19/3.analysis/9.MIL/scAMIL_cell/WENO_su_2020_model_vae_ed128_md64_lr0.0001_500_0.1_5_15_leaktantan_fix_auc433_noflt_only2'

In [12]:

adata.__dict__['_raw'].__dict__['_var'] = adata.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})#
adata.write(filename=f"{saved_model_path}/anndata_proc.h5ad")

In [14]:
adata.obs.to_csv(f"{saved_model_path}/meta.csv")

In [None]:
saved_model_path = '/home/local/kyeonghunjeong_920205/nipa_bu/COVID19/3.analysis/9.MIL/scAMIL_cell/WENO_su_2020_model_vae_ed128_md64_lr0.0001_500_0.1_5_15_leaktantan_fix_auc433_noflt_only2'

In [17]:
for exp in range(1,9):
    print(f'Experiment {exp}')
    _, _, test_dataset, label_encoder, scaler = load_dataset_and_preprocessors(data_dir, exp, device)

    # instance_train_dataset = update_instance_labels_with_bag_labels(train_dataset)
    # instance_val_dataset = update_instance_labels_with_bag_labels(val_dataset)
    instance_test_dataset = update_instance_labels_with_bag_labels(test_dataset)

    model_teacher = torch.load(f'{saved_model_path}/model_teacher_exp{exp}.pt',map_location=device)
    model_encoder = torch.load(f'{saved_model_path}/model_encoder_exp{exp}.pt',map_location=device)
    model_student = torch.load(f'{saved_model_path}/model_student_exp{exp}.pt',map_location=device)


    model_encoder.eval()
    model_student.eval()
    model_teacher.eval()
    with torch.no_grad():
        features = model_encoder(instance_test_dataset.data.clone().detach().float().to(device))[:, :model_teacher.input_dims].detach().requires_grad_(False)
        cell_score_stud = model_student(features)
        cell_score_teacher = model_teacher.attention_module(features).squeeze(0)
    cell_score_stud_softmax = torch.softmax(cell_score_stud, dim=1)
    features_np = features.cpu().detach().numpy()
    cell_score_stud_softmax_np = cell_score_stud_softmax.cpu().detach().numpy()
    cell_score_stud_np = cell_score_stud.cpu().detach().numpy()
    cell_score_teacher_np = cell_score_teacher.cpu().detach().numpy()


    df = pd.DataFrame(features_np, columns = [f'feature_{i}' for i in range(features_np.shape[1])])

    df['cell_type']= label_encoder.inverse_transform(instance_test_dataset.instance_labels.cpu().detach().numpy())
    df['cell_score_teacher'] = cell_score_teacher_np
    df['cell_score_stud'] = cell_score_stud_np[:,1]
    df['cell_score_stud_softmax'] = cell_score_stud_softmax_np[:,1]
    df['bag_labels'] = instance_test_dataset.bag_labels.cpu().detach().numpy()
    df['instance_labels'] = instance_test_dataset.instance_labels.cpu().detach().numpy()
    df['cell_score_teacher_minmax']= (df['cell_score_teacher'].values - min(df['cell_score_teacher'].values)) / (max(df['cell_score_teacher'].values)- min(df['cell_score_teacher'].values))
    df.to_csv(f'{saved_model_path}/cell_score_{exp}.csv', index=False)

    
    split_ratio = [0.5, 0.25, 0.25]
    train_val_set, test_set = train_test_split(sample_labels, test_size=split_ratio[2], random_state=exp, stratify=sample_labels['disease_numeric'])
    train_set, val_set = train_test_split(train_val_set, test_size=split_ratio[1] / (1 - split_ratio[2]), random_state=exp,stratify=train_val_set['disease_numeric'])
    test_set.to_csv(f"{saved_model_path}/test_set_barcodes_{exp}.csv")
    test_data = adata[adata.obs['sample_id_numeric'].isin(test_set['sample_id_numeric'])]    
    test_data.obs.to_csv(f"{saved_model_path}/obs_{exp}.csv")





Experiment 1




Experiment 2




Experiment 3




Experiment 4




Experiment 5




Experiment 6




Experiment 7




Experiment 8




In [18]:
torch.cuda.empty_cache()

In [19]:
saved_model_path = '/home/local/kyeonghunjeong_920205/nipa_bu/COVID19/3.analysis/9.MIL/scAMIL_cell/NO_Opt_student_WENO_su_2020_model_vae_ed128_md64_lr0.0001_500_0.1_500_15_NO_Opt_student_leaktantan_fix_auc4054054_noflt_only2'

for exp in range(1,9):
    print(f'Experiment {exp}')
    _, _, test_dataset, label_encoder, scaler = load_dataset_and_preprocessors(data_dir, exp, device)

    # instance_train_dataset = update_instance_labels_with_bag_labels(train_dataset)
    # instance_val_dataset = update_instance_labels_with_bag_labels(val_dataset)
    instance_test_dataset = update_instance_labels_with_bag_labels(test_dataset)

    model_teacher = torch.load(f'{saved_model_path}/model_teacher_exp{exp}.pt',map_location=device)
    model_encoder = torch.load(f'{saved_model_path}/model_encoder_exp{exp}.pt',map_location=device)
    model_student = torch.load(f'{saved_model_path}/model_student_exp{exp}.pt',map_location=device)


    model_encoder.eval()
    model_student.eval()
    model_teacher.eval()
    with torch.no_grad():
        features = model_encoder(instance_test_dataset.data.clone().detach().float().to(device))[:, :model_teacher.input_dims].detach().requires_grad_(False)
        cell_score_stud = model_student(features)
        cell_score_teacher = model_teacher.attention_module(features).squeeze(0)
    cell_score_stud_softmax = torch.softmax(cell_score_stud, dim=1)
    features_np = features.cpu().detach().numpy()
    cell_score_stud_softmax_np = cell_score_stud_softmax.cpu().detach().numpy()
    cell_score_stud_np = cell_score_stud.cpu().detach().numpy()
    cell_score_teacher_np = cell_score_teacher.cpu().detach().numpy()


    df = pd.DataFrame(features_np, columns = [f'feature_{i}' for i in range(features_np.shape[1])])

    df['cell_type']= label_encoder.inverse_transform(instance_test_dataset.instance_labels.cpu().detach().numpy())
    df['cell_score_teacher'] = cell_score_teacher_np
    df['cell_score_stud'] = cell_score_stud_np[:,1]
    df['cell_score_stud_softmax'] = cell_score_stud_softmax_np[:,1]
    df['bag_labels'] = instance_test_dataset.bag_labels.cpu().detach().numpy()
    df['instance_labels'] = instance_test_dataset.instance_labels.cpu().detach().numpy()
    df['cell_score_teacher_minmax']= (df['cell_score_teacher'].values - min(df['cell_score_teacher'].values)) / (max(df['cell_score_teacher'].values)- min(df['cell_score_teacher'].values))
    df.to_csv(f'{saved_model_path}/cell_score_{exp}.csv', index=False)

    
    split_ratio = [0.5, 0.25, 0.25]
    train_val_set, test_set = train_test_split(sample_labels, test_size=split_ratio[2], random_state=exp, stratify=sample_labels['disease_numeric'])
    train_set, val_set = train_test_split(train_val_set, test_size=split_ratio[1] / (1 - split_ratio[2]), random_state=exp,stratify=train_val_set['disease_numeric'])
    test_data = adata[adata.obs['sample_id_numeric'].isin(test_set['sample_id_numeric'])]
    test_data.obs.to_csv(f"{saved_model_path}/obs_{exp}.csv")
    print("실제 데이터에서 훈련, 검증, 테스트 샘플 추출 완료")
    torch.cuda.empty_cache()
    # test_data.obs.columns = [sub.replace('(', '') for sub in test_data.obs.columns]
    # test_data.obs.columns = [sub.replace(')', '') for sub in test_data.obs.columns]
    # test_data.obs.rename(columns={'_index': 'index'}, inplace=True)
    # test_data.__dict__['_raw'].__dict__['_var'] = adata.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})#
    # del test_data.obs['disease_numeric']
    # del test_data.obs['sample_id_numeric']
    # test_data.write(filename=f"{saved_model_path}/anndata_{exp}.h5ad")





Experiment 1




실제 데이터에서 훈련, 검증, 테스트 샘플 추출 완료
Experiment 2




실제 데이터에서 훈련, 검증, 테스트 샘플 추출 완료
Experiment 3




실제 데이터에서 훈련, 검증, 테스트 샘플 추출 완료
Experiment 4




실제 데이터에서 훈련, 검증, 테스트 샘플 추출 완료
Experiment 5




실제 데이터에서 훈련, 검증, 테스트 샘플 추출 완료
Experiment 6




실제 데이터에서 훈련, 검증, 테스트 샘플 추출 완료
Experiment 7




실제 데이터에서 훈련, 검증, 테스트 샘플 추출 완료
Experiment 8




실제 데이터에서 훈련, 검증, 테스트 샘플 추출 완료
