In [8]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np

# 1. 모델 정의
class AETransformerLite(nn.Module):
    def __init__(self, input_dim, latent_dim=8, tf_embed_dim=8, dropout=0.2):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim)
        )
        self.embedding = nn.Linear(1, tf_embed_dim)
        self.self_attn = nn.MultiheadAttention(embed_dim=tf_embed_dim, num_heads=1, dropout=dropout, batch_first=True)
        self.ffn = nn.Sequential(
            nn.LayerNorm(tf_embed_dim),
            nn.Linear(tf_embed_dim, 1),
            nn.Sigmoid()
        )
        self.attn_weights = None

    def forward(self, x):
        latent = self.encoder(x)
        x = latent.unsqueeze(2)
        x = self.embedding(x)
        attn_out, attn_weights = self.self_attn(x, x, x, need_weights=True)
        self.attn_weights = attn_weights.detach().cpu()
        x = attn_out.mean(dim=1)
        out = self.ffn(x)
        return out

In [9]:
# 2. 데이터 불러오기
df = pd.read_csv("input.csv")
sample_ids = df["Unnamed: 0"].values
gene_names = df.columns[1:]
X = torch.tensor(df.drop(columns=["Unnamed: 0"]).values, dtype=torch.float32)

In [10]:
# 3. 모델 로드
model = AETransformerLite(input_dim=X.shape[1])
model.load_state_dict(torch.load("model.pt", map_location="cpu"))
model.eval()

  model.load_state_dict(torch.load("model.pt", map_location="cpu"))


AETransformerLite(
  (encoder): Sequential(
    (0): Linear(in_features=18631, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=8, bias=True)
  )
  (embedding): Linear(in_features=1, out_features=8, bias=True)
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
  )
  (ffn): Sequential(
    (0): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=8, out_features=1, bias=True)
    (2): Sigmoid()
  )
)

In [11]:
# 4. 예측값 얻기
with torch.no_grad():
    preds = model(X).squeeze()
    attns = model.attn_weights  # shape: (N, L, L)

In [12]:
# 5. 레이블 추출 (Stroke / Control)
labels = np.array(["stroke" in sid.lower() for sid in sample_ids], dtype=int)

In [14]:
# 6. TP 샘플 인덱스 (label==1 & pred>0.5)
tp_idx = (labels == 1) & (preds.numpy() > 0.5)
tp_attn = attns[tp_idx]  # (TP, L, L)

In [15]:
# 7. Attention 평균
attn_mean = tp_attn.mean(dim=0)  # (L, L)
importance_score = attn_mean.sum(dim=0).numpy()  # (L,)

In [17]:
# 8. encoder weight 기반 유전자 중요도 추출
W1 = model.encoder[0].weight.detach().numpy()   # (64, G)
W2 = model.encoder[2].weight.detach().numpy()   # (8, 64)
encoder_full = W2 @ W1                          # (8, G)

projected = importance_score @ encoder_full     # (G,)

In [18]:
# 9. Top 20 유전자
topk = 20
top_indices = np.argsort(projected)[::-1][:topk]
top_genes = gene_names[top_indices]
top_scores = projected[top_indices]

print("TP 상위 중요 유전자:")
for i, (gene, score) in enumerate(zip(top_genes, top_scores), 1):
    print(f"{i:2d}. {gene}: {score:.4f}")

TP 상위 중요 유전자:
 1. BACE2: 0.0466
 2. ZFP28: 0.0464
 3. HDAC8: 0.0455
 4. PSCD3: 0.0434
 5. PACSIN1: 0.0394
 6. KRT2B: 0.0383
 7. TMCO4: 0.0381
 8. CCNT2: 0.0381
 9. VENTXP7: 0.0380
10. EMX1: 0.0378
11. NAT5: 0.0373
12. HARS2: 0.0371
13. ZNF3: 0.0370
14. TSSK2: 0.0361
15. KCNAB3: 0.0361
16. XPO5: 0.0361
17. PTD008: 0.0361
18. SLC13A5: 0.0360
19. ZNF358: 0.0359
20. IL21R: 0.0354
