メモ：RMSDを比較することで、良い特徴量だけを選び出す。さらに、それらの良い特徴量だけで計算を行い、全部入りの計算結果と比較して性能差がどれくらいあるか調べる。性能差がそこまで変わらず、計算時間が減ったら嬉しい。KipfらによるSpatial GCNを用いる。

In [1]:
import pandas as pd
import pickle
import rdkit
from rdkit import Chem, RDLogger
from rdkit.Chem import PandasTools
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem.rdchem import HybridizationType
import rdkit.Chem.AllChem as AllChem
import torch
import torch_geometric as pyg
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.utils import one_hot, scatter
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Dataset, InMemoryDataset, download_url, extract_zip
from torch_geometric.datasets import QM9
from torch_geometric.nn import global_add_pool
from torch_geometric.nn.conv import GCNConv
from torch_geometric.nn.models import MLP
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import one_hot, scatter
from torch.utils.data import random_split
import matplotlib.pyplot as plt
import japanize_matplotlib
import numpy as np
import time
import pandas as pd
import os
import os.path as osp
import pickle
import sys
import shutil
from typing import Callable, List, Optional
import tqdm
from math import sqrt as sqrt

In [None]:
qm9csv = pd.read_csv("/home/higuchi/Pytorch/GCN/QM9/raw/gdb9.sdf.csv")
sdf = "/home/higuchi/Pytorch/GCN/QM9/raw/gdb9.sdf"
#mols = rdkit.Chem.SDMolSupplier(sdf, removeHs=False) #sdfからmolオブジェクトを生成
#mols = [m for m in mols if m is not None]
with open("mols_unprocessed", "rb") as f:
    mols = pickle.load(f)

# ETKDG Process (Mettya zikan kakaru...)
mols_ETKDG = []
mols_Hs = [rdkit.Chem.AddHs(mol) for mol in mols]
for mol in mols_Hs:
    ETKDG = AllChem.ETKDG()
    AllChem.EmbedMolecule(mol, ETKDG)
    mols_ETKDG.append(mol)

In [None]:
# 3D structure generation with DeepChem (Hayai! approx. 2m)
import deepchem
featurizer = deepchem.feat.Mol2VecFingerprint()
mols_features = []
for mol in mols_Hs:
    features = featurizer.featurize(rdkit.Chem.MolToSmiles(mol))
    mols_features.append(features)

Setting:
layer=3, hidden_layer=64
epoch_num=100

In [5]:
class GCN(torch.nn.Module):
    def __init__(self, dataset):
        super().__init__()
        #self.conv1 = GCNConv(dataset.num_node_features, 32)
        self.dataset = dataset
        self.conv1 = GCNConv(self.dataset.num_node_features, 32)
        self.conv2 = GCNConv(32, 32)
        self.linear1 = nn.Linear(16,1)
        self.out = nn.Linear(32, 1)
        #self.conv3 = GCNConv(32, dataset.num_classes) #num_classes:ラベルの数
    #バッチノルム(正則化)
    def forward(self, data):
        x, batch, edge_index, edge_attr = data.x, data.batch, data.edge_index, data.edge_attr
        # Dropout:一定割合のノードを不活性化(0になる)させ、過学習を緩和する。pはゼロになるノードの確率で、0.5がデフォルト。
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = torch_geometric.nn.global_add_pool(x, batch) #これが必要やった
        #x = F.dropout(x, p=0.2, training=self.training) # 取ってみる
        x = self.out(x)
        return x

class GCN_N(torch.nn.Module):
    def __init__(self, dataset, layer=3, dim=64):
        super().__init__()
        self.layer = layer
        self.dataset = dataset
        self.dim = dim
        self.conv1 = GCNConv(self.dataset.num_node_features, self.dim, improved=True)
        self.convn = GCNConv(self.dim, self.dim, improved=True)
        self.out = pyg.nn.Linear(self.dim, 1)

    def forward(self, data):
        x, batch, edge_index, edge_attr = data.x, data.batch, data.edge_index, data.edge_attr
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        for i in range(2, self.layer + 1):
            x = self.convn(x, edge_index)
            x = F.relu(x)
        x = pyg.nn.global_add_pool(x, batch) 
        #x = F.dropout(x, p=0.2, training=self.training)
        x = self.out(x)
        return x

class GCN3(torch.nn.Module):
    def __init__(self):
        super(GCN3, self).__init__()
        hidden_layer = 64
        self.conv1 =  GCNConv(dataset.num_node_features, 64)
        self.conv2 = GCNConv(64, 64)
        self.conv3 = GCNConv(64, 64)
        self.out = nn.Linear(64, 1)
    
    def forward(self, data):
        batch, x, edge_index = data.batch, data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = global_add_pool(x, batch)
        x = self.out(x)
        return x

In [10]:
def seed_worker(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

HAR2EV = 27.211386246
KCALMOL2EV = 0.04336414

conversion = torch.tensor([
    1., 1., HAR2EV, HAR2EV, HAR2EV, 1., HAR2EV, HAR2EV, HAR2EV, HAR2EV, HAR2EV,
    1., KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, 1., 1., 1.
])

atomrefs = {
    6: [0., 0., 0., 0., 0.],
    7: [
        -13.61312172, -1029.86312267, -1485.30251237, -2042.61123593,
        -2713.48485589
    ],
    8: [
        -13.5745904, -1029.82456413, -1485.26398105, -2042.5727046,
        -2713.44632457
    ],
    9: [
        -13.54887564, -1029.79887659, -1485.2382935, -2042.54701705,
        -2713.42063702
    ],
    10: [
        -13.90303183, -1030.25891228, -1485.71166277, -2043.01812778,
        -2713.88796536
    ],
    11: [0., 0., 0., 0., 0.],
}

class MyQM9(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    def mean(self, target: int) -> float:
        y = torch.cat([self.get(i).y for i in range(len(self))], dim=0)
        return float(y[:, target].mean())

    def std(self, target: int) -> float:
        y = torch.cat([self.get(i).y for i in range(len(self))], dim=0)
        return float(y[:, target].std())

    def atomref(self, target) -> Optional[torch.Tensor]:
        if target in atomrefs:
            out = torch.zeros(100)
            out[torch.tensor([1, 6, 7, 8, 9])] = torch.tensor(atomrefs[target])
            return out.view(-1, 1)
        return None
        
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return "data_v3.pt"
    
    def download(self):
        pass
    
    def process(self):
        types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
        bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}

        #回帰ターゲット
        df = pd.read_csv("qm9_dataset.csv")
        df_target = df.reindex(columns=["mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom", "A", "B", "C"])
        target = torch.tensor([list(i[1:]) for i in df_target.itertuples()], dtype=torch.float)
        self.target = target

        with open("./uncharacterized.txt") as f:
            #計算できんかったやつ
            skip = [int(x.split()[0]) - 1 for x in f.read().split('\n')[9:-2]]
        
        smiles = df["smiles"].tolist()
        mols = [Chem.MolFromSmiles(m) for m in smiles]
        data_list = []
        for i, mol in enumerate(tqdm.tqdm(mols)):
            if i in skip: #計算できんかったやつを飛ばす
                continue

            mol = Chem.AddHs(mol)
        
            N = mol.GetNumAtoms() #分子の原子数
            
            conf = mol.GetConformers()

            type_idx = []
            atomic_number = []
            formal_charge = []
            valence = []
            degree = []
            aromatic = []
            sp = []
            sp2 = []
            sp3 = []
            num_hs = []

            for atom in mol.GetAtoms():
                type_idx.append(types[atom.GetSymbol()])
                atomic_number.append(atom.GetAtomicNum())
                formal_charge.append(atom.GetFormalCharge())
                valence.append(atom.GetTotalValence())
                degree.append(atom.GetTotalDegree())
                aromatic.append(1 if atom.GetIsAromatic() else 0)
                hybridization = atom.GetHybridization()
                sp.append(1 if hybridization == HybridizationType.SP else 0)
                sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
                sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
                num_hs.append(atom.GetTotalNumHs(includeNeighbors=True))

            row, col, edge_type = [], [], []
            for bond in mol.GetBonds():
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                row += [start, end]
                col += [end, start]
                edge_type += 2 * [bonds[bond.GetBondType()]]

            edge_index = torch.tensor([row, col], dtype=torch.long)
            edge_type = torch.tensor(edge_type, dtype=torch.long)
            edge_attr = one_hot(edge_type, num_classes=len(bonds))
            perm = (edge_index[0] * N + edge_index[1]).argsort()
            edge_index = edge_index[:, perm]
            edge_type = edge_type[perm]
            edge_attr = edge_attr[perm]

            row, col = edge_index
            
            #x1 = one_hot(torch.tensor(type_idx), num_classes=len(types))
            #x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs],
                            #dtype=torch.float).t().contiguous()
            desc_dict = {
            "atomic_number":atomic_number,
            "formal_charge":formal_charge,
            "valence":valence,
            "degree":degree,
            "aromatic":aromatic,
            "sp":sp,
            "sp2":sp2,
            "sp3":sp3,
            "num_hs":num_hs
            }
            descriptors_in_use = [atomic_number, formal_charge, valence, degree, aromatic, sp, sp2, sp3, num_hs]
            if pre_reduce:
                print(pre_reduce)
                descriptors_in_use.remove(desc_dict[pre_reduce])
            x = torch.tensor(descriptors_in_use, dtype=torch.float).t().contiguous()
            #x = torch.cat([x1, x2], dim=-1)
            y = target[i].unsqueeze(0)
            smiles = rdkit.Chem.MolToSmiles(mol, isomericSmiles=True)
            data = Data(x=x, edge_index=edge_index, smiles=smiles, edge_attr=edge_attr, y=y, idx=i)
            data_list.append(data)

        torch.save(self.collate(data_list), self.processed_paths[0])

In [11]:
# https://discuss.pytorch.org/t/rmse-loss-function/16540/3
class RMSELoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.mse = nn.MSELoss(reduction="sum")
        self.eps = eps
    
    def forward(self, yhat, y):
        loss = torch.sqrt(self.mse(yhat,y) + self.eps)
        return loss


class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [12]:
device = "cuda"
rmse = RMSELoss()
def train(initial_epoch_num=None, early_stopping=None):
    #early_stopping = EarlyStopping(patience=1, verbose=True, path=f"{pre_reduce}_best.pt")
    results = []
    for epoch in range(epoch_num):
        if initial_epoch_num > 0:
            epoch += initial_epoch_num
        # train
        model.train()
        train_loss = 0
        total_graphs = 0
        for batch in train_loader:
            batch.to(device)
            optimizer.zero_grad()
            prediction = model(batch)
            loss = rmse(prediction, batch.y[:, target_idx].unsqueeze(1))
            loss.backward()
            train_loss += loss.item()
            total_graphs += batch.num_graphs
            optimizer.step()
        train_loss = train_loss / total_graphs #損失の平均(batchあたり) ルートを取ってから平均

        # validation
        model.eval()
        valid_loss = 0
        total_graphs = 0
        for batch in valid_loader:
            batch.to(device)
            prediction = model(batch)
            loss = rmse(prediction, batch.y[:, target_idx].unsqueeze(1))
            valid_loss += loss.item()
            total_graphs += batch.num_graphs
        valid_loss = valid_loss / total_graphs

        print(f"Epoch {epoch+1} | train_loss:{train_loss}, valid_loss:{valid_loss}")
        results.append({"Epoch":epoch+1, "train_loss":train_loss, "valid_loss":valid_loss})
    return results


In [None]:
from autoviz.AutoViz_Class import AutoViz_Class
plt.style.library["seaborn"] = plt.style.library["seaborn-v0_8"]
vizdf = AutoViz_Class()
filename = "./qm9_dataset.csv"
graph = vizdf.AutoViz(
    filename,
    depVar=""
)

In [None]:
df = pd.read_csv("./qm9_dataset.csv")
plt.hist(df["alpha"].tolist(), range=(0,200))

In [13]:
#乱数ジェネレート
descriptors = [None, "atomic_number", "aromatic", "sp", "sp2", "sp3", "num_hs"]
device = "cuda"

layer = 3
dim = 64
epoch_num = 100
target_idx = 1
batch_size = 32

add_to_old_file = True
for pre_reduce in descriptors:
    filepath = f"./results/1031/GCN_without_{pre_reduce}"
    try:
        shutil.rmtree("./QM9")
    except FileNotFoundError:
        pass
    dataset = MyQM9(root="./QM9")
    num_train = int(len(dataset)*0.8)
    num_val = len(dataset) - num_train
    
    # 乱数の固定
    for i in range(1):
        train_set, valid_set = random_split(dataset, [num_train, num_val])
        #Dataloaderの生成
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True)

        if os.path.isfile(filepath):
            with open(filepath, "rb") as f:
                results_dict_old = pickle.load(f)
                results_old = results_dict_old["results"]
                initial_epoch_num = int(results_old[-1]["Epoch"])
                time_old = results_dict_old["time"]
        else:
            initial_epoch_num = 0
        
        model = GCN3().to(device)
        # Optimizerの初期化
        optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01, weight_decay=5e-4)
        old_file_exists = False
        if os.path.isfile(filepath + "_model"):
            old_file_exists = True
            model.load_state_dict(torch.load(filepath + "_model"))
            print("loaded old model")
        else:
            print("using brand new model")

        start = time.time()
        print(initial_epoch_num)
        results = train(initial_epoch_num=initial_epoch_num) #RMSE
        end = time.time()
        diff = end - start

        if old_file_exists:
            results = results_old + results
            diff += time_old
        results_dict = {"results":results, "time":diff}
        
        with open(filepath, "wb") as f:
            pickle.dump(results_dict, f)
        torch.save(model.state_dict(), filepath + "_model")

None


Processing...
  3%|▎         | 4016/133885 [00:00<00:20, 6396.80it/s]


KeyboardInterrupt: 

In [22]:
pre_reduce = None
print(f"./results/GCN_without_{pre_reduce}")

./results/GCN_without_None


In [None]:
str(None)

In [None]:
len(valid_loader)

In [None]:
dataset[0].x

In [None]:
res

In [None]:
model

In [None]:
dataset = MyQM9(root="./QM9_reduced/")
layer=3
dim=64
device="cuda"
model = GCN_N(dataset=dataset, layer=layer, dim=dim).to(device)

In [None]:
#消す記述子の選択
descriptors = ["atomic_number", "aromatic", "sp", "sp2", "sp3", "num_hs"]
for desc in descriptors:
    print(desc)
    try:
        shutil.rmtree("./QM9_reduced/processed/")
    except FileNotFoundError:
        pass
    dataset_reduced = MyQM9(root="./QM9_reduced")
    dataset = dataset_reduced #必ずチェック！！
    print(dataset[0].x.shape)

    #データの分割(total: 130831)
    num_train = int(len(dataset)*0.8)
    num_val = len(dataset) - num_train
    num_test = 0
    batch_size = 64

    # 乱数の固定
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    g = torch.Generator()
    train_set, valid_set, test_set = random_split(dataset, [num_train, num_val, num_test], g.manual_seed(0))

    #Dataloaderの生成
    train_loader = DataLoader(train_set, batch_size=batch_size, worker_init_fn=seed_worker)
    valid_loader = DataLoader(valid_set, batch_size=batch_size, worker_init_fn=seed_worker)
    test_loader = DataLoader(test_set, batch_size=batch_size, worker_init_fn=seed_worker)
    
    device = "cuda"
    
    layer = 3
    dim = 64
    epoch_num = 100
    target_idx = 1
    mse = F.mse_loss

    start = time.time()
    results = train(mse, dataset=dataset) #RMSE
    end = time.time()
    diff = end-start

    results = {"results":results, "time":diff}

    with open(f"./results/GCN_without_{desc}", "wb") as f:
        pickle.dump(results, f)

In [None]:
import glob
def plot_train(filepath):
    filename = os.path.basename(filepath)
    with open(filepath, "rb") as f:
        result = pickle.load(f)
    plt.plot([i["train_loss"] for i in result["results"]], label=f"training_{filename}")

def plot_valid(filepath):
    filename = os.path.basename(filepath)
    with open(filepath, "rb") as f:
        result = pickle.load(f)
    plt.plot([i["valid_loss"] for i in result["results"]], label=f"valid_{filename}")

def plot_train_errorbar(filepath, label=None):
    files = glob.glob(f"{filepath}_*")
    result_list = []
    for file in files:
        with open(file, "rb") as f:
            result = pickle.load(f)
            result = [i["train_loss"] for i in result["results"]]
        result_list.append(result)
    result_avg = np.mean(result_list, axis=0)
    result_error_upper = [max([result[i] for result in result_list]) - result_avg[i] for i in range(len(result_avg))]
    result_error_lower = [abs(min([result[i] for result in result_list]) - result_avg[i]) for i in range(len(result_avg))]
    plt.errorbar(x=[i for i in range(len(result_avg))] , y=result_avg, yerr=[result_error_upper, result_error_lower ])
    plt.plot(result_avg, label=label)

def plot_valid_errorbar(filepath):
    files = glob.glob(f"{filepath}_*")
    result_list = []
    for file in files:
        with open(file, "rb") as f:
            result = pickle.load(f)
            result = [i["valid_loss"] for i in result["results"]]
        result_list.append(result)
    result_avg = np.mean(result_list, axis=0)
    result_error_upper = [max([result[i] for result in result_list]) - result_avg[i] for i in range(len(result_avg))]
    result_error_lower = [abs(min([result[i] for result in result_list]) - result_avg[i]) for i in range(len(result_avg))]
    plt.errorbar(result_avg, yerr=[result_error_upper, result_error_lower])

In [None]:
for desc in descriptors:
    plot_train_errorbar(f"./results/test/GCN_without_{desc}", label=desc)
plt.ylim([1.0,2.0])
plt.legend()

In [None]:
with open("./results/GCN_all", "rb") as f:
    result_all = pickle.load(f)

with open("./results/GCN_without_atomic_number", "rb") as f:
    result_atomic_number = pickle.load(f)

with open("./results/GCN_without_aromatic", "rb") as f:
    result_aromatic = pickle.load(f)

with open("./results/GCN_without_sp", "rb") as f:
    result_sp = pickle.load(f)

with open("./results/GCN_without_sp2", "rb") as f:
    result_sp2 = pickle.load(f)

with open("./results/GCN_without_sp3", "rb") as f:
    result_sp3 = pickle.load(f)

with open("./results/GCN_without_num_hs", "rb") as f:
    result_sp3 = pickle.load(f)

with open("./results/GCN_without_num_hs", "rb") as f:
    result_num_hs = pickle.load(f)



In [None]:
files = glob.glob("./results/test/GCN_without_aromatic*")
result_list = []
for file in files:
    with open(file, "rb") as f:
        result = pickle.load(f)
        result = [i["train_loss"] for i in result["results"]]
    result_list.append(result)
result_avg = np.mean(result_list, axis=0)
result_error_upper = [max([result[i] for result in result_list]) - result_avg[i] for i in range(len(result_avg))]
result_error_lower = [abs(min([result[i] for result in result_list]) - result_avg[i]) for i in range(len(result_avg))]
plt.errorbar(x=[i for i in range(len(result_avg))] , y=result_avg, yerr=[result_error_upper, result_error_lower ])

x: ノード特徴量

y: 正解ラベル

pos: 原子の座標

edge_index: エッジインデックス

edge_attr: エッジ特徴量(使えん)

ノード特徴量
type_idx: 原子の種類 
aromatic: 芳香性
sp: sp混成
sp2: sp2混成
sp3: sp3混成

In [None]:
import networkx
from matplotlib import pyplot as plt
import numpy as np
from torch_geometric.utils import to_networkx

data = dataset[100]
nxg = to_networkx(data)

#原子番号追加
elements = {
    1:"H",
    2:"He",
    3:"Li",
    4:"Be",
    5:"B",
    6:"C",
    7:"N",
    8:"O",
    9:"F"
}
elem_labels = {}
for i in range(data.num_nodes):
    elem = elements[int(data.z[i])]
    elem_labels[i] = elem

pagerank = networkx.pagerank(nxg) #pagerankはノードの中心性(重要性の指標)
pagerank_max = np.array(list(pagerank.values())).max()

#可視化する時のノード位置
draw_position = networkx.spring_layout(nxg,seed=0)

plt.figure(figsize=(10,10))

networkx.draw_networkx_nodes(nxg,
                            draw_position,
                            node_size=[v / pagerank_max * 1000 for v in pagerank.values()]
                            )

networkx.draw_networkx_edges(nxg, draw_position, arrowstyle='-', alpha=0.2)
networkx.draw_networkx_labels(nxg, draw_position, elem_labels, font_size=10)
plt.show()