In [None]:
import sys

sys.path.append('../')

In [None]:
import os
from glob import glob

import pickle

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F

from torch_geometric.utils import from_smiles
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, Dataset
from torch_geometric.nn.models import DimeNet, DimeNetPlusPlus

from sklearn.preprocessing import StandardScaler

from rdkit import Chem
from rdkit.Chem import AllChem


In [None]:
import yaml

# YAML 파일 읽기
with open("../config/config.yaml", 'r') as file:
    config = yaml.load(file, Loader=yaml.FullLoader)

In [None]:
train = pd.read_csv('../data/train.csv', index_col='id')
test = pd.read_csv('../data/test.csv', index_col='id')

In [None]:
df = train

for col in df.select_dtypes('Int64').columns:
    df[col] = df[col].astype('category')

float_cols = df.select_dtypes('float64').columns
scaler = StandardScaler()
df[float_cols] = scaler.fit_transform(df[float_cols])

In [None]:
data_list = []

for index, row in df.iterrows():
    # SMILES 문자열에서 분자 객체 생성
    mol = Chem.MolFromSmiles(row['SMILES'])
    
    # 수소 원자 추가
    # mol = Chem.AddHs(mol)
    
    # 3D 구조 생성
    AllChem.EmbedMolecule(mol, AllChem.ETKDG())
    conf = mol.GetConformer()
    positions = conf.GetPositions()
    
    # 3D 위치 정보를 data 객체에 추가
    data = from_smiles(row['SMILES'])
    data.pos = torch.tensor(positions, dtype=torch.float)
    
    # 그래프 특성으로 화학적 특성 추가
    chem_features = torch.tensor([row['AlogP'], row['Molecular_Weight'], 
                                row['Num_H_Acceptors'], row['Num_H_Donors'], row['Num_RotatableBonds'], 
                                row['LogD'], row['Molecular_PolarSurfaceArea']], dtype=torch.float)
    data.graph_attr = chem_features
    
    # 타겟 레이블로 MLM과 HLM 설정
    data.y = torch.tensor([row['MLM'], row['HLM']], dtype=torch.float)
    
    data_list.append(data)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data_list):
        super().__init__(self)
        self.data_list = data_list

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data_list[idx]

In [None]:
dataset = CustomDataset(data_list)
test_loader = DataLoader(dataset, batch_size=1)

In [None]:
inputs = next(iter(test_loader))