In [1]:
import sys
from pathlib import Path

sys.path.append(str(Path().resolve().parent))

In [2]:
from ocpmodels.common.utils import make_trainer_from_conf_str
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter
import pandas as pd
import os
from tqdm.notebook import tqdm

from pymatgen.core.periodic_table import Element
from pymatgen.core.composition import Composition

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fa29c5d1820>

In [3]:
def to_reduced_formula(list_of_z):
    return Composition.from_dict(
        Counter([Element.from_Z(i).symbol for i in list_of_z])
    ).reduced_formula

In [4]:
trainer = make_trainer_from_conf_str(
    "faenet-is2re-all",
    overrides={
        "is_debug": True,
        "graph_rewiring": "",
        "optim": {
            "batch_size": 256,
        },
        "task": {
            "dataset": "stats_lmdb",
        }
    },
)

🏭 Overriding num_workers from 4 to 23 to match the machine's CPUs. Use --no_cpus_to_workers=true to disable this behavior.
Setting max_steps to  21578 from max_epochs (12), dataset length (460328), and batch_size (256)

🗑️ Setting dropout_lin for output block to 0.0
⛄️ No layer to freeze

Using max_steps for scheduler -> 21578


In [5]:
batch = next(iter(trainer.loaders["train"]))[0]

In [6]:
sample = batch.to_data_list()[0]
print(sample)
print(sample.stats["sid"][0])
sample["tags"]

Data(
  edge_index=[2, 2985],
  pos=[65, 3],
  cell=[1, 3, 3],
  atomic_numbers=[65],
  natoms=[1],
  cell_offsets=[2985, 3],
  force=[65, 3],
  distances=[2985],
  fixed=[65],
  sid=[1],
  tags=[65],
  y_init=[1],
  y_relaxed=[1],
  pos_relaxed=[65, 3],
  id='0_256684',
  load_time=[1],
  transform_time=[1],
  total_get_time=[1],
  idx_in_dataset=[1],
  stats={
    atomic_numbers_bulk=[64],
    atomic_numbers_ads=[1],
    composition_bulk='ZnSnN2',
    composition_ads='H2',
    idx_in_dataset=[1],
    sid=[1],
    y_relaxed=[1],
    y_init=[1]
  }
)
2077917


tensor([0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0,
        0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 2])

In [7]:
# %timeit to_reduced_formula(sample.atomic_numbers.int())


In [8]:
# from multiprocessing import Pool


# def make_entry(batch_list):
#     batch = batch_list[0]
#     entries = []
#     for sample in batch.to_data_list():
#         entries.append(
#             {
#                 "atomic_numbers": sample.atomic_numbers.int().tolist(),
#                 "composition": to_reduced_formula(sample.atomic_numbers.int()),
#                 "idx_in_dataset": sample.idx_in_dataset.item(),
#                 "sid": sample.sid.item(),
#                 "y_relaxed": sample.y_relaxed.item(),
#                 "y_init": sample.y_init.item(),
#             }
#         )
#     return entries


# num_workers = trainer.loaders["train"].num_workers * 2
# # iterate over batches by chunks of n_workers
# iterator = iter(trainer.loaders["train"])
# n_iters = len(trainer.loaders["train"]) // num_workers
# if len(trainer.loaders["train"]) % num_workers != 0:
#     n_iters += 1


# entries = []
# for _ in tqdm(range(n_iters)):
#     batch_list = []
#     for _ in tqdm(range(num_workers), leave=False):
#         try:
#             batch_list.append(next(iterator))
#         except StopIteration:
#             break
#     with Pool(num_workers) as p:
#         entries += sum(p.map(make_entry, batch_list), [])

In [9]:
entries = []
for batch_list in tqdm(trainer.loaders["train"]):
    batch = batch_list[0]
    entries.append(batch.stats)

  0%|          | 0/1799 [00:00<?, ?it/s]

In [10]:
flat_entries = {k: [] for k in entries[0]}
for entry in tqdm(entries):
    for k, v in entry.items():
        if len(v[0]) == 1:
            flat_entries[k] += [u[0] for u in v]
        else:
            flat_entries[k] += v
for k, v in flat_entries.items():
    print(k, len(v))

  0%|          | 0/1799 [00:00<?, ?it/s]

atomic_numbers_bulk 460328
atomic_numbers_ads 460328
composition_bulk 460328
composition_ads 460328
idx_in_dataset 460328
sid 460328
y_relaxed 460328
y_init 460328


In [11]:
df = pd.DataFrame(flat_entries) # df = pd.read_json("/network/scratch/s/schmidtv/crystals-proxys/data/is2re/comp.json")
desc = df.describe()
df.info()
df.to_json("/network/scratch/s/schmidtv/crystals-proxys/data/is2re/comp.json")
desc.to_json("/network/scratch/s/schmidtv/crystals-proxys/data/is2re/description.json")

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 460328 entries, 0 to 460327
Data columns (total 8 columns):
 #   Column               Non-Null Count   Dtype  
---  ------               --------------   -----  
 0   atomic_numbers_bulk  460328 non-null  object 
 1   atomic_numbers_ads   460328 non-null  object 
 2   composition_bulk     460328 non-null  object 
 3   composition_ads      460328 non-null  object 
 4   idx_in_dataset       460328 non-null  int64  
 5   sid                  460328 non-null  int64  
 6   y_relaxed            460328 non-null  float64
 7   y_init               460328 non-null  float64
dtypes: float64(2), int64(2), object(4)
memory usage: 28.1+ MB


In [3]:
df = pd.read_json("/network/scratch/s/schmidtv/crystals-proxys/data/is2re/comp.json")

In [4]:
atoms = []
for ats in tqdm(df.atomic_numbers):
    atoms += ats

  0%|          | 0/460328 [00:00<?, ?it/s]

In [9]:
atom_dist = Counter(atoms)
atom_dist = {k: atom_dist[k] for k in sorted(atom_dist.keys())}
atom_dist_named = {Element.from_Z(k).symbol: v for k, v in atom_dist.items()}

In [11]:
df["composition"].sample(10)

178539          HgAsHPd5CO
73657           Ti12H29C2O
351639        Bi7H6Pd14C2O
42635     Fe3Ni3H5(Pt12C)2
49087        Ag3Ge3H2CSe6O
26536         Si11H5Rh17C2
38661                CS18N
443730           Hf17Ni28H
220171            Bi9HIr4C
139300          Ta8Pd16NO2
Name: composition, dtype: object

In [16]:
import pickle as pkl

In [17]:
bulk_db = pkl.load(open("/network/scratch/s/schmidtv/ocp/datasets/ocp/dataset-creation/bulk_db_flat_2021sep20.pkl", "rb"))

In [18]:
len(bulk_db)

11410

In [16]:
# distribution of atomic numbers with element names as xticks
atomic_numbers_counter = Counter(sum(df.atomic_numbers.tolist(), []))

In [29]:
trainer.loaders["train"].batch_sampler.drop_last

False

In [30]:
len(trainer.loaders["train"])

1799

TypeError: 'DataLoader' object is not subscriptable

In [None]:
ce