In [1]:
import torch
from pathlib import Path
import sys
import ocpmodels
import numpy as np

In [2]:
from ocpmodels.common.utils import (
    build_config,
    create_grid,
    save_experiment_log,
    setup_imports,
    setup_logging,
)
from ocpmodels.common.flags import flags
from ocpmodels.common.registry import registry

In [3]:
def dict2str(d, level=0, spaces=4, margin=30):
    """
    simple recursive dict printing util
    (victor)
    """
    s = ""
    for k, v in d.items():
        s += f"{' ' *spaces * level}{k:{margin-spaces*level}}: "
        if not isinstance(v, dict):
            s += str(v)
        else:
            s += "\n" + dict2str(v, level + 1, spaces, margin)
        s += "\n"
    return s


# Config

In [4]:
sys.argv.append("--mode=train")
sys.argv.append("--config=configs/is2re/10k/schnet/schnet.yml")
setup_logging()

In [5]:
parser = flags.get_parser()
args, override_args = parser.parse_known_args()
config = build_config(args, override_args)
print("Config:", "\n" + dict2str(config))

Config: 
trainer                       : energy
dataset                       : [{'src': '/network/projects/_groups/ocp/oc20/is2re/10k/train/data.lmdb', 'normalize_labels': True, 'target_mean': -1.525913953781128, 'target_std': 2.279365062713623}, {'src': '/network/projects/_groups/ocp/oc20/is2re/all/val_id/data.lmdb'}]
logger                        : tensorboard
task                          : 
    dataset                   : single_point_lmdb
    description               : Relaxed state energy prediction from initial structure.
    type                      : regression
    metric                    : mae
    labels                    : ['relaxed energy']

model                         : 
    name                      : schnet
    hidden_channels           : 256
    num_filters               : 128
    num_interactions          : 3
    num_gaussians             : 100
    cutoff                    : 6.0
    use_pbc                   : True
    regress_forces            : False

optim 

### Override config

In [6]:
config["optim"]["num_workers"] = 4
config["optim"]["batch_size"] = 1

### Make trainer and task

In [7]:
setup_imports()
trainer = registry.get_trainer_class(config.get("trainer", "energy"))(
                task=config["task"],
                model=config["model"],
                dataset=config["dataset"],
                optimizer=config["optim"],
                identifier=config["identifier"],
                timestamp_id=config.get("timestamp_id", None),
                run_dir=config.get("run_dir", "./"),
                is_debug=config.get("is_debug", False),
                print_every=config.get("print_every", 10),
                seed=config.get("seed", 0),
                logger=config.get("logger", "tensorboard"),
                local_rank=config["local_rank"],
                amp=config.get("amp", False),
                cpu=config.get("cpu", False),
                slurm=config.get("slurm", {}),
            )

amp: false
cmd:
  checkpoint_dir: ./checkpoints/2022-04-12-09-56-16
  commit: bd247bc
  identifier: ''
  logs_dir: ./logs/tensorboard/2022-04-12-09-56-16
  print_every: 10
  results_dir: ./results/2022-04-12-09-56-16
  seed: 0
  timestamp_id: 2022-04-12-09-56-16
dataset:
  normalize_labels: true
  src: /network/projects/_groups/ocp/oc20/is2re/10k/train/data.lmdb
  target_mean: -1.525913953781128
  target_std: 2.279365062713623
gpus: 1
logger: tensorboard
model: schnet
model_attributes:
  cutoff: 6.0
  hidden_channels: 256
  num_filters: 128
  num_gaussians: 100
  num_interactions: 3
  regress_forces: false
  use_pbc: true
optim:
  batch_size: 1
  eval_batch_size: 64
  lr_gamma: 0.1
  lr_initial: 0.005
  lr_milestones:
  - 1562
  - 2343
  - 3125
  max_epochs: 30
  num_workers: 4
  warmup_factor: 0.2
  warmup_steps: 468
slurm: {}
task:
  dataset: single_point_lmdb
  description: Relaxed state energy prediction from initial structure.
  labels:
  - relaxed energy
  metric: mae
  type: reg



2022-04-12 09:56:11 (INFO): Loaded SchNetWrap with 541697 parameters.




In [8]:
task = registry.get_task_class(config["mode"])(config)

In [9]:
task.setup(trainer)

# Explore Trainer

In [10]:
print(trainer.__class__.__name__)

EnergyTrainer


## Data

Data explanations can also be found in `scripts/README_is2res.md`

In [11]:
# get 1 batch
for batch in trainer.train_loader:
    break

In [12]:
print(len(batch)) # 1?
b = batch[0]

1


### Batch contents

In [21]:
# what's a batch item?
b

DataBatch(edge_index=[2, 1372], pos=[40, 3], cell=[1, 3, 3], atomic_numbers=[40], natoms=[1], cell_offsets=[1372, 3], force=[40, 3], distances=[1372], fixed=[40], sid=[1], tags=[40], y_init=[1], y_relaxed=[1], pos_relaxed=[40, 3], batch=[40], ptr=[2], neighbors=[1])

In [43]:
print("`edge_index` contains the edges of all graphs in the batch:")
print(b.edge_index.shape)
print(b.edge_index)

`edge_index` contains the edges of all graphs in the batch:
torch.Size([2, 170988])
tensor([[  19,   18,   14,  ..., 4570, 4514, 4530],
        [   0,    0,    0,  ..., 4584, 4584, 4584]], device='cuda:0')


In [44]:
print("`batch` contains the graph id of each atom in the batch:")
print(b.batch.shape)
print(b.batch)

`batch` contains the graph id of each atom in the batch:
torch.Size([4585])
tensor([ 0,  0,  0,  ..., 63, 63, 63], device='cuda:0')


In [68]:
print("`pos` and `pos_relaxed` contain the 3D position of each atom in the batch, respectively initially or in the relaxed state:")
print(b.pos.shape, b.pos_relaxed.shape)
print(b.pos[:3])
print(b.pos_relaxed[:3])

`pos` and `pos_relaxed` contain the 3D position of each atom in the batch, respectively initially or in the relaxed state:
torch.Size([4585, 3]) torch.Size([4585, 3])
tensor([[ 6.3715,  1.4460, 16.5523],
        [ 9.9958,  1.4460, 18.7673],
        [ 2.7473,  1.4460, 14.3374]], device='cuda:0')
tensor([[ 6.3715,  1.4460, 16.5523],
        [ 9.9823,  1.3752, 18.7785],
        [ 2.7473,  1.4460, 14.3374]], device='cuda:0')


In [48]:
print("`cell` contains the 3D cell dimensions of each graph:")
print(b.cell.shape)
print(b.cell[:2])

`cell` contains the 3D cell dimensions of each graph:
torch.Size([64, 3, 3])
tensor([[[10.8727, -0.0000, -1.5002],
         [-3.6242,  7.3342, -2.2150],
         [ 0.0000,  0.0000, 32.5804]],

        [[10.8287,  0.0000,  0.0000],
         [ 0.0000, 10.9216,  2.1019],
         [ 0.0000,  0.0000, 29.4261]]], device='cuda:0')


In [50]:
print("`atomic_numbers` contains the atomic number of each atom in the batch:")
print(b.atomic_numbers.shape)
print(b.atomic_numbers)

`atomic_numbers` contains the atomic number of each atom in the batch:
torch.Size([4585])
tensor([39., 39., 39.,  ...,  1.,  1.,  8.], device='cuda:0')


In [53]:
print("`natoms` contains the number of atoms in each graph in the batch:")
print(b.natoms.shape)
print(b.natoms)
print(b.natoms.sum())

`natoms` contains the number of atoms in each graph in the batch:
torch.Size([64])
tensor([ 40,  55,  43,  56,  78, 111,  70,  63,  67,  99, 100,  52,  60,  90,
         40,  69,  77,  99,  52,  85, 100,  32,  50,  78,  67,  66,  68,  79,
        114,  53,  54,  79,  97, 183,  92,  67,  75,  50,  10, 100,  64,  51,
         87,  54,  70,  87,  65, 103,  70,  70,  65,  69,  52,  98,  59,  49,
         75,  69,  40,  67,  71,  32,  85, 113], device='cuda:0')
tensor(4585, device='cuda:0')


In [160]:
print("`cell_offsets` contains the 3D 'cell offset' of each edge in the batch:\n")
print("[n_edges x 3] offset matrix where each index corresponds to the unit cell offset necessary to find the corresponding neighbor in  `edge_index`. For example,  `cell_offsets[0, :] = [0,1,0]` corresponds to `edge_index[:, 0]= [1,0]` representing node 1 as node 0’s neighbor located one unit cell over in the +y direction.\n")
print(b.cell_offsets.shape)
print(b.cell_offsets)
print(torch.unique(b.cell_offsets))

print()
print()
print()

print(b.cell_offsets[38, :])
print(b.edge_index[:, 38])
print(b.cell[0])
print(b.distances[36])
print(torch.norm(b.cell[0][1]))

`cell_offsets` contains the 3D 'cell offset' of each edge in the batch:

[n_edges x 3] offset matrix where each index corresponds to the unit cell offset necessary to find the corresponding neighbor in  `edge_index`. For example,  `cell_offsets[0, :] = [0,1,0]` corresponds to `edge_index[:, 0]= [1,0]` representing node 1 as node 0’s neighbor located one unit cell over in the +y direction.

torch.Size([170988, 3])
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]], device='cuda:0')
tensor([-1,  0,  1], device='cuda:0')



tensor([ 0, -1,  0], device='cuda:0')
tensor([10,  0], device='cuda:0')
tensor([[10.8727, -0.0000, -1.5002],
        [-3.6242,  7.3342, -2.2150],
        [ 0.0000,  0.0000, 32.5804]], device='cuda:0')
tensor(5.5423, device='cuda:0')
tensor(8.4754, device='cuda:0')


In [161]:
print("`force` contains the 3D forces 'experienced' by each atom in the batch:")
print(b.force.shape)
print(b.force)

`force` contains the 3D forces 'experienced' by each atom in the batch:
torch.Size([4585, 3])
tensor([ 0.2466, -0.0354,  0.4422], device='cuda:0')


In [58]:
print("`distances` contains the length (A?) of each edge in the batch")
print(b.distances.shape)
print(b.distances)

`distances` contains the length (A?) of each edge in the batch
torch.Size([170988])
tensor([2.9223, 2.9223, 2.9319,  ..., 5.9122, 5.9325, 5.9584], device='cuda:0')


In [59]:
print("`fixed` contains the boolean flag for each atom being fixed or not")
print(b.fixed.shape)
print(b.fixed)

`fixed` contains the boolean flag for each atom being fixed or not
torch.Size([4585])
tensor([1., 0., 1.,  ..., 0., 0., 0.], device='cuda:0')


In [62]:
print("`sid` contains the system id associated with each graph in the batch")
print(b.sid.shape)
print(b.sid[:5])

`sid` contains the system id associated with each graph in the batch
torch.Size([64])
tensor([1831766, 1974982,  589818, 2158881,  550912], device='cuda:0')


In [63]:
print("`tags` contains the tag of each atom: 0 - Fixed, sub-surface atoms, 1 - Free, surface atoms 2 - Free, adsorbate atoms")
print(b.tags.shape)
print(b.tags)

`tags` contains the tag of each atom: 0 - Fixed, sub-surface atoms, 1 - Free, surface atoms 2 - Free, adsorbate atoms
torch.Size([4585])
tensor([0, 1, 0,  ..., 2, 2, 2], device='cuda:0')


In [65]:
print("`y_init` and `y_relaxed` respectively contain the initial and relaxed energies of each graph:")
print(b.y_init.shape, b.y_relaxed.shape)
print(b.y_init[:5], b.y_relaxed[:5])

`y_init` and `y_relaxed` respectively contain the initial and relaxed energies of each graph:
torch.Size([64]) torch.Size([64])
tensor([0.5247, 2.9358, 0.5761, 0.1011, 3.9838], device='cuda:0') tensor([-2.6237, -0.4135, -2.4655, -3.4437, -0.8626], device='cuda:0')


In [69]:
print("`ptr` ??")
print(b.ptr.shape)
print(b.ptr)

`ptr` ??
torch.Size([65])
tensor([   0,   40,   95,  138,  194,  272,  383,  453,  516,  583,  682,  782,
         834,  894,  984, 1024, 1093, 1170, 1269, 1321, 1406, 1506, 1538, 1588,
        1666, 1733, 1799, 1867, 1946, 2060, 2113, 2167, 2246, 2343, 2526, 2618,
        2685, 2760, 2810, 2820, 2920, 2984, 3035, 3122, 3176, 3246, 3333, 3398,
        3501, 3571, 3641, 3706, 3775, 3827, 3925, 3984, 4033, 4108, 4177, 4217,
        4284, 4355, 4387, 4472, 4585], device='cuda:0')


In [162]:
print("`neighbors` ??")
print(b.neighbors.shape)
print(b.neighbors)
print(b.neighbors.sum())

`neighbors` ??
torch.Size([64])
tensor([1372, 1634, 1813, 2502, 2300, 5449, 3109, 1976, 1640, 4159, 2952,  880,
        2677, 2684, 1086, 2954, 3554, 4461, 2455, 3349, 2812,  894, 2024, 2734,
        1760, 2691, 2576, 1466, 2760, 2254, 1604, 2504, 4218, 7825, 4034, 2290,
        3070, 2100,   90, 4486, 2920, 1890, 2012, 2180, 2886, 3236, 2741, 2954,
        2056, 3220, 1352, 3212, 2217, 4535, 2602, 1816, 3114, 3243, 1292, 2912,
        2396,  378, 3330, 5296], device='cuda:0')
tensor(170988, device='cuda:0')


### Find atoms in graphs

In [164]:
# select a graph you care about. graph_id is equivalent to a batch index
graph_id = 0
# find the atom indices of that graph
atoms = torch.argwhere(b.batch == graph_id).squeeze()
# the number of atoms selected above should match natoms
assert len(atoms) == b.natoms[graph_id]
print(atoms)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39], device='cuda:0')


In [92]:
# find the indices of edges in that graph
# ie edges whose atoms are in the graph's atoms list
# (assumes an edge can not link 2 atoms from different graphs)
edge_indices = torch.isin(b.edge_index[0, :], atoms)
# select the actual edges from the indices
edges = b.edge_index[:, edge_indices]
print(edge_indices.shape)
print(edge_indices)
print(edges.shape)
print(edges)


torch.Size([170988])
tensor([ True,  True,  True,  ..., False, False, False], device='cuda:0')
torch.Size([2, 1372])
tensor([[19, 18, 14,  ..., 25, 19, 14],
        [ 0,  0,  0,  ..., 39, 39, 39]], device='cuda:0')


In [87]:
# find the forces experienced by each atom in the graph
forces = b.force[atoms]
print(forces.shape)

torch.Size([40, 3])


In [93]:
# find the length of each edge
distances = b.distances[edge_indices]
# the number of distances must match the number of edges
assert len(distances) == edges.shape[1]
print(distances.shape)

torch.Size([1372])


In [104]:
# find the tag of each graph atom
tags = b.tags[atoms]
print(tags)
print(list(np.array(["fixed", "surface", "adsorbate"])[tags.cpu().numpy()]))


tensor([0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0,
        0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 2, 2, 2, 2], device='cuda:0')
['fixed', 'surface', 'fixed', 'fixed', 'fixed', 'surface', 'surface', 'fixed', 'fixed', 'fixed', 'fixed', 'surface', 'surface', 'fixed', 'surface', 'fixed', 'surface', 'fixed', 'fixed', 'fixed', 'surface', 'surface', 'fixed', 'fixed', 'fixed', 'surface', 'fixed', 'fixed', 'fixed', 'surface', 'surface', 'fixed', 'fixed', 'fixed', 'fixed', 'surface', 'adsorbate', 'adsorbate', 'adsorbate', 'adsorbate']


In [108]:
# find adsorbate atomic numbers
adsorbates_numbers = b.atomic_numbers[atoms][tags == 2]
print(adsorbates_numbers)

tensor([6., 6., 1., 1.], device='cuda:0')


In [124]:
from ase import Atoms

x = Atoms(numbers=[int(i) for i in adsorbates_numbers])

In [165]:
symbols = []
for gid in range(len(b.sid)):
    a = torch.argwhere(b.batch == gid).squeeze()
    t  = b.tags[a]
    an = b.atomic_numbers[a][t == 2]
    symbols.append(str(Atoms(numbers=[int(i) for i in an]).symbols))

from collections import Counter

sym_counts = Counter(symbols)
print("H2O:", sym_counts.get("H2O", 0))
print("H2:", sym_counts.get("H2", 0))
print()
print(dict2str(sym_counts, margin=10))

H2O: 1
H2: 0

C2H2      : 3
C2H3O2    : 4
C2H       : 2
C2H5O     : 4
C2H2O2    : 4
N2O       : 2
C2H5      : 3
C2H5O2    : 1
NO2       : 1
H2O       : 1
CH3       : 3
C2H4O2    : 2
CH4       : 4
C2HO2     : 2
CH2       : 2
H2N2      : 1
NO        : 3
N2HO      : 2
C2H4O     : 3
CN        : 1
C2O       : 3
N2H       : 1
NH        : 1
N2H2      : 1
C2H3O     : 3
O4N2      : 1
N         : 1
C2        : 1
C2H2O     : 1
O         : 1
NH3O      : 2



In [17]:
# forward the batch, compute loss and metrics

out = trainer._forward(batch)
loss = trainer._compute_loss(out, batch)
metrics = trainer.compute_metrics(
    out,
    batch,
    trainer.evaluator,
    metrics={},
)


In [30]:
# explore contents of predictions, loss, metrics

print(out.keys())
print(out["energy"].shape, out["energy"].dtype)
print(loss)
print(dict2str(metrics))

dict_keys(['energy'])
torch.Size([64]) torch.float32
tensor(48.8437, device='cuda:0', grad_fn=<L1LossBackward0>)
energy_mae                    : 
    metric                    : 111.33262634277344
    total                     : 7125.2880859375
    numel                     : 64

energy_mse                    : 
    metric                    : 21693.33984375
    total                     : 1388373.75
    numel                     : 64

energy_within_threshold       : 
    metric                    : 0.0
    total                     : 0
    numel                     : 64




## LMDBs

In [13]:
def ls(p):

    assert Path(p).is_dir()
    return list(Path(p).iterdir())

In [14]:
d = trainer.train_loader.dataset

In [15]:
d.path

PosixPath('/network/projects/_groups/ocp/oc20/is2re/10k/train/data.lmdb')

In [16]:
ls(d.path.parent)

[PosixPath('/network/projects/_groups/ocp/oc20/is2re/10k/train/data.lmdb-lock'),
 PosixPath('/network/projects/_groups/ocp/oc20/is2re/10k/train/data.lmdb')]

In [17]:
d.env

<Environment at 0x7fcb07444210>

In [18]:
import lmdb



In [19]:
d.env.stat()["entries"]

10000

In [20]:
f"{0}".encode("ascii")

b'0'

In [21]:
import pickle

sample = pickle.loads(d.env.begin().get(f"{0}".encode("ascii")))

In [26]:
from torch_geometric.data import Data

sample = Data(**{k: v for k, v in sample.__dict__.items() if v is not None})

In [27]:
sample

Data(edge_index=[2, 2964], pos=[86, 3], cell=[1, 3, 3], atomic_numbers=[86], natoms=86, cell_offsets=[2964, 3], force=[86, 3], distances=[2964], fixed=[86], sid=2472718, tags=[86], y_init=6.282500615000004, y_relaxed=-0.025550085000020317, pos_relaxed=[86, 3])