In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from math import sqrt
import os

#properties 계산할 때 필요한 라이브러리
from rdkit import Chem
from rdkit.Chem import Descriptors, Crippen, Lipinski, rdMolDescriptors, Fragments

In [None]:
class CustomDataset(Dataset):
    def __init__(self, vector_data, targets):
        self.data = np.array(vector_data)
        self.targets = np.array(targets)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.float32), torch.tensor(self.targets[idx], dtype=torch.float32)

In [None]:
class CombinedDataset(Dataset):
    def __init__(self, csv_file):
        self.df = pd.read_csv(csv_file)
        self.smiles_data = self.df['Smiles'].tolist()
        # Assume vectors are stored as string representations of lists
        self.smiles_features = [np.array(eval(v)) for v in self.df['smiles_feature_vector']]
        self.image_features = [np.array(eval(v)) for v in self.df['image_feature_vector']]
        self.protein_features = [np.array(eval(v)) for v in self.df['target_protein_vector']]
        self.targets = self.df['Standard Value'].values
        self.properties_features = np.array([self.calculate_properties(smiles) for smiles in self.smiles_data])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        smiles_vector = torch.tensor(self.smiles_features[idx], dtype=torch.float32)
        image_vector = torch.tensor(self.image_features[idx], dtype=torch.float32)
        protein_vector = torch.tensor(self.protein_features[idx], dtype=torch.float32)
        properties_vector = torch.tensor(self.properties_features[idx], dtype=torch.float32)
        target = torch.tensor(self.targets[idx], dtype=torch.float32)
        return smiles_vector, image_vector, protein_vector, properties_vector, target

    @staticmethod
    def calculate_properties(smiles):
        mol = Chem.MolFromSmiles(smiles)
        
        properties = []
        
        properties.append(round(Descriptors.MolWt(mol), 6))  # 분자량 (Molecular Weight)
        properties.append(round(Crippen.MolLogP(mol), 6))  # Crippen의 방식으로 계산된 로그 P 값 (LogP)
        properties.append(round(Descriptors.TPSA(mol), 6))  # 극성 표면적 (Topological Polar Surface Area)
        properties.append(round(Lipinski.NumHAcceptors(mol), 6))  # 수소 수용체의 개수 (Number of Hydrogen Bond Acceptors)
        properties.append(round(Lipinski.NumHDonors(mol), 6))  # 수소 공여체의 개수 (Number of Hydrogen Bond Donors)
        properties.append(round(Lipinski.NumRotatableBonds(mol), 6))  # 회전 가능한 결합의 수 (Number of Rotatable Bonds)
        properties.append(round(Chem.GetFormalCharge(mol), 6))  # 분자의 형식적 전하 (Formal Charge)
        properties.append(round(rdMolDescriptors.CalcNumAtomStereoCenters(mol), 6))  # 원자 중심 입체 중심 수 (Number of Atom Stereocenters)
        properties.append(round(rdMolDescriptors.CalcFractionCSP3(mol), 6))  # 탄소 sp3 부분의 분율 (Fraction of sp3 Carbon Atoms)
        properties.append(round(Descriptors.NumAliphaticCarbocycles(mol), 6))  # 지방족 탄소고리의 수 (Number of Aliphatic Carbocycles)
        properties.append(round(Descriptors.NumAromaticRings(mol), 6))  # 방향족 고리의 수 (Number of Aromatic Rings)
        properties.append(round(Descriptors.NumHeteroatoms(mol), 6))  # 헤테로 원자의 수 (Number of Heteroatoms)
        properties.append(round(Fragments.fr_COO(mol), 6))  # 카복실산 작용기의 수 (Number of Carboxylic Acid Groups)
        properties.append(round(Fragments.fr_Al_OH(mol), 6))  # 알코올 작용기의 수 (Number of Aliphatic Alcohol Groups)
        properties.append(round(Fragments.fr_alkyl_halide(mol), 6))  # 알킬 할라이드 작용기의 수 (Number of Alkyl Halide Groups)
        properties.append(round(Descriptors.NumAromaticCarbocycles(mol), 6))  # 방향족 탄소고리의 수 (Number of Aromatic Carbocycles)
        properties.append(round(Fragments.fr_piperdine(mol), 6))  # 피페리딘 작용기의 수 (Number of Piperidine Groups)
        properties.append(round(Fragments.fr_methoxy(mol), 6))  # 메톡시 작용기의 수 (Number of Methoxy Groups)

        return properties

In [None]:
class Molecular_Properties_Model(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Molecular_Properties_Model, self).__init__()
        
        self.fc1 = nn.Linear(input_dim, 128)
        self.bn1 = nn.BatchNorm1d(128)
        self.relu1 = nn.ReLU()
        
        self.fc2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)
        self.relu2 = nn.ReLU()
        
        self.fc3 = nn.Linear(64, 32)
        self.bn3 = nn.BatchNorm1d(32)
        self.relu3 = nn.ReLU()
        
        self.fc4 = nn.Linear(32, output_dim)
        self.bn4 = nn.BatchNorm1d(output_dim)
    
    def forward(self, x):
        x = self.relu1(self.bn1(self.fc1(x)))
        x = self.relu2(self.bn2(self.fc2(x)))
        x = self.relu3(self.bn3(self.fc3(x)))
        x = self.bn4(self.fc4(x))
        return x

In [None]:
class RegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(RegressionModel, self).__init__()
        
        # 첫 번째 레이어 블록
        self.fc1 = nn.Linear(input_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.relu1 = nn.ReLU()
        self.drop1 = nn.Dropout(p=0.3)
        
        # 두 번째 레이어 블록
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu2 = nn.ReLU()
        self.drop2 = nn.Dropout(p=0.2)
        
        # 세 번째 레이어 블록
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.relu3 = nn.ReLU()
        self.drop3 = nn.Dropout(p=0.1)
        
        # 최종 출력 레이어
        self.fc4 = nn.Linear(128, output_dim)
    
    def forward(self, x):
        x = self.drop1(self.relu1(self.bn1(self.fc1(x))))
        x = self.drop2(self.relu2(self.bn2(self.fc2(x))))
        x = self.drop3(self.relu3(self.bn3(self.fc3(x))))
        x = self.fc4(x)
        return x


In [None]:
def train_and_validate(device, 
                       train_loader, 
                       val_loader,                       
                       molecular_model, 
                       regression_model,
                       num_epochs, 
                       lr,
                       save_path,
                       save_interval
                      ):

    molecular_optimizer = optim.AdamW(molecular_model.parameters(), lr)
    regression_optimizer = optim.AdamW(regression_model.parameters(), lr)
    loss_fn = nn.MSELoss()

    history = {'train_loss': [], 'val_loss': []}

    for epoch in range(num_epochs):
        molecular_model.train()
        regression_model.train()
        train_loss = 0

        for smiles_vec, img_vec, prot_vec, prop_vec, targets in train_loader:
            smiles_vec, img_vec, prot_vec, prop_vec, targets = smiles_vec.to(device), img_vec.to(device), prot_vec.to(device), prop_vec.to(device), targets.to(device)
            molecular_optimizer.zero_grad()
            regression_optimizer.zero_grad()

            molecular_output = molecular_model(prop_vec)
            combined_features = torch.cat([smiles_vec, img_vec, prot_vec, molecular_output], dim=1)

            predictions = regression_model(combined_features)
            loss = loss_fn(predictions, targets)
            loss.backward()
            molecular_optimizer.step()
            regression_optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)

        val_loss = validate(device, val_loader, molecular_model, regression_model, loss_fn)
        history['val_loss'].append(val_loss)

        print(f'Epoch {epoch + 1}: Training Loss = {train_loss:.4f}, Validation Loss = {val_loss:.4f}')

        if (epoch + 1) % save_interval == 0:
            checkpoint_path = f'{save_path}/checkpoint_epoch_{epoch+1}.pt'
            torch.save({
                'epoch': epoch,
                'molecular_model_state_dict': molecular_model.state_dict(),
                'regression_model_state_dict': regression_model.state_dict(),
                'molecular_optimizer_state_dict': molecular_optimizer.state_dict(),
                'regression_optimizer_state_dict': regression_optimizer.state_dict(),
                'loss': train_loss
            }, checkpoint_path)
            print(f'save checkpoint, {checkpint_path}')

    print('train complete')

    return pd.DataFrame(history)

def validate(molecular_model, regression_model, val_loader, device, loss_fn):
    molecular_model.eval()
    regression_model.eval()
    val_loss = 0
    with torch.no_grad():
        for properties, image_features, protein_features, targets in val_loader:
            properties, image_features, protein_features, targets = properties.to(device), image_features.to(device), protein_features.to(device), targets.to(device)
            feature_vector = molecular_model(properties)
            combined_features = torch.cat([feature_vector, image_features, protein_features], dim=1)
            predictions = regression_model(combined_features)
            loss = loss_fn(predictions, targets)
            val_loss += loss.item()

    val_loss /= len(val_loader)
    return val_loss

In [None]:
file_path = os.path.join('..','train_data')
train_file = os.path.join(file_path, 'train_data_img_smi_protein.csv')
val_file = os.path.join(file_path, 'val_data_img_smi_protein.csv')

train_dataset = Molecular_Properties_Dataset(train_file)
val_dataset = Molecular_Properties_Dataset(val_file)

In [None]:
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
molecular_model = Molecular_Properties_Model(input_dim=20, output_dim=32).to(device)
regression_model = RegressionModel(input_dim=32, output_dim=1).to(device)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
num_epochs = 200
lr = 0.001
save_interval = 10
save_path = '../ckpts/regression'

In [None]:
history_df = train_and_validate(device,
                                train_loader, 
                                val_loader,                                
                                molecular_model, 
                                regression_model,
                                num_epochs, 
                                lr,
                                save_path,
                                save_interval
                               )

history_df.to_csv('training_history.csv', index_label='Epoch')