In [1]:
import json
from pathlib import Path
import pandas as pd
import numpy as np
from pymatgen.core import Structure
from megnet.data.crystal import CrystalGraph
from megnet.data.graph import GaussianDistance
import seaborn as sns

In [2]:
train_dir = Path('data/dichalcogenides_public/')
train = pd.read_csv(train_dir/'targets.csv')

In [3]:
train.head()

Unnamed: 0,_id,band_gap
0,6141cf0f51c1cbd9654b8870,1.0843
1,6141cf1051c1cbd9654b8872,1.1102
2,6141cf11cc0e69a0cf28ab35,1.1484
3,6141cf11b842c2e72e2f2d48,1.8068
4,6141cf11ae4fb853db2e3f14,0.36


In [4]:
def read_pymatgen_dict(file):
    with open(file, "r") as f:
        d = json.load(f)
    return Structure.from_dict(d)

In [13]:
structure = read_pymatgen_dict(train_dir/f'structures/{train["_id"].values[0]}.json')
structure

Structure Summary
Lattice
    abc : 25.5225256 25.5225256 14.879004
 angles : 90.0 90.0 119.99999999999999
 volume : 8393.668021812642
      A : 25.5225256 0.0 1.5628039641098191e-15
      B : -12.761262799999994 22.10315553833868 1.5628039641098191e-15
      C : 0.0 0.0 14.879004
PeriodicSite: Mo (0.0000, 1.8419, 3.7198) [0.0417, 0.0833, 0.2500]
PeriodicSite: Mo (-1.5952, 4.6048, 3.7198) [0.0417, 0.2083, 0.2500]
PeriodicSite: Mo (-3.1903, 7.3677, 3.7198) [0.0417, 0.3333, 0.2500]
PeriodicSite: Mo (-4.7855, 10.1306, 3.7198) [0.0417, 0.4583, 0.2500]
PeriodicSite: Mo (-6.3806, 12.8935, 3.7198) [0.0417, 0.5833, 0.2500]
PeriodicSite: Mo (-7.9758, 15.6564, 3.7198) [0.0417, 0.7083, 0.2500]
PeriodicSite: Mo (-9.5709, 18.4193, 3.7198) [0.0417, 0.8333, 0.2500]
PeriodicSite: Mo (-11.1661, 21.1822, 3.7198) [0.0417, 0.9583, 0.2500]
PeriodicSite: Mo (3.1903, 1.8419, 3.7198) [0.1667, 0.0833, 0.2500]
PeriodicSite: Mo (1.5952, 4.6048, 3.7198) [0.1667, 0.2083, 0.2500]
PeriodicSite: Mo (0.0000, 7.3677, 3

In [14]:
converter = CrystalGraph(cutoff=4, bond_converter=GaussianDistance(np.linspace(0, 5, 100), 1))
res = converter(structure)

In [24]:
np.unique(res['atom'])

array([16, 42])

In [35]:
import torch
from torch_geometric.utils import to_networkx
from torch_geometric.data import Data
import networkx as nx

In [12]:
# d = Data(
#     x=torch.tensor(res['atom']).reshape(-1, 1),
#     edge_index=torch.tensor(np.vstack([res['index1'], res['index2']])),
#     edge_attr=torch.tensor(res['bond'])
# )
# g = to_networkx(d, to_undirected=True)
# nx.draw(g)

In [7]:
data = pd.read_pickle('data/data_cache_cutoff4.pickle')

In [5]:
from monty.serialization import loadfn

In [6]:
mp = loadfn('data/mp.2018.6.1.json')

In [7]:
len(mp)

69239

In [40]:
atom_count = {}
for i in range(1500):
    structure = read_pymatgen_dict(train_dir/f'structures/{train["_id"].values[i]}.json')
    atom_dict = pd.Series(converter(structure)['atom']).value_counts().to_dict()
    for k, v in atom_dict.items():
        if k in atom_count.keys():
            atom_count[k] += v
        else:
            atom_count[k] = v
atom_count

{16: 189032, 42: 94538, 34: 1611, 74: 789}

In [41]:
atom_count = {}
for i in range(50000):
    atom_dict = pd.Series(mp[i]['graph']['atom']).value_counts().to_dict()
    for k, v in atom_dict.items():
        if k in atom_count.keys():
            atom_count[k] += v
        else:
            atom_count[k] = v

In [42]:
atom_count[16], atom_count[42], atom_count[34], atom_count[74]

(47903, 6818, 20575, 5928)

In [8]:
matched_dataset = []
for crystal in mp:
    atoms = crystal['graph']['atom']
    # threshold = min(20, int(len(atoms)/2)) # maximum other atoms
    threshold = 1
    counter = 0
    for a in atoms:
        if a not in [16, 34, 42, 74]:
            counter += 1
        if counter == threshold:
            counter = 0
            break
    else:
        matched_dataset.append({
            'graph': crystal['graph'], 
            'band_gap': crystal['band_gap']
        })
matched_dataset = pd.DataFrame(matched_dataset)

In [94]:
# matched_dataset.to_pickle('data/mp2018_matched.pickle')

In [9]:
len(matched_dataset)

246

In [8]:
def make_meta_feature(graph):
    record = {
        'ratio16': 0, 
        'ratio42': 0, 
        'ratio34': 0, 
        'ratio74': 0,
    }
    atoms = pd.Series(graph['atom'])
    atom_count = len(atoms)
    atom_dict = (atoms.value_counts() / atom_count).to_dict()
    for k, v in atom_dict.items():
        record[f'ratio{k}'] = v
    record['atom_count'] = atom_count
    record['link_count'] = len(graph['bond'])
    return pd.Series(record)

In [9]:
train_test = data[['id']].merge(train, left_on='id', right_on='_id', how='outer').drop('_id', axis=1)
train_test

Unnamed: 0,id,band_gap
0,6141cf72cc0e69a0cf28ab65,1.1379
1,6142706cbaaf234b352906e2,1.1097
2,6141f8fd31cf3ef3d4a9f1ac,0.3205
3,6141e726baaf234b35290474,1.1176
4,6141d3be4e27a1844a5f0188,1.1467
...,...,...
5928,6142615131cf3ef3d4a9f4b8,
5929,6141f4354e27a1844a5f046c,
5930,6141e31a4e27a1844a5f0346,
5931,61429f0531cf3ef3d4a9f5b4,


In [10]:
train_test = pd.concat([train_test, data['graph'].apply(make_meta_feature)], axis=1)
train_test

Unnamed: 0,id,band_gap,ratio16,ratio42,ratio34,ratio74,atom_count,link_count
0,6141cf72cc0e69a0cf28ab65,1.1379,0.659686,0.335079,0.005236,0.000000,191.0,2028.0
1,6142706cbaaf234b352906e2,1.1097,0.663158,0.331579,0.000000,0.005263,190.0,2008.0
2,6141f8fd31cf3ef3d4a9f1ac,0.3205,0.663158,0.331579,0.005263,0.000000,190.0,2004.0
3,6141e726baaf234b35290474,1.1176,0.663158,0.331579,0.000000,0.005263,190.0,2008.0
4,6141d3be4e27a1844a5f0188,1.1467,0.659686,0.329843,0.005236,0.005236,191.0,2028.0
...,...,...,...,...,...,...,...,...
5928,6142615131cf3ef3d4a9f4b8,,0.663158,0.331579,0.005263,0.000000,190.0,2004.0
5929,6141f4354e27a1844a5f046c,,0.659686,0.329843,0.005236,0.005236,191.0,2028.0
5930,6141e31a4e27a1844a5f0346,,0.659686,0.329843,0.005236,0.005236,191.0,2028.0
5931,61429f0531cf3ef3d4a9f5b4,,0.663158,0.331579,0.005263,0.000000,190.0,2004.0


In [11]:
train_meta = train_test.loc[~train_test['band_gap'].isna()]
test_meta = train_test.loc[train_test['band_gap'].isna()]
train_meta

Unnamed: 0,id,band_gap,ratio16,ratio42,ratio34,ratio74,atom_count,link_count
0,6141cf72cc0e69a0cf28ab65,1.1379,0.659686,0.335079,0.005236,0.000000,191.0,2028.0
1,6142706cbaaf234b352906e2,1.1097,0.663158,0.331579,0.000000,0.005263,190.0,2008.0
2,6141f8fd31cf3ef3d4a9f1ac,0.3205,0.663158,0.331579,0.005263,0.000000,190.0,2004.0
3,6141e726baaf234b35290474,1.1176,0.663158,0.331579,0.000000,0.005263,190.0,2008.0
4,6141d3be4e27a1844a5f0188,1.1467,0.659686,0.329843,0.005236,0.005236,191.0,2028.0
...,...,...,...,...,...,...,...,...
2961,6143b24631cf3ef3d4a9f78e,0.4186,0.663158,0.331579,0.005263,0.000000,190.0,2004.0
2962,614217f84e27a1844a5f05f2,1.1467,0.659686,0.329843,0.005236,0.005236,191.0,2028.0
2963,6141d829baaf234b35290372,1.1057,0.663158,0.331579,0.000000,0.005263,190.0,2008.0
2964,61420dc34e27a1844a5f0592,1.1455,0.659686,0.329843,0.005236,0.005236,191.0,2028.0


In [14]:
# train[['_id']].merge(train_meta, left_on='_id', right_on='id', how='left').drop('_id', axis=1).to_csv('data/train_meta.csv', index=False)
# test_meta.to_csv('data/test_meta.csv', index=False)

In [83]:
replace_dict = {}
for gname, gdf in train_meta.groupby(
    ['ratio16', 'ratio42', 'ratio34', 'ratio74', 'atom_count', 'link_count']):
    # display(gdf)
    diff = gdf['band_gap'].max() - gdf['band_gap'].min()
    if diff > 0.02:
        # print(gname)
        # print(f'counter: {len(gdf)}')
        # print(f'diff: {diff}')
        # replace_dict[gname] = gdf['band_gap'].mean()
        # display(gdf)
        pass
replace_dict
    

{}

In [84]:
submission = pd.read_csv('results/exp_11/submission.csv')
test_meta = test_meta.merge(submission, on='id', how='left')
test_meta

Unnamed: 0,id,band_gap,ratio16,ratio42,ratio34,ratio74,atom_count,link_count,predictions
0,6141cf9631cf3ef3d4a9edb4,,0.663158,0.331579,0.005263,0.000000,190.0,3510.0,0.282416
1,6141d2fd9cbada84a8676921,,0.656250,0.328125,0.010417,0.005208,192.0,3584.0,1.807065
2,6142341931cf3ef3d4a9f3c0,,0.663158,0.331579,0.005263,0.000000,190.0,3510.0,0.402278
3,6142199f4e27a1844a5f05fa,,0.659686,0.329843,0.005236,0.005236,191.0,3546.0,1.142626
4,6141d441ee0a3fd43fb47b65,,0.659686,0.329843,0.010471,0.000000,191.0,3548.0,0.365480
...,...,...,...,...,...,...,...,...,...
2962,6142615131cf3ef3d4a9f4b8,,0.663158,0.331579,0.005263,0.000000,190.0,3510.0,0.402219
2963,6141f4354e27a1844a5f046c,,0.659686,0.329843,0.005236,0.005236,191.0,3546.0,1.142377
2964,6141e31a4e27a1844a5f0346,,0.659686,0.329843,0.005236,0.005236,191.0,3546.0,1.142564
2965,61429f0531cf3ef3d4a9f5b4,,0.663158,0.331579,0.005263,0.000000,190.0,3510.0,0.402235


In [85]:
# for gname, gdf in test_meta.groupby(
#     ['ratio16', 'ratio42', 'ratio34', 'ratio74', 'atom_count', 'link_count']):
#     if gname in replace_dict.keys():
#         gdf['band_gap'] = replace_dict[gname]
#         display(gdf.loc[np.abs(gdf['band_gap'] - gdf['predictions']) >= 0.0001])