# Custom CGCNN (Crystal Graph Convolutional Neural Network) with Sample Regression

1. Original CGCNN Code Refactoring for the Jupyter Notebook Environment 
    * 함수 순서 변경, program arguments (args) 수정
2. Original CGCNN Code에서 main() 함수 부분을 전역프레임에서 가동
3. ./data/sample-regression의 10개의 sample cif file 및 sample target value를 CGCNN 코드 작동 확인을 위해 실행



## Google Drive Mount



In [4]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [5]:
cd /content/drive/MyDrive/

/content/drive/MyDrive


In [6]:
ls -l

total 44
drwx------ 2 root root  4096 Oct 14  2020 [0m[01;34m'00. Github'[0m/
drwx------ 2 root root  4096 Nov  6  2020 [01;34m'01. Backup Files'[0m/
drwx------ 2 root root  4096 Jan 20  2021  [01;34mcgcnn[0m/
drwx------ 2 root root  4096 Oct  1 14:07  [01;34mTemp[0m/
-rw------- 1 root root 27670 Jan 13 06:52  temp.ipynb


## CGCNN Github Repository Clone



In [7]:
#!git clone https://github.com/txie-93/cgcnn

## Pymatgen Library Installation



In [None]:
!pip install pymatgen==2020.11.11

In [9]:
!pip list | grep imgaug
!pip list | grep pymatgen
!python --version

imgaug                        0.2.9
pymatgen                      2020.11.11
Python 3.7.12


In [10]:
cd /content/drive/My Drive/cgcnn

/content/drive/My Drive/cgcnn


In [11]:
ls -l

total 5948
drwx------ 2 root root    4096 Jan 20  2021 [0m[01;34mcgcnn[0m/
-rw------- 1 root root 1006068 Mar  4  2021 checkpoint2.pth.tar
-rw------- 1 root root 1006068 Mar  9  2021 checkpoint_46744_energy_per_atom.pth.tar
-rw------- 1 root root 1003342 Jan  5 07:11 checkpoint.pth.tar
drwx------ 2 root root    4096 Jan 20  2021 [01;34mdata[0m/
-rw------- 1 root root    1065 Jan 20  2021 LICENSE
-rw------- 1 root root   20707 Jan 20  2021 main.py
-rw------- 1 root root 1006068 Mar  4  2021 model_best2.pth.tar
-rw------- 1 root root 1006068 Mar  9  2021 model_best_46744_energy_per_atom.pth.tar
-rw------- 1 root root 1003342 Jan  5 07:08 model_best.pth.tar
-rw------- 1 root root   11219 Jan 20  2021 predict.py
drwx------ 2 root root    4096 Jan 20  2021 [01;34mpre-trained[0m/
-rw------- 1 root root    8225 Jan 20  2021 README.md
drwx------ 2 root root    4096 Mar  9  2021 [01;34mSrTiO3_folder[0m/
-rw------- 1 root root       0 Jan  5 07:11 test_results.csv


## **cgcnn/cgcnn/data.py**



In [12]:
from __future__ import print_function, division

import csv
import functools
import json
import os
import random
import warnings

import numpy as np
import torch

from pymatgen.core.structure import Structure
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler

### get_train_val_test_loader



In [55]:
def get_train_val_test_loader(dataset, collate_fn = default_collate,
                              batch_size = 64, train_ratio = None,
                              val_ratio = 0.1, test_ratio = 0.1, return_test = False,
                              num_workers = 1, pin_memory = False, **kwargs):
    """
    Utility function for dividing a dataset to train, val, test datasets.

    !!! The dataset needs to be shuffled before using the function !!!

    Parameters
    ----------
    dataset : torch.utils.data.Dataset
              The full dataset to be divided.
    collate_fn : torch.utils.data.DataLoader
    batch_size : int
    train_ratio : float
    val_ratio : float
    test_ratio : float
    return_test : bool
                  Whether to return the test dataset loader. 
                  If False, the last test_size data will be hidden.
    num_workers : int
    pin_memory : bool

    Returns
    -------
    train_loader : torch.utils.data.DataLoader
                   DataLoader that random samples the training data.
    val_loader : torch.utils.data.DataLoader
                 DataLoader that random samples the validation data.
    (test_loader) : torch.utils.data.DataLoader
                    DataLoader that random samples the test data, returns if return_test = True.
    """
    total_size = len(dataset) # Number of total data

    if kwargs['train_size'] is None:
        if train_ratio is None:
            assert val_ratio + test_ratio < 1
            train_ratio = 1 - val_ratio - test_ratio
            print('[Warning] train_ratio is None, using 1 - val_ratio - test_ratio = {} as training data.'.format(train_ratio))
        else:
            assert train_ratio + val_ratio + test_ratio <= 1

    indices = list(range(total_size)) # Index list of total data

    if kwargs['train_size']:
        train_size = kwargs['train_size']
    else:
        train_size = int(train_ratio * total_size)

    if kwargs['test_size']:
        test_size = kwargs['test_size']
    else:
        test_size = int(test_ratio * total_size)

    if kwargs['val_size']:
        valid_size = kwargs['val_size']
    else:
        valid_size = int(val_ratio * total_size)

    # Random sampler for training, validation dataset
    train_sampler = SubsetRandomSampler(indices = indices[:train_size]) 
    val_sampler   = SubsetRandomSampler(indices = indices[ -(valid_size + test_size) : -test_size ])

    if return_test: # Test data loader를 사용하는 경우
        test_sampler = SubsetRandomSampler(indices = indices[-test_size:])

    train_loader = DataLoader(dataset, batch_size = batch_size,
                              sampler = train_sampler,
                              num_workers = num_workers,
                              collate_fn = collate_fn, pin_memory = pin_memory)
    val_loader = DataLoader(dataset, batch_size = batch_size,
                            sampler = val_sampler,
                            num_workers = num_workers,
                            collate_fn = collate_fn, pin_memory = pin_memory)
    
    if return_test: # Test data loader를 사용하는 경우
        test_loader = DataLoader(dataset, batch_size = batch_size,
                                 sampler = test_sampler,
                                 num_workers = num_workers,
                                 collate_fn = collate_fn, pin_memory = pin_memory)  
        
    if return_test: # Test data loader를 사용하는 경우
        return train_loader, val_loader, test_loader
    else:
        return train_loader, val_loader

### collate_pool



In [14]:
def collate_pool(dataset_list):
    """
    Collate a list of data and return a batch for predicting crystal properties.
    (Crystal Graph data에 대한 Graph batch를 생성하는 함수)

    Parameters
    ----------

    dataset_list : list of tuples for each data point.
                   (atom_fea, nbr_fea, nbr_fea_idx, target)
                   
                   atom_fea : torch.Tensor shape (n_i, atom_fea_len)
                   nbr_fea : torch.Tensor shape (n_i, M, nbr_fea_len)
                   nbr_fea_idx : torch.LongTensor shape (n_i, M)
                   target : torch.Tensor shape (1, )
                   cif_id : str or int

    Returns
    -------
    N = sum(n_i); N0 = sum(i)

    batch_atom_fea : torch.Tensor shape (N, orig_atom_fea_len)
                     Atom features from atom type
    batch_nbr_fea : torch.Tensor shape (N, M, nbr_fea_len)
                    Bond features of each atom's M neighbors
    batch_nbr_fea_idx : torch.LongTensor shape (N, M)
                        Indices of M neighbors of each atom
    crystal_atom_idx : list of torch.LongTensor of length N0
                       Mapping from the crystal idx to atom idx
    target : torch.Tensor shape (N, 1)
             Target value for prediction
    batch_cif_ids : list
    """
    batch_atom_fea, batch_nbr_fea, batch_nbr_fea_idx = [], [], []
    crystal_atom_idx, batch_target = [], []
    batch_cif_ids = []
    base_idx = 0

    for i, ((atom_fea, nbr_fea, nbr_fea_idx), target, cif_id) in enumerate(dataset_list):

        n_i = atom_fea.shape[0]  # Number of atoms for this crystal
        batch_atom_fea.append(atom_fea)
        batch_nbr_fea.append(nbr_fea)
        batch_nbr_fea_idx.append(nbr_fea_idx + base_idx)

        new_idx = torch.LongTensor(np.arange(n_i) + base_idx)

        crystal_atom_idx.append(new_idx)
        batch_target.append(target)
        batch_cif_ids.append(cif_id)
        base_idx += n_i

    return (torch.cat(batch_atom_fea, dim = 0),
            torch.cat(batch_nbr_fea, dim = 0),
            torch.cat(batch_nbr_fea_idx, dim = 0),
            crystal_atom_idx), \
            torch.stack(batch_target, dim = 0), \
            batch_cif_ids

### GaussianDistance



In [15]:
class GaussianDistance(object):
    """
    Expands the distance by Gaussian basis. (원자 사이의 거리(distance)를 Gaussian basis로 확장하는 클래스)

    Unit : Angstrom
    """
    def __init__(self, dmin, dmax, step, var = None):
        """
        Parameters
        ----------

        dmin : float
               Minimum interatomic distance
        dmax : float
               Maximum interatomic distance
        step : float
               Step size for the Gaussian filter
        """
        assert dmin < dmax
        assert dmax - dmin > step
        self.filter = np.arange(dmin, dmax + step, step)
        if var is None:
            var = step
        self.var = var 

    def expand(self, distances):
        """
        Apply Gaussian distance filter to a numpy distance array

        Parameters
        ----------

        distance : np.array shape n-d array
                   A distance matrix of any shape
        
        Returns
        -------
        expanded_distance : shape (n+1)-d array
                            Expanded distance matrix with the last dimension of length len(self.filter)
        """
        return np.exp( -(distances[..., np.newaxis] - self.filter)**2 / self.var**2 )

### AtomInitializer



In [16]:
class AtomInitializer(object):
    """
    Base class for initializing the vector representation for atoms.

    !!! Use one AtomInitializer per dataset !!!
    """
    def __init__(self, atom_types):
        self.atom_types = set(atom_types) # set(집합)으로 중복된 atom type 있으면 제거
        self._embedding = {}

    def get_atom_fea(self, atom_type):    # atom type에 따른 initial vector representation을 return하는 method
        assert atom_type in self.atom_types 
        return self._embedding[atom_type]

    def load_state_dict(self, state_dict):
        self._embedding = state_dict
        self.atom_types = set(self._embedding.keys())
        self._decodedict = {idx : atom_type for atom_type, idx in self._embedding.items()}

    def state_dict(self):
        return self._embedding

    def decode(self, idx):
        if not hasattr(self, '_decodedict'):
            self._decodedict = {idx : atom_type for atom_type, idx in self._embedding.items()}
        return self._decodedict[idx]

### AtomCustomJSONInitializer



In [17]:
class AtomCustomJSONInitializer(AtomInitializer):
    """
    Initialize atom feature vectors using a JSON file, which is a python
    dictionary mapping from element number to a list representing the
    feature vector of the element.

    Parameters
    ----------

    elem_embedding_file : str
                          The path to the .json file
    """
    def __init__(self, elem_embedding_file):
        with open(elem_embedding_file) as f: # JSON file open
            elem_embedding = json.load(f)    # Load JSON file contents
        elem_embedding = {int(key) : value for key, value in elem_embedding.items()} # string key(atomic number) -> integer
        atom_types = set(elem_embedding.keys()) # set(집합)으로 중복된 atom number 있으면 제거
        super(AtomCustomJSONInitializer, self).__init__(atom_types)
        for key, value in elem_embedding.items():
            self._embedding[key] = np.array(value, dtype = float) # atomic number에 따른 value를 np.array로 embedding 속성에 저장

### CIFData



In [18]:
class CIFData(Dataset):
    """
    The CIFData dataset is a wrapper for a dataset where the crystal structures
    are stored in the form of CIF files. The dataset should have the following 
    directory structure:

    root_dir
    ├── id_prop.csv
    ├── atom_init.json
    ├── id0.cif
    ├── id1.cif
    ├── ...

    id_prop.csv : a CSV file with two columns. The first column recodes a 
    unique ID for each crystal, and the second column recodes the value of
    target property.

    atom_init.json : a JSON file that stores the initialization vector for each element.

    ID.cif : a CIF file that recodes the crystal structure, where ID is the
    unique ID for the crystal.

    Parameters
    ----------

    root_dir : str
               The path to the root directory of the dataset
    max_num_nbr : int
                  The maximum number of neighbors while constructing the crystal graph
    radius : float
             The cutoff radius for searching neighbors
    dmin : float
           The minimum distance for constructing GaussianDistance
    step : float
           The step size for constructing GaussianDistance
    random_seed : int
                  Random seed for shuffling the dataset

    Returns
    -------

    atom_fea : torch.Tensor shape (n_i, atom_fea_len)
    nbr_fea : torch.Tensor shape (n_i, M, nbr_fea_len)
    nbr_fea_idx : torch.LongTensor shape (n_i, M)
    target : torch.Tensor shape (1, )
    cif_id : str or int
    """
    def __init__(self, root_dir, max_num_nbr = 12, radius = 8, dmin = 0, step = 0.2,
                 random_seed = 123):
        self.root_dir = root_dir
        self.max_num_nbr, self.radius = max_num_nbr, radius
        assert os.path.exists(root_dir), 'root_dir does not exist!'

        id_prop_file = os.path.join(self.root_dir, 'id_prop.csv')
        assert os.path.exists(id_prop_file), 'id_prop.csv does not exist!'

        with open(id_prop_file) as f:
            reader = csv.reader(f)
            self.id_prop_data = [row for row in reader]
        
        random.seed(random_seed)
        random.shuffle(self.id_prop_data)

        atom_init_file = os.path.join(self.root_dir, 'atom_init.json')
        assert os.path.exists(atom_init_file), 'atom_init.json file does not exist!'

        self.ari = AtomCustomJSONInitializer(atom_init_file)
        self.gdf = GaussianDistance(dmin = dmin, dmax = self.radius, step = step)

    def __len__(self):
        return len(self.id_prop_data)

    @functools.lru_cache(maxsize = None)  # Cache loaded structures
    def __getitem__(self, idx):
        cif_id, target = self.id_prop_data[idx]
        crystal = Structure.from_file(os.path.join(self.root_dir, cif_id + '.cif'))

        atom_fea = np.vstack([self.ari.get_atom_fea(crystal[i].specie.number) for i in range(len(crystal))])
        atom_fea = torch.Tensor(atom_fea)

        all_nbrs = crystal.get_all_neighbors(self.radius, include_index = True)
        all_nbrs = [sorted(nbrs, key = lambda x : x[1]) for nbrs in all_nbrs]

        nbr_fea_idx, nbr_fea = [], []
        for nbr in all_nbrs:
            if len(nbr) < self.max_num_nbr:
                warnings.warn('{} not find enough neighbors to build graph. If it happens frequently, consider increase radius.'.format(cif_id))
                nbr_fea_idx.append( list(map(lambda x : x[2], nbr)) + [0] * (self.max_num_nbr - len(nbr)) )
                nbr_fea.append( list(map(lambda x : x[1], nbr)) + [self.radius + 1.] * (self.max_num_nbr - len(nbr)) )
            else:
                nbr_fea_idx.append( list(map(lambda x : x[2], nbr[:self.max_num_nbr])) )
                nbr_fea.append( list(map(lambda x : x[1], nbr[:self.max_num_nbr])) )
        
        nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
        nbr_fea = self.gdf.expand(nbr_fea)
        atom_fea = torch.Tensor(atom_fea)
        nbr_fea = torch.Tensor(nbr_fea)
        nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
        target = torch.Tensor([float(target)])

        return (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id


## **cgcnn/cgcnn/model.py**



In [19]:
from __future__ import print_function, division

import torch
import torch.nn as nn

### ConvLayer



In [20]:
class ConvLayer(nn.Module):
    """
    Convolutional operation on graphs
    """
    def __init__(self, atom_fea_len, nbr_fea_len):
        """
        Initialize ConvLayer.

        Parameters
        ----------

        atom_fea_len : int
                       Number of atom hidden features.
        nbr_fea_len : int
                      Number of bond features.
        """
        super(ConvLayer, self).__init__()
        self.atom_fea_len = atom_fea_len
        self.nbr_fea_len = nbr_fea_len
        self.fc_full = nn.Linear(in_features = 2 * self.atom_fea_len + self.nbr_fea_len,
                                 out_features = 2 * self.atom_fea_len)
        self.sigmoid = nn.Sigmoid()
        self.softplus1 = nn.Softplus()
        self.bn1 = nn.BatchNorm1d(num_features = 2 * self.atom_fea_len)
        self.bn2 = nn.BatchNorm1d(num_features = self.atom_fea_len)
        self.softplus2 = nn.Softplus()

    def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
        """
        Forward pass

        N : Total number of atoms in the batch
        M : Max number of neighbors

        Parameters
        ----------

        atom_in_fea : Variable(torch.Tensor) shape (N, atom_fea_len)
                      Atom hidden features before convolution
        nbr_fea : Variable(torch.Tensor) shape (N, M, nbr_fea_len)
                  Bond features of each atom's M neighbors
        nbr_fea_idx : torch.LongTensor shape (N, M)
                      Indices of M neighbors of each atom

        Returns
        -------

        atom_out_fea : nn.Variable shape (N, atom_fea_len)
                       Atom hidden features after convolution
        """
        # TODO will there be problems with the index zero padding?
        N, M = nbr_fea_idx.shape

        # Convolution -> this process should be investigated with exact shape per each step!!
        # We can modify this process
        atom_nbr_fea = atom_in_fea[nbr_fea_idx, :] # important operation!
        total_nbr_fea = torch.cat([atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len), atom_nbr_fea, nbr_fea], dim = 2)

        total_gated_fea = self.fc_full(total_nbr_fea)
        total_gated_fea = self.bn1(total_gated_fea.view(-1, self.atom_fea_len * 2)).view(N, M, self.atom_fea_len * 2)

        nbr_filter, nbr_core = total_gated_fea.chunk(2, dim = 2)
        nbr_filter = self.sigmoid(nbr_filter)
        nbr_core = self.softplus1(nbr_core)

        nbr_sumed = torch.sum(nbr_filter * nbr_core, dim = 1)
        nbr_sumed = self.bn2(nbr_sumed)
        
        out = self.softplus2(atom_in_fea + nbr_sumed)
        return out


### CrystalGraphConvNet



In [21]:
class CrystalGraphConvNet(nn.Module):
    """
    Create a crystal graph convolutional neural network for predicting total
    material properties.
    """
    def __init__(self, orig_atom_fea_len, nbr_fea_len,
                 atom_fea_len = 64, n_conv = 3, h_fea_len = 128, n_h = 1,
                 classification = False):
        """
        Initialize CrystalGraphConvNet.

        Parameters
        ----------

        orig_atom_fea_len : int
                            Number of atom features in the input.
        nbr_fea_len : int
                      Number of bond features.
        atom_fea_len : int
                       Number of hidden atom features in the convolutional layers
        n_conv : int
                 Number of convolutional layers
        h_fea_len : int
                    Number of hidden features after pooling
        n_h : int
              Number of hidden layers after pooling
        """
        super(CrystalGraphConvNet, self).__init__()
        self.classification = classification
        self.embedding = nn.Linear(in_features = orig_atom_fea_len, out_features = atom_fea_len) # Dimension reduction layer
        self.convs = nn.ModuleList([ConvLayer(atom_fea_len = atom_fea_len,
                                              nbr_fea_len = nbr_fea_len) for _ in range(n_conv)]) # Iterative Graph Convolution layers
        self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
        self.conv_to_fc_softplus = nn.Softplus()

        if n_h > 1: # Iterative FCL layers
            self.fcs = nn.ModuleList([nn.Linear(h_fea_len, h_fea_len) for _ in range(n_h - 1)])
            self.softpluses = nn.ModuleList([nn.Softplus() for _ in range(n_h - 1)])
        
        if self.classification:
            self.fc_out = nn.Linear(h_fea_len, 2) # Binary classification
        else:
            self.fc_out = nn.Linear(h_fea_len, 1) # Only one physical property Regression -> Maybe we can modify into multiple property prediction
        
        if self.classification:
            self.logsoftmax = nn.LogSoftmax(dim = 1)
            self.dropout = nn.Dropout()

    def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx):
        """
        Forward pass

        N : Total number of atoms in the batch
        M : Max number of neighbors
        N0 : Total number of crystals in the batch

        Parameters
        ----------

        atom_fea : Variable(torch.Tensor) shape (N, orig_atom_fea_len)
                   Atom features from atom type
        nbr_fea : Variable(torch.Tensor) shape (N, M, nbr_fea_len)
                  Bond features of each atom's M neighbors
        nbr_fea_idx : torch.LongTensor shape (N, M)
                      Indices of M neighbors of each atom
        crystal_atom_idx : list of torch.LongTensor of length N0
                           Mapping from the crystal idx to atom idx

        Returns
        -------

        prediction : nn.Variable shape (N, )
                     Atom hidden features after convolution
        """
        atom_fea = self.embedding(atom_fea)

        for conv_func in self.convs: # Iterative Graph Convolution
            atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
        
        crys_fea = self.pooling(atom_fea, crystal_atom_idx) # Graph Pooling
        crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
        crys_fea = self.conv_to_fc_softplus(crys_fea)

        if self.classification:
            crys_fea = self.dropout(crys_fea)

        if hasattr(self, 'fcs') and hasattr(self, 'softpluses'): # Check whether the instance have 'fcs' and 'softpluses' attributes
            for fc, softplus in zip(self.fcs, self.softpluses):
                crys_fea = softplus(fc(crys_fea))
        
        out = self.fc_out(crys_fea)

        if self.classification:
            out = self.logsoftmax(out)
        
        return out

    def pooling(self, atom_fea, crystal_atom_idx):
        """
        Pooling the atom features to crystal features

        N : Total number of atoms in the batch
        N0 : Total number of crystals in the batch

        Parameters
        ----------

        atom_fea : Variable(torch.Tensor) shape (N, atom_fea_len)
                   Atom feature vectors of the batch
        crystal_atom_idx : list of torch.LongTensor of length N0
                           Mapping from the crystal idx to atom idx
        """
        assert sum([len(idx_map) for idx_map in crystal_atom_idx]) == atom_fea.data.shape[0]

        # Simple Mean Pooling operation --> We can modify it to the modern pooling function
        summed_fea = [torch.mean(atom_fea[idx_map], dim = 0, keepdim = True) for idx_map in crystal_atom_idx]

        return torch.cat(summed_fea, dim = 0)


## **cgcnn/main.py**



In [22]:
import argparse
import os
import shutil
import sys
import time
import warnings
from random import sample 

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from sklearn import metrics
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR

# from cgcnn.data import CIFData
# from cgcnn.data import collate_pool, get_train_val_test_loader
# from cgcnn.model import CrystalGraphConvNet

### Args



In [51]:
class Args(object):
    """ Replacement class of the Argparse """

    def __init__(self):

        self.data_options = '/content/drive/My Drive/cgcnn/data/sample-regression' # 절대경로
        self.task = 'regression'
        self.disable_cuda = False  # action = 'store_true'이어서 입력되면 True이므로 default = False
        self.workers = 0           # Number of data loading workers
        self.epochs = 30           # Number of total epochs to run
        self.start_epoch = 0       # Manual epoch number (useful on restarts)
        self.batch_size = 256      # Mini-batch size
        self.lr = 0.01             # Initial learning rate
        self.lr_milestones = [100] # Milestones for scheduler
        self.momentum = 0.9 
        self.weight_decay = 0
        self.print_freq = 10
        self.resume = None         # Path to latest checkpoint

        # Train_group
        self.train_ratio = None # Percentage of training data to be loaded
        self.train_size = None  # Number of training data to be loaded

        # Valid_group
        self.val_ratio = 0.1    # Percentage of validation data to be loaded
        self.val_size = 1000    # Number of validation data to be loaded

        # Test_group
        self.test_ratio = 0.1   # Percentage of test data to be loaded
        self.test_size = 1000   # Number of test data to be loaded
        
        self.optim = 'Adam'     # Choose an optimizer, SGD or Adam
        self.atom_fea_len = 64  # Number of hidden atom features in conv layers
        self.h_fea_len = 128    # Number of hidden features after pooling
        self.n_conv = 3         # Number of conv layers
        self.n_h = 1            # Number of hidden layers after pooling

        self.cuda = None

args = Args()
args

<__main__.Args at 0x7fbd844e5690>

In [52]:
args.cuda = not args.disable_cuda and torch.cuda.is_available()
print(args.cuda) # True if GPU on

True


In [53]:
if args.task == 'regression':
    best_mae_error = 1e10
else:
    best_mae_error = 0.

print(args.task, best_mae_error)

regression 10000000000.0


### Normalizer



In [35]:
class Normalizer(object): 
    """ Normalize a Tensor and restore it later. """
    """ 주어진 Tensor data의 평균과 표준편차를 계산하고 이를 기반으로 정규화하는 클래스 """
    
    def __init__(self, tensor):
        """ tensor is taken as a sample to calculate the mean and std """
        self.mean = torch.mean(tensor)
        self.std  = torch.std(tensor)

    def norm(self, tensor): # Normalize method
        return (tensor - self.mean) / self.std

    def denorm(self, normed_tensor): # Denormalize method
        return normed_tensor * self.std + self.mean

    def state_dict(self):
        return {'mean' : self.mean, 'std' : self.std}
    
    def load_state_dict(self, state_dict):
        self.mean = state_dict['mean']
        self.std  = state_dict['std']

### mae



In [36]:
def mae(prediction, target):
    """
    Computes the mean absolute error between prediction and target

    Parameters
    ----------

    prediction : torch.Tensor (N, 1)
    target : torch.Tensor (N, 1)
    """
    return torch.mean(torch.abs(target - prediction))

### class_eval



In [37]:
def class_eval(prediction, target):
    """ Class(Category) evaluation function only used for (binary -> ex) metal or semiconductor?) classification of the model """

    prediction = np.exp(prediction.numpy()) # Convert to probability
    target = target.numpy()
    pred_label = np.argmax(prediction, axis = 1) 
    target_label = np.squeeze(target)

    if not target_label.shape:
        target_label = np.asarray([target_label])
    
    if prediction.shape[1] == 2:
        precision, recall, fscore, _ = metrics.precision_recall_fscore_support(target_label,
                                                                               pred_label,
                                                                               average = 'binary') # binary classification
        auc_score = metrics.roc_auc_score(target_label, prediction[:, 1]) # AUC
        accuracy = metrics.accuracy_score(target_label, pred_label)       # Accuracy
    else:
        raise NotImplementedError

    return accuracy, precision, recall, fscore, auc_score

### AverageMeter



In [38]:
class AverageMeter(object):
    """ Computes and stores the average and current value """
    """ 기록하고 싶은 특정 지표들을 추적하는 클래스 """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n = 1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

### save_checkpoint



In [39]:
def save_checkpoint(state, is_best, filename = 'checkpoint.pth.tar'):
    torch.save(state, filename) # 모델의 state dictionary를 file name으로 저장
    if is_best: # 최고성능 모델이면
        shutil.copyfile(filename, 'model_best.pth.tar') 

### adjust_learning_rate



In [40]:
def adjust_learning_rate(optimizer, epoch, k):
    """ Sets the learning rate to the initial LR decayed by 10 every k epochs """
    """ 학습률(Learning rate)의 decay를 조정하는 함수 """
    assert type(k) is int
    lr = args.lr * (0.1 ** (epoch // k))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

### train



In [41]:
def train(train_loader, model, criterion, optimizer, epoch, normalizer):

    # We will record batch_time, data_time, losses with value and average
    batch_time = AverageMeter() # batch process runtime
    data_time  = AverageMeter() # data loading time
    losses     = AverageMeter()

    if args.task == 'regression':
        mae_errors = AverageMeter() # MAE : Mean Absolute Error, 평균절대오차
    else:          # classification
        accuracies = AverageMeter() # 정확도 : 직관적으로 모델성능을 나타내는 지표, 그러나 domain의 편중(bias)를 고려할 것!!
        precisions = AverageMeter() # 정밀도 : 모델이 True라고 분류한 것 중 실제 True인 것의 비율 (Positive 정답률, PPV(Positive Predictive Value))
        recalls    = AverageMeter() # 재현율 : 실제 True인 것 중에서 모델이 True라고 예측한 것의 비율 (통계학에서의 sensitivity, hit rate)
        fscores    = AverageMeter() # F1 score : 데이터 label이 불균형 구조일 때, 모델성능을 정확하게 평가하는 지표
        auc_scores = AverageMeter() # AUC(Area Under Curve) : ROC curve는 그래프이므로 명확한 수치표현 불가 -> 그래프 아래 면적값 (AUC) 이용, 최대값 1에 가까울수록 좋은 모델

    # Switch to train mode
    model.train()

    end = time.time()
    for i, (input, target, _) in enumerate(train_loader):

        # Measure data loading time
        data_time.update(time.time() - end)

        if args.cuda: # GPU -> 데이터를 variable로 만들고 모두 cuda device로 넘김.
            input_var = (Variable(input[0].cuda(non_blocking = True)),
                         Variable(input[1].cuda(non_blocking = True)),
                         input[2].cuda(non_blocking = True),
                         [crys_idx.cuda(non_blocking = True) for crys_idx in input[3]])
        else:         # CPU -> 데이터를 variable로 만들고, 그대로 CPU device에서 사용
            input_var = (Variable(input[0]),
                         Variable(input[1]),
                         input[2],
                         input[3])
        
        # Normalize target
        if args.task == 'regression':
            target_normed = normalizer.norm(target) # target value들을 normalize(정규화)
        else: # classification
            target_normed = target.view(-1).long()  # target(label)들을 1차원 벡터 및 정수데이터로 변환

        if args.cuda: # GPU
            target_var = Variable(target_normed.cuda(non_blocking = True)) # target value -> cuda device로 넘김.
        else: # CPU
            target_var = Variable(target_normed) # target value -> 그대로 CPU device에서 사용

        # Compute output
        output = model(*input_var)
        loss = criterion(output, target_var)

        # Measure accuracy and record loss
        if args.task == 'regression': # Loss and MAE recorded
            mae_error = mae(prediction = normalizer.denorm(output.data.cpu()), target = target)
            losses.update(loss.data.cpu(), target.size(0))
            mae_errors.update(mae_error, target.size(0))
        else: 
            accuracy, precision, recall, fscore, auc_score = class_eval(prediction = output.data.cpu(), target = target)
            losses.update(loss.data.cpu().item(), target.size(0))
            accuracies.update(accuracy, target.size(0))
            precisions.update(precision, target.size(0))
            recalls.update(recall, target.size(0))
            fscores.update(fscore, target.size(0))
            auc_scores.update(auc_score, target.size(0))

        # Compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            if args.task == 'regression':
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'MAE {mae_errors.val:.3f} ({mae_errors.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                               batch_time = batch_time,
                                                                               data_time = data_time,
                                                                               loss = losses,
                                                                               mae_errors = mae_errors))
            else:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Accu {accu.val:.3f} ({accu.avg:.3f})\t'
                      'Precision {prec.val:.3f} ({prec.avg:.3f})\t'
                      'Recall {recall.val:.3f} ({recall.avg:.3f})\t'
                      'F1 {f1.val:.3f} ({f1.avg:.3f})\t'
                      'AUC {auc.val:.3f} ({auc.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                 batch_time = batch_time,
                                                                 data_time = data_time,
                                                                 loss = losses,
                                                                 accu = accuracies,
                                                                 prec = precisions,
                                                                 recall = recalls, 
                                                                 f1 = fscores,
                                                                 auc = auc_scores))
                

### validate



In [42]:
def validate(val_loader, model, criterion, normalizer, test = False):

    # We will record batch_time, losses with value and average
    batch_time = AverageMeter()
    losses = AverageMeter()

    if args.task == 'regression':
        mae_errors = AverageMeter()
    else:
        accuracies = AverageMeter() # 정확도
        precisions = AverageMeter() # 정밀도
        recalls    = AverageMeter() # 재현율
        fscores    = AverageMeter() # F1 score
        auc_scores = AverageMeter() # AUC

    if test: # test data loader를 사용하는 경우
        test_targets = []
        test_preds = []
        test_cif_ids = []

    # Switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target, batch_cif_ids) in enumerate(val_loader):
        
        if args.cuda: # GPU
            with torch.no_grad(): # Disable gradient calculation
                input_var = (Variable(input[0].cuda(non_blocking = True)),
                             Variable(input[1].cuda(non_blocking = True)),
                             input[2].cuda(non_blocking = True),
                             [crys_idx.cuda(non_blocking = True) for crys_idx in input[3]])
        else: # CPU
            with torch.no_grad(): # Disable gradient calculation
                input_var = (Variable(input[0]),
                             Variable(input[1]),
                             input[2],
                             input[3])
        
        if args.task == 'regression':
            target_normed = normalizer.norm(target) # target value들을 normalize(정규화)
        else:
            target_normed = target.view(-1).long()  # target(label)들을 1차원 벡터 및 정수데이터로 변환

        if args.cuda: # GPU
            with torch.no_grad(): # Disable gradient calculation
                target_var = Variable(target_normed.cuda(non_blocking = True)) # target value -> cuda device로 넘김.
        else:         # CPU
            with torch.no_grad(): # Disable gradient calculation
                target_var = Variable(target_normed) # target value -> 그대로 CPU device에서 사용

        # Compute output
        output = model(*input_var)
        loss = criterion(output, target_var)

        # Measure accuracy and record loss
        if args.task == 'regression':
            mae_error = mae(normalizer.denorm(output.data.cpu()), target)
            losses.update(loss.data.cpu().item(), target.size(0))
            mae_errors.update(mae_error, target.size(0))

            if test: # test data loader를 사용하는 경우
                test_pred = normalizer.denorm(output.data.cpu())
                test_target = target
                test_preds += test_pred.view(-1).tolist()
                test_targets += test_target.view(-1).tolist()
                test_cif_ids += batch_cif_ids

        else: # classification
            accuracy, precision, recall, fscore, auc_score = class_eval(output.data.cpu(), target)
            losses.update(loss.data.cpu().item(), target.size(0))
            accuracies.update(accuracy, target.size(0))
            precisions.update(precision, target.size(0))
            recalls.update(recall, target.size(0))
            fscores.update(fscore, target.size(0))
            auc_scores.update(auc_score, target.size(0))

            if test: # test data loader를 사용하는 경우
                test_pred = torch.exp(output.data.cpu())
                test_target = target
                assert test_pred.shape[1] == 2
                test_preds += test_pred[:, 1].tolist()
                test_targets += test_target.view(-1).tolist()
                test_cif_ids += batch_cif_ids
        
        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            if args.task == 'regression':
                print('Test : [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'MAE {mae_errors.val:.3f} ({mae_errors.avg:.3f})'.format(i, len(val_loader),
                                                                               batch_time = batch_time,
                                                                               loss = losses,
                                                                               mae_errors = mae_errors))
            else:
                print('Test : [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.3f} ({loss.avg:.4f})\t'
                      'Accu {accu.val:.3f} ({accu.avg:.3f})\t'
                      'Precision {prec.val:.3f} ({prec.avg:.3f})\t'
                      'Recall {recall.val:.3f} ({recall.avg:.3f})\t'
                      'F1 {f1.val:.3f} ({f1.avg:.3f})\t'
                      'AUC {auc.val:.3f} ({auc.avg:.3f})'.format(i, len(val_loader),
                                                                 batch_time = batch_time,
                                                                 loss = losses,
                                                                 accu = accuracies, 
                                                                 prec = precisions,
                                                                 recall = recalls,
                                                                 f1 = fscores,
                                                                 auc = auc_scores))
                
    if test: # Test data loader 사용하는 경우
        star_label = '**'
        import csv
        with open('test_results.csv', 'w') as f: # Test data에 대한 prediction value를 csv file로 저장
            writer = csv.writer(f)
            for cif_id, target, pred in zip(test_cif_ids, test_targets, test_preds):
                writer.writerow((cif_id, target, pred))
    else:
        star_label = '*'


    if args.task == 'regression':
        print(' {star} MAE {mae_errors.avg:.3f}'.format(star = star_label, mae_errors = mae_errors))
        return mae_errors.avg
    else: # classification
        print(' {star} AUC {auc.avg:.3f}'.format(star = star_label, auc = auc_scores))
        return auc_scores.avg

### main

In [43]:
cd /content/drive/My Drive/cgcnn/

/content/drive/My Drive/cgcnn


In [44]:
ls -l

total 5948
drwx------ 2 root root    4096 Jan 20  2021 [0m[01;34mcgcnn[0m/
-rw------- 1 root root 1006068 Mar  4  2021 checkpoint2.pth.tar
-rw------- 1 root root 1006068 Mar  9  2021 checkpoint_46744_energy_per_atom.pth.tar
-rw------- 1 root root 1003342 Jan  5 07:11 checkpoint.pth.tar
drwx------ 5 root root    4096 Jan 20  2021 [01;34mdata[0m/
-rw------- 1 root root    1065 Jan 20  2021 LICENSE
-rw------- 1 root root   20707 Jan 20  2021 main.py
-rw------- 1 root root 1006068 Mar  4  2021 model_best2.pth.tar
-rw------- 1 root root 1006068 Mar  9  2021 model_best_46744_energy_per_atom.pth.tar
-rw------- 1 root root 1003342 Jan  5 07:08 model_best.pth.tar
-rw------- 1 root root   11219 Jan 20  2021 predict.py
drwx------ 2 root root    4096 Jan 20  2021 [01;34mpre-trained[0m/
-rw------- 1 root root    8225 Jan 20  2021 README.md
drwx------ 2 root root    4096 Mar  9  2021 [01;34mSrTiO3_folder[0m/
-rw------- 1 root root       0 Jan  5 07:11 test_results.csv


In [45]:
print(args)
print(best_mae_error)

<__main__.Args object at 0x7fbd90433950>
10000000000.0


In [47]:
# Load data
dataset = CIFData(root_dir = args.data_options)      # CIFData에 root directory(dataset directory) 입력하여 dataset pipeline 준비
print('Data Pipeline Instance : {}'.format(dataset)) # Data Pipeline instance
print('Number of Total data'.format(len(dataset)))   # Number of Total data

Data Pipeline Instance : <__main__.CIFData object at 0x7fbd88bf0290>
Number of Total data


In [49]:
collate_fn = collate_pool # Graph data batch를 처리하는 collate function
print('Collate function for graph data batch : {}'.format(collate_fn))

Collate function for graph data batch : <function collate_pool at 0x7fbd90b4db00>


In [75]:
train_loader, val_loader, test_loader = get_train_val_test_loader(dataset = dataset,
                                                                  collate_fn = collate_fn,
                                                                  batch_size = args.batch_size,
                                                                  train_ratio = 0.6,
                                                                  num_workers = args.workers,
                                                                  val_ratio = 0.2,
                                                                  test_ratio = 0.2,
                                                                  pin_memory = args.cuda,
                                                                  train_size = args.train_size,
                                                                  val_size = args.val_size,
                                                                  test_size = args.test_size,
                                                                  return_test = True)
print(len(train_loader))
print(len(val_loader))
print(len(test_loader))

1
0
1


In [62]:
# Obtain target value normalizer
if args.task == 'classification':
    normalizer = Normalizer(torch.zeros(2))
    normalizer.load_state_dict({'mean' : 0, 'std' : 1.}) # 평균 0, 표준편차 1로 정규화(Normalize) 예정
else: # regression
    if len(dataset) < 500:
        warnings.warn('Dataset has less than 500 data points. Lower accuracy is expected.')
        sample_data_list = [dataset[i] for i in range(len(dataset))]
    else:
        print('Dataset has more than 500 data points. Good!')
        sample_data_list = [dataset[i] for i in sample(population = range(len(dataset)), k = 500)]

  import sys


In [63]:
print(sample_data_list)

[((tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0

In [64]:
_, sample_target, _ = collate_pool(sample_data_list) # sample data를 collate function으로 graph batch data 변환
normalizer = Normalizer(sample_target)               # sample target value들을 해당 데이터 분포에 따라 Normalize

print(sample_target)
print(normalizer)

tensor([[ 9.],
        [ 8.],
        [ 6.],
        [10.],
        [ 3.],
        [ 4.],
        [ 7.],
        [ 2.],
        [ 5.],
        [ 1.]])
<__main__.Normalizer object at 0x7fbd8422c8d0>


In [65]:
# Build model
structures, _, _ = dataset[0] # ((atom_fea, nbr_fea, nbr_fea_idx), target, cif_id) -> Crystal structure
print(type(structures))
print(structures) # (atom_fea, nbr_fea, nbr_fea_idx)

<class 'tuple'>
(tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1.,

In [66]:
orig_atom_fea_len = structures[0].shape[-1] # Original atomic feature length : 92 (by JSON initial file)
nbr_fea_len = structures[1].shape[-1]       # Neighbor feature length

print('Original Atomic feature length : {}'.format(orig_atom_fea_len))
print('Neighbor feature length : {}'.format(nbr_fea_len))

Original Atomic feature length : 92
Neighbor feature length : 41


In [67]:
model = CrystalGraphConvNet(orig_atom_fea_len = orig_atom_fea_len,
                            nbr_fea_len = nbr_fea_len,
                            atom_fea_len = args.atom_fea_len,
                            n_conv = args.n_conv,
                            h_fea_len = args.h_fea_len,
                            n_h = args.n_h,
                            classification = True if args.task == 'classification' else False)
print(model)

CrystalGraphConvNet(
  (embedding): Linear(in_features=92, out_features=64, bias=True)
  (convs): ModuleList(
    (0): ConvLayer(
      (fc_full): Linear(in_features=169, out_features=128, bias=True)
      (sigmoid): Sigmoid()
      (softplus1): Softplus(beta=1, threshold=20)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (softplus2): Softplus(beta=1, threshold=20)
    )
    (1): ConvLayer(
      (fc_full): Linear(in_features=169, out_features=128, bias=True)
      (sigmoid): Sigmoid()
      (softplus1): Softplus(beta=1, threshold=20)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (softplus2): Softplus(beta=1, threshold=20)
    )
    (2): ConvLayer(
      (fc_full): Linear(in_features=169, out_featu

In [68]:
if args.cuda:    # GPU 
    model.cuda() # model -> cuda device

In [69]:
# Define loss function and optimizer
if args.task == 'classification':
    criterion = nn.NLLLoss( ) # Negative Log Likelihood Loss
else:
    criterion = nn.MSELoss( ) # Mean Squared Error (Squared L2 norm)

if args.optim == 'SGD':
    optimizer = optim.SGD(model.parameters(), args.lr,
                          momentum = args.momentum,
                          weight_decay = args.weight_decay)
elif args.optim == 'Adam':
    optimizer = optim.Adam(model.parameters(), args.lr,
                           weight_decay = args.weight_decay)
else: # We could use another state-of-the-art optimizer such as [torch.optim.LBFGS]
    raise NameError('Only SGD or Adam is allowed as --optim')

print('Criterion : {}'.format(criterion))
print('Optimizer : {}'.format(optimizer))

Criterion : MSELoss()
Optimizer : Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.01
    weight_decay: 0
)


In [70]:
# ========= Optionally resume from a checkpoint ==========
if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_mae_error = checkpoint['best_mae_error']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        normalizer.load_state_dict(checkpoint['normalizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
    else:
        print("=> No checkpoint found at '{}'".format(args.resume))
# ========================================================

In [71]:
# Learning rate scheduler
scheduler = MultiStepLR(optimizer, milestones = args.lr_milestones, gamma = 0.1)
print(scheduler)

<torch.optim.lr_scheduler.MultiStepLR object at 0x7fbd83f66090>


In [72]:
from tqdm import tqdm

for epoch in tqdm(range(args.start_epoch, args.epochs)):

    # Train for one epoch
    train(train_loader, model, criterion, optimizer, epoch, normalizer)

    # Evaluate on validation set
    mae_error = validate(val_loader, model, criterion, normalizer, test = False)

    # ========== Check whether the mae_error is Nan or not  ==========
    if mae_error != mae_error: # Wow, Surprising code!!
        print('Exit due to NaN')
        sys.exit(1)
    # ================================================================

    scheduler.step() # Learning rate decayed 

    # Remember the best mae_error and save checkpoint
    if args.task == 'regression':
        is_best = mae_error < best_mae_error
        best_mae_error = min(mae_error, best_mae_error)
    else:
        is_best = mae_error > best_mae_error 
        best_mae_error = max(mae_error, best_mae_error) # Why choose max value of mae_error?? --> check classification loss 
        
    save_checkpoint({'epoch' : epoch + 1,
                     'state_dict' : model.state_dict(),
                     'best_mae_error' : best_mae_error,
                     'optimizer' : optimizer.state_dict(),
                     'normalizer' : normalizer.state_dict(),
                     'args' : vars(args)},
                     is_best)

  0%|          | 0/30 [00:00<?, ?it/s]

Epoch: [0][0/1]	Time 0.177 (0.177)	Data 0.020 (0.020)	Loss 1.2355 (1.2355)	MAE 2.922 (2.922)
 * MAE 0.000


 17%|█▋        | 5/30 [00:01<00:05,  4.84it/s]

Epoch: [1][0/1]	Time 0.013 (0.013)	Data 0.001 (0.001)	Loss 14.6273 (14.6273)	MAE 11.290 (11.290)
 * MAE 0.000
Epoch: [2][0/1]	Time 0.009 (0.009)	Data 0.001 (0.001)	Loss 1.0136 (1.0136)	MAE 2.640 (2.640)
 * MAE 0.000
Epoch: [3][0/1]	Time 0.009 (0.009)	Data 0.001 (0.001)	Loss 2.0168 (2.0168)	MAE 4.082 (4.082)
 * MAE 0.000
Epoch: [4][0/1]	Time 0.009 (0.009)	Data 0.001 (0.001)	Loss 3.8110 (3.8110)	MAE 5.749 (5.749)
 * MAE 0.000
Epoch: [5][0/1]	Time 0.014 (0.014)	Data 0.001 (0.001)	Loss 2.5025 (2.5025)	MAE 4.516 (4.516)
 * MAE 0.000
Epoch: [6][0/1]	Time 0.008 (0.008)	Data 0.001 (0.001)	Loss 1.2619 (1.2619)	MAE 2.913 (2.913)
 * MAE 0.000
Epoch: [7][0/1]	Time 0.011 (0.011)	Data 0.001 (0.001)	Loss 0.6835 (0.6835)	MAE 1.953 (1.953)
 * MAE 0.000
Epoch: [8][0/1]	Time 0.009 (0.009)	Data 0.001 (0.001)	Loss 0.4798 (0.4798)	MAE 1.779 (1.779)
 * MAE 0.000


 40%|████      | 12/30 [00:01<00:01, 11.83it/s]

Epoch: [9][0/1]	Time 0.011 (0.011)	Data 0.002 (0.002)	Loss 0.4237 (0.4237)	MAE 1.806 (1.806)
 * MAE 0.000
Epoch: [10][0/1]	Time 0.009 (0.009)	Data 0.001 (0.001)	Loss 0.4133 (0.4133)	MAE 1.810 (1.810)
 * MAE 0.000
Epoch: [11][0/1]	Time 0.009 (0.009)	Data 0.001 (0.001)	Loss 0.4140 (0.4140)	MAE 1.797 (1.797)
 * MAE 0.000
Epoch: [12][0/1]	Time 0.012 (0.012)	Data 0.003 (0.003)	Loss 0.4158 (0.4158)	MAE 1.772 (1.772)
 * MAE 0.000
Epoch: [13][0/1]	Time 0.008 (0.008)	Data 0.001 (0.001)	Loss 0.4162 (0.4162)	MAE 1.737 (1.737)
 * MAE 0.000
Epoch: [14][0/1]	Time 0.009 (0.009)	Data 0.001 (0.001)	Loss 0.4154 (0.4154)	MAE 1.696 (1.696)
 * MAE 0.000


 67%|██████▋   | 20/30 [00:01<00:00, 20.62it/s]

Epoch: [15][0/1]	Time 0.011 (0.011)	Data 0.001 (0.001)	Loss 0.4137 (0.4137)	MAE 1.652 (1.652)
 * MAE 0.000
Epoch: [16][0/1]	Time 0.011 (0.011)	Data 0.001 (0.001)	Loss 0.4107 (0.4107)	MAE 1.609 (1.609)
 * MAE 0.000
Epoch: [17][0/1]	Time 0.009 (0.009)	Data 0.001 (0.001)	Loss 0.4056 (0.4056)	MAE 1.568 (1.568)
 * MAE 0.000
Epoch: [18][0/1]	Time 0.009 (0.009)	Data 0.001 (0.001)	Loss 0.3979 (0.3979)	MAE 1.532 (1.532)
 * MAE 0.000
Epoch: [19][0/1]	Time 0.009 (0.009)	Data 0.001 (0.001)	Loss 0.3876 (0.3876)	MAE 1.503 (1.503)
 * MAE 0.000
Epoch: [20][0/1]	Time 0.012 (0.012)	Data 0.003 (0.003)	Loss 0.3749 (0.3749)	MAE 1.480 (1.480)
 * MAE 0.000
Epoch: [21][0/1]	Time 0.010 (0.010)	Data 0.000 (0.000)	Loss 0.3607 (0.3607)	MAE 1.458 (1.458)
 * MAE 0.000
Epoch: [22][0/1]	Time 0.008 (0.008)	Data 0.001 (0.001)	Loss 0.3470 (0.3470)	MAE 1.436 (1.436)
 * MAE 0.000


100%|██████████| 30/30 [00:02<00:00, 14.42it/s]

Epoch: [23][0/1]	Time 0.008 (0.008)	Data 0.001 (0.001)	Loss 0.3348 (0.3348)	MAE 1.414 (1.414)
 * MAE 0.000
Epoch: [24][0/1]	Time 0.010 (0.010)	Data 0.001 (0.001)	Loss 0.3225 (0.3225)	MAE 1.384 (1.384)
 * MAE 0.000
Epoch: [25][0/1]	Time 0.008 (0.008)	Data 0.001 (0.001)	Loss 0.3100 (0.3100)	MAE 1.347 (1.347)
 * MAE 0.000
Epoch: [26][0/1]	Time 0.010 (0.010)	Data 0.001 (0.001)	Loss 0.2965 (0.2965)	MAE 1.304 (1.304)
 * MAE 0.000
Epoch: [27][0/1]	Time 0.009 (0.009)	Data 0.001 (0.001)	Loss 0.2827 (0.2827)	MAE 1.254 (1.254)
 * MAE 0.000
Epoch: [28][0/1]	Time 0.011 (0.011)	Data 0.003 (0.003)	Loss 0.2700 (0.2700)	MAE 1.204 (1.204)
 * MAE 0.000
Epoch: [29][0/1]	Time 0.009 (0.009)	Data 0.001 (0.001)	Loss 0.2572 (0.2572)	MAE 1.162 (1.162)
 * MAE 0.000





In [73]:
ls -l # checkpoint.pth.tar model was generated!!

total 5948
drwx------ 2 root root    4096 Jan 20  2021 [0m[01;34mcgcnn[0m/
-rw------- 1 root root 1006068 Mar  4  2021 checkpoint2.pth.tar
-rw------- 1 root root 1006068 Mar  9  2021 checkpoint_46744_energy_per_atom.pth.tar
-rw------- 1 root root 1003089 Jan 14 09:42 checkpoint.pth.tar
drwx------ 5 root root    4096 Jan 20  2021 [01;34mdata[0m/
-rw------- 1 root root    1065 Jan 20  2021 LICENSE
-rw------- 1 root root   20707 Jan 20  2021 main.py
-rw------- 1 root root 1006068 Mar  4  2021 model_best2.pth.tar
-rw------- 1 root root 1006068 Mar  9  2021 model_best_46744_energy_per_atom.pth.tar
-rw------- 1 root root 1003089 Jan 14 09:42 model_best.pth.tar
-rw------- 1 root root   11219 Jan 20  2021 predict.py
drwx------ 2 root root    4096 Jan 20  2021 [01;34mpre-trained[0m/
-rw------- 1 root root    8225 Jan 20  2021 README.md
drwx------ 2 root root    4096 Mar  9  2021 [01;34mSrTiO3_folder[0m/
-rw------- 1 root root       0 Jan  5 07:11 test_results.csv


In [74]:
# Test best model
print('---------- Evaluate Model on Test Set ----------')
best_checkpoint = torch.load('model_best.pth.tar')
model.load_state_dict(best_checkpoint['state_dict'])
validate(test_loader, model, criterion, normalizer, test = True)

---------- Evaluate Model on Test Set ----------
Test : [0/1]	Time 0.005 (0.005)	Loss 362.1590 (362.1590)	MAE 57.568 (57.568)
 ** MAE 57.568


tensor(57.5677)

# Conclusion

* Original CGCNN Code를 Jupyter Notebook 환경에 맞춰 Refactoring하고 sample data에 대한 regression 학습이 정상작동됨을 확인함.
* Sample data 10개로 학습하였으므로 당연히 Test MAE 값은 매우 크다.

