<a href="https://colab.research.google.com/github/alezakuskin/Stark_ML/blob/Ions/Predictions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title # Run this cell to get all dependencies and packages ready
!pip install roman

RunInColab = 'google.colab' in str(get_ipython())

from itertools import compress
from urllib import request, parse

import pandas as pd
import numpy as np
import xgboost
import catboost
import roman
import joblib
import mariadb
import re

# !git clone -b Ions https://github.com/alezakuskin/Stark_ML
from Stark_ML.utils.terms import *

if RunInColab:
    from google.colab import output
    def clear_output():
        output.clear()
else:
    from IPython import display
    def clear_output():
        display.clear_output()
        
def predict_width(data_for_prediction):
    '''
    Get predicted Stark broadening parameters for input lines
    
    Parameters
    ----------
    data_for_prediction : pd.DataFrame, dataframe with any number of rows,
        all values of input features filled in; with "Element", "Wavelength",
        "Z number", "w (A)", "d (A)" columns.
    
    Returns
    ----------
    numpy.ndarray
        A one-dimentional array with predicted values of broadening parameters in \u212B
    '''
    #Importing pretrained models
    model1 = xgboost.XGBRegressor()
    model1.load_model('Stark_ML/XGB_A+I_Eraw_Raw_No.json')

    model2 = xgboost.XGBRegressor()
    model2.load_model('Stark_ML/XGB_A+I_Enorm_Aug_No.json')

    model3 = catboost.CatBoostRegressor()
    model3.load_model('Stark_ML/CatBoost_A+I_Enorm_Raw_No.json')

    model4 = joblib.load('Stark_ML/LightGBM_A+I_Eraw_Raw_No.pkl')

    model5 = joblib.load('Stark_ML/LightGBM_A+I_Enorm_Raw_Scaler.pkl')

    #Loading Standard Scaler
    scaler = joblib.load('Stark_ML/scaler_width.pkl')
    
    #Getting predictions
    epsilon = 1e-3
    #Models without energy normalization
    pred1 = model1.predict(data_for_prediction.drop(columns=['Element', 'Wavelength', 'Z number', 'w (A)', 'd (A)']))
    pred4 = model4.predict(data_for_prediction.drop(columns=['Element', 'Wavelength', 'Z number', 'w (A)', 'd (A)']))
    #Models with energy normalization
    data_for_prediction['E lower']    = energy_to_fraction(data_for_prediction, 'E lower')
    data_for_prediction['E upper']    = energy_to_fraction(data_for_prediction, 'E upper')
    data_for_prediction['Gap to ion'] = energy_to_fraction(data_for_prediction, 'Gap to ion')
    pred2 = model2.predict(data_for_prediction.drop(columns=['Element', 'Wavelength', 'Z number', 'w (A)', 'd (A)']))
    pred3 = model3.predict(data_for_prediction.drop(columns=['Element', 'Wavelength', 'Z number', 'w (A)', 'd (A)']))
    pred5 = model5.predict(scaler.transform(data_for_prediction.drop(columns=['Element', 'Wavelength', 'Z number', 'w (A)', 'd (A)'])))
    preds = (pred1 + pred2 + pred3 + pred4 + pred5)/5
    preds = (np.exp(preds) - 1) * epsilon
    
    return(preds)

def predict_shift(data_for_prediction):
    '''
    Get predicted Stark shift parameters for input lines
    
    Parameters
    ----------
    data_for_prediction : pd.DataFrame, dataframe with any number of rows,
        all values of input features filled in; with "Element", "Wavelength",
        "Z number", "w (A)", "d (A)" columns.
    
    Returns
    ----------
    numpy.ndarray
        A two-dimentional array with predicted values of both broadening (1-st column)
        and shift (2nd column) parameters in \u212B
    '''
    #Importing pretrained models
    model = joblib.load('Stark_ML/RF_Both_Eraw_Aug_No.pkl')

    #Get broadening predictions first
    widths = predict_width(data_for_prediction)
    
    #Adjust input data
    data_for_prediction['w (A)'] = widths
    data_for_prediction = data_for_prediction[model.model.feature_names_in_]
    
    #Get shift predictions
    preds = model.predict(data_for_prediction)
    
    return(np.column_stack((widths, preds)))

clear_output()

In [2]:
def connect_to_DB(username,
                  password,
                  server = "laser365-1.chem.msu.ru",
                  port=3306,
                  database = "kurucz"):
    
    conn = mariadb.connect(
            user=username,
            password=password,
            host=server,
            port=port,
            database=database)
    
    return conn

In [3]:
def is_valid_element(symbol):
    # List of all valid element symbols in the periodic table
    valid_elements = [
        "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne",
        "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca",
        "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn",
        "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr",
        "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn",
        "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd",
        "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb",
        "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg",
        "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th",
        "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm",
        "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds",
        "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"
    ]
    return symbol.capitalize() in valid_elements

In [4]:
###-----------------------------------
###Working version
###-----------------------------------
def convert_species_request(s):
    def parse_roman_part(part):
        part = part.strip()
        if not part:
            return []
        if '-' in part:
            start, end = [p.strip() for p in part.split('-')]
            return list(range(roman.fromRoman(start.upper())-1, roman.fromRoman(end.upper()) + 1))
        return [roman.fromRoman(part.upper())-1]

    def parse_ionization_stages(stages):
        # Separate by commas, then handle each part
        parts = re.split(r'\s*,\s*', stages)
        ionization = []
        for part in parts:
            ionization.extend(parse_roman_part(part))
        return ionization

    elements = []
    ionizations = []
    chem_elems = s.split(';')
    
    for chem_elem in chem_elems:
        match = re.match(r'([A-Za-z]{1,2})\s*([ivx,\s-]*)', chem_elem.strip(), re.IGNORECASE)
        if match:
            element = match.group(1).capitalize()
            if is_valid_element(element) == False:
                raise ValueError(f"Chemical element symbol {element} is incorrect")
            if match.group(2).strip():
                ionization = parse_ionization_stages(match.group(2))
            else:
                ionization = 'All'
            elements.append(element)
            ionizations.append(ionization)
        else:
            raise ValueError("Invalid input format")
    
    return elements, ionizations

In [None]:
#@title #Request data from DataBase
spectra = "Ti i; Ti ii; ti III" #@param {type: "string"}
#@markdown Examples of allowed spectra:
#@markdown **Ar I** or **Mg I-IV** or **Fe I; Si IX,XI**

#@markdown

#@markdown ###Enter wavelength in *nm*:
lower = 240 #@param {type: "number"}
upper = 245 #@param {type: "number"}

target = "both" #@param ["broadening", "shift", "both"] {type:"raw"}

#@markdown

#@markdown ###Would you like to save lines that cannot be encoded automatically to a separate file

save_for_manual_check = True #@param {type: "boolean"}

elements, ionizations = convert_species_request(spectra)

connection = connect_to_DB(username = '',
                          password = '')
cur = connection.cursor()
DB_df = None
for i in range(len(elements)):
    el = elements[i]
    ion = ionizations[i]
    if ion != 'All':
        query = f'''
        SELECT *
        FROM mytestview2
        WHERE airwl >= {lower}
        AND airwl <= {upper}
        AND el_name = '{el}'
        AND ion_stage in {f"({', '.join(map(str, ion))})"}
        '''
    else:
        query = f'''
        SELECT *
        FROM mytestview2
        WHERE airwl >= {lower}
        AND airwl <= {upper}
        AND el_name = '{el}'
        '''
#     print(query)
    cur.execute(query)
    column_names = [desc[0] for desc in cur.description]
    req_results = cur.fetchall()
    req_results = pd.DataFrame(req_results, columns=column_names)
    if not req_results.empty:
        if DB_df is None:
            DB_df = req_results
        else:
            DB_df = pd.concat([DB_df, req_results], ignore_index=True)
# print(DB_df)
cur.close()
connection.close()




data_i = pd.read_excel(Stark_ML.__path__.__dict__['_path'][0] + '/Source_files/Stark_data.xlsx',
                       sheet_name='Ions',
                       usecols='A:BQ',
                       nrows = 2
                   )
request_df = split_OK_check(DB_to_StarkML(DB_df, data_i), save_manual_check = save_for_manual_check)

In [None]:
#@title #The main part
#@markdown Currently your will get results on the NIST query above.

#@markdown You can upload you own *.txt* file or manually sanitized *for_manual_check.txt* to the panel on the left and specify the filename:

filename = 'requested_lines.txt' #@param {type:"string"}
filename = 'Stark_ML/' + filename

#@markdown Select whether you would like to get predictions for a single tempeature value or for a temperature range
Temperature_mode = 'single' #@param ['single', 'range']

#@markdown If you selected *range* in the previous field, specify all three parameters here:
Low_T = 8000   #@param {type: "number"}
High_T = 10000 #@param {type: "number"}
T_step = 100  #@param {type: "number"}



#Loading linelist
try:
    data_predictions = pd.read_csv(filename,
                                   index_col = 0
                                   )
except:
    data_predictions = pd.read_csv(filename[9:],
                                     index_col = 0
                                     )
    
#Data preprocessing
data_predictions.insert(data_predictions.columns.get_loc('E upper')+1, 'Gap to ion', 0)
data_predictions['Gap to ion'] = gap_to_ion(data_predictions, 'E upper')
data_predictions = data_predictions

if Temperature_mode == 'single':
    dtypes = data_predictions.dtypes.to_dict()
    for index, row in data_predictions.iterrows():
        data_predictions.at[index, 'T'] = Low_T
    data_predictions = data_predictions.astype(dtypes)

if Temperature_mode == 'range':
    dtypes = data_predictions.dtypes.to_dict()
    Ts = np.arange(Low_T, High_T + 1, T_step)
    for index, row in data_predictions.iterrows():
        data_predictions.at[index, 'T'] = Low_T
        for T in Ts:
            if T == Low_T:
                continue
            row['T'] = T
            data_predictions = pd.concat([data_predictions, row.to_frame().T], ignore_index=True)
    data_predictions = data_predictions.astype(dtypes)
data_predictions = data_predictions.sort_values(['Wavelength', 'T']).reset_index(drop = True)
    
#Get predictions
if target == 'broadening':
    preds = predict_width(data_predictions)
    preds = pd.Series(preds, name = 'w (A)')
if target == 'shift':
    preds = predict_shift(data_predictions)[:, 1]
    preds = pd.Series(preds, name = 'd (A)')
if target == 'both':
    preds = predict_shift(data_predictions)
    preds = pd.DataFrame(preds, columns = ['w (A)', 'd (A)'])
    
    
#building output file
columns = ['Element', 'Charge', 'Wavelength', 'T', 'w (A)', 'd (A)']
#@markdown

#@markdown ###Select additional transition parameters you would like to include in output file
Element_symbol = True  #@param {type: 'boolean'}
Wavelength     = True  #@param {type: 'boolean'}
Temperature    = True  #@param {type: 'boolean'}
Charge         = True #@param {type: 'boolean'}

results = pd.DataFrame(columns = list(compress(columns, [Element_symbol, Charge, Wavelength, Temperature,
                                                         True if (target == 'broadening') | (target == 'both') else False,
                                                        True if (target == 'shift') | (target == 'both') else False])))
results = pd.concat(
        [
        data_predictions[list(compress(columns, [Element_symbol, Charge, Wavelength, Temperature]))],
        preds,
        ],
    axis = 1
    )
results.to_csv(f'PREDICTED_{filename[9:-4]}.csv', index = False)
print(results)

## Congratulations! If the previous cell finished execution without errors, you can now download <filename.csv> file with predicted values of Stark broadening parameter.

### For more details refer to 'paper' or contact us: ale-zakuskin@laser.chem.msu.ru