In [1]:
from aiida import load_dbenv
load_dbenv()

In [2]:
import datetime
import time
import numpy as np
from aiida.orm import DataFactory, CalculationFactory, Code, load_node
from aiida.tools.data.array.kpoints import get_explicit_kpoints_path
from aiida.work import run, submit

In [3]:
import numpy


def get_data_cls(descriptor):
    load_dbenv_if_not_loaded()
    from aiida.orm import DataFactory
    return DataFactory(descriptor)


def simple(pot_family, import_from, queue, code, computer, no_import):
    load_dbenv_if_not_loaded()
    from aiida.orm import CalculationFactory, Code
    if not no_import:
        click.echo('importing POTCAR files...')
        with cli_spinner():
            import_pots(import_from, pot_family)
    pot_cls = get_data_cls('vasp.potcar')
    pot_si = pot_cls.find_one(family=pot_family, full_name='Si')

    vasp_calc = CalculationFactory('vasp.vasp')()
    vasp_calc.use_structure(create_structure_Si())
    vasp_calc.use_kpoints(create_kpoints())
    vasp_calc.use_parameters(create_params_simple())
    code = Code.get_from_string('{}@{}'.format(code, computer))
    vasp_calc.use_code(code)
    vasp_calc.use_potential(pot_si, 'Si')
    vasp_calc.set_computer(code.get_computer())
    vasp_calc.set_queue_name(queue)
    vasp_calc.set_resources({'num_machines': 1, 'num_mpiprocs_per_machine': 20})
    vasp_calc.label = 'Test VASP run'
    vasp_calc.store_all()
    vasp_calc.submit()


def load_dbenv_if_not_loaded():
    from aiida import load_dbenv, is_dbenv_loaded
    if not is_dbenv_loaded():
        load_dbenv()


def create_structure_Si():
    structure_cls = get_data_cls('structure')
    alat = 5.4
    structure = structure_cls(cell=numpy.array([[.5, 0, .5], [.5, .5, 0], [0, .5, .5]]) * alat)
    structure.append_atom(position=numpy.array([.25, .25, .25]) * alat, symbols='Si')
    return structure


def create_kpoints_path():
    return get_explicit_kpoints_path(structure=create_structure_Si())

def create_params_noncol():
    param_cls = get_data_cls('parameter')
    return param_cls(
        dict={
            'SYSTEM': 'InAs',
            'EDIFF': 1e-5,
            'LORBIT': 11,
            'LSORBIT': '.True.',
            'GGA_COMPAT': '.False.',
            'ISMEAR': 0,
            'SIGMA': 0.05,
            'GGA': 'PE',
            'ENCUT': '280.00 eV',
            'MAGMOM': '6*0.0',
            'NBANDS': 24,
        })


def create_params_simple():
    param_cls = get_data_cls('parameter')
    return param_cls(dict={'prec': 'NORMAL', 'encut': 200, 'ediff': 1e-8, 'ialgo': 38, 'ismear': 0, 'sigma': 0.1})


def import_pots(folder_path, family_name):
    pot_cls = get_data_cls('vasp.potcar')
    pot_cls.upload_potcar_family(folder_path, group_name=family_name, group_description='Test family', stop_if_existing=False)


In [4]:
def now_str():
    now = datetime.datetime.now()
    return now.strftime(format='%Y-%m-%d %H:%M')

In [5]:
def is_same_structure(left, right):
    result = True
    result &= bool(np.all(np.array(left.cell) - np.array(right.cell) < 1e-10))
    result &= bool(left.get_formula() == right.get_formula())
    result &= bool(left.get_kind_names() == right.get_kind_names())
    return result

In [6]:
def new_or_existing_structure(new_structure):
    structure_cls = get_data_cls('structure')
    result = new_structure
    query = structure_cls.querybuild()
    structures = [item[0] for item in query.all()]
    same_structures = [structure for structure in structures if is_same_structure(new_structure, structure)]
    if same_structures:
        result = same_structures[0]
    return result

In [7]:
proc = CalculationFactory('vasp.vasp').process()

In [8]:
def make_example_inputs():
    potcar_cls = get_data_cls('vasp.potcar')
    inputs = proc.get_inputs_template()
    potcar_map = {'Si': 'Si'}
    auto_kpoints = create_kpoints_path()
    inputs._label = 'Test {}'.format(now_str())
    inputs.code = 'vasp@monch'
    inputs._description = 'This is a test'
    inputs._options.max_wallclock_seconds = 3000
    inputs._options.resources = {'num_machines': 1, 'num_mpiprocs_per_machine': 20}
    inputs._options.queue_name = 'dphys_compute'
    inputs._options.computer = Code.get_from_string(inputs['code']).get_computer()
    inputs.kpoints = auto_kpoints['explicit_kpoints']
    inputs.parameters = create_params_simple()
    inputs.structure = new_or_existing_structure(auto_kpoints['conv_structure'])
    inputs.potential = potcar_cls.get_potcars_from_structure(family_name='PBE', structure=inputs.structure, mapping=potcar_map)
    inputs.settings = DataFactory('parameter')(dict={'parser_settings':{'add_bands': True, 'add_dos': True}})
    inputs['code'] = Code.get_from_string('vasp')
    return inputs

In [9]:
def run_example():
    return submit(proc, **make_example_inputs())

In [10]:
def poll_calc():
    result = load_node(running.pid)
    state = result.get_state()
    while not state in ['FINISHED', 'FAILED']:
        result = load_node(running.pid)
        state = result.get_state()
        print state
        time.sleep(10)

In [11]:
def show_bands(calc):
    result.out.output_band.show_mpl()

In [30]:
# Main imports for UI
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML, Javascript
from fileupload import FileUploadWidget

In [28]:
warningsbox = widgets.HTML("")

styles = [
    "background-color: rgb(242, 222, 222)",
    "border-color: rgb(235, 204, 209)",
    "border-radius: 4px",
    "border-style: solid",
    "border-width: 1px",
    "font-size: 14px",
    "line-height: 20px",
    "margin: 0px",
    "padding: 10px",
    "padding-left: 20px",
    "padding-right: 20px",    
]

def show_warning(msg):
    """
    Shows a warning in a red box.

    If msg is empty or None, removes the box.
    :param msg: Should be a valid HTML string (this is not validated)
    """
    global warningsbox
    
    if not msg:
        warningsbox.value = ''
    else:
        warningsbox.value = '<div style="{}">WARNING! {}</div>'.format("; ".join(styles), msg)

In [31]:
# An object that can be 'Display'ed to jump back here
js_jump = Javascript("""window.location.href = "#warningsbox";""")

In [32]:
display(warningsbox)
# To use: call #show_warning(msg)

A Jupyter Widget

In [33]:
def nice_errors(default_return=None):
    def decorator(fn):
        def inner_fn(*args, **kwargs):
            try:
                # Clean-up before running the function
                show_warning('')                
                return fn(*args, **kwargs)
            except Exception as e:
                # "Eat" the exception but show an error
                show_warning("Internal error during execution. Error: {}".format(
                    str(e)))
                display(js_jump)
                return default_return
        return inner_fn
    return decorator

In [38]:
structure = None

upload_out = widgets.HTML()

@nice_errors(default_return=None)
def on_file_upload(c):
    global structure, upload_out
    
    from aiida.orm import DataFactory
    StructureData = DataFactory('structure')
    
    import ase, ase.io
    from tempfile import NamedTemporaryFile
    
    upload_out.value = "Uploading structure..."
    tmp = NamedTemporaryFile(suffix=in_file_upload.filename)
    f = open(tmp.name, "w")
    f.write(in_file_upload.data)
    f.close()
    if tmp.name.endswith('.aiida'):
        from aiida.orm.importexport import import_data
        upload_out.value = "Importing data..."
        import_dict = import_data(tmp.name,silent=True)
        qs = StructureData.query(pk__in=[_[1] for _ in import_dict['Node']['existing']+import_dict['Node']['new']])
        if qs.count()==0:
            raise ValueError("No structure found!")
        if qs.count() > 1:
            upload_out.value = "<strong>Number of structures found</strong>: %d; considering only the first one.<br>"%qs.count()
        else:
            upload_out.value = "<strong>Structure loaded.</strong><br>"    
        structure = qs.first()
    else:
        ase_structures = ase.io.read(tmp.name, index=":")
        if len(ase_structures) > 1:
            upload_out.value = "<strong>Number of structures found</strong>: %d; considering only the first one.<br>"%len(ase_structure)
        else:
            upload_out.value = "<strong>Structure loaded.</strong><br>"
        ase_structure = ase_structures[0]
        structure = StructureData(ase = ase_structure)
    tmp.close()
    #update_view()
    upload_out.value += '\nStructure chemical formula: <strong>%s</strong>.' % structure.get_formula()
    
#TODO: FileUploadWidget doesn't fire event when same file is uploaded twice
in_file_upload = FileUploadWidget("Upload Structure")
in_file_upload.observe(on_file_upload, names='data')

in_use_example_structure = widgets.Checkbox(
    value=False,
    description='Use an example structure',
    disabled=False
)

example_structure_options = [['Diamond', 'diamond'], ['Aluminum', 'al'],['GaAs', 'gaas'], ['Cobalt', 'co']]
in_example_structure = widgets.Dropdown(
    options=example_structure_options,
    value=example_structure_options[0][1],
    disabled=True
)

def on_use_example_structure_change(v):
    if v['owner'].value:
        in_file_upload.disabled = True
        in_example_structure.disabled = False
    else:
        in_file_upload.disabled = False        
        in_example_structure.disabled = True

# Setup listener
in_use_example_structure.observe(on_use_example_structure_change, names='value')

structure_group = widgets.VBox(
    [
        in_file_upload,
        upload_out,
        widgets.HBox(
            [
            in_use_example_structure,
            in_example_structure,
            ]),
    ])

def get_example_structure(key):
    from aiida.orm import DataFactory
    from ase.lattice.spacegroup import crystal
    StructureData = DataFactory('structure')
    
    if key == 'diamond':
        # This is the lattice constant in angstrom
        alat = 3.56
        diamond_ase = crystal('C', [(0,0,0)], spacegroup=227,
                          cellpar=[alat, alat, alat, 90, 90, 90],primitive_cell=True)
        s = StructureData(ase=diamond_ase)
        return s
    elif key == 'al':
        # This is the lattice constant in angstrom
        alat = 4.05
        Al_ase = crystal('Al', [(0,0,0)], spacegroup=225,
                          cellpar=[alat, alat, alat, 90, 90, 90],primitive_cell=True)
        s = StructureData(ase=Al_ase)
        return s
    elif key == 'gaas':
        # This is the lattice constant in angstrom
        alat = 5.75
        GaAs_ase = crystal('GaAs', [(0,0,0),(0.25,0.25,0.25)], spacegroup=216,
                          cellpar=[alat, alat, alat, 90, 90, 90],primitive_cell=True)
        s = StructureData(ase=GaAs_ase)
        return s
    elif key == 'co':
        # These are the lattice constants in angstrom
        a = 2.5
        c = 4.07
        Co_ase = crystal('Co', [(1./3,2./3,0.25)], spacegroup=194,
                          cellpar=[a, a, c, 90, 90, 120],primitive_cell=True)
        s = StructureData(ase=Co_ase)
        return s
    else:
        raise ValueError("Unknown or unsupported example structure '{}'".format(key))

def get_structure():
    global structure
    if in_use_example_structure.value:
        structure_key = in_example_structure.value
        return get_example_structure(structure_key)
    else:
        if structure is None:
            raise ValueError("You did not upload a structure. Either upload a structure or choose an example.")
        else:
            return structure

In [13]:
STRUCTURE_PICKER = widgets.Dropdown(
    options={
        'Si': create_structure_Si
    },
    value=create_structure_Si,
    description='Choose a Material',
    style={'description_width': 'initial'}
)


def get_key_value_display(key, value, **kwargs):
    key_label = widgets.Label(value=str(key).upper())
    value_label = widgets.Label(value=str(value))
    return widgets.Box([key_label, value_label], **kwargs)


def get_parameters_display(parameters, **kwargs):
    key_values = parameters.get_dict()
    items_layout = widgets.Layout(justify_content='space-between')
    widgets_list = [get_key_value_display(key, value, layout=items_layout) for key, value in key_values.items()]
    if not 'layout' in kwargs:
        kwargs['layout'] = widgets.Layout(
            width=kwargs.pop('width', '30%'),
            border=kwargs.pop('border', 'solid')
        )
    return widgets.VBox(widgets_list, **kwargs)


class HtmlTable(object):
    table_tpl = '<table class="table">\n{header}{rows}\n</table>'
    table_head_tpl = '<thead>{tr_row}</thead>'
    table_row_tpl = '<tr>{row}</tr>'
    table_cell_tpl = '<td>{}</td>'

    def __init__(self, header=None, rows=None):
        self.rows = rows or []
        self.header = header or []
        
    @classmethod
    def build_row(cls, row):
        td_list = [cls.table_cell_tpl.format(item) for item in row]
        tr_string = cls.table_row_tpl.format(row=''.join(td_list))
        return tr_string
    
    @classmethod
    def build_header(cls, row):
        return cls.table_head_tpl.format(tr_row=cls.build_row(row))
    
    def build_table(self):
        header = self.build_header(self.header) if self.header else ''
        rows = [self.build_row(row) for row in self.rows]
        rows_str = '\n'.join(rows)
        table = self.table_tpl.format(header=header, rows=rows_str)
        return table
        
    def __str__(self):
        return self.build_table()
    

def get_kpoints_display(kpoints, **kwargs):
    special_points = kpoints.get_special_points()[0]
    header = ['Name', 'x', 'y', 'z']
    rows = [[name, pos[0], pos[1], pos[2]] for name, pos in special_points.items()]
    special_points_table = HtmlTable(header=header, rows=rows)
    path = ' -> '.join([label[1] for label in kpoints.labels])
    template = '<p>{special_points}</p><p>{path}</p>'
    if not 'layout' in kwargs:
        kwargs['layout'] = widgets.Layout(
            width=kwargs.pop('width', '40%'),
            border=kwargs.pop('border', 'solid')
        )
    return widgets.HTML(value=template.format(special_points=special_points_table, path=path), **kwargs)

In [14]:
inputs = make_example_inputs()

In [15]:
display(widgets.Box([get_parameters_display(inputs.parameters), get_kpoints_display(inputs.kpoints)]))

A Jupyter Widget

In [16]:
inputs.kpoints.get_special_points()[0], inputs.kpoints.labels

({'G': [0.0, 0.0, 0.0],
  'K': [0.375, 0.375, 0.75],
  'L': [0.5, 0.5, 0.5],
  'U': [0.625, 0.25, 0.625],
  'W': [0.5, 0.25, 0.75],
  'X': [0.5, 0.0, 0.5]},
 [(0, 'GAMMA'),
  (45, 'X'),
  (60, 'U'),
  (61, 'K'),
  (109, 'GAMMA'),
  (148, 'L'),
  (179, 'W'),
  (201, 'X')])

In [17]:
import nglview

In [18]:
STRUCTURE_VIEWER = nglview.NGLWidget()

In [19]:
def refresh_structure_view(atoms):
    global STRUCTURE_VIEWER
    viewer = STRUCTURE_VIEWER
    if hasattr(viewer, "component_0"):
        #viewer.clear_representations()
        viewer.component_0.remove_ball_and_stick()
        viewer.component_0.remove_ball_and_stick()
        viewer.component_0.remove_ball_and_stick()
        viewer.component_0.remove_unitcell()
        cid = viewer.component_0.id
        viewer.remove_component(cid)

    viewer.add_component(nglview.ASEStructure(atoms)) # adds ball+stick
    viewer.add_unitcell()
    viewer.center()

In [20]:
display(widgets.VBox([STRUCTURE_VIEWER], layout=widgets.Layout(border='solid')))

A Jupyter Widget

In [21]:
refresh_structure_view(inputs.structure.get_ase())

In [22]:
STRUCTURE_VIEWER.add_component(nglview.ASEStructure(inputs.structure.get_ase()))

<nglview.component.ComponentViewer at 0x10d14bf10>

In [23]:
STRUCTURE_VIEWER.add_unitcell()

In [24]:
STRUCTURE_VIEWER.display()

A Jupyter Widget

In [25]:
asestruc = inputs.structure.get_ase()

In [26]:
asestruc

Atoms(symbols='Si4', pbc=True, cell=[5.4, 5.4, 5.4], masses=...)

In [27]:
nglview.show_ase(asestruc)

A Jupyter Widget

In [28]:
view = nglview.show_pdbid("3pqr")

In [29]:
view

A Jupyter Widget