In [1]:
from argparse import ArgumentParser
nets = ["alexnet", "vgg19_bn", "resnet152", "densenet201",  "resnet152_extended", "alexnet_extended", "densenet201_extended", "resnet34","resnet34_extended", "densenet121_extended"]
tox = ["IGC50", "IBC50", "LC50", "LC50DM"]
argparser = ArgumentParser()
args_group = argparser.add_argument_group(title='Running args')
args_group.add_argument('-seed', type=int, help='Seed used to split the data into training, validation and test sets', required=False, default=1)
args_group.add_argument('-task', type=str, help='Set this argument to train in order to train the network, or to predict to load a pretrained model', required=False, default="train",choices=['train','predict'])
args_group.add_argument('-architecture', type=str, help='ConvNet to be used', required=False, choices=nets, default="vgg19_bn")
args_group.add_argument('-lr', type=float, help='Maximum learning rate value. Please note that optimal values vary across architectures', required=False, default=0.01)
args_group.add_argument('-step_size_lr_decay', type=int, help='Step size to decrease the learning rate by a given factor (parameter drop_factor_lr)', required=False, default=25)
args_group.add_argument('-drop_factor_lr', type=float, help='The learning rate is reduced by the factor indicated in this argument', required=False, default=0.6 )
args_group.add_argument('-batch_size', type=int, help='Batch size', required=False, default=16)
args_group.add_argument('-data_augmentation', type=int, help='Whether data augmentation should be applied to the validation and training sets (1: yes; 0: no)', required=False, default=1,choices=[0,1] )
args_group.add_argument('-nb_epochs_training_per_cycle', type=int, help='Number of epochs for each learning rate annealing cycle', required=False, default=200)
args_group.add_argument('-nb_epochs_training', type=int, help='Number of epochs to be considered for training', required=False, default=600)
args_group.add_argument('-epochs_early_stop', type=int, help='Number of epochs for early stopping', required=False, default=250)
args_group.add_argument('-load_dict', type=list, help='load_dict', required=False, nargs='+', default=[])
args_group.add_argument('-Tox', type=str, help='Tox', required=False, default=0)
args_group.add_argument('-cv', type=int, help='cv', required=False, default=5)
args = argparser.parse_args(['-batch_size', '16'
                             , "-architecture"
                             , "densenet201_extended"
                             ,'-load_dict','0','0','0','0','0'
                             ,'-nb_epochs_training', '300'
                             ,'-lr', '0.01'
                             ,'-nb_epochs_training_per_cycle', '100'
                             ,'-Tox', 'IGC50'
                             ,'-cv', '5'])


seed = args.seed
net=args.architecture
lr=args.lr

import torch
from sklearn.metrics import mean_squared_error
from math import sqrt
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models
import os, glob, time
import copy
import numpy as np
import scipy
import os,sys, os.path
from collections import defaultdict
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
import rdkit.rdBase
from rdkit import DataStructs
from rdkit.DataStructs import BitVectToText
from rdkit import DataStructs
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
from IPython.display import SVG

from IPython.core.display import SVG
from torch.autograd import Variable
import multiprocessing

from torchvision import transforms
import pandas as pd
from rdkit.Chem import Draw
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
from IPython import display

import torch.utils.data as data
from PIL import Image
import os
import os.path

def default_loader(path):
    return Image.open(path).convert('RGB')

def default_flist_reader(flist):
    imlist = []
    with open(flist, 'r') as rf:
        for line in rf.readlines():
            impath, imlabel = line.strip().split()
            imlist.append( (impath, int(imlabel)) )
    
    return imlist

class ImageFilelist(data.Dataset):
    def __init__(self,  paths_labels, transform=None, target_transform=None,
        flist_reader=default_flist_reader, loader=default_loader):
        self.imlist = paths_labels
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        
    def __getitem__(self, index):
        impath, target = self.imlist[index]
        img = self.loader(impath)
        if self.transform is not None:
           img = self.transform(img)
        if self.target_transform is not None:
           target = self.target_transform(target)
   
        return img, target
        
    def __len__(self):
        return len(self.imlist)

mols_train = []
mols_test = []
suppl_train = Chem.SDMolSupplier("./data/{}/{}_training.sdf".format(args.Tox, args.Tox))
suppl_pre = Chem.SDMolSupplier("./data/{}/{}_prediction.sdf".format(args.Tox, args.Tox))
for mol in suppl_train:
    mols_train.append(mol)
for mol in suppl_pre:
    mols_test.append(mol)
print(len(mols_train),len(mols_test))
if len(suppl_train)+len(suppl_pre) == len(mols_train)+len(mols_test):
    print("mols ready")

my_smiles_train=[Chem.MolToSmiles(submol) for submol in mols_train]
my_smiles_test=[Chem.MolToSmiles(submol) for submol in mols_test]
chembl_ids_train=[m.GetProp("CAS") for m in mols_train]
chembl_ids_test=[m.GetProp("CAS") for m in mols_test]
activities_train =[float(m.GetProp("Tox")) for m in mols_train]
activities_test =[float(m.GetProp("Tox")) for m in mols_test]

if(len(my_smiles_train)+len(my_smiles_test) != len(activities_train)+len(activities_test)):
    raise "The number of compounds does not correspond to the number of bioactivities"

base_indices = np.arange(0,len(activities_train))
np.random.seed(seed)
np.random.shuffle(base_indices)
np.random.seed(seed)
np.random.shuffle(base_indices)

#-----------------------------------------------
# data augmentation
#-----------------------------------------------

if args.data_augmentation == 1:
    transforms = {
            'train': transforms.Compose([
                transforms.Resize(224),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(degrees=90),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ]),
            'val': transforms.Compose([
                transforms.Resize(224),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(degrees=90),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ]),
            'test': transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ]),
            }
else:
    transforms = {
            'train': transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ]),
            'val': transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ]),
            'test': transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ]),
            }


#------------------------------------
# Data loaders
#------------------------------------

paths_labels_train=[]
for i,x in enumerate(activities_train):
    path_now = './images/{}/train/{}.png'.format(args.Tox, chembl_ids_train[i])
    now = (path_now , x)
    paths_labels_train.append(now)

paths_labels_test=[]
for i,x in enumerate(activities_test):
    path_now = './images/{}/test/{}.png'.format(args.Tox, chembl_ids_test[i])
    now = (path_now , x)
    paths_labels_test.append(now)                                               
     

workers=multiprocessing.cpu_count()
workers = 0
shuffle=False


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
#-----------------------------------------------
# Training the model
#-----------------------------------------------


def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    print("-"*20)
    print("Strat Training {} {} cv: {}".format('Aquatox', args.Tox, cv+1))
    print(net)
    print("data_augmentation: {}".format(args.data_augmentation))
    print("batch_size: {}".format(args.batch_size))
    start_epoch = -1
    best_epoch = 0
    load_epoch = int(args.load_dict[cv][0])
    if  load_epoch > 0:
        checkpoint = torch.load("./models/checkpoint/{}/{}/{}/{}_{}.pth".format(args.Tox,args.seed,net,load_epoch, cv))  # 加载断点
        print("strat from",checkpoint["epoch"],"epoch")
        
        best_epoch = load_epoch
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        scheduler.load_state_dict(checkpoint['lr_schedule'])
        net_dict = copy.deepcopy(checkpoint['net'])
        model.load_state_dict(net_dict)
        del net_dict,checkpoint
        
    print("-"*20)
    model.cuda()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1000.0
    best_r2 = -10000
    early = 0
    start_time = time.time()
    for epoch in range(start_epoch+1, num_epochs):
        time_epoch = time.time()

        # cyclical learning rate
        if early % args.nb_epochs_training_per_cycle == 0:
            optimizer = optim.SGD(model.parameters(), lr=args.lr)
            scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size_lr_decay, gamma=args.drop_factor_lr)
        
        print("-"*20)
        print('Epoch {}/{} early:{}'.format(epoch, num_epochs - 1, early))

        for phase in ['train', 'val']:
            epoch_losses=0.0
            deno=0.0
            if phase == 'train':
#                 scheduler.step()
                model.train()  

                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.cuda()
                    labels = labels.cuda()
                    labels = labels.type(torch.FloatTensor)
                    optimizer.zero_grad()

                    # forward
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        preds=outputs.squeeze(1)
                        preds = preds.type(torch.FloatTensor)
                        loss = criterion(preds, labels)

                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    del inputs, outputs, labels
                    epoch_losses += loss.item() * len(preds)  #data[0]
                    deno +=len(preds)
                    del preds
                    
                epoch_loss = epoch_losses / deno
                train_loss.append(epoch_loss)
                print('{} Loss: {:.4f} {}'.format(phase, epoch_loss, deno))
                
            if phase == 'val':
                model.eval()
                pred=[]
                obs=[]
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.cuda()
                    labels = labels.cuda()
                    labels = labels.type(torch.FloatTensor)
                    optimizer.zero_grad()
            
                    outputs = model(inputs)
                    for i in range(len(labels)):
                        pred.append(float(outputs.data[i]))
                        obs.append(float(labels.data[i]))
                
                    del inputs, outputs, labels

                mse = mean_squared_error(obs, pred)
                r2 = r2_score(obs, pred)
                lr = optimizer.state_dict()['param_groups'][0]['lr']
                print('val Loss: {:.4f} r^2: {:.4f} lr: {:.4f}'.format(mse, r2, lr))
                lr_decay.append(lr)
                scores_mse.append(mse)
                scores_r2.append(r2)
                
            #torch.cuda.empty_cache()

            # deep copy the model
            if phase == 'val' and r2 > best_r2:
                best_r2 = r2
                best_model_wts = copy.deepcopy(model.state_dict())
                early=0
                
                kf_pred_all.iloc[index_val, 0] = pred
                kf_pred_all.iloc[index_val, 1] = obs
                
                print('save at:',epoch)
                checkpoint = {
                            "net": model.state_dict(),
                            'optimizer':optimizer.state_dict(),
                            "epoch": epoch,
                            "lr_schedule": scheduler.state_dict()
                            }
                if not os.path.isdir("./models/checkpoint/{}/{}/{}".format(args.Tox,args.seed,net)):
                    os.makedirs("./models/checkpoint/{}/{}/{}".format(args.Tox,args.seed,net),exist_ok=True)
                torch.save(checkpoint, "./models/checkpoint/{}/{}/{}/{}_{}.pth".format(args.Tox,args.seed,net,epoch, cv))
                if epoch >0:
                    os.remove("./models/checkpoint/{}/{}/{}/{}_{}.pth".format(args.Tox,args.seed,net,best_epoch, cv))
                best_epoch = epoch
                
        
            if phase == 'val' and r2 < best_r2:
                early+=1
      
            if phase == 'train' and early>20:
                scheduler.step()

        print('Epoch complete in {:.0f}m {:.0f}s'.format( (time.time() - time_epoch) // 60, (time.time() - time_epoch) % 60))

    time_elapsed = time.time() - start_time
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best r2: {:4f}'.format(best_r2))
    print("-"*20)
    print("end")
    print("-"*20)
    model.load_state_dict(best_model_wts)
    return model

def train_test(model, criterion, optimizer, scheduler, num_epochs=25):
    print("-"*20)
    print("Strat Training Test {} {}".format('Aquatox', args.Tox))
    print(net)
    print("data_augmentation: {}".format(args.data_augmentation))
    print("batch_size: {}".format(args.batch_size))
    start_epoch = -1
    best_epoch = 0
    print("-"*20)
    model.cuda()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1000.0
    best_r2 = -10000
    early = 0
    start_time = time.time()
    for epoch in range(start_epoch+1, num_epochs):
        time_epoch = time.time()

        if early % args.nb_epochs_training_per_cycle == 0:
            optimizer = optim.SGD(model.parameters(), lr=args.lr)
            scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size_lr_decay, gamma=args.drop_factor_lr)
        
        print("-"*20)
        print('Epoch {}/{} early:{}'.format(epoch, num_epochs - 1, early))

        for phase in ['train', 'val']:
            epoch_losses=0.0
            deno=0.0
            if phase == 'train':
#                 scheduler.step()
                model.train()  

                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.cuda()
                    labels = labels.cuda()
                    labels = labels.type(torch.FloatTensor)
                    optimizer.zero_grad()

                    # forward
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        preds=outputs.squeeze(1)
                        preds = preds.type(torch.FloatTensor)
                        loss = criterion(preds, labels)

                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    del inputs, outputs, labels
                    epoch_losses += loss.item() * len(preds)  #data[0]
                    deno +=len(preds)
                    del preds
                    
                epoch_loss = epoch_losses / deno
                train_loss.append(epoch_loss)
                print('{} Loss: {:.4f} {}'.format(phase, epoch_loss, deno))
                
            if phase == 'val':
                model.eval()
                pred=[]
                obs=[]
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.cuda()
                    labels = labels.cuda()
                    labels = labels.type(torch.FloatTensor)
                    optimizer.zero_grad()
            
                    outputs = model(inputs)
                    for i in range(len(labels)):
                        pred.append(float(outputs.data[i]))
                        obs.append(float(labels.data[i]))
                
                    del inputs, outputs, labels

                mse = mean_squared_error(obs, pred)
                r2 = r2_score(obs, pred)
                lr = optimizer.state_dict()['param_groups'][0]['lr']
                print('val Loss: {:.4f} r^2: {:.4f} lr: {:.4f}'.format(mse, r2, lr))
                lr_decay.append(lr)
                scores_mse.append(mse)
                scores_r2.append(r2)
                
            #torch.cuda.empty_cache()

            if phase == 'val' and r2 > best_r2:
                best_r2 = r2
                best_model_wts = copy.deepcopy(model.state_dict())
                early=0
                
                test_pred.iloc[:, 0] = pred
                test_pred.iloc[:, 1] = obs
                
                print('save at:',epoch)
        
            if phase == 'val' and r2 < best_r2:
                early+=1            
                
            if phase == 'train' and early>20:
                scheduler.step()

        print('Epoch complete in {:.0f}m {:.0f}s'.format( (time.time() - time_epoch) // 60, (time.time() - time_epoch) % 60))
    

    time_elapsed = time.time() - start_time
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best r2: {:4f}'.format(best_r2))
    print("-"*20)
    print("end")
    print("-"*20)
    model.load_state_dict(best_model_wts)
    return model

#-----------------------------------------------
# Architectures
#-----------------------------------------------

def modelselect(net):
    if net not in nets:
        raise "The selected architecture is not available"

    if net == "alexnet": 
        model_ft = models.alexnet(pretrained=False)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, 1)
        optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr)
        model_ft = model_ft.to(device)

    if net == "alexnet_extended": 
        model_ft = models.alexnet(pretrained=False)
        modules=[]
        modules.append( nn.Linear(in_features=9216, out_features=4096, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=4096, out_features=1000, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=1000, out_features=200, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=200, out_features=100, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=100, out_features=1, bias=True) )
        classi = nn.Sequential(*modules)
        model_ft.classifier = classi
        optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr)
        model_ft = model_ft.to(device)

    if net == "densenet201": 
        model_ft = models.densenet201(pretrained=False)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, 1)
        optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr)#, momentum=0.95) #, nesterov=True)
        model_ft = model_ft.to(device)

    if net == "densenet201_extended": 
        model_ft = models.densenet201(pretrained=False)
        modules=[]
        modules.append( nn.Linear(in_features=1920, out_features=4096, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=4096, out_features=1000, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=1000, out_features=200, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=200, out_features=100, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=100, out_features=1, bias=True) )
        classi = nn.Sequential(*modules)
        model_ft.classifier = classi
        optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr)
        model_ft = model_ft.to(device)

    if net == "vgg19_bn": 
        model_ft = models.vgg19_bn(pretrained=False)
        modules=[]
        modules.append( nn.Linear(in_features=25088, out_features=4096, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=4096, out_features=1000, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=1000, out_features=200, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=200, out_features=100, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=100, out_features=1, bias=True) )
        classi = nn.Sequential(*modules)
        model_ft.classifier = classi
        optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr)#, momentum=0.95) #, nesterov=True)
        model_ft = model_ft.to(device)

    if net == "resnet152": 
        model_ft = models.resnet152(pretrained=False)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 1)
        optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr)#, momentum=0.95) #, nesterov=True)
        model_ft = model_ft.to(device)

    if net == "resnet152_extended": 
        model_ft = models.resnet152(pretrained=False)
        modules=[]
        modules.append( nn.Linear(in_features=2048, out_features=4096, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=4096, out_features=1000, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=1000, out_features=200, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=200, out_features=100, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=100, out_features=1, bias=True) )
        classi = nn.Sequential(*modules)
        model_ft.fc = classi
        optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr)
        model_ft = model_ft.to(device)

    if net == "resnet34_extended": 
        model_ft = models.resnet152(pretrained=False)
        modules=[]
        modules.append( nn.Linear(in_features=2048, out_features=4096, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=4096, out_features=1000, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=1000, out_features=200, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=200, out_features=100, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=100, out_features=1, bias=True) )
        classi = nn.Sequential(*modules)
        model_ft.fc = classi
        optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr)
        model_ft = model_ft.to(device)

    if net == "densenet121_extended": 
        model_ft = models.densenet121(pretrained=False)
        modules=[]
        modules.append( nn.Linear(in_features=1024, out_features=2048, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=2048, out_features=1000, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=1000, out_features=200, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=200, out_features=100, bias=True) )
        modules.append( nn.ReLU(inplace=True) )
        modules.append( nn.Dropout(p=0.5) )
        modules.append( nn.Linear(in_features=100, out_features=1, bias=True) )
        classi = nn.Sequential(*modules)
        model_ft.classifier = classi
        optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr)
        model_ft = model_ft.to(device)

    return model_ft

for net in ["vgg19_bn", "resnet152_extended","densenet201_extended"]:

    #-----------------------------------------------
    # Training
    #-----------------------------------------------
    r2_all = []
    lr_all = []
    mse_all = []
    kf_pred_all = pd.DataFrame(index=range(len(mols_train)), columns=["cv1~5", "True"])
    test_score_all = pd.DataFrame(index=list(range(len(mols_test)))+["r2","MSE"], columns=["cv1", "cv2", "cv3", "cv4", "cv5", "True"])
    train_loss_all = []
    if args.task=="train":
        index = base_indices
        step = int(len(index)/args.cv)
        for cv in range(args.cv):
            model_ft = modelselect(net)
            optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr)
            scores_r2 = []
            lr_decay = []
            scores_mse = []
            train_loss = []
            if cv < args.cv-1:
                index_train = np.concatenate([index[:cv*step],index[(cv+1)*step:]], axis=0)
                index_val = index[cv*step:(cv+1)*step]
            else: 
                index_train = index[0:cv*step]
                index_val = index[cv*step:]

            paths_labels_train_train = []
            for i in index_train:
                paths_labels_train_train.append(paths_labels_train[i])

            paths_labels_train_val = []
            for i in index_val:
                paths_labels_train_val.append(paths_labels_train[i])
                
            trainloader = torch.utils.data.DataLoader(
                        ImageFilelist(paths_labels= paths_labels_train_train,
                        transform=transforms['train']),
                        batch_size=args.batch_size, shuffle=shuffle,
                        num_workers=workers) 

            valloader = torch.utils.data.DataLoader(
                        ImageFilelist(paths_labels= paths_labels_train_val,
                        transform=transforms['val']),
                        batch_size=args.batch_size, shuffle=shuffle,
                        num_workers=workers) 

            testloader = torch.utils.data.DataLoader(
                        ImageFilelist(paths_labels= paths_labels_test,
                        transform=transforms['test']),
                        batch_size=args.batch_size, shuffle=shuffle,
                        num_workers=workers) 

            dataloaders = {'train': trainloader, 'val':valloader, 'test':testloader}
            criterion = torch.nn.MSELoss()
            exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=args.step_size_lr_decay, gamma=args.drop_factor_lr)
            model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=args.nb_epochs_training)

            pred = []
            obs = []
            for inputs, labels in dataloaders['test']:
                inputs = inputs.cuda()
                labels = labels.cuda()
                model_ft.cuda()
                labels = labels.type(torch.FloatTensor)
                outputs = model_ft(inputs)
                for i in range(len(labels)):
                    pred.append(float(outputs.data[i]))
                    obs.append(float(labels.data[i]))

                del inputs, outputs, labels

            mse = mean_squared_error(obs, pred)
            r2 = r2_score(obs, pred)
            test_score_all.iloc[:len(pred),cv] = pred
            test_score_all.iloc[:len(obs),args.cv] = obs
            test_score_all.iloc[len(pred),cv] = r2
            test_score_all.iloc[len(pred)+1,cv] = mse
            print('test Loss: {:.4f} r^2: {}'.format(mse, r2))


            r2_all.append(scores_r2)
            lr_all.append(lr_decay)
            mse_all.append(scores_mse)
            train_loss_all.append(train_loss)
        
        os.makedirs("./results/{}/DL/{}/{}/train".format(args.Tox, args.seed,net), exist_ok=True)
        os.makedirs("./results/{}/DL/{}/{}/test".format(args.Tox, args.seed,net), exist_ok=True)
        pd.DataFrame(r2_all).T.to_csv("./results/{}/DL/{}/{}/train/r2_all.csv".format(args.Tox, args.seed,net),index=None)
        pd.DataFrame(lr_all).T.to_csv("./results/{}/DL/{}/{}/train/lr_all.csv".format(args.Tox, args.seed,net),index=None)
        pd.DataFrame(mse_all).T.to_csv("./results/{}/DL/{}/{}/train/mse_all.csv".format(args.Tox, args.seed,net),index=None)
        pd.DataFrame(kf_pred_all).to_csv("./results/{}/DL/{}/{}/train/kf_pred_all.csv".format(args.Tox, args.seed,net),index=None)
        pd.DataFrame(train_loss_all).T.to_csv("./results/{}/DL/{}/{}/train/train_loss_all.csv".format(args.Tox, args.seed,net),index=None)
        pd.DataFrame(test_score_all).to_csv("./results/{}/DL/{}/{}/test/test_score_all.csv".format(args.Tox, args.seed,net),index=None)
        
    ## use all train to test
    scores_r2 = []
    lr_decay = []
    scores_mse = []
    train_loss = []
    test_pred = pd.DataFrame(index=range(len(activities_test)), columns=["pred", "True"])

    trainloader = torch.utils.data.DataLoader(
                ImageFilelist(paths_labels= paths_labels_train,
                transform=transforms['train']),
                batch_size=args.batch_size, shuffle=shuffle,
                num_workers=workers) 

    valloader = torch.utils.data.DataLoader(
                ImageFilelist(paths_labels= paths_labels_test,
                transform=transforms['val']),
                batch_size=args.batch_size, shuffle=shuffle,
                num_workers=workers) 

    dataloaders = {'train': trainloader, 'val':valloader}

    model_ft = modelselect(net)
    criterion = torch.nn.MSELoss()
    optimizer_ft = optim.Adam(model_ft.parameters(), lr=args.lr)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=args.step_size_lr_decay, gamma=args.drop_factor_lr)
    model_ft = train_test(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=args.nb_epochs_training)

    pd.DataFrame(scores_r2).to_csv("./results/{}/DL/{}/{}/test/scores_r2.csv".format(args.Tox, args.seed,net),index=None)
    pd.DataFrame(lr_decay).to_csv("./results/{}/DL/{}/{}/test/lr_decay.csv".format(args.Tox, args.seed,net),index=None)
    pd.DataFrame(scores_mse).to_csv("./results/{}/DL/{}/{}/test/scores_mse.csv".format(args.Tox, args.seed,net),index=None)
    pd.DataFrame(train_loss).to_csv("./results/{}/DL/{}/{}/test/train_loss.csv".format(args.Tox, args.seed,net),index=None)
    pd.DataFrame(test_pred).to_csv("./results/{}/DL/{}/{}/test/test_pred.csv".format(args.Tox, args.seed,net),index=None)