# Import Dependencies

In [None]:
%matplotlib widget

#data tools
import pandas as pd
import numpy as np
import qgrid
from pymatgen.ext.matproj import MPRester
from pymatgen.vis import structure_chemview as viz

#simtool loading and interface
from simtool import findInstalledSimToolNotebooks,searchForSimTool
from simtool import getSimToolInputs,getSimToolOutputs,Run

#user interface utilities 
import os, stat
import ipywidgets as widgets
from IPython.display import display
from IPython.display import clear_output
  
# importing required libraries
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt


### User Create ~/.mpkey.txt if it doesn't already exist

In [None]:
#key security
try:
    user = str(input('Paste MP API key: '))
    clear_output()
    if not user.isalnum():
        raise TypeError('Wrong Key')
    if user is None:
        raise TypeError('Empty')
    with open(os.path.expanduser('~/.mpkey.txt'), 'w') as keyfile:
        keyfile.write(user)
    os.chmod(os.path.expanduser('~/.mpkey.txt'), stat.S_IREAD | stat.S_IWRITE)
    del user
    print("Success")
except:
    print("Something seems wrong with your key")

## User Prompted To pick their SemiConductor of Choice
### Choice (a): query MP and fliter on properties. Dataframe is updated with each selection.

In [None]:
with open(os.path.expanduser("~/.mpkey.txt"), "r+") as file:
    apikey = file.readline()
rester = MPRester(apikey)
sc_dicts = rester.query({ "crystal_system": "cubic"},
                        ["task_id","pretty_formula","formula","elements","e_above_hull", "spacegroup.number", "band_gap", "crystal_system"])
sc_df = pd.DataFrame(sc_dicts)

In [None]:
# define a function for visualizing input structures

def mpid_plot(mpid):

    # import POSCAR file

    struct = rester.get_structure_by_material_id(mpid, final = False, conventional_unit_cell=True)
    POSCAR_str = struct.to(fmt = "poscar")

    lines = POSCAR_str.split('\n')

    # get the lattice information

    lattice = lines[1]
    cell_vectors = np.array([lines[2].split() , lines[3].split() , lines[4].split()]).astype(float)

    # get the list of sites

    sites = []
    for line in lines[8:]:
        if not line:
            break
        sites.append([line.split()[0],line.split()[1],line.split()[2]])

    # convert from fractional to xyz

    sites = np.array(sites).astype(float)
    xyz = np.matmul(sites,cell_vectors).transpose()

    # get the coordinates of the box

    corners = np.array([[0,1,1,0,0,0,0,1,1,0,0,1,1,1,1,0,0],[0,0,1,1,0,0,0,0,1,1,0,0,0,1,1,1,1],[0,0,0,0,0,1,1,1,1,1,1,1,0,0,1,1,0]]).T
    cell = np.matmul(corners,cell_vectors).transpose()

    # get a color dictionary

    elements = lines[5].split()
    atoms = lines[6].split()
    hues = ['b','g','r','y']
    colors = []
    for i, atom in enumerate(atoms):
        colors.extend([elements[i]]*int(atoms[i]))

    zip_iterator = zip(elements,hues)
    color_dict = dict(zip_iterator)

    # creating figure

    plt.close('all')

    fig = plt.figure("input structure")
    ax = Axes3D(fig)

    plot = ax.scatter(xyz[0],xyz[1],xyz[2], color=[color_dict[i] for i in colors], s = 256)
    plot = ax.plot(cell[0],cell[1],cell[2],color='black')

    ax.set_title("POSCAR")
    ax.axis('off')

    # displaying the plot

    plt.show()



In [None]:
# if you know the mp-id: give it here to get a structure object
mpid = widgets.IntText(
        value=2133,
        description='MPID:',
        disabled=False
)

plot_button = widgets.Button(
    description='plot',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='plot'
)
        
def on_button_clicked(b):
    with output:
        selected_row = mpid_widget.get_changed_df().index[mpid_widget.get_selected_rows()][0]
        mpid_selected= sc_df.at[selected_row,'task_id']
        # print(mpid_selected)
        mpid_plot(mpid_selected)
        mpid.value = mpid_selected.split('-')[1]
        
def update_plot(args):
    mpid_selected="mp-"+str(mpid.value)
    mpid_plot(mpid_selected)
    
output = widgets.Output()
mpid.observe(update_plot,'value')  
plot_button.on_click(on_button_clicked)
display(mpid, plot_button, output)

# if you don't, search for your structure by filtering
mpid_widget = qgrid.show_grid(sc_df)

# filter by pretty_formula to look for a specific structure
display(mpid_widget)

In [None]:
struct = rester.get_structure_by_material_id("mp-"+str(mpid.value), final = False, conventional_unit_cell=True)

In [None]:
#debugging struct
type(struct.sites[0].a)

In [None]:
# if you're just hunting for something we display the structure for you.
mv = viz.quick_view(struct)
mv.ball_and_stick()

In [None]:
#This is passed to the simtool to perform simulations
POSCAR_str = struct.to(fmt = "poscar", )
POSCAR_str

In [None]:
#debug poscar conversion
from pymatgen.core import Structure
with open("./POSCAR", "w") as f:
    POSCAR = f.write(POSCAR_str)
    f.close()

struct2 = Structure.from_file("./POSCAR")

### Choice (b): upload your own poscar directly. No query necessary.

In [None]:
with open(os.Path.expanduser("~/we_know_this_dir/POSCAR")) as file:
    POSCAR_str = file.readlines()
    

# Perform Structure Relaxation and SCF and Phonon computation and spectra extraction using simtool

### Find 670raman simtool notebook and confirm

In [None]:
#simToolName = "670raman"
simToolName = "relax_sim"
simToolLocation = searchForSimTool(simToolName)
for key in simToolLocation.keys():
    print(f"{key} = {simToolLocation[key]}")

In [None]:
installedSimToolNotebooks = findInstalledSimToolNotebooks(simToolName,returnString=True)
print(installedSimToolNotebooks)

### User Set Validated Inputs
670raman will automatically activate your rest api interface to Materials Project
if you have the dotfile ".mpkey.txt" in your home directory.

Otherwise, it will attempt to generate a realistic crystal structure from your chemical discription

In [None]:
#Enter your values with units! The simtool will make sure you know what you're talking about.
inputs = getSimToolInputs(simToolLocation)

In [None]:
inputs

In [None]:
compound = widgets.Text(
    value='ZnO',
    placeholder='chemical formula',
    description='Compound:',
    disabled=False
)
spacegroup = widgets.IntText(
    value=186,
    placeholder='space group',
    description='Space group:',
    disabled=False
) 

ecutwfc = widgets.BoundedFloatText(
    value=50,
    min=50,
    max=400,
    step=10,
    description='ecutwfc:',
    disabled=False
)

ecutrho = widgets.BoundedFloatText(
    value=200,
    min=200,
    max=1600,
    step=40,
    description='ecutrho:',
    disabled=False
)

button = widgets.Button(
    description='run simtool',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='run to submit qe simtool'
)

log = widgets.Select(
    options=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
    value='DEBUG',
    # rows=10,
    description='Log Level:',
    disabled=False
)

walltime = widgets.Text(
    value='01:00:00',
    placeholder='walltime',
    description='walltime:',
    disabled=False
)

numnodes = widgets.IntText(
    value=8,
    placeholder='nodes',
    description='nodes:',
    disabled=False
) 

pp_list = []
for filename in os.listdir("./simtool/pseudo/"):
    f = os.path.join("./simtool/pseudo/",filename)
    # get a list of all the PPs
    if os.path.isfile(f):
        pp_list.append(filename)
        
# TODO: filter by selected compound compositions
filtered_pp_list = pp_list

pp_menu1 = widgets.Combobox(
    placeholder="choose a pseudopotential",
    options=filtered_pp_list,
    description='pseudopotential 1:',
    disabled=False
) 

pp_menu2 = widgets.Combobox(
    placeholder="choose a pseudopotential",
    options=filtered_pp_list,
    description='pseudopotential 2:',
    disabled=False
)

smearing = widgets.Select(
    options=['smearing','fixed'],
    value='fixed',
    rows = 2,
    description='smearing:',
    disabled=False
)
    
output = widgets.Output()

# display(c, s, button, output)

def runSim2l():
    inputs['loglevel'].value = log.value
    inputs['walltime'].value = walltime.value
    inputs['numnodes'].value = numnodes.value
    inputs['ecutwfc'].value = ecutwfc.value
    inputs['ecutrho'].value = ecutrho.value
    inputs['pps'].value = [pp_menu1.value, pp_menu2.value]
    inputs['smearing'].value = smearing.value
    inputs['POSCAR_str'].value = POSCAR_str #TODO make a widget of this at some point

def on_button_clicked(b):
    with output:
        print("submitting sim2l run with formula" , compound.value, spacegroup.value)
        runSim2l()
        r = Run(simToolLocation,inputs)
        r.getResultSummary()
        print(r.read('spectra'))
        
        
button.on_click(on_button_clicked)

structure = widgets.VBox([compound,spacegroup])
simulation = widgets.VBox([ecutrho, ecutwfc, smearing, pp_menu1, pp_menu2])
run_details = widgets.VBox([walltime, numnodes, log])

accordion = widgets.VBox([widgets.Accordion(children=[simulation,run_details]),button,output])
display(accordion)

In [None]:
inputs #request documentation of inputs  if desired

In [None]:
inputs['loglevel'].value = "DEBUG"
inputs['walltime'].value = "01:00:00"
inputs['numnodes'].value = 8
inputs['compound'].value = "ZnO"
inputs['ecutwfc'].value = 50
inputs['ecutrho'].value = 200
inputs['spacegroup_international'].value = 186
inputs['pps'].value = ['O.pbe-hgh.UPF', 'Zn.pbe-d-hgh.UPF']
inputs['smearing'].value = "fixed"

### Show User Predetermined Outputs and their Explainations

In [None]:
outputs = getSimToolOutputs(simToolLocation)

In [None]:
outputs

### Run simtool to obtain Predicted Raman Tensor and Spectrum Graph

In [None]:
r = Run(simToolLocation,inputs)

In [None]:
r.getResultSummary()

In [None]:
print(r.read('logreport'))

In [None]:
r.read('spectra')

In [None]:
#check inputs
r.input_dict

In [None]:
#find output location
print(r.outdir)