## Step 1: Get structure(s) using OPTiMaDe

In [None]:
from __future__ import print_function
from __future__ import with_statement

from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)

from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.data.structure import StructureData, Kind, Site
from aiida.orm.data.cif import CifData
from aiida.orm.calculation import Calculation

from aiida.tools.dbimporters.plugins.cod import CodDbImporter

import ase.io
import numpy as np
import ipywidgets as ipw
from base64 import b64decode
from IPython.display import display, clear_output, Image, HTML
from fileupload import FileUploadWidget

import nglview

import requests
import tempfile

In [None]:
atoms = None
structures = [("select structure",{"status":False})]
databases = [
    ("Crystallography Open Database (CoD)",{
        "name": "cod",
        "url": "http://www.crystallography.net/cod/optimade",
        "importer": None
    }),
    ("AiiDA @ localhost:5000",{
        "name": "aiida",
        "url": "http://127.0.0.1:5000/optimade",
        "importer": None
    }),
    ("Custom",{
        "name": "custom",
        "url": "http://26c3722d.ngrok.io/optimade",
        "importer": None
    })
]
query_db = databases[0][1]  # CoD is default
min_api_version = (0,9,5)

layout = ipw.Layout(width="400px")
#style = {"description_width":"100px"}

viewer = nglview.NGLWidget()
clear_output()

In [None]:
class ApiVersionError(Exception):
    """ Wrong API Version """
    pass

In [None]:
class OptimadeImporter():
    """
    OPTiMaDe v0.9.5 and v0.9.7a
    """
    
    def __init__(self,**kwargs):
        
        self.db = kwargs["name"] if "name" in kwargs else "cod"
        
        self.db_baseurl = kwargs["url"] if "url" in kwargs else \
            "http://www.crystallography.net/cod/optimade"
        
        self.api_version = self._set_api_version()
        
    def _set_api_version(self):
        endpoint = "/info"
        url = ''.join([self.db_baseurl, endpoint])
        r = requests.get(url)
        
        if r.status_code != 200:
            raise ImportError("Query returned HTTP status code: {}".format(r.status_code))
        
        response = r.json()
        _api_version = response["meta"]["api_version"][1:]
        
        try:
            _api_version = tuple(int(i) for i in _api_version.split('.'))
        except ValueError:
            # Remove 'alpha' from PATCH number
            _api_version = _api_version.split('.')
            _api_version[-1] = _api_version[-1][0]
            
        try:
            _api_version = tuple(int(i) for i in _api_version)
        except:
            raise
        
        return _api_version
    
    def query(self, filter=None):
        # Type check
        if not isinstance(filter, dict):
            raise TypeError("'filter' should be a dict")
        
        # Initiate
        query_str = ""
        idn = None
        
        # Get filters
        if filter is not None:
            endpoint = "/structures"
            if "id" in filter:
                try:
                    idn = int(filter["id"])
                except ValueError:
                    # Return all structures - id not typed in correctly
                    idn = None
        else:
            endpoint = "/all"
        
        # Write query
        if idn:
            query_str = "/{}".format(idn)
        
        # Make query - get data
        url = ''.join([self.db_baseurl, endpoint, query_str])
        r = requests.get(url)
        
        if r.status_code >= 400:
            raise ImportError("Query returned HTTP status code: {}".format(r.status_code))
        elif r.status_code != 200:
            print("Query returned HTTP status code: {}".format(r.status_code))
        
        response = r.json()
        
        return response
    

def ver_to_str(api_version):
    """
    Convert api_version from tuple of integers to string.
    
    :param api_version: Tuple of integers representing API version: (MAJOR,MINOR,PATCH)
    :return: String representing API version: "vMAJOR.MINOR.PATCH"
    """
    
    # Perform check(s)
    if not isinstance(api_version, tuple):
        raise TypeError("api_version must be of type 'tuple'.")
    if len(api_version) > 3:  # Shouldn't be necessary to check
        raise ApiVersionError("Too many arguments for api_version. "
                              "API version is defined as maximum (MAJOR,MINOR,PATCH).")
    if len(api_version) == 1 and api_version[0] == 0:  # Shouldn't be necessary to check
        raise ApiVersionError("When API MAJOR version is 0, MINOR version MUST be specified.")
    
    # Convert
    version = "v"
    version += ".".join([str(v) for v in api_version])
    
    return version
    

In [None]:
def query(idn=None, formula=None):
    global query_db
    importer = query_db["importer"]
    if importer is None:
        importer = OptimadeImporter(**query_db)
        query_db["importer"] = importer
    
    filter = dict()
    
    if idn is not None:
        filter["id"] = idn
    if formula is not None:
        filter["formula"] = formula  # TODO: Implement 'filter' queries
    
    return importer.query(filter), importer.api_version
    

def on_click_query(b):
    global inp_id, inp_host, query_db, structures
    structures = [("select structure",{"status":False})]
    idn = None
    formula = None
    try:
        idn = int(inp_id.value)
    except:
        formula = str(inp_id.value)
    
    # Custom host
    # NB! There are no checks on the host input by user, only if empty or not.
    if query_db["name"] == "custom":
        if inp_host.value == "":
            query_message.value = "You must specify a host URL, e.g. 'localhost:5000'"
            return
        else:
            query_db["url"] = "http://{}/optimade".format(inp_host.value)
    
    count = 0
    non_valid_count = 0
    query_message.value = "Quering the database ... "
    response, api_version = query(idn=idn, formula=formula)
    
    # API version check
    old = False
    valid = api_version >= min_api_version
    if api_version < min_api_version:
        query_message.value = "OPTiMaDe API {} is not supported. " \
                              "Must be at least {}.".format(ver_to_str(api_version), ver_to_str(min_api_version))
    elif api_version == min_api_version:
        old = True
    
    # Go through data entries
    for entry in response["data"]:
        if not valid:
            """ Not a valid API version: too old API version """
            # While there may be several entries in response["data"]
            # they will not be considered here, since there is no guarantee
            # that they are readable/parseable
            # So break, do not continue.
            break
            
        elif old:
            """ API version 0.9.5 (specifically for CoD) """
            
            cif_url = entry["links"]["self"]
            fn = requests.get(cif_url)
            with tempfile.NamedTemporaryFile(mode='w+') as f:
                f.write(fn.text)
                f.flush()
                entry_cif = CifData(file=f.name, parse_policy='lazy')
                
            formula = entry_cif.get_ase().get_chemical_formula()
            
        else:
            """ API version 0.9.7a """
            
            attr = entry["attributes"]
            valid_entry = True
            
            s = StructureData(cell=attr["lattice_vectors"])
            # Add Kinds
            for kind in attr["species"].values():
                # ASE cannot handle vacancies, therefore:
                #     if a vacancy is present, the structure will be skipped,
                #     and a message will be relayed
                for i in range(len(kind["chemical_symbols"])):
                    symbol = kind["chemical_symbols"][i]
                    if symbol == "vacancy": # Not allowed in AiiDA
                        valid_entry = False
                        kind["chemical_symbols"].pop(i)
                        kind["concentration"].pop(i)
                
                s.append_kind(Kind(
                    symbols=kind["chemical_symbols"],
                    weights=kind["concentration"],
                    mass=kind["mass"],
                    name=kind["original_name"]
                ))
            
            if not valid_entry:
                count += 1
                non_valid_count += 1
                continue
            
            # Add Sites
            for idx in range(len(attr["cartesian_site_positions"])):
                s.append_site(Site(
                    kind_name=attr["species_at_sites"][idx],
                    position=attr["cartesian_site_positions"][idx]
                ))
            
            entry_cif = s._get_cif()
            formula = s.get_formula()
            cif_url = ""
            
        
        idn = entry["id"]
        entry_name = "{} (id: {})".format(formula, idn)
        entry_add = (entry_name,
                        {
                            "status": True,
                            "cif": entry_cif,
                            "url": cif_url,
                            "id": idn,
                        }
                    )
        structures.append(entry_add)
        count += 1
    
    if valid:
        query_message.value = "Quering the database ... %d structure(s) found" % count
        if non_valid_count > 0:
            query_message.value += " ... {} non-valid structure(s) found " \
                                   "(vacancies are not allowed)".format(non_valid_count)
    drop_structure.options = structures
    if len(structures) > 1:
        drop_structure.value = structures[1][1]

def on_change_db(c):
    global query_db, inp_host
    new_element = c['new']
    query_db = new_element
    
    # Add text-field if "Custom"
    if new_element["name"] == "custom":
        inp_host.disabled = False
    else:
        inp_host.disabled = True

In [None]:
head_dbs = ipw.HTML("OPTiMaDe database:")
drop_dbs = ipw.Dropdown(description="", options=databases, layout=layout)
head_host = ipw.HTML("Custom host:")
inp_host = ipw.Text(description="http://", value="26c3722d.ngrok.io",
                    placeholder="e.g.: localhost:5000", layout=layout, disabled=True)
txt_host = ipw.HTML("/optimade")
drop_dbs.observe(on_change_db, names="value")

head_filters = ipw.HTML("<h4><strong>Filters:</strong></h4>")
inp_id = ipw.Text(description="id:", value="", placeholder='e.g.: 9009008', layout=layout)

btn_query = ipw.Button(description='Query in DB')
btn_query.on_click(on_click_query)

query_message = ipw.HTML("Waiting for input...")

display(
    ipw.HBox([head_dbs, drop_dbs]),
    ipw.HBox([head_host, inp_host, txt_host]),
    head_filters,
    ipw.HBox([inp_id, btn_query]),
    query_message
)


## Step 2: Select Structure

In [None]:
def refresh_structure_view():
    global viewer, atoms
    if hasattr(viewer, "component_0"):
        viewer.clear_representations()
        viewer.component_0.remove_ball_and_stick()
        viewer.component_0.remove_ball_and_stick()
        viewer.component_0.remove_ball_and_stick()
        viewer.component_0.remove_unitcell()
        cid = viewer.component_0.id
        viewer.remove_component(cid)

    viewer.add_component(nglview.ASEStructure(atoms.get_ase())) # adds ball+stick
    viewer.add_unitcell()
    viewer.center()

In [None]:
def on_change(c):
    global atoms, query_db
    indx = c['owner'].index
    new_element = c['new']
    if new_element['status'] is False:
        return
    atoms = new_element['cif']
    formula = atoms.get_ase().get_chemical_formula()
    #print (indx, atoms)
    #formula = atoms.get_chemical_formula()
    
    # search for existing calculations using chosen structure
    qb = QueryBuilder()
    qb.append(StructureData)
    qb.append(Calculation, filters={'extras.formula':formula}, descendant_of=StructureData)
    qb.order_by({Calculation:{'ctime':'desc'}})
    for n in qb.iterall():
        calc = n[0]
        print("Found existing calculation: PK=%d | %s"%(calc.pk, calc.get_extra("structure_description")))
        thumbnail = b64decode(calc.get_extra("thumbnail"))
        display(Image(data=thumbnail))
    struct_url = new_element['url'].split('.cif')[0]+'.html'
    if new_element['url'] != "":
        link.value='<a href="{}" target="_blank">{} entry {}</a>'.format(struct_url, query_db["name"], new_element['id'])
    else:
        link.value='{} entry {}'.format(query_db["name"], new_element['id'])
    refresh_structure_view()


    
drop_structure = ipw.Dropdown(description="", options=structures, layout=layout )
drop_structure.observe(on_change, names='value')
link = ipw.HTML("Link to the web-page will appear here")
display(drop_structure, link, ipw.VBox([viewer]))

## Step 3: Store in AiiDA Database

In [None]:
def on_click_store(b):
    global atoms
    with store_out:
        clear_output()
        if atoms is None:
            print ("Specify a structure first!")
            return
        #AiiDA requires structures to have cell
#       if np.all(atoms.cell == 0.0):
#           atoms.center(vacuum=0.1)
        if data_format.value is 'CifData':
            s=atoms.copy()
        elif data_format.value is 'StructureData':
            s = StructureData(ase=atoms.get_ase())
            # ensure that tags got correctly translated into kinds 
            for t1, k in zip(atoms.get_ase().get_tags(), s.get_site_kindnames()):
                t2 = int(k[-1]) if k[-1].isnumeric() else 0
                assert t1==t2
            s.description = inp_descr.value
        
        s.store()
        print("Stored in AiiDA: "+repr(s))

inp_descr = ipw.Text(placeholder="Description (optional)")   
btn_store = ipw.Button(description='Store in AiiDA')
btn_store.on_click(on_click_store)
data_format = ipw.RadioButtons(
    options=['CifData', 'StructureData'],
#     value='pineapple',
    description='Data type:',
    disabled=False
)


store_out = ipw.Output()
display(data_format, ipw.HBox([btn_store, inp_descr]), store_out)