In [1]:
import pickle
from load_data import LoadData

In [2]:
import torch
import numpy as np
import pandas as pd
import os
import random
from pathlib import Path

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [3]:
SEED = 42
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(SEED)

## Prepare dataset

In [4]:
class CombinationDataset(Dataset):
    def __init__(self, database='DCDB', neg_ratio=1, transform=None):
        if (database != 'DCDB') & (database != 'C_DCDB'):
            raise ValueError('database must be DCDB or C_DCDB')
        if neg_ratio < 1:
            raise ValueError('neg_ratio must be greater than 1')
        self.database = database
        self.neg_ratio = neg_ratio
        self.transform = transform
        self.data_path = Path('data/processed')/f'{database}_neg{neg_ratio}.pt'
        if self.data_path.exists():
            self.data = torch.load(self.data_path)
        else:
            self._process()
            self.data = torch.load(self.data_path)
        
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]
    
    def _process(self):
        print('Processing dataset...')
        dataset_list = self._create_dataset()
        print(f'Saving dataset...{self.data_path}')
        torch.save(dataset_list, self.data_path)
    
    def _create_dataset(self):
        dataloader = LoadData()
        # Get embedding
        with open('data/embedding/embeddings_node2vec_msi_seed0.pkl', 'rb') as f:
            embedding_dict = pickle.load(f)
        # drug dictionary
        drug_id2name, drug_name2id = dataloader.get_dict(type='drug')
        # Prepare positive labels
        pos_df = pd.read_csv(f'data/labels/{self.database}_msi.tsv', sep='\t')

        dataset_list = []
        # positive samples
        for i in range(len(pos_df)):
            drug1_id = pos_df.iloc[i, 0]
            drug2_id = pos_df.iloc[i, 1]
            comb_embedding = np.concatenate([embedding_dict[drug1_id], embedding_dict[drug2_id]])
            dataset_list.append([torch.tensor(comb_embedding, dtype=torch.float), torch.tensor(1, dtype=torch.long)])
        # negative samples
        n_drug = len(drug_id2name)
        while len(dataset_list) < len(pos_df) * (1 + self.neg_ratio):
            drug1_id = random.choice(list(drug_id2name.keys()))
            drug2_id = random.choice(list(drug_id2name.keys()))
            if drug1_id == drug2_id:
                continue
            if ((pos_df['drug_1'] == drug1_id) & (pos_df['drug_2'] == drug2_id)).any():
                continue
            if ((pos_df['drug_1'] == drug2_id) & (pos_df['drug_2'] == drug1_id)).any():
                continue
            comb_embedding = np.concatenate([embedding_dict[drug1_id], embedding_dict[drug2_id]])
            dataset_list.append([torch.tensor(comb_embedding, dtype=torch.float), torch.tensor(0, dtype=torch.long)])
        return dataset_list

In [5]:
# dataset = CombinationDataset(database='DCDB', neg_ratio=1)
# print(len(dataset))

In [6]:
dataset = CombinationDataset(database='C_DCDB', neg_ratio=1)
print(len(dataset))

8442


In [7]:
train_size = int(0.8 * len(dataset))
valid_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

## Training

In [8]:
class CombNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, comb_type='cat', dropout=0.1):
        super(CombNet, self).__init__()
        self.input_dim = input_dim # dimension of concatenated drug embeddings
        if (comb_type != 'cat') & (comb_type != 'sum') & (comb_type != 'diff') & (comb_type != 'sumdiff'):
            raise ValueError('comb_type must be cat, sum, diff or sumdiff')
        self.comb_type = comb_type
        self.lt = nn.Sequential(
            nn.Linear(input_dim // 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
        )
        dual_dim = hidden_dim if (comb_type == 'sum' or comb_type == 'diff') else hidden_dim * 2
        self.fc = nn.Sequential(
            nn.Linear(dual_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            # nn.Linear(hidden_dim, hidden_dim),
            # nn.BatchNorm1d(hidden_dim),
            # nn.ReLU(),
            # nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim) # for BCEWithLogitsLoss
        )
    def forward(self, data):
        drug1, drug2 = data[:, :self.input_dim//2], data[:, self.input_dim//2:] # drug1과 drug2를 분리
        drug1, drug2 = self.lt(drug1), self.lt(drug2)
        if self.comb_type == 'cat': # [(drug1), (drug2)]로 concat
            comb = torch.cat([drug1, drug2], dim=1) # (batch_size, hidden_dim * 2)
        elif self.comb_type == 'sum': # [(drug1) + (drug2)]
            comb = drug1 + drug2 # (batch_size, hidden_dim)
        elif self.comb_type == 'diff': # [(drug1) - (drug2)]
            comb = torch.abs(drug1 - drug2) # (batch_size, hidden_dim)
        elif self.comb_type == 'sumdiff': # [(drug1) + (drug2), (drug1) - (drug2)]로 concat
            comb = torch.cat([drug1 + drug2, torch.abs(drug1 - drug2)], dim=1) # (batch_size, hidden_dim * 2)
        return self.fc(comb)

In [9]:
def train(model, device, train_loader, criterion, optimizer, metric='accuracy'):
    # train
    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.float().to(device)
        optimizer.zero_grad()
        output = model(data).view(-1) # z
        # print(output)
        loss = criterion(output, target) # z, y
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if metric == 'accuracy':
            pred = torch.sigmoid(output).round()
            correct += pred.eq(target.view_as(pred)).sum().item()
        elif metric == 'auc':
            pass
    return train_loss / (batch_idx + 1), correct / len(train_loader.dataset)

In [10]:
def evaluate(model, device, loader, criterion, metric='accuracy', checkpoint=None):
    # evaluate
    if checkpoint is not None:
        model.load_state_dict(torch.load(checkpoint))
    model.eval()
    eval_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(device), target.float().to(device)
            output = model(data).view(-1)
            eval_loss += criterion(output, target).item()
            if metric == 'accuracy':
                pred = torch.sigmoid(output).round()
                correct += pred.eq(target.view_as(pred)).sum().item()
            elif metric == 'auc':
                pass
    return eval_loss / (batch_idx + 1), correct / len(loader.dataset)

In [11]:
input_dim = dataset[0][0].shape[0]
hidden_dim = input_dim
output_dim = 1
model = CombNet(input_dim, hidden_dim, output_dim, comb_type='cat')

In [12]:
EPOCHS = 100
LR = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)

In [13]:
best_valid_loss = float('inf')
for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, device, train_loader, criterion, optimizer)
    valid_loss, valid_acc = evaluate(model, device, valid_loader, criterion)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model.pt')
    print(f'Epoch {epoch+1:03d}: | Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f} Val. Loss: {valid_loss:.4f} | Val. Acc: {valid_acc*100:.2f}%')

Epoch 001: | Train Loss: 0.5848 | Train Acc: 69.02 Val. Loss: 0.5315 | Val. Acc: 73.58%
Epoch 002: | Train Loss: 0.4578 | Train Acc: 78.37 Val. Loss: 0.4970 | Val. Acc: 76.07%
Epoch 003: | Train Loss: 0.4012 | Train Acc: 82.32 Val. Loss: 0.5028 | Val. Acc: 74.76%
Epoch 004: | Train Loss: 0.3547 | Train Acc: 84.61 Val. Loss: 0.4409 | Val. Acc: 80.33%
Epoch 005: | Train Loss: 0.3227 | Train Acc: 85.98 Val. Loss: 0.4586 | Val. Acc: 79.50%
Epoch 006: | Train Loss: 0.3025 | Train Acc: 87.47 Val. Loss: 0.5233 | Val. Acc: 76.18%
Epoch 007: | Train Loss: 0.2659 | Train Acc: 88.86 Val. Loss: 0.5289 | Val. Acc: 76.78%
Epoch 008: | Train Loss: 0.2418 | Train Acc: 90.11 Val. Loss: 0.5206 | Val. Acc: 77.37%
Epoch 009: | Train Loss: 0.2119 | Train Acc: 91.68 Val. Loss: 0.5008 | Val. Acc: 78.32%
Epoch 010: | Train Loss: 0.1975 | Train Acc: 92.11 Val. Loss: 0.5505 | Val. Acc: 78.79%
Epoch 011: | Train Loss: 0.1831 | Train Acc: 92.89 Val. Loss: 0.6089 | Val. Acc: 76.66%
Epoch 012: | Train Loss: 0.1633 

In [14]:
test_loss, test_acc = evaluate(model, device, test_loader, criterion, checkpoint='model.pt')
print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.4129 | Test Acc: 81.07%
