# Uncomment for Google Colab

# Basic Example

## Initialize the database

In [1]:
from colabfit.tools.database import MongoDatabase

client = MongoDatabase('test', nprocs=1, drop_database=True)

## Attaching a property definition

In [2]:
client.insert_property_definition({
    'property-id': 'energy-forces',
    'property-title': 'A default property for storing energies and forces',
    'property-description': 'Energies and forces computed using DFT',
    'energy': {'type': 'float', 'has-unit': True, 'extent': [], 'required': True, 'description': 'Cohesive energy'},
    'forces': {'type': 'float', 'has-unit': True, 'extent': [':',3], 'required': True, 'description': 'Atomic forces'},
})



In [3]:
client.get_property_definition('energy-forces')['definition']

{'property-id': 'tag:@,0000-00-00:property/energy-forces',
 'property-title': 'A default property for storing energies and forces',
 'property-description': 'Energies and forces computed using DFT',
 'energy': {'type': 'float',
  'has-unit': True,
  'extent': [],
  'required': True,
  'description': 'Cohesive energy'},
 'forces': {'type': 'float',
  'has-unit': True,
  'extent': [':', 3],
  'required': True,
  'description': 'Atomic forces'}}

## Adding data

### Generating configurations

#### Manually

In [4]:
import numpy as np
from colabfit.tools.configuration import Configuration

images = []
for i in range(1, 1000):
    atoms = Configuration(symbols='H'*i, positions=np.random.random((i, 3)))

    atoms.info['_name'] = 'configuration_' + str(i)
    
    atoms.info['dft_energy'] = i*i
    atoms.arrays['dft_forces'] = np.random.normal(size=(i, 3))

    images.append(atoms)

#### Using `load_data()`

In [None]:
from ase.io import write

# outfile = '/content/example.extxyz'   # use this line for ColabFit
outfile = '/tmp/example.extxyz'   # use this line for local runs

write(outfile, images)  # use this line for local runs

In [None]:
from colabfit.tools.database import load_data

images = list(load_data(
    file_path=outfile,  # use this line for local runs
    file_format='xyz',
    name_field='_name',
    elements=['H'],
    default_name=None,
    verbose=True
))

### Defining a `property_map`

In [None]:
property_map = {
    # property name
    'energy-forces': {
        # property field: {'field': configuration info/arrays field, 'units': field units}
        'energy': {'field': 'dft_energy', 'units': 'eV'},
        'forces': {'field': 'dft_forces', 'units': 'eV/Ang'},
    }
}

### `insert_data()`

In [None]:
from colabfit.tools.property_settings import PropertySettings

ids = list(client.insert_data(
    images,
    property_map=property_map,
    property_settings={
        'energy-forces': PropertySettings(
                            method='VASP',
                            description='A basic VASP calculation',
                            files=None,
                            labels=['PBE', 'GGA'],
                        ),
    },
    generator=False,
    verbose=True
))

In [None]:
all_co_ids, all_pr_ids = list(zip(*ids))

len(set(all_co_ids)), len(set(all_pr_ids))

In [None]:
configurations = client.get_configurations('all', verbose=True)

## Defining a `ConfigurationSet`

In [None]:
co_ids = client.get_data('configurations', fields='_id', query={'_id': {'$in': all_co_ids}, 'nsites': {'$lt': 100}}, ravel=True).tolist()

In [None]:
cs_id = client.insert_configuration_set(co_ids, description='Configurations with fewer than 100 atoms')
cs_id

In [None]:
cs = client.get_configuration_set(cs_id)['configuration_set']

In [None]:
cs.description

In [None]:
for k,v in cs.aggregated_info.items():
    print(k, v)

## Creating a `Dataset` from scratch

In [None]:
co_ids1 = client.get_data('configurations', fields='_id', query={'_id': {'$in': all_co_ids}, 'nsites': {'$lt': 100}}, ravel=True).tolist()
co_ids2 = client.get_data('configurations', fields='_id', query={'_id': {'$in': all_co_ids}, 'nsites': {'$gte': 100}}, ravel=True).tolist()

print(len(co_ids1))
print(len(co_ids2))

In [None]:
cs_id1 = client.insert_configuration_set(co_ids1, 'Small configurations')
cs_id2 = client.insert_configuration_set(co_ids2, 'Big configurations')

In [None]:
cs_id1

In [None]:
cs = client.get_configuration_set(cs_id1)['configuration_set']

In [None]:
cs.description

In [None]:
pr_ids = client.get_data(
    'properties',
    fields='_id',
    query={'relationships.configurations': {'$elemMatch': {'$in': co_ids1+co_ids2}}},
    ravel=True
).tolist()
len(pr_ids)

In [None]:
ds_id = client.insert_dataset(
    cs_ids=[cs_id1, cs_id2],
    pr_ids=pr_ids,
    name='basic_example',
    authors=['J. E. Lennard-Jones'],
    links=['https://en.wikipedia.org/wiki/John_Lennard-Jones'],
    description="This is an example dataset",
    resync=True
)
ds_id

In [None]:
ds = client.get_dataset(ds_id)['dataset']

for k,v in ds.aggregated_info.items():
    print(k, v)

## Applying labels to configurations

In [None]:
client.apply_labels(dataset_id=ds_id, collection_name='configurations', query={'nsites': {'$lt': 100}}, labels={'small'}, verbose=True)

In [None]:
cs = client.get_configuration_set(cs_id)['configuration_set']
cs.aggregated_info['labels']

Note: need to resync the configuration set document

In [None]:
cs = client.get_configuration_set(cs_id, resync=True)['configuration_set']
cs.aggregated_info['labels']

In [None]:
client.apply_labels(dataset_id=ds_id, collection_name='configurations', query={}, labels={'random_data'}, verbose=True)

In [None]:
cs = client.get_configuration_set(cs_id, resync=True)['configuration_set']
cs.aggregated_info['labels']

# Exploring the data

## Aggregated data

In [None]:
dataset = client.get_dataset(ds_id, resync=True)['dataset']

In [None]:
for k,v in dataset.aggregated_info.items():
    print(k, v)

In [None]:
client.get_statistics(dataset.aggregated_info['property_fields'], ids=dataset.property_ids)

In [None]:
client.plot_histograms(dataset.aggregated_info['property_fields'], ids=dataset.property_ids)

## Apply transformations to properties

In [None]:
all_co_ids, all_pr_ids = list(zip(*ids))

len(all_co_ids), len(all_pr_ids)

In [None]:
# Convert to per-atom energies
client.apply_transformation(
    dataset_id=ds_id,
    property_ids=all_pr_ids,
    update_map={
        'energy-forces.energy':
        lambda f, doc: f/doc['configuration']['nsites']
    },
    configuration_ids=all_co_ids,
)

In [None]:
client.plot_histograms(['energy-forces.energy', 'energy-forces.forces'], ids=dataset.property_ids)

## Filtering

In [None]:
def ff(pr_doc):
    emax = np.max(np.abs(pr_doc['energy-forces']['energy']['source-value']))
    fmax = np.max(np.abs(pr_doc['energy-forces']['forces']['source-value']))
    return (emax < 100) and (fmax < 3)

In [None]:
dataset = client.get_dataset(ds_id)['dataset']

In [None]:
clean_config_sets, clean_property_ids = client.filter_on_properties(
    ds_id,
    filter_fxn=ff,
    fields=['energy-forces.energy', 'energy-forces.forces'],
    verbose=True
)

In [None]:
new_cs_ids = []
for cs in clean_config_sets:
    if len(cs.configuration_ids):
        new_cs_ids.append(client.insert_configuration_set(cs.configuration_ids, cs.description, verbose=True))
    
print(new_cs_ids)

In [None]:
ds_id_clean = client.insert_dataset(
    cs_ids=new_cs_ids,
    pr_ids=clean_property_ids,
    name='basic_example_filtered',
    authors=['ColabFit'],
    links=[],
    description="A dataset generated during a basic filtering example",
    resync=True,
    verbose=True,
)
ds_id_clean

In [None]:
clean_ds = client.get_dataset(ds_id_clean)['dataset']
len(clean_ds.property_ids)

In [None]:
client.plot_histograms(['energy-forces.energy', 'energy-forces.forces'], ids=clean_ds.property_ids)