<a href="https://colab.research.google.com/github/SYWoo02/KNU_Capstone_Project_2025/blob/main/%EC%84%B1%EC%9D%B8%EB%B3%91_%EC%95%BD%EB%AC%BC%EC%83%81%ED%98%B8%EC%9E%91%EC%9A%A9_%EC%98%88%EC%B8%A1%EB%AA%A8%EB%8D%B8%EA%B5%AC%ED%98%84_%EB%B0%8F_%EC%84%B1%EB%8A%A5%ED%8F%89%EA%B0%80_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 라이브러리 불러오기
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

# CSV 파일 불러오기
df = pd.read_csv("성인병 관련 약물조합 w interaction_type(label)_26.csv")

# 약물과 라벨 인덱싱
unique_drugs = pd.unique(df[["Drug1_ID", "Drug2_ID"]].values.ravel())
drug_to_index = {drug: idx for idx, drug in enumerate(unique_drugs)}
df["Drug1_Idx"] = df["Drug1_ID"].map(drug_to_index)
df["Drug2_Idx"] = df["Drug2_ID"].map(drug_to_index)

label_set = sorted(df["Label"].unique())
label_to_index = {label: i for i, label in enumerate(label_set)}
df["Label_Idx"] = df["Label"].map(label_to_index)
num_classes = len(label_to_index)

# Dataset 정의
class DrugCombinationDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe[["Drug1_Idx", "Drug2_Idx", "Label_Idx"]].reset_index(drop=True)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        return {
            "drug1": torch.tensor(row["Drug1_Idx"], dtype=torch.long),
            "drug2": torch.tensor(row["Drug2_Idx"], dtype=torch.long),
            "label": torch.tensor(row["Label_Idx"], dtype=torch.long)
        }

# 데이터 분리 및 로더 구성
full_dataset = DrugCombinationDataset(df)
train_size = int(0.8 * len(full_dataset))
train_dataset, valid_dataset = random_split(full_dataset, [train_size, len(full_dataset) - train_size])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64)

# 클래스 가중치
labels_np = df["Label_Idx"].values
from sklearn.utils.class_weight import compute_class_weight
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(labels_np), y=labels_np)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float)

# 모델 정의
class ImprovedDrugInteractionModel(nn.Module):
    def __init__(self, num_drugs, embedding_dim, hidden_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(num_drugs, embedding_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc1 = nn.Linear(embedding_dim * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
    def forward(self, drug1, drug2):
        embed1 = self.embedding(drug1)
        embed2 = self.embedding(drug2)
        x = torch.cat([embed1, embed2], dim=1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# 학습
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImprovedDrugInteractionModel(len(drug_to_index), 64, 64, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor.to(device))

for epoch in range(5000):
    model.train()
    total_loss = 0
    for batch in train_loader:
        drug1 = batch["drug1"].to(device)
        drug2 = batch["drug2"].to(device)
        label = batch["label"].to(device)

        optimizer.zero_grad()
        output = model(drug1, drug2)
        loss = loss_fn(output, label)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)

    # Epoch 100 단위 마다 출력
    if (epoch + 1) % 100 == 0:

      # 검증 (Loss & Validation Accuracy)
      model.eval()
      correct, total = 0, 0
      with torch.no_grad():
          for batch in valid_loader:
              drug1 = batch["drug1"].to(device)
              drug2 = batch["drug2"].to(device)
              label = batch["label"].to(device)
              pred = model(drug1, drug2).argmax(1)
              correct += (pred == label).sum().item()
              total += label.size(0)
      print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f} → Validation Accuracy: {correct / total:.4f}")


[Epoch 100] Loss: 0.1154 → Validation Accuracy: 0.8793
[Epoch 200] Loss: 0.0522 → Validation Accuracy: 0.9167
[Epoch 300] Loss: 0.0231 → Validation Accuracy: 0.9349
[Epoch 400] Loss: 0.0117 → Validation Accuracy: 0.9435
[Epoch 500] Loss: 0.0104 → Validation Accuracy: 0.9492
[Epoch 600] Loss: 0.0103 → Validation Accuracy: 0.9492
[Epoch 700] Loss: 0.0050 → Validation Accuracy: 0.9579
[Epoch 800] Loss: 0.0070 → Validation Accuracy: 0.9492
[Epoch 900] Loss: 0.0032 → Validation Accuracy: 0.9550
[Epoch 1000] Loss: 0.0045 → Validation Accuracy: 0.9559
[Epoch 1100] Loss: 0.0033 → Validation Accuracy: 0.9569
[Epoch 1200] Loss: 0.0029 → Validation Accuracy: 0.9492
[Epoch 1300] Loss: 0.0025 → Validation Accuracy: 0.9588
[Epoch 1400] Loss: 0.0040 → Validation Accuracy: 0.9540
[Epoch 1500] Loss: 0.0026 → Validation Accuracy: 0.9511
[Epoch 1600] Loss: 0.0035 → Validation Accuracy: 0.9492
[Epoch 1700] Loss: 0.0026 → Validation Accuracy: 0.9579
[Epoch 1800] Loss: 0.0032 → Validation Accuracy: 0.9521
[

In [None]:
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt

# 전체 예측 수집
y_true, y_pred = [], []
model.eval()
with torch.no_grad():
    for batch in valid_loader:
        d1 = batch["drug1"].to(device)
        d2 = batch["drug2"].to(device)
        labels = batch["label"].to(device)
        outputs = model(d1, d2)
        preds = outputs.argmax(1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

# 정밀도, 재현율, F1 등
print(classification_report(y_true, y_pred, digits=4))

# 정확도: 약 96%
# 다수 데이터뿐만아니라 소수 데이터에서도 괜찮은 성능을 보임

              precision    recall  f1-score   support

           0     1.0000    1.0000    1.0000         3
           1     0.9692    0.9130    0.9403        69
           2     0.9615    1.0000    0.9804        25
           3     0.4000    0.4000    0.4000         5
           4     0.7857    0.8462    0.8148        13
           5     1.0000    0.7778    0.8750         9
           6     1.0000    1.0000    1.0000         1
           7     1.0000    0.9565    0.9778        46
           8     0.5000    1.0000    0.6667         2
          10     0.0000    0.0000    0.0000         0
          11     0.7143    0.8333    0.7692         6
          12     1.0000    1.0000    1.0000        39
          13     1.0000    1.0000    1.0000        19
          14     0.6667    1.0000    0.8000         4
          16     1.0000    0.5000    0.6667         2
          17     0.9864    0.9769    0.9816       519
          18     0.3333    1.0000    0.5000         1
          19     0.5000    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
