In [None]:
import thermoengine as thermo

import numpy as np
import pandas as pd
import pickle as pkl
import os
import matplotlib.pyplot as plt
from collections import OrderedDict

%matplotlib notebook

In [None]:
LEPR_phase_symbols = {
    'Liquid':'Liq',
    'Clinopyroxene':'Cpx',
    'Garnet':'Grt',
    'Olivine':'Ol',
    'Orthopyroxene':'Opx',
    'Biotite':'Bt',
    'Fluid':None,
    'Corundum':'Crn',
    'Rutile':'Rt',
    'Plagioclase':'Fsp',
    'Amphibole':'Cam',
    'Zoisite':'Zo',
    'Cordierite':'Crd',
    'Muscovite':'Ms',
    'Quartz':'Qz',
    'Kyanite':'Ky',
    'Potassium feldspar':'Fsp',
    'Sillimanite':'Sil',
    'Spinel':'SplS',
    'Staurolite':None ,
    'Melilite':'Mll',
    'Carbonate melt':None,
    'Nepheline':'NphS',
    'Ilmenite':'Ilm',
    'Eskolaite':None,
    'Anorthite':'An',
    'cc-dol':None}

In [None]:
# for sym in LEPR_phase_symbols.values():
#     if sym is not None:
#         modelDB.get_phase(sym)

In [None]:
def load_analysis(analysis_file='data/garnet-calib.pkl'):
    try:
        with open(analysis_file, 'rb') as f:
            analysis = pkl.load(f)
    except:
        analysis = {}
        
    return analysis
        
def save_analysis(analysis, analysis_file='data/garnet-calib.pkl'):
    with open(analysis_file, 'wb') as f:
        pkl.dump(analysis, f)

In [None]:
def load_LEPR_data(filename, data_dir='data/'):

    lepr_data=pd.read_excel(data_dir+filename, 
                           sheet_name=None)
    #create dictionary of phase names to phase symbols
    
    experiments = lepr_data.pop('Experiment')
    
    exp_data = experiments.filter(['Index', 'T (C)', 'err T (C)', 'P (GPa)','err P (GPa)', 'fO2 cond', 
                                                     'Wt: SiO2', 'Wt: TiO2','Wt: Al2O3', 'Wt: Fe2O3', 'Wt: Cr2O3', 
                                                     'Wt: FeO', 'Wt: MnO', 'Wt: MgO','Wt: NiO', 'Wt: CaO', 'Wt: Na2O', 
                                                     'Wt: K2O', 'Wt: P2O5', 'Wt: H2O+', 'Wt: CO2'], axis=1)
    
    exp_data.rename(index=str, columns={'T (C)': 'T', 'err T (C)': 'T_err', 'P (GPa)':'P', 'err P (GPa)':'P_err',
                                   'fO2 cond': 'fO2', 'Wt: SiO2': 'SiO2','Wt: TiO2':'TiO2', 
                                   'Wt: Al2O3':'Al2O3', 'Wt: Fe2O3': 'Fe2O3', 'Wt: Cr2O3':'Cr2O3',
                                   'Wt: FeO': 'FeO', 'Wt: MnO': 'MnO', 'Wt: MgO': 'MgO', 'Wt: NiO': 'NiO',
                                   'Wt: CaO': 'CaO', 'Wt: Na2O': 'Na2O', 'Wt: K2O': 'K2O', 'Wt: P2O5': 'P2O5',
                                   'Wt: H2O+':'H2O', 'Wt: CO2':'CO2'}, inplace=True)
    
    metadata = experiments.filter(['Index','Experiment', 'Author (year)', 'Laboratory', 'Device', 'Container', 'Method', 
                                   'Duration_hours', 'Phases'], axis=1)
    
    #exp_data = lepr_data.pop('Experiment')
    
    phs_symbols = lepr_data.pop('phase_symbols')
    
    phs_data = lepr_data
    
    def set_data_index(df):
        df.dropna(subset=['Index'], inplace=True)
        df['Index'] = df['Index'].astype(int)
        df =  df.set_index('Index')
        return df
        
    exp_data = set_data_index(exp_data)
    
    for iphs_name in phs_data:
        iphs = phs_data[iphs_name]
        iphs = set_data_index(iphs)
        iphs = iphs.dropna(how='all')
        iphs = iphs.fillna(0)
        phs_data[iphs_name] = iphs
        
    
        
    return exp_data, metadata, phs_data, phs_symbols

def major_wt_oxide_LEPR_data(phs_data):
    phs_major_wt_data = OrderedDict()
    phs_major_err = OrderedDict()
    oxide_list = ['Wt: SiO2', 'Wt: TiO2', 'Wt: Al2O3', 'Wt: Fe2O3', 
                    'Wt: Cr2O3', 'Wt: FeO', 'Wt: MnO', 'Wt: MgO', 'Wt: NiO', 
                    'Wt: CoO', 'Wt: CaO', 'Wt: Na2O', 'Wt: K2O', 'Wt: P2O5', 
                    'Wt: H2O', 'Wt: CO2']
    
    new_col_dict = OrderedDict()
    for ioxide in oxide_list:
        new_col_dict[ioxide] = ioxide[4:]
    
    for iphs_name in phs_data:
        iphs = phs_data[iphs_name]
        idf = iphs[oxide_list]
        idf = idf.rename(columns=new_col_dict)
        phs_major_wt_data[iphs_name] = idf
       
    return phs_major_wt_data

In [None]:
def mol_oxide_to_mol_endmember(mol_oxides, endmember_site_occ, endmember_stoic):
    
    endmember_site_occ_inv = np.linalg.pinv(endmember_site_occ.T)
    site_occ_stoic = np.dot(endmember_stoic.T, endmember_site_occ_inv)
    
    site_occ, residual = optimize.nnls(site_occ_stoic, mol_oxides)
    #endmember_comp = np.dot(endmember_site_occ_inv, site_occ)
    endmember_comp, residual = optimize.nnls(endmember_site_occ.T, site_occ)
    
    return endmember_comp, residual

In [None]:
def tern_scatter(XA, XB, XC, values, fignum=None,           
                 marker='o', vmin=-.1, vmax=+.1, cmap='viridis',
                 labels=['A','B','C'], label_offset=.03):
    def get_len(arr):
        try:
            N = len(arr)
        except:
            N = 1
            
        return N
        
    N_v = get_len(values)
    N_A = get_len(XA)
    N_B = get_len(XB)
    N_C = get_len(XC)
    
    N = np.max((N_v, N_A, N_B, N_C))
        
    if np.isscalar(values):
        values = values*np.ones(N)
    
    if np.isscalar(XA):
        XA = XA*np.ones(N)
        
    if np.isscalar(XB):
        XB = XB*np.ones(N)
        
    if np.isscalar(XC):
        XC = XC*np.ones(N)
        
        
    if fignum is None:
        plt.figure(frameon=False)
    
        plt.plot([0,1,.5,0],[0,0,np.sqrt(3)/2,0],'k-')
        plt.gca().set_aspect('equal')
        plt.gca().axis('off')
        fignum=plt.gcf().number
        
        if labels is not None:
            assert len(labels)==3, (
                'labels must be length 3 list of strings')
            
            plt.text(0.5, np.sqrt(3)/2+label_offset, labels[0], 
                     fontsize=14, 
                     horizontalalignment='center',
                     verticalalignment='bottom')
            plt.text(0-label_offset, 0-label_offset, labels[1], 
                     horizontalalignment='right', fontsize=14,
                     verticalalignment='top')
            plt.text(1+label_offset, 0-label_offset, labels[2], 
                     horizontalalignment='left', fontsize=14,
                     verticalalignment='top')
        
    else:
        plt.figure(fignum)
        
        
    #cmap = plt.colormaps('seismic')

    plt.scatter(XC-0.5*(1-XA)+.5, np.sqrt(3)/2*XA, 
                c=values, marker=marker, 
                vmin=vmin, vmax=vmax, cmap=cmap)
    

    return fignum


In [None]:
#TAS diagram
def add_LeMaitre_fields(plot_axes, fontsize=8, color=(0.6, 0.6, 0.6)):
    """Add fields for geochemical classifications from LeMaitre et al (2002)
    to pre-existing axes.  If necessary, the axes object can be retrieved via
    plt.gca() command. e.g.
    
    ax1 = plt.gca()
    add_LeMaitre_fields(ax1)
    ax1.plot(silica, total_alkalis, 'o')
    
    Fontsize and color options can be used to change from the defaults.
    
    It may be necessary to follow the command with plt.draw() to update
    the plot.
    
    Le Maitre RW (2002) Igneous rocks : IUGS classification and glossary of
        terms : recommendations of the International Union of Geological 
        Sciences Subcommission on the Systematics of igneous rocks, 2nd ed. 
        Cambridge University Press, Cambridge
"""

    # Check matplotlib is imported
    import sys
    if 'matplotlib.pyplot' not in sys.modules:
        raise MissingModuleException("""Matplotlib not imported.
        Matplotlib is installed as part of many scientific packages and is
        required to create plots.""")
    
    # Check that plot_axis can plot
    if 'plot' not in dir(plot_axes):
        raise TypeError('plot_axes is not a matplotlib axes instance.')
    
    # Prepare the field information
    from collections import namedtuple
    FieldLine = namedtuple('FieldLine', 'x1 y1 x2 y2')
    lines = (FieldLine(x1=41, y1=0, x2=41, y2=7),
             FieldLine(x1=41, y1=7, x2=52.5, y2=14),
             FieldLine(x1=45, y1=0, x2=45, y2=5),
             FieldLine(x1=41, y1=3, x2=45, y2=3),
             FieldLine(x1=45, y1=5, x2=61, y2=13.5),
             FieldLine(x1=45, y1=5, x2=52, y2=5),
             FieldLine(x1=52, y1=5, x2=69, y2=8),
             FieldLine(x1=49.4, y1=7.3, x2=52, y2=5),
             FieldLine(x1=52, y1=5, x2=52, y2=0),
             FieldLine(x1=48.4, y1=11.5, x2=53, y2=9.3),
             FieldLine(x1=53, y1=9.3, x2=57, y2=5.9),
             FieldLine(x1=57, y1=5.9, x2=57, y2=0),
             FieldLine(x1=52.5, y1=14, x2=57.6, y2=11.7),
             FieldLine(x1=57.6, y1=11.7, x2=63, y2=7),
             FieldLine(x1=63, y1=7, x2=63, y2=0),
             FieldLine(x1=69, y1=12, x2=69, y2=8),
             FieldLine(x1=45, y1=9.4, x2=49.4, y2=7.3),
             FieldLine(x1=69, y1=8, x2=77, y2=0))

    FieldName = namedtuple('FieldName', 'name x y rotation')
    names = (FieldName('Picro\nbasalt', 43, 2, 0),
             FieldName('Basalt', 48.5, 2, 0),
             FieldName('Basaltic\nandesite', 54.5, 2, 0),
             FieldName('Andesite', 60, 2, 0),
             FieldName('Dacite', 68.5, 2, 0),
             FieldName('Rhyolite', 76, 9, 0),
             FieldName('Trachyte\n(Q < 20%)\n\nTrachydacite\n(Q > 20%)',
                       64.5, 11.5, 0),
             FieldName('Basaltic\ntrachyandesite', 53, 8, -20),
             FieldName('Trachy-\nbasalt', 49, 6.2, 0),
             FieldName('Trachyandesite', 57.2, 9, 0),
             FieldName('Phonotephrite', 49, 9.6, 0),
             FieldName('Tephriphonolite', 53.0, 11.8, 0),
             FieldName('Phonolite', 57.5, 13.5, 0),
             FieldName('Tephrite\n(Ol < 10%)', 45, 8, 0),
             FieldName('Foidite', 44, 11.5, 0),
             FieldName('Basanite\n(Ol > 10%)', 43.5, 6.5, 0))

    # Plot the lines and fields
    for line in lines:
        plot_axes.plot([line.x1, line.x2], [line.y1, line.y2],
                       '-', color=color, zorder=0)
    for name in names:
        plot_axes.text(name.x, name.y, name.name, color=color, size=fontsize,
                 horizontalalignment='center', verticalalignment='top',
                 rotation=name.rotation, zorder=0)