In [1]:
import pickle
from load_data import LoadData

In [2]:
import torch
import numpy as np
import pandas as pd
import os
import random
from pathlib import Path

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchmetrics import AUROC, Accuracy, Precision, Recall
from torchmetrics.classification import BinaryAUROC, BinaryF1Score

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score

In [3]:
SEED = 42
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(SEED)

## Prepare dataset

In [4]:
class CombinationDataset(Dataset):
    def __init__(self, database='DCDB', neg_ratio=1, duplicate=False, transform=None):
        if (database != 'DCDB') & (database != 'C_DCDB'):
            raise ValueError('database must be DCDB or C_DCDB')
        if neg_ratio < 1:
            raise ValueError('neg_ratio must be greater than 1')
        self.database = database
        self.neg_ratio = neg_ratio
        self.transform = transform
        self.duplicate = duplicate
        self.data_path = Path('data/processed')/f'{database}_neg{neg_ratio}_dup{int(duplicate)}.pt'
        if self.data_path.exists():
            self.data = torch.load(self.data_path)
        else:
            self._process()
            self.data = torch.load(self.data_path)
        
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]
    
    def _process(self):
        print('Processing dataset...')
        dataset_list = self._create_dataset()
        print(f'Saving dataset...{self.data_path}')
        torch.save(dataset_list, self.data_path)
    
    def _create_dataset(self):
        dataloader = LoadData()
        # Get embedding
        with open('data/embedding/embeddings_node2vec_msi_seed0.pkl', 'rb') as f:
            embedding_dict = pickle.load(f)
        # drug dictionary
        drug_id2name, drug_name2id = dataloader.get_dict(type='drug')
        # Prepare positive labels
        pos_df = pd.read_csv(f'data/labels/{self.database}_msi.tsv', sep='\t')

        dataset_list = []
        # positive samples
        for i in range(len(pos_df)):
            drug1_id = pos_df.iloc[i, 0]
            drug2_id = pos_df.iloc[i, 1]
            comb_embedding = np.concatenate([embedding_dict[drug1_id], embedding_dict[drug2_id]])
            dataset_list.append([torch.tensor(comb_embedding, dtype=torch.float), torch.tensor(1, dtype=torch.long)])
            if self.duplicate:
                comb_embedding2 = np.concatenate([embedding_dict[drug2_id], embedding_dict[drug1_id]])
                dataset_list.append([torch.tensor(comb_embedding2, dtype=torch.float), torch.tensor(1, dtype=torch.long)])
                
        # negative samples
        count = 0
        while count < len(pos_df) * self.neg_ratio:
        # while len(dataset_list) < len(pos_df) * (1 + self.neg_ratio):
            drug1_id = random.choice(list(drug_id2name.keys()))
            drug2_id = random.choice(list(drug_id2name.keys()))
            if drug1_id == drug2_id:
                continue
            if ((pos_df['drug_1'] == drug1_id) & (pos_df['drug_2'] == drug2_id)).any():
                continue
            if ((pos_df['drug_1'] == drug2_id) & (pos_df['drug_2'] == drug1_id)).any():
                continue
            comb_embedding = np.concatenate([embedding_dict[drug1_id], embedding_dict[drug2_id]])
            dataset_list.append([torch.tensor(comb_embedding, dtype=torch.float), torch.tensor(0, dtype=torch.long)])
            if self.duplicate:
                comb_embedding2 = np.concatenate([embedding_dict[drug2_id], embedding_dict[drug1_id]])
                dataset_list.append([torch.tensor(comb_embedding2, dtype=torch.float), torch.tensor(0, dtype=torch.long)])
            count += 1
        return dataset_list

In [5]:
# dataset = CombinationDataset(database='DCDB', neg_ratio=1)
# print(len(dataset))

In [6]:
dataset = CombinationDataset(database='C_DCDB', neg_ratio=1, duplicate=False)
print(len(dataset))

Processing dataset...
Saving dataset...data/processed/C_DCDB_neg1_dup0.pt
8442


In [7]:
train_size = int(0.8 * len(dataset))
valid_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# need to split data well if duplicate=True

## Training

In [8]:
class CombNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, comb_type='cat', dropout=0.1):
        super(CombNet, self).__init__()
        self.input_dim = input_dim # dimension of concatenated drug embeddings
        if (comb_type != 'cat') & (comb_type != 'sum') & (comb_type != 'diff') & (comb_type != 'sumdiff'):
            raise ValueError('comb_type must be cat, sum, diff or sumdiff')
        self.comb_type = comb_type
        self.lt = nn.Sequential(
            nn.Linear(input_dim // 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
        )
        dual_dim = hidden_dim if (comb_type == 'sum' or comb_type == 'diff') else hidden_dim * 2
        self.fc = nn.Sequential(
            nn.Linear(dual_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            # nn.Linear(hidden_dim, hidden_dim),
            # nn.BatchNorm1d(hidden_dim),
            # nn.ReLU(),
            # nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim) # for BCEWithLogitsLoss
        )
    def forward(self, data):
        drug1, drug2 = data[:, :self.input_dim//2], data[:, self.input_dim//2:] # drug1과 drug2를 분리
        drug1, drug2 = self.lt(drug1), self.lt(drug2)
        if self.comb_type == 'cat': # [(drug1), (drug2)]로 concat
            comb = torch.cat([drug1, drug2], dim=1) # (batch_size, hidden_dim * 2)
        elif self.comb_type == 'sum': # [(drug1) + (drug2)]
            comb = drug1 + drug2 # (batch_size, hidden_dim)
        elif self.comb_type == 'diff': # [(drug1) - (drug2)]
            comb = torch.abs(drug1 - drug2) # (batch_size, hidden_dim)
        elif self.comb_type == 'sumdiff': # [(drug1) + (drug2), (drug1) - (drug2)]로 concat
            comb = torch.cat([drug1 + drug2, torch.abs(drug1 - drug2)], dim=1) # (batch_size, hidden_dim * 2)
        return self.fc(comb)

In [9]:
def train(model, device, train_loader, criterion, optimizer, metric_list=[accuracy_score]):

    # train
    model.train()
    train_loss = 0

    target_list = []
    pred_list = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.float().to(device)
        optimizer.zero_grad()
        output = model(data).view(-1) # z
        # print(output)
        loss = criterion(output, target) # z, y
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        pred_list.append(torch.sigmoid(output).detach().cpu().numpy())
        target_list.append(target.long().detach().cpu().numpy())
    
    # metric
    scores = []
    for metric in metric_list:
        if metric == roc_auc_score:
            scores.append(metric(np.concatenate(target_list), np.concatenate(pred_list)))
        else: # accuracy_score, f1_score, precision_score, recall_score
            scores.append(metric(np.concatenate(target_list), np.concatenate(pred_list).round()))
    
    return train_loss / (batch_idx + 1), scores

In [10]:
def evaluate(model, device, loader, criterion, metric_list=[accuracy_score], checkpoint=None):
    # evaluate
    if checkpoint is not None:
        model.load_state_dict(torch.load(checkpoint))
    model.eval()
    eval_loss = 0

    target_list = []
    pred_list = []
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(device), target.float().to(device)
            output = model(data).view(-1)
            eval_loss += criterion(output, target).item()
            pred_list.append(torch.sigmoid(output).detach().cpu().numpy())
            target_list.append(target.long().detach().cpu().numpy())

    scores = []
    for metric in metric_list:
        if metric == roc_auc_score:
            scores.append(metric(np.concatenate(target_list), np.concatenate(pred_list)))
        else: # accuracy_score, f1_score, precision_score, recall_score
            scores.append(metric(np.concatenate(target_list), np.concatenate(pred_list).round()))
    return eval_loss / (batch_idx + 1), scores

In [11]:
input_dim = dataset[0][0].shape[0]
hidden_dim = input_dim
output_dim = 1
model = CombNet(input_dim, hidden_dim, output_dim, comb_type='cat')

In [12]:
EPOCHS = 50
LR = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)

In [13]:
best_valid_loss = float('inf')
for epoch in range(EPOCHS):
    # train_loss, train_acc = train(model, device, train_loader, criterion, optimizer)
    # valid_loss, valid_acc = evaluate(model, device, valid_loader, criterion)
    train_loss, train_scores = train(model, device, train_loader, criterion, optimizer, metric_list=[accuracy_score, roc_auc_score])
    valid_loss, valid_scores = evaluate(model, device, valid_loader, criterion, metric_list=[accuracy_score, roc_auc_score])

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'checkpoint.pt')
    #print(f'Epoch {epoch+1:03d}: | Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | Train AUROC: {train_auroc:.2f} || Val. Loss: {valid_loss:.4f} | Val. Acc: {valid_acc*100:.2f}% | Val. AUROC: {valid_auroc:.2f}')
    # print(f'Epoch {epoch+1:03d}: | Train Loss: {train_loss:.4f} | Train scores: {train_scores} || Val. Loss: {valid_loss:.4f} | Val. scores: {valid_scores}')
    print(f'Epoch {epoch+1:03d}: | Train Loss: {train_loss:.4f} | Train Acc: {train_scores[0]*100:.2f}% | Train AUROC: {train_scores[1]:.2f} || Val. Loss: {valid_loss:.4f} | Val. Acc: {valid_scores[0]*100:.2f}% | Val. AUROC: {valid_scores[1]:.2f}')

Epoch 001: | Train Loss: 0.5788 | Train Acc: 68.96% | Train AUROC: 0.76 || Val. Loss: 0.5067 | Val. Acc: 75.00% | Val. AUROC: 0.83
Epoch 002: | Train Loss: 0.4546 | Train Acc: 79.31% | Train AUROC: 0.87 || Val. Loss: 0.4676 | Val. Acc: 76.66% | Val. AUROC: 0.86
Epoch 003: | Train Loss: 0.4056 | Train Acc: 81.64% | Train AUROC: 0.90 || Val. Loss: 0.4751 | Val. Acc: 77.13% | Val. AUROC: 0.87
Epoch 004: | Train Loss: 0.3584 | Train Acc: 84.51% | Train AUROC: 0.92 || Val. Loss: 0.4783 | Val. Acc: 77.49% | Val. AUROC: 0.88
Epoch 005: | Train Loss: 0.3193 | Train Acc: 86.72% | Train AUROC: 0.94 || Val. Loss: 0.4678 | Val. Acc: 79.03% | Val. AUROC: 0.87
Epoch 006: | Train Loss: 0.3041 | Train Acc: 87.04% | Train AUROC: 0.94 || Val. Loss: 0.4564 | Val. Acc: 79.38% | Val. AUROC: 0.88
Epoch 007: | Train Loss: 0.2647 | Train Acc: 89.29% | Train AUROC: 0.96 || Val. Loss: 0.4903 | Val. Acc: 79.03% | Val. AUROC: 0.88
Epoch 008: | Train Loss: 0.2388 | Train Acc: 90.46% | Train AUROC: 0.97 || Val. Los

In [14]:
# test_loss, test_acc = evaluate(model, device, test_loader, criterion, checkpoint='model.pt')
# test_loss, test_acc, test_auroc = evaluate(model, device, test_loader, criterion, metric='both', checkpoint='model.pt')
test_loss, test_scores = evaluate(model, device, test_loader, criterion, metric_list=[accuracy_score, roc_auc_score], checkpoint='checkpoint.pt')
# print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc*100:.2f}% | Test AUROC: {test_auroc:.2f}')
# print(f'Test Loss: {test_loss:.4f} | Test scores: {test_scores}')
print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_scores[0]*100:.2f}% | Test AUROC: {test_scores[1]:.2f}')

Test Loss: 0.4742 | Test Acc: 77.87% | Test AUROC: 0.87
