In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.4.0-py3-none-any.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m43.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.4.0
[0m

In [2]:
import numpy as np
import torch
from torch_geometric.datasets import MD17

In [3]:
name = 'benzene'
dataset = MD17('/notebooks/datasets/MD17', name)
dataset.process()

Downloading http://quantum-machine.org/gdml/data/npz/md17_benzene2017.npz
Processing...
Done!


The dataset contains: 

- The Cartesian `positions` of atoms in Angstrom,
- Their `atomic numbers`
- The total `energy` in kcal/mol
- The `forces` in kcal/mol/Angstrom on each atom.
    
The latter two are the regression targets for this collection. Specifically, the molecules are represented by chunks of 12 lines, each line represents a single atome of a given benzene molecule.

Also, all this data is stores con `MD17/NAME/processed/data.pt`

The following table corresponds to a summary of all the possible data sets handled by MD17.

```python

+--------------------+--------------------+-------------------------------+-----------+
| Molecule           | Level of Theory    | Name                          | #Examples |
+====================+====================+===============================+===========+
| Benzene            | DFT                | :obj:`benzene`                | 627.983   |
+--------------------+--------------------+-------------------------------+-----------+
| Uracil             | DFT                | :obj:`uracil`                 | 133.770   |
+--------------------+--------------------+-------------------------------+-----------+
| Naphthalene        | DFT                | :obj:`napthalene`             | 326.250   |
+--------------------+--------------------+-------------------------------+-----------+
| Aspirin            | DFT                | :obj:`aspirin`                | 211.762   |
+--------------------+--------------------+-------------------------------+-----------+
| Salicylic acid     | DFT                | :obj:`salicylic acid`         | 320.231   |
+--------------------+--------------------+-------------------------------+-----------+
| Malonaldehyde      | DFT                | :obj:`malonaldehyde`          | 993.237   |
+--------------------+--------------------+-------------------------------+-----------+
| Ethanol            | DFT                | :obj:`ethanol`                | 555.092   |
+--------------------+--------------------+-------------------------------+-----------+
| Toluene            | DFT                | :obj:`toluene`                | 442.790   |
+--------------------+--------------------+-------------------------------+-----------+
| Paracetamol        | DFT                | :obj:`paracetamol`            | 106.490   |
+--------------------+--------------------+-------------------------------+-----------+
| Azobenzene         | DFT                | :obj:`azobenzene`             | 99.999    |
+--------------------+--------------------+-------------------------------+-----------+
| Benzene (R)        | DFT (PBE/def2-SVP) | :obj:`revised benzene`        | 100.000   |
+--------------------+--------------------+-------------------------------+-----------+
| Uracil (R)         | DFT (PBE/def2-SVP) | :obj:`revised uracil`         | 100.000   |
+--------------------+--------------------+-------------------------------+-----------+
| Naphthalene (R)    | DFT (PBE/def2-SVP) | :obj:`revised napthalene`     | 100.000   |
+--------------------+--------------------+-------------------------------+-----------+
| Aspirin (R)        | DFT (PBE/def2-SVP) | :obj:`revised aspirin`        | 100.000   |
+--------------------+--------------------+-------------------------------+-----------+
| Salicylic acid (R) | DFT (PBE/def2-SVP) | :obj:`revised salicylic acid` | 100.000   |
+--------------------+--------------------+-------------------------------+-----------+
| Malonaldehyde (R)  | DFT (PBE/def2-SVP) | :obj:`revised malonaldehyde`  | 100.000   |
+--------------------+--------------------+-------------------------------+-----------+
| Ethanol (R)        | DFT (PBE/def2-SVP) | :obj:`revised ethanol`        | 100.000   |
+--------------------+--------------------+-------------------------------+-----------+
| Toluene (R)        | DFT (PBE/def2-SVP) | :obj:`revised toluene`        | 100.000   |
+--------------------+--------------------+-------------------------------+-----------+
| Paracetamol (R)    | DFT (PBE/def2-SVP) | :obj:`revised paracetamol`    | 100.000   |
+--------------------+--------------------+-------------------------------+-----------+
| Azobenzene (R)     | DFT (PBE/def2-SVP) | :obj:`revised azobenzene`     | 99.988    |
+--------------------+--------------------+-------------------------------+-----------+
| Benzene            | CCSD(T)            | :obj:`benzene CCSD(T)`        | 1.500     |
+--------------------+--------------------+-------------------------------+-----------+
| Aspirin            | CCSD               | :obj:`aspirin CCSD`           | 1.500     |
+--------------------+--------------------+-------------------------------+-----------+
| Malonaldehyde      | CCSD(T)            | :obj:`malonaldehyde CCSD(T)`  | 1.500     |
+--------------------+--------------------+-------------------------------+-----------+
| Ethanol            | CCSD(T)            | :obj:`ethanol CCSD(T)`        | 2.000     |
+--------------------+--------------------+-------------------------------+-----------+
| Toluene            | CCSD(T)            | :obj:`toluene CCSD(T)`        | 1.501     |
+--------------------+--------------------+-------------------------------+-----------+
| Benzene            | DFT FHI-aims       | :obj:`benzene FHI-aims`       | 49.863    |
+--------------------+--------------------+-------------------------------+-----------+

```

In [5]:
data = torch.load('/notebooks/datasets/MD17/benzene/processed/data.pt')[0]
print(data)
print(data['pos'])
print(data['z'])
print(data['energy'])
print(data['force'])

Data(pos=[7535796, 3], z=[7535796], energy=[627983], force=[7535796, 3])
tensor([[   0.0000,    1.3970,    0.0000],
        [   1.2098,    0.6985,    0.0000],
        [   1.2098,   -0.6985,    0.0000],
        ...,
        [-177.5760,  196.8930, -167.1030],
        [-176.2220,  199.0660, -167.0190],
        [-177.0470,  200.9420, -165.6810]])
tensor([6, 6, 6,  ..., 1, 1, 1])
tensor([-146536.1094, -146536.1250, -146536.0938,  ..., -146528.2031,
        -146528.1875, -146528.1094])
tensor([[ 4.7385e-11, -3.8260e+00, -6.7172e-13],
        [-3.2493e+00, -1.9419e+00, -3.4315e-13],
        [-3.2493e+00,  1.9419e+00,  6.6742e-13],
        ...,
        [-1.1325e+00,  9.4016e+00,  1.8223e+00],
        [ 2.0346e+01, -4.2068e+00, -4.6041e+00],
        [-3.6244e+00, -4.3245e+00, -1.7331e+01]])
