In [1]:
import scanpy as sc
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn as nn
from torch.optim import AdamW
from tqdm import tqdm

In [2]:
merged_tisch2_sub = sc.read_h5ad('/storage/sc_cancer_atlas/3CA/merged/merged_sc_atlas_preprocessed_scaled_subset_0.2.h5ad')

In [3]:
# Use var.index for gene names
gene_names = merged_tisch2_sub.var.index

# Generate sentences describing top genes for each cell
top_genes_texts = []
for i in range(merged_tisch2_sub.shape[0]):
    # Get indices of top 20 genes (sorted by expression value)
    top_genes_indices = merged_tisch2_sub.X[i].argsort()[-20:][::-1]
    top_genes = gene_names[top_genes_indices]
    sentence = f"Top genes are {', '.join(top_genes)}."
    top_genes_texts.append(sentence)

# Combine with cell types
data = pd.DataFrame({
    "text": top_genes_texts,
    "label": merged_tisch2_sub.obs['cell_type_merged_coarse']
})
data = data.dropna()  # Drop rows with missing labels

In [4]:
label_encoder = LabelEncoder()
data['encoded_label'] = label_encoder.fit_transform(data['label'])
num_classes = len(label_encoder.classes_)

In [5]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(data, test_size=0.2, random_state=42, stratify=data['encoded_label'])

In [6]:
class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts.reset_index(drop=True)
        self.labels = labels.reset_index(drop=True)
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.texts[idx], self.labels[idx]

In [7]:
# 훈련 데이터셋 및 데이터로더
train_dataset = TextDataset(train_df['text'], train_df['encoded_label'])
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 테스트 데이터셋 및 데이터로더
test_dataset = TextDataset(test_df['text'], test_df['encoded_label'])
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [8]:
# SentenceTransformer 모델 로드
model = SentenceTransformer('all-MiniLM-L6-v2')

# 분류 헤드 추가
classifier = nn.Linear(model.get_sentence_embedding_dimension(), num_classes)



In [9]:
class ClassificationModel(nn.Module):
    def __init__(self, model, classifier):
        super(ClassificationModel, self).__init__()
        self.model = model
        self.classifier = classifier
    
    def forward(self, input_texts):
        # 토크나이즈 및 입력 데이터 생성
        input_features = self.model.tokenize(input_texts)
        # 입력 데이터를 모델의 디바이스로 이동
        input_features = {key: value.to(self.model.device) for key, value in input_features.items()}
        # SentenceTransformer 모델을 통해 임베딩 추출
        embeddings = self.model(input_features)['sentence_embedding']
        # 분류 헤드를 통해 로짓 계산
        logits = self.classifier(embeddings)
        return logits

In [10]:
classification_model = ClassificationModel(model, classifier)
# GPU 사용 가능 시 모델을 GPU로 이동
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classification_model.to(device)

ClassificationModel(
  (model): SentenceTransformer(
    (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
    (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
    (2): Normalize()
  )
  (classifier): Linear(in_features=384, out_features=13, bias=True)
)

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(classification_model.parameters(), lr=2e-5)

In [12]:
num_epochs = 10

for epoch in range(num_epochs):
    classification_model.train()
    total_loss = 0
    for texts, labels in tqdm(train_dataloader, desc=f'Epoch {epoch+1}'):
        labels = labels.to(device).long()
        
        optimizer.zero_grad()
        logits = classification_model(texts)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    avg_loss = total_loss / len(train_dataloader)
    print(f'Epoch {epoch+1}, Loss: {avg_loss}')
    
    # 평가
    classification_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for texts, labels in test_dataloader:
            labels = labels.to(device).long()
            logits = classification_model(texts)
            predictions = torch.argmax(logits, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total
    print(f'Epoch {epoch+1}, Validation Accuracy: {accuracy}')

Epoch 1: 100%|██████████████████████████████| 5875/5875 [04:40<00:00, 20.94it/s]


Epoch 1, Loss: 1.9273898255977224
Epoch 1, Validation Accuracy: 0.5514255319148936


Epoch 2: 100%|██████████████████████████████| 5875/5875 [04:40<00:00, 20.95it/s]


Epoch 2, Loss: 1.4651445059979216
Epoch 2, Validation Accuracy: 0.5834680851063829


Epoch 3: 100%|██████████████████████████████| 5875/5875 [04:41<00:00, 20.88it/s]


Epoch 3, Loss: 1.283619684665761
Epoch 3, Validation Accuracy: 0.5891702127659575


Epoch 4: 100%|██████████████████████████████| 5875/5875 [04:44<00:00, 20.64it/s]


Epoch 4, Loss: 1.1965293184036905
Epoch 4, Validation Accuracy: 0.5985744680851064


Epoch 5: 100%|██████████████████████████████| 5875/5875 [04:45<00:00, 20.60it/s]


Epoch 5, Loss: 1.1469673238206417
Epoch 5, Validation Accuracy: 0.601531914893617


Epoch 6: 100%|██████████████████████████████| 5875/5875 [04:43<00:00, 20.69it/s]


Epoch 6, Loss: 1.1138788581604653
Epoch 6, Validation Accuracy: 0.5992765957446808


Epoch 7: 100%|██████████████████████████████| 5875/5875 [04:44<00:00, 20.68it/s]


Epoch 7, Loss: 1.0905856708313557
Epoch 7, Validation Accuracy: 0.6057021276595744


Epoch 8: 100%|██████████████████████████████| 5875/5875 [04:43<00:00, 20.71it/s]


Epoch 8, Loss: 1.0699980573400538
Epoch 8, Validation Accuracy: 0.6044255319148936


Epoch 9: 100%|██████████████████████████████| 5875/5875 [04:44<00:00, 20.67it/s]


Epoch 9, Loss: 1.0525244386196138
Epoch 9, Validation Accuracy: 0.6082127659574468


Epoch 10: 100%|█████████████████████████████| 5875/5875 [04:45<00:00, 20.59it/s]


Epoch 10, Loss: 1.036445513202789
Epoch 10, Validation Accuracy: 0.6055531914893617


In [13]:
# 모델 저장
torch.save(classification_model.state_dict(), 'classification_model.pt')

# 모델 로드
classification_model = ClassificationModel(model, classifier)
classification_model.load_state_dict(torch.load('classification_model.pt'))
classification_model.to(device)

ClassificationModel(
  (model): SentenceTransformer(
    (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
    (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
    (2): Normalize()
  )
  (classifier): Linear(in_features=384, out_features=13, bias=True)
)

In [18]:
def predict_label(text):
    classification_model.eval()
    with torch.no_grad():
        logits = classification_model([text])
        probabilities = torch.nn.functional.softmax(logits, dim=1)
        predicted_label_id = torch.argmax(probabilities, dim=1).item()
        predicted_label = label_encoder.inverse_transform([predicted_label_id])[0]
    return predicted_label

# 실제 라벨 (인코딩된 라벨)
true_labels = data['encoded_label'].values

# 예측 라벨을 저장할 리스트
predicted_labels = []

# 데이터셋의 텍스트 리스트
texts = data['text'].tolist()

# 예측 수행
for text in tqdm(texts, desc='Predicting'):
    predicted_label = predict_label(text)
    predicted_label_encoded = label_encoder.transform([predicted_label])[0]
    predicted_labels.append(predicted_label_encoded)

# 성능 지표 계산
accuracy = accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels, average='weighted')
recall = recall_score(true_labels, predicted_labels, average='weighted')
f1 = f1_score(true_labels, predicted_labels, average='weighted')

print("모델 성능 평가:")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision (Weighted): {precision:.4f}")
print(f"Recall (Weighted): {recall:.4f}")
print(f"F1 Score (Weighted): {f1:.4f}")

# 자세한 분류 보고서 출력
print("\n분류 보고서:")
print(classification_report(true_labels, predicted_labels, target_names=label_encoder.classes_))

Predicting: 100%|██████████████████████| 234997/234997 [14:19<00:00, 273.40it/s]


NameError: name 'accuracy_score' is not defined

In [19]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report

# 성능 지표 계산
accuracy = accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels, average='weighted')
recall = recall_score(true_labels, predicted_labels, average='weighted')
f1 = f1_score(true_labels, predicted_labels, average='weighted')

print("모델 성능 평가:")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision (Weighted): {precision:.4f}")
print(f"Recall (Weighted): {recall:.4f}")
print(f"F1 Score (Weighted): {f1:.4f}")

# 자세한 분류 보고서 출력
print("\n분류 보고서:")
print(classification_report(true_labels, predicted_labels, target_names=label_encoder.classes_))

  _warn_prf(average, modifier, msg_start, len(result))


모델 성능 평가:
Accuracy: 0.6409
Precision (Weighted): 0.7988
Recall (Weighted): 0.6409
F1 Score (Weighted): 0.6579

분류 보고서:
               precision    recall  f1-score   support

            B       0.91      0.72      0.80     12542
   Epithelial       0.90      0.72      0.80     19782
   Fibroblast       0.75      0.85      0.80      4254
Immune_Others       0.00      0.00      0.00       227
    Lymphatic       0.18      0.02      0.03       117
     Lymphoid       0.72      0.41      0.52      3497
    Malignant       0.36      0.98      0.53     41852
      Myeloid       0.94      0.55      0.69     39891
         NK_T       0.92      0.62      0.74     62783
       Others       0.00      0.00      0.00        92
      Stromal       0.79      0.26      0.40     17614
     Vascular       0.86      0.59      0.70     11623
          nan       0.89      0.43      0.58     20723

     accuracy                           0.64    234997
    macro avg       0.63      0.47      0.51    234997

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


## Fine-tuning 하기 전의 모델

In [20]:
# 사전 학습된 모델 로드
pretrained_model = SentenceTransformer('all-MiniLM-L6-v2')

# 데이터셋 임베딩 생성 (훈련 및 테스트)
train_embeddings = pretrained_model.encode(train_df['text'].tolist(), convert_to_tensor=True)
test_embeddings = pretrained_model.encode(test_df['text'].tolist(), convert_to_tensor=True)



In [21]:
from sklearn.linear_model import LogisticRegression

# 분류기 학습
classifier = LogisticRegression(max_iter=1000, random_state=42)
classifier.fit(train_embeddings.cpu().numpy(), train_df['encoded_label'])

In [22]:
# 테스트 세트 예측
test_predictions = classifier.predict(test_embeddings.cpu().numpy())

# 성능 지표 계산
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report

accuracy = accuracy_score(test_df['encoded_label'], test_predictions)
precision = precision_score(test_df['encoded_label'], test_predictions, average='weighted')
recall = recall_score(test_df['encoded_label'], test_predictions, average='weighted')
f1 = f1_score(test_df['encoded_label'], test_predictions, average='weighted')

print("Pretrained 모델 성능 평가:")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision (Weighted): {precision:.4f}")
print(f"Recall (Weighted): {recall:.4f}")
print(f"F1 Score (Weighted): {f1:.4f}")

# 자세한 분류 보고서 출력
print("\n분류 보고서:")
print(classification_report(test_df['encoded_label'], test_predictions, target_names=label_encoder.classes_))

Pretrained 모델 성능 평가:
Accuracy: 0.4465
Precision (Weighted): 0.5205
Recall (Weighted): 0.4465
F1 Score (Weighted): 0.4280

분류 보고서:
               precision    recall  f1-score   support

            B       0.64      0.20      0.30      2508
   Epithelial       0.54      0.37      0.44      3957
   Fibroblast       0.64      0.39      0.49       851
Immune_Others       0.00      0.00      0.00        45
    Lymphatic       0.00      0.00      0.00        23
     Lymphoid       0.78      0.23      0.36       699
    Malignant       0.32      0.85      0.46      8371
      Myeloid       0.59      0.36      0.45      7978
         NK_T       0.55      0.54      0.55     12557
       Others       0.00      0.00      0.00        18
      Stromal       0.55      0.13      0.21      3523
     Vascular       0.53      0.20      0.29      2325
          nan       0.50      0.21      0.29      4145

     accuracy                           0.45     47000
    macro avg       0.44      0.27      0.3

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
