In [1]:
import torch
import numpy as np
import pandas as pd
import os
import random
from pathlib import Path

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
SEED = 42
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(SEED)

### Prepare dataset
새로운 data에 대해서는 아래 class 새로 구현해야 함.

In [3]:
class CombinationDataset(Dataset):
    def __init__(self, database='DCDB', dimension='3D', neg_ratio=1, transform=None):
        '''
        Args:
            database (str): database 이름 (DCDB, C_DCDB)
            dimension (str): drug embedding dimension (3D, 12D)
            neg_ratio (int): negative sample의 비율. 1이면 positive sample과 동일, 2이면 2배, 3이면 3배
            transform (callable, optional): Optional transform to be applied on a sample.
        '''
        if (database != 'DCDB') & (database != 'C_DCDB'):
            raise ValueError('database must be DCDB or C_DCDB')
        if (dimension != '3D') & (dimension != '12D'):
            raise ValueError('dimension must be 3D or 12D')
        if neg_ratio < 1:
            raise ValueError('neg_ratio must be greater than 1')
        self.dimension = dimension
        self.database = database
        self.transform = transform
        self.neg_ratio = neg_ratio
        self.data_path = Path('data/processed')/f'{database}_{dimension}_neg{neg_ratio}.pt'
        if self.data_path.exists():
            self.data = torch.load(self.data_path)
        else:
            self._process()
            self.data = torch.load(self.data_path)
    
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]
    
    def _process(self):
        print('Processing dataset...')
        dataset_list = self._create_dataset()
        print(f'Saving dataset...{self.data_path}')
        torch.save(dataset_list, self.data_path)

    def _create_dataset(self):
        # Get drug embedding
        _embedding_filelist = os.listdir('data/embedding')
        embedding = [Path('data/embedding')/x for x in _embedding_filelist if self.dimension in x]
        # Concatenate drug embeddings of each network
        drug_embedding = pd.concat([pd.read_csv(x) for x in embedding], axis=1).to_numpy()
        print(f'Shape of concatenated drug embedding: {drug_embedding.shape}')
        # Create drug dictionary
        _drug_list = pd.read_csv('data/raw/drug_dict.txt', sep=':', header=None)[0].tolist()
        drug_id2idx = {_drug_list[i]:i for i in range(len(_drug_list))}
        drug_idx2id = {i:_drug_list[i] for i in range(len(_drug_list))}
        # Prepare positive labels
        pos_df = pd.read_csv(f'data/labels/{self.database}_deepdtnet.tsv', sep='\t')
        
        dataset_list = []
        # positive samples
        for i in range(len(pos_df)):
            drug1_id = pos_df.iloc[i, 0]
            drug2_id = pos_df.iloc[i, 1]
            comb_embedding = np.concatenate([drug_embedding[drug_id2idx[drug1_id]], drug_embedding[drug_id2idx[drug2_id]]])
            dataset_list.append([torch.tensor(comb_embedding, dtype=torch.float), torch.tensor(1, dtype=torch.long)])
        # negative samples
        n_drug = drug_embedding.shape[0]
        while len(dataset_list) < len(pos_df) * (1 + self.neg_ratio):
            drug1_id = drug_idx2id[np.random.randint(0, n_drug)]
            drug2_id = drug_idx2id[np.random.randint(0, n_drug)]
            if drug1_id == drug2_id:
                continue
            if ((pos_df['drug_1'] == drug1_id) & (pos_df['drug_2'] == drug2_id)).any():
                continue
            if ((pos_df['drug_1'] == drug2_id) & (pos_df['drug_2'] == drug1_id)).any():
                continue
            comb_embedding = np.concatenate([drug_embedding[drug_id2idx[drug1_id]], drug_embedding[drug_id2idx[drug2_id]]])
            dataset_list.append([torch.tensor(comb_embedding, dtype=torch.float), torch.tensor(0, dtype=torch.long)])
        return dataset_list

In [4]:
dataset = CombinationDataset(database='C_DCDB', dimension='12D', neg_ratio=1)
print(len(dataset))

Processing dataset...
Shape of concatenated drug embedding: (732, 72)
Saving dataset...data/processed/C_DCDB_12D.pt
3224


In [5]:
train_size = int(0.8 * len(dataset))
valid_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

### Training

In [7]:
class CombNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, comb_type='cat', dropout=0.1):
        super(CombNet, self).__init__()
        self.input_dim = input_dim # dimension of concatenated drug embeddings
        if (comb_type != 'cat') & (comb_type != 'sum') & (comb_type != 'diff') & (comb_type != 'sumdiff'):
            raise ValueError('comb_type must be cat, sum, diff or sumdiff')
        self.comb_type = comb_type
        self.lt = nn.Sequential(
            nn.Linear(input_dim // 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
        )
        dual_dim = hidden_dim if (comb_type == 'sum' or comb_type == 'diff') else hidden_dim * 2
        self.fc = nn.Sequential(
            nn.Linear(dual_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            # nn.Linear(hidden_dim, hidden_dim),
            # nn.BatchNorm1d(hidden_dim),
            # nn.ReLU(),
            # nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim) # for BCEWithLogitsLoss
        )
    def forward(self, data):
        drug1, drug2 = data[:, :self.input_dim//2], data[:, self.input_dim//2:] # drug1과 drug2를 분리
        drug1, drug2 = self.lt(drug1), self.lt(drug2)
        if self.comb_type == 'cat': # [(drug1), (drug2)]로 concat
            comb = torch.cat([drug1, drug2], dim=1) # (batch_size, hidden_dim * 2)
        elif self.comb_type == 'sum': # [(drug1) + (drug2)]
            comb = drug1 + drug2 # (batch_size, hidden_dim)
        elif self.comb_type == 'diff': # [(drug1) - (drug2)]
            comb = torch.abs(drug1 - drug2) # (batch_size, hidden_dim)
        elif self.comb_type == 'sumdiff': # [(drug1) + (drug2), (drug1) - (drug2)]로 concat
            comb = torch.cat([drug1 + drug2, torch.abs(drug1 - drug2)], dim=1) # (batch_size, hidden_dim * 2)
        return self.fc(comb)

In [8]:
def train(model, device, train_loader, criterion, optimizer, metric='accuracy'):
    # train
    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.float().to(device)
        optimizer.zero_grad()
        output = model(data).view(-1) # z
        # print(output)
        loss = criterion(output, target) # z, y
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if metric == 'accuracy':
            pred = torch.sigmoid(output).round()
            correct += pred.eq(target.view_as(pred)).sum().item()
        elif metric == 'auc':
            pass
    return train_loss / (batch_idx + 1), correct / len(train_loader.dataset)

In [9]:
def evaluate(model, device, loader, criterion, metric='accuracy', checkpoint=None):
    # evaluate
    if checkpoint is not None:
        model.load_state_dict(torch.load(checkpoint))
    model.eval()
    eval_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(device), target.float().to(device)
            output = model(data).view(-1)
            eval_loss += criterion(output, target).item()
            if metric == 'accuracy':
                pred = torch.sigmoid(output).round()
                correct += pred.eq(target.view_as(pred)).sum().item()
            elif metric == 'auc':
                pass
    return eval_loss / (batch_idx + 1), correct / len(loader.dataset)

In [10]:
input_dim = dataset[0][0].shape[0]
hidden_dim = input_dim
output_dim = 1
model = CombNet(input_dim, hidden_dim, output_dim, comb_type='cat')

In [11]:
EPOCHS = 200
LR = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)

In [12]:
best_valid_loss = float('inf')
for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, device, train_loader, criterion, optimizer)
    valid_loss, valid_acc = evaluate(model, device, valid_loader, criterion)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model.pt')
    print(f'Epoch {epoch+1:03d}: | Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f} Val. Loss: {valid_loss:.4f} | Val. Acc: {valid_acc*100:.2f}%')

Epoch 001: | Train Loss: 0.6801 | Train Acc: 58.36 Val. Loss: 0.6759 | Val. Acc: 54.97%
Epoch 002: | Train Loss: 0.6577 | Train Acc: 60.49 Val. Loss: 0.7416 | Val. Acc: 54.04%
Epoch 003: | Train Loss: 0.6331 | Train Acc: 64.13 Val. Loss: 0.6847 | Val. Acc: 56.21%
Epoch 004: | Train Loss: 0.6252 | Train Acc: 64.75 Val. Loss: 0.6740 | Val. Acc: 58.70%
Epoch 005: | Train Loss: 0.6054 | Train Acc: 65.88 Val. Loss: 0.6663 | Val. Acc: 56.52%
Epoch 006: | Train Loss: 0.6084 | Train Acc: 66.38 Val. Loss: 0.6821 | Val. Acc: 56.83%
Epoch 007: | Train Loss: 0.5862 | Train Acc: 68.05 Val. Loss: 0.6861 | Val. Acc: 61.80%
Epoch 008: | Train Loss: 0.5748 | Train Acc: 70.22 Val. Loss: 0.6747 | Val. Acc: 58.39%
Epoch 009: | Train Loss: 0.5697 | Train Acc: 70.03 Val. Loss: 0.6483 | Val. Acc: 61.49%
Epoch 010: | Train Loss: 0.5559 | Train Acc: 71.27 Val. Loss: 0.5812 | Val. Acc: 67.39%
Epoch 011: | Train Loss: 0.5425 | Train Acc: 73.05 Val. Loss: 0.5565 | Val. Acc: 67.39%
Epoch 012: | Train Loss: 0.5483 

In [13]:
test_loss, test_acc = evaluate(model, device, test_loader, criterion, checkpoint='model.pt')
print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.5904 | Test Acc: 70.59%


### Case study

In [14]:
drug_list = pd.read_csv('data/raw/drug_dict.txt', sep=':', header=None)[0].tolist() # list of drugs in DrugBank id
drug_id2idx = {drug_list[i]:i for i in range(len(drug_list))} # dictionary of drugs - key: id, value: index
drug_idx2id = {i:drug_list[i] for i in range(len(drug_list))} # dictionary of drugs - key: index, value: id

In [15]:
print(drug_id2idx['DB01098']) # rosuvastatin
print(drug_id2idx['DB01076']) # atorvastatin
print(drug_id2idx['DB01039']) # fenofibrate

486

To be added...