In [None]:
"""

"""
import os
import json
from torch_geometric.data import DataLoader
from datetime import datetime
import random
import pandas as pd
from DeepHostGuest.utils.data import *
from DeepHostGuest.models import *
from torch.optim.lr_scheduler import MultiStepLR
from sklearn.model_selection import KFold
import shutil
import numpy as np

# set the random seeds for reproducibility
random.seed(1000)
np.random.seed(1000)
torch.cuda.manual_seed_all(1000)
torch.manual_seed(1000)

aug_fold = 10

host_ply_dir = '/path/to/host_ply'
host_mol_dir = '/path/to/mol'
guest_mol_dir = '/path/to/guest_mol'
train_dir = '/path/to/training'

os.makedirs(train_dir, exist_ok=True)

# The JSON file contains analysis data for 1,499 entries.
with open('./analysis_info.json', 'r') as f:
    check_info = json.load(f)

prefixes = [i for i in check_info.keys()]
print(len(prefixes))
random.shuffle(prefixes)
%matplotlib inline

In [None]:
# Prepare the training directories and files
valid_names = prefixes[-99:]

os.makedirs(train_dir, exist_ok=True)
os.makedirs(os.path.join(train_dir, 'host_ply'), exist_ok=True)
os.makedirs(os.path.join(train_dir, 'guest_mol'), exist_ok=True)
os.makedirs(os.path.join(train_dir, 'train_model'), exist_ok=True)
os.makedirs(os.path.join(train_dir, 'valid'), exist_ok=True)
os.makedirs(os.path.join(train_dir, 'valid', 'aug_structures'), exist_ok=True)
os.makedirs(os.path.join(train_dir, 'valid', 'structures'), exist_ok=True)

for prefix in prefixes:
    if prefix in valid_names:
        for i in range(aug_fold):
            shutil.copy(os.path.join(host_ply_dir, f'{prefix}_1_{i}.ply'),
                        os.path.join(train_dir, 'valid', 'aug_structures'))
            shutil.copy(os.path.join(guest_mol_dir, f'{prefix}_2_{i}.mol'),
                        os.path.join(train_dir, 'valid', 'aug_structures'))
        shutil.copy(os.path.join(host_ply_dir, f'{prefix}_1_0.ply'),
                    os.path.join(train_dir, 'valid', 'structures', f'{prefix}_1.ply'))
        shutil.copy(os.path.join(host_mol_dir, f'{prefix}_1_0.mol'),
                    os.path.join(train_dir, 'valid', 'structures', f'{prefix}_1.mol'))
        shutil.copy(os.path.join(guest_mol_dir, f'{prefix}_2_0.mol'),
                    os.path.join(train_dir, 'valid', 'structures', f'{prefix}_2.mol'))
    else:
        for i in range(aug_fold):
            shutil.copy(os.path.join(host_ply_dir, f'{prefix}_1_{i}.ply'),
                        os.path.join(train_dir, 'host_ply'))
            shutil.copy(os.path.join(guest_mol_dir, f'{prefix}_2_{i}.mol'),
                        os.path.join(train_dir, 'guest_mol'))

In [None]:
removeHs = False

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if removeHs:
    guest_model = LigandNet(13, edge_features=7, residual_layers=10, dropout_rate=0.10)
else:
    guest_model = LigandNet(14, edge_features=7, residual_layers=10, dropout_rate=0.10)
host_model = TargetNet(1, residual_layers=10, dropout_rate=0.10)
model = DeepDock(guest_model, host_model, hidden_dim=64, n_gaussians=10, dropout_rate=0.10, dist_threhold=10.).to(
    device)

lr = 0.001
epochs = 300
if removeHs:
    batch_size = 32
else:
    batch_size = 16
save_each = 25
aux_weight = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = MultiStepLR(optimizer, milestones=[50, 100, 150, 200, 250, 300, 350, 400, 450, 500], gamma=0.2)
losses = []

In [None]:
%%time

kfold = KFold(n_splits=10, shuffle=True, random_state=1000)
training_set_path = os.path.join(train_dir, 'host_ply')
training_names = [file for file in os.listdir(training_set_path) if file.endswith('.ply')]

n = len(training_names)

# train: test = 9 : 1
train_index, test_index = next(iter(kfold.split(np.arange(n))))

print(len(train_index), len(test_index))

db_complex = HostGuest_dataset(removeHs=removeHs,
                               root=train_dir)

db_complex_train = [db_complex[i] for i in train_index]
db_complex_test = [db_complex[i] for i in test_index]
print('Complexes in training set:', len(db_complex_train))
print('Complexes in test set:', len(db_complex_test))
loader_train = DataLoader(db_complex_train, batch_size=batch_size, shuffle=True)
loader_test = DataLoader(db_complex_test, batch_size=batch_size, shuffle=False)

now = datetime.now()
print(now.strftime("Start date: %d/%m/%Y at %H:%M:%S"))
# format: {dataset_name}_{epochs}_{batch_size}_{lr}_{aux_weight}
model_name = f'dist10_data1400_{removeHs}_{epochs}_{batch_size}_{lr}_{aux_weight}_mlr'
os.chdir(os.path.join(train_dir, 'train_model'))


def train():
    model.train()

    total_loss = 0
    mdn_loss = 0
    atom_loss = 0
    bond_loss = 0
    for data in loader_train:
        optimizer.zero_grad()
        target, ligand = data
        target, ligand = target.to(device), ligand.to(device)
        atom_labels = torch.argmax(ligand.x, dim=1, keepdim=False)
        bond_labels = torch.argmax(ligand.edge_attr, dim=1, keepdim=False)

        pi, sigma, mu, dist, atom_types, bond_types, batch = model(ligand, target)

        mdn = mdn_loss_fn(pi, sigma, mu, dist)
        mdn = mdn[torch.where(dist <= model.dist_threhold)[0]]
        mdn = mdn.mean()
        atom = F.cross_entropy(atom_types, atom_labels)
        bond = F.cross_entropy(bond_types, bond_labels)
        loss = mdn + (atom * aux_weight) + (bond * aux_weight)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * (ligand.batch.max().item() + 1)
        mdn_loss += mdn.item() * (ligand.batch.max().item() + 1)
        atom_loss += atom.item() * (ligand.batch.max().item() + 1)
        bond_loss += bond.item() * (ligand.batch.max().item() + 1)

        #print('Step, Total Loss: {:.3f}, MDN: {:.3f}'.format(total_loss, mdn_loss))
        if np.isinf(mdn_loss) or np.isnan(mdn_loss): break

    return total_loss / len(loader_train.dataset), mdn_loss / len(loader_train.dataset), atom_loss / len(
        loader_train.dataset), bond_loss / len(loader_train.dataset)


@torch.no_grad()
def test(dataset):
    model.eval()

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    total_loss = 0
    mdn_loss = 0
    atom_loss = 0
    bond_loss = 0
    for data in loader:
        target, ligand = data
        target, ligand = target.to(device), ligand.to(device)
        atom_labels = torch.argmax(ligand.x, dim=1, keepdim=False)
        bond_labels = torch.argmax(ligand.edge_attr, dim=1, keepdim=False)

        pi, sigma, mu, dist, atom_types, bond_types, batch = model(ligand, target)

        mdn = mdn_loss_fn(pi, sigma, mu, dist)
        mdn = mdn[torch.where(dist <= model.dist_threhold)[0]]
        mdn = mdn.mean()
        atom = F.cross_entropy(atom_types, atom_labels)
        bond = F.cross_entropy(bond_types, bond_labels)
        loss = mdn + (atom * aux_weight) + (bond * aux_weight)

        total_loss += loss.item() * (ligand.batch.max().item() + 1)
        mdn_loss += mdn.item() * (ligand.batch.max().item() + 1)
        atom_loss += atom.item() * (ligand.batch.max().item() + 1)
        bond_loss += bond.item() * (ligand.batch.max().item() + 1)

    return total_loss / len(loader.dataset), mdn_loss / len(loader.dataset), atom_loss / len(
        loader.dataset), bond_loss / len(loader.dataset)


prev_test_total_loss = 1000
for epoch in range(1, epochs + 1):
    total_loss, mdn_loss, atom_loss, bond_loss = train()
    if np.isinf(mdn_loss) or np.isnan(mdn_loss):
        print('Inf ERROR')
        break
    test_total_loss, test_mdn_loss, test_atom_loss, test_bond_loss = test(db_complex_test)
    scheduler.step()
    losses.append(
        [total_loss, mdn_loss, atom_loss, bond_loss, test_total_loss, test_mdn_loss, test_atom_loss,
         test_bond_loss])

    if test_mdn_loss <= prev_test_total_loss:
        prev_test_total_loss = test_total_loss
        torch.save(
            {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
             'rng_state': torch.get_rng_state(), 'total_loss': total_loss,
             'mdn_loss': mdn_loss, 'atom_loss': atom_loss, 'bond_loss': bond_loss},
            f'{model_name}_minTestLoss.chk')
    l = pd.DataFrame(losses,
                     columns=['total_loss', 'mdn_loss', 'atom_loss', 'bond_loss', 'test_total_loss',
                              'test_mdn_loss',
                              'test_atom_loss', 'test_bond_loss'])
    l.to_csv(f'dist7_{model_name}_loss.csv')

    print(
        'Epoch: {:03d}, Total Loss: {:.3f}, Valid Loss: {:.3f}'.format(epoch, total_loss, test_total_loss))

    if epoch % save_each == 0:
        torch.save(
            {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
             'rng_state': torch.get_rng_state(), 'total_loss': total_loss,
             'mdn_loss': mdn_loss, 'atom_loss': atom_loss, 'bond_loss': bond_loss},
            f'{model_name}_epoch_{epoch}.chk')
        l = pd.DataFrame(losses, columns=['total_loss', 'mdn_loss', 'atom_loss', 'bond_loss', 'test_total_loss',
                                          'test_mdn_loss', 'test_atom_loss', 'test_bond_loss'])
        l.to_csv(f'dist7_{model_name}_loss.csv')

torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
            'rng_state': torch.get_rng_state(), 'total_loss': total_loss,
            'mdn_loss': mdn_loss, 'atom_loss': atom_loss, 'bond_loss': bond_loss},
           f'{model_name}_epoch_{epoch}.chk')
l = pd.DataFrame(losses,
                 columns=['total_loss', 'mdn_loss', 'atom_loss', 'bond_loss', 'test_total_loss', 'test_mdn_loss',
                          'test_atom_loss', 'test_bond_loss'])
l.to_csv(f'{model_name}_loss.csv')
l[['total_loss', 'test_total_loss']].plot()