In [37]:
import argparse
import os
import shutil
import sys
import time
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
from sklearn import metrics
from torch.autograd import Variable
from torch.utils.data import DataLoader

from cgcnn.data import CIFData
from cgcnn.data import collate_pool
from cgcnn.model import CrystalGraphConvNet

In [27]:
args = argparse.Namespace()

args.modelpath = './checkpoint.pth.tar'
args.cifpath = '../data/dichalcogenides_private/cifs'
args.batch_size = 32
args.workers = 1
args.cuda = True
args.print_freq = 10

best_mae_error = 1e10

In [3]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [4]:
def class_eval(prediction, target):
    prediction = np.exp(prediction.numpy())
    target = target.numpy()
    pred_label = np.argmax(prediction, axis=1)
    target_label = np.squeeze(target)
    if prediction.shape[1] == 2:
        precision, recall, fscore, _ = metrics.precision_recall_fscore_support(
            target_label, pred_label, average='binary')
        auc_score = metrics.roc_auc_score(target_label, prediction[:, 1])
        accuracy = metrics.accuracy_score(target_label, pred_label)
    else:
        raise NotImplementedError
    return accuracy, precision, recall, fscore, auc_score


In [5]:

def mae(prediction, target):
    """
    Computes the mean absolute error between prediction and target

    Parameters
    ----------

    prediction: torch.Tensor (N, 1)
    target: torch.Tensor (N, 1)
    """
    return torch.mean(torch.abs(target - prediction))


In [6]:
class Normalizer(object):
    """Normalize a Tensor and restore it later. """
    def __init__(self, tensor):
        """tensor is taken as a sample to calculate the mean and std"""
        self.mean = torch.mean(tensor)
        self.std = torch.std(tensor)

    def norm(self, tensor):
        return (tensor - self.mean) / self.std

    def denorm(self, normed_tensor):
        return normed_tensor * self.std + self.mean

    def state_dict(self):
        return {'mean': self.mean,
                'std': self.std}

    def load_state_dict(self, state_dict):
        self.mean = state_dict['mean']
        self.std = state_dict['std']


In [33]:
def validate(val_loader, model, criterion, normalizer, test=False):
    batch_time = AverageMeter()

    if test:
        test_targets = []
        test_preds = []
        test_cif_ids = []

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for (inp, target, batch_cif_ids) in tqdm(val_loader):
        with torch.no_grad():
            if args.cuda:
                input_var = (Variable(inp[0].cuda(non_blocking=True)),
                             Variable(inp[1].cuda(non_blocking=True)),
                             inp[2].cuda(non_blocking=True),
                             [crys_idx.cuda(non_blocking=True) for crys_idx in inp[3]])
            else:
                input_var = (Variable(inp[0]),
                             Variable(inp[1]),
                             inp[2],
                             inp[3])
        if model_args.task == 'regression':
            target_normed = normalizer.norm(target)
        else:
            target_normed = target.view(-1).long()
        with torch.no_grad():
            if args.cuda:
                target_var = Variable(target_normed.cuda(non_blocking=True))
            else:
                target_var = Variable(target_normed)

        # compute output
        output = model(*input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        test_pred = normalizer.denorm(output.data.cpu())
        test_target = target
        test_preds += test_pred.view(-1).tolist()
        test_targets += test_target.view(-1).tolist()
        test_cif_ids += batch_cif_ids

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
    
    star_label = '**'
    import csv
    with open('test_results.csv', 'w') as f:
        writer = csv.writer(f)
        for cif_id, target, pred in zip(test_cif_ids, test_targets,
                                        test_preds):
            writer.writerow((cif_id, target, pred))
    

In [34]:
model_args = argparse.Namespace()

model_args.atom_fea_len = 64
model_args.n_conv = 3
model_args.h_fea_len = 128
model_args.n_h = 1
model_args.task = 'regression'

In [35]:
def main():
    global args, model_args, best_mae_error

    # load data
    dataset = CIFData(args.cifpath)
    collate_fn = collate_pool
    test_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
                             num_workers=args.workers, collate_fn=collate_fn,
                             pin_memory=args.cuda)

    # build model
    structures, _, _ = dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len,
                                atom_fea_len=model_args.atom_fea_len,
                                n_conv=model_args.n_conv,
                                h_fea_len=model_args.h_fea_len,
                                n_h=model_args.n_h,
                                classification=True if model_args.task ==
                                'classification' else False)
    if args.cuda:
        model.cuda()

    # define loss func and optimizer
    if model_args.task == 'classification':
        criterion = nn.NLLLoss()
    else:
        criterion = nn.MSELoss()
    # if args.optim == 'SGD':
    #     optimizer = optim.SGD(model.parameters(), args.lr,
    #                           momentum=args.momentum,
    #                           weight_decay=args.weight_decay)
    # elif args.optim == 'Adam':
    #     optimizer = optim.Adam(model.parameters(), args.lr,
    #                            weight_decay=args.weight_decay)
    # else:
    #     raise NameError('Only SGD or Adam is allowed as --optim')

    normalizer = Normalizer(torch.zeros(3))

    # optionally resume from a checkpoint
    if os.path.isfile(args.modelpath):
        print("=> loading model '{}'".format(args.modelpath))
        checkpoint = torch.load(args.modelpath,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        normalizer.load_state_dict(checkpoint['normalizer'])
        print("=> loaded model '{}' (epoch {}, validation {})"
              .format(args.modelpath, checkpoint['epoch'],
                      checkpoint['best_mae_error']))
    else:
        print("=> no model found at '{}'".format(args.modelpath))

    validate(test_loader, model, criterion, normalizer, test=True)


In [38]:
main()

=> loading model './checkpoint.pth.tar'
=> loaded model './checkpoint.pth.tar' (epoch 10, validation 0.07258342951536179)


100%|██████████| 93/93 [05:49<00:00,  3.75s/it]
