## GATCNN Implementation in PyG
reference : [github](https://github.com/superlouis/GATGNN/tree/master/gatgnn)

In [1]:
from dataloader_v2 import *
from gatgnn import *


from sklearn.model_selection import train_test_split
import random
import torch
import numpy as np
import pandas as pd
import os
import shutil
import argparse

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error as sk_MAE
from tabulate import tabulate
import time


In [2]:
# SETTING UP CODE TO RUN ON GPU
gpu_id = 0
device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')


In [3]:
# DATALOADER/ TARGET NORMALIZATION
src_CIF = 'CIF-DATA'
random_num = 456
random.seed(random_num)
dataset = pd.read_csv(
    '../data/dichalcogenides_public/targets.csv').sample(frac=1, random_state=random_num)
NORMALIZER = DATA_normalizer(dataset['band_gap'].values)
RSM = {'radius': 4, 'step': 0.5, 'max_num_nbr': 16}


***

utils

In [4]:
def set_model_properties(crystal_property):
    if crystal_property in ['poisson-ratio', 'band-gap', 'absolute-energy', 'fermi-energy', 'formation-energy', 'new-property']:
        norm_action = None
        classification = None
    elif crystal_property == 'is_metal':
        norm_action = 'classification-1'
        classification = 1
    elif crystal_property == 'is_not_metal':
        norm_action = 'classification-0'
        classification = 1
    else:
        norm_action = 'log'
        classification = None
    return norm_action, classification


In [5]:
def torch_MAE(tensor1, tensor2):
    return torch.mean(torch.abs(tensor1-tensor2))


def torch_accuracy(pred_tensor, true_tensor):
    _, pred_tensor = torch.max(pred_tensor, dim=1)
    correct = (pred_tensor == true_tensor).sum().float()
    total = pred_tensor.size(0)
    accuracy_ans = correct/total
    return accuracy_ans


In [6]:
def output_training(metrics_obj, epoch, estop_val, extra='---'):
    header_1, header_2 = 'MSE | e-stop', 'MAE | TIME'
    if metrics_obj.c_property in ['is_metal', 'is_not_metal']:
        header_1, header_2 = 'Cross_E | e-stop', 'Accuracy | TIME'

    train_1, train_2 = metrics_obj.training_loss1[epoch], metrics_obj.training_loss2[epoch]
    valid_1, valid_2 = metrics_obj.valid_loss1[epoch], metrics_obj.valid_loss2[epoch]

    tab_val = [['TRAINING', f'{train_1:.4f}', f'{train_2:.4f}'], [
        'VALIDATION', f'{valid_1:.4f}', f'{valid_2:.4f}'], ['E-STOPPING', f'{estop_val}', f'{extra}']]

    output = tabulate(tab_val, headers=[
                      f'EPOCH # {epoch}', header_1, header_2], tablefmt='fancy_grid')
    print(output)
    return output


In [7]:
def load_metrics():
    saved_metrics = pickle.load(open("MODELS/metrics_.pickle", "rb", -1))
    return saved_metrics


In [8]:
# MOdel Training Early Stopping Function
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=7, verbose=False, increment=0.001, save_best=True, classification=None):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
        """
        self.classification = classification
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf

        self.increment = increment
        self.flag_value = f' *** '
        self.FLAG = None
        self.save_best = save_best

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score <= self.best_score + self.increment:
            if self.classification == None:
                self.increase_measure(val_loss, model, score)
            else:
                self.decrease_measure(val_loss, model, score)
        elif score > self.best_score + self.increment:
            if self.classification == None:
                self.decrease_measure(val_loss, model, score)
            else:
                self.increase_measure(val_loss, model, score)

    def increase_measure(self, val_loss, model, score):
        self.counter += 1
        self.flag_value = f'> {self.counter} / {self.patience}'
        self.FLAG = True
        if self.save_best == False:
            self.save_checkpoint(val_loss, model)
        if self.counter >= self.patience:
            self.early_stop = True

    def decrease_measure(self, val_loss, model, score):
        self.best_score = score
        self.save_checkpoint(val_loss, model)
        self.counter = 0
        self.flag_value = f' *** '

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            pass
        torch.save(model.state_dict(), '../tmp/models/crystal-checkpoint.pt')
        self.val_loss_min = val_loss
        self.FLAG = False


***

In [9]:

CRYSTAL_DATA = CIF_Dataset(
    dataset, root_dir='../data/dichalcogenides_public/cifs/', **RSM)
idx_list = list(range(len(dataset)))
random.shuffle(idx_list)

In [5]:
private_dataset = pd.read_csv('../data/dichalcogenides_private/targets.csv')
private_data = CIF_Dataset(
    private_dataset, root_dir='../data/dichalcogenides_private/cifs/', **RSM)
test_idx = list(range(len(private_dataset)))

In [10]:

train_idx, val_idx = train_test_split(
    idx_list, train_size=0.8, random_state=random_num)


In [6]:

norm_action, classification = set_model_properties('band_gap')


In [19]:

training_set = CIF_Lister(train_idx, CRYSTAL_DATA,
                          NORMALIZER, norm_action, df=dataset, src='MEGNET')
validation_set = CIF_Lister(val_idx, CRYSTAL_DATA,
                            NORMALIZER, norm_action,  df=dataset, src='MEGNET')


In [7]:
test_set = CIF_Lister(test_idx, private_data,
                            NORMALIZER, norm_action,  df=private_dataset, src='MEGNET')

In [13]:
training_set[0]


Data(x=[191, 92], edge_index=[2, 3056], edge_attr=[3056, 9], y=[1], global_feature=[1, 103], cluster=[191], num_atoms=[1], coords=[191, 3], the_idx=[1])

***

In [8]:
# === Model Configs ===
n_heads = 4
number_neurons = 64
number_layers = 3
xtra_l = True
global_att = 'composit'  # ['composit', 'cluster']
attention_technique = 'learnable'  # ['fixed', 'random', 'learnable']
concat_comp = True
data_src = 'MEGNET'  # ['CGCNN','MEGNET','NEW']

# ====================
# === Model Training Configs ===
learning_rate = 5e-3
milestones = [150, 250]
stop_patience = 150
crystal_property = 'band_gap'
num_epochs = 200
train_param = {'batch_size': 32, 'shuffle': True}
valid_param = {'batch_size': 32, 'shuffle': True}


In [24]:

# NEURAL-NETWORK
the_network = GATGNN(n_heads, classification, neurons=number_neurons, nl=number_layers, xtra_layers=xtra_l, global_attention=global_att,
                     unpooling_technique=attention_technique, concat_comp=concat_comp, edge_format=data_src)
net = the_network.to(device)


In [16]:
# LOSS & OPTMIZER & SCHEDULER
if classification == 1:
    criterion = nn.CrossEntropyLoss().cuda()
    funct = torch_accuracy
else:
    criterion = nn.SmoothL1Loss().cuda()
    funct = torch_MAE

optimizer = optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=1e-1)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.3)

# EARLY-STOPPING INITIALIZATION
early_stopping = EarlyStopping(patience=stop_patience, increment=1e-6,
                               verbose=True, save_best=True, classification=classification)

# METRICS-OBJECT INITIALIZATION
metrics = METRICS(crystal_property, num_epochs, criterion, funct, device)


In [None]:
print(f'> TRAINING MODEL ...')
train_loader = torch_DataLoader(dataset=training_set,   **train_param)
valid_loader = torch_DataLoader(dataset=validation_set, **valid_param)
for epoch in range(num_epochs):
    # TRAINING-STAGE
    net.train()
    start_time = time.time()
    for data in train_loader:
        data = data.to(device)
        predictions = net(data)
        train_label = metrics.set_label('training', data)
        loss = metrics('training', predictions, train_label, 1)
        _ = metrics('training', predictions, train_label, 2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metrics.training_counter += 1
    metrics.reset_parameters('training', epoch)
    # VALIDATION-PHASE
    net.eval()
    for data in valid_loader:
        data = data.to(device)
        with torch.no_grad():
            predictions = net(data)
        valid_label = metrics.set_label('validation', data)
        _ = metrics('validation', predictions, valid_label, 1)
        _ = metrics('validation', predictions, valid_label, 2)

        metrics.valid_counter += 1

    metrics.reset_parameters('validation', epoch)
    scheduler.step()
    end_time = time.time()
    e_time = end_time-start_time
    metrics.save_time(e_time)

    # EARLY-STOPPING
    early_stopping(metrics.valid_loss2[epoch], net)
    flag_value = early_stopping.flag_value + \
        '_'*(22-len(early_stopping.flag_value))
    if early_stopping.FLAG == True:
        estop_val = flag_value
    else:
        estop_val = '@best: saving model...'
        best_epoch = epoch+1
    output_training(metrics, epoch, estop_val, f'{e_time:.1f} sec.')

    if early_stopping.early_stop:
        print("> Early stopping")
        break


***

Load & Predict for private dataset

In [9]:
pred_model = GATGNN(n_heads, classification, neurons=number_neurons, nl=number_layers, xtra_layers=xtra_l, global_attention=global_att,
 unpooling_technique=attention_technique, concat_comp=concat_comp, edge_format=data_src)
pred_model.load_state_dict(torch.load('../tmp/models/crystal-checkpoint.pt'))
pred_model = pred_model.to(device)
pred_model.eval()


GATGNN(
  (embed_n): Linear(in_features=92, out_features=64, bias=True)
  (embed_e): Linear(in_features=9, out_features=64, bias=True)
  (embed_comp): Linear(in_features=103, out_features=64, bias=True)
  (node_att): ModuleList(
    (0): GAT_Crystal()
    (1): GAT_Crystal()
    (2): GAT_Crystal()
  )
  (batch_norm): ModuleList(
    (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (cluster_att): CLUSTER_Attention(
    (learn_unpool): Linear(in_features=131, out_features=3, bias=True)
    (layer_1): Linear(in_features=64, out_features=64, bias=True)
    (atten_layer): Linear(in_features=64, out_features=1, bias=True)
  )
  (comp_atten): COMPOSITION_Attention(
    (node_layer1): Linear(in_features=167, out_features=32, bias=True)
    (atten_layer): Linear(in_features=32, o

In [11]:
test_param = {'batch_size': 32, 'shuffle': True}
test_loader = torch_DataLoader(dataset=test_set, **test_param)
pred_arr = []
for data in test_loader:
    data = data.to(device)
    with torch.no_grad():
        bt_preds = pred_model(data)
    pred_arr.extend(list(bt_preds.cpu().numpy()))

In [16]:
finalized_preds = NORMALIZER.denorm(torch.Tensor(pred_arr)).numpy()
finalized_preds

array([0.6117314 , 0.9379415 , 0.9002631 , ..., 0.6145652 , 0.63429594,
       0.8844879 ], dtype=float32)

In [17]:
submission_df = pd.DataFrame({'id': private_dataset['_id'].values, 'predictions': finalized_preds})
submission_df.head()

Unnamed: 0,id,predictions
0,6141cf0efbfd4bd9ab2c2f7e,0.611731
1,6141cf0fe689ecc4c43cdd4b,0.937941
2,6141cf10b842c2e72e2f2d44,0.900263
3,6141cf10b842c2e72e2f2d46,0.614716
4,6141cf1302d926221cabc549,0.614632


In [18]:
submission_df.to_csv('../data/sample_submission1.csv', index=False)

***