In [27]:
import numpy as np
import pandas as pd
import os
import argparse

def get_atom_index():
    with open('geo', 'r') as f:
        lines = f.readlines()
    
    atom_list, index_unique = [], []
    index_dict = {}

    for i, line in enumerate(lines):
        if line.startswith('HETATM'):
            atom_type = line.strip().split()[2]
            if atom_type not in atom_list:
                atom_list.append(atom_type)
                index_unique.append(int(line.strip().split()[1]))
    
    with open('fort.7', 'r') as f:
        lines = f.readlines()

    for atom, index in zip(atom_list, index_unique):
        atom_index = int(lines[index].strip().split()[1])
        index_dict[atom] = atom_index

    if 'Mg' not in index_dict.keys():
        index_dict['Mg'] = 123456
    
    return index_dict

    

def get_avg_charges(atoms_per_layer, index_dict):
    with open('fort.7', 'r') as f:
        lines = f.readlines()

    n_bonds = int(lines[0].strip().split()[-1])
    data_orig = lines[1:-2]
    data_orig = np.array([line.strip().split() for line in data_orig], dtype=float)
    data = data_orig[atoms_per_layer:-atoms_per_layer, :]
    neighbors_dict = {}

    #Get O-average charges based on connected metal atom:
    for line in data:
        if line[1] == index_dict['O']:
            neighbors = [x for x in line[2:11] if x != 0]
            zn_idx, mg_idx = 0, 0
            for n in neighbors:
                n_line = data_orig[int(n-1)]
                if n_line[1] == index_dict['Zn']:
                    zn_idx +=1
                elif n_line[1] == index_dict['Mg']:
                    mg_idx += 1
            try:
                neighbors_dict[str((zn_idx, mg_idx))]['sum_charge'] += line[-1]
                neighbors_dict[str((zn_idx, mg_idx))]['count'] += 1
            except:
                neighbors_dict[str((zn_idx, mg_idx))] = {'sum_charge': line[-1], 'count':1}

    for key, value in neighbors_dict.items():
        neighbors_dict[key]['sum_charge'] = round(neighbors_dict[key]['sum_charge'], 3)
        neighbors_dict[key]['average'] = round(neighbors_dict[key]['sum_charge'] / neighbors_dict[key]['count'], 3)
    
    mg_charges, zn_charges = [], []
    for row in data:
        if row[1] == index_dict['Zn']:
            zn_charges.append(row[-1])
        elif row[1] == index_dict['Mg']:
            mg_charges.append(row[-1])

    if len(mg_charges) != 0:
        zn_avg = round(np.average(zn_charges), 3)
        mg_avg = round(np.average(mg_charges), 3)
        o_charge_sum, o_count_sum = 0, 0

        for key in neighbors_dict.keys():
            mg_bonds = int(key[4])
            if mg_bonds != 0:
                o_charge_sum += neighbors_dict[key]['count'] * neighbors_dict[key]['average']
                o_count_sum += neighbors_dict[key]['count']
                

        o_mg_avg = round(o_charge_sum / o_count_sum, 3)
        o_zn_avg = neighbors_dict['(4, 0)']['average']
        result = (o_zn_avg, o_mg_avg, zn_avg, mg_avg)
    else:
        result = (1, -1)

    return result



def generate_geo(result):
    if len(result) == 4:
        o_mg_status = True
        o_zn_avg, o_mg_avg, zn_avg, mg_avg = result
    else:
        o_mg_status = False
        zn_avg, o_avg = 1, -1

    if o_mg_status:
        with open('fort.7', 'r') as f:
            fort_lines = f.readlines()[1:-2]

        fort_lines = np.array([line.strip().split() for line in fort_lines], dtype=float)
        tol = 0.001
        sum_charges = 1 + tol
        fraction = 1
        while abs(sum_charges) > tol:
            with open('geo', 'r') as infile:
                molcharge_lines = []

                index = 0
                cumulative_o_charge = 0
                cumulative_m_charge = 0
                cumulative_z_charge = 0
                for line in infile:
                    if line.startswith('HETATM'):
                        index += 1
                        fort_line = fort_lines[index - 1]
                        atom_type = line.split()[2]
                        atom_index = index

                        if atom_type == 'O':
                            neighbors = [x for x in fort_line[2:11] if x != 0]
                            mg_status = False
                            for n in neighbors:
                                n_line = fort_lines[int(n-1)]
                                if n_line[1] == 3:
                                    mg_status = True
                                else:
                                    pass
                            if mg_status:
                                o_avg = o_mg_avg * fraction
                            else:
                                o_avg = o_zn_avg * fraction
                            molcharge_lines.append(f"MOLCHARGE {atom_index} {atom_index}   " + str(round(o_avg, 4)))
                            cumulative_o_charge += o_avg
                        elif atom_type == 'Zn':
                            molcharge_lines.append(f"MOLCHARGE {atom_index} {atom_index}   " + str(zn_avg))
                            cumulative_z_charge += zn_avg
                        elif atom_type == 'Mg':
                            molcharge_lines.append(f"MOLCHARGE {atom_index} {atom_index}   " + str(mg_avg))
                            cumulative_m_charge += mg_avg

                sum_charges = cumulative_z_charge + cumulative_m_charge + cumulative_o_charge
                print(f'Sum of charges: {sum_charges}')
                fraction = abs((cumulative_z_charge + cumulative_m_charge) / cumulative_o_charge)

        if molcharge_lines:
            with open('geo', 'r') as infile, open('final/geo', 'w') as outfile:
                for line in infile:
                    if line.startswith('FORMAT ATOM'):
                        for molcharge_line in molcharge_lines:
                            outfile.write(molcharge_line + '\n')

                    outfile.write(line)

    else:
        with open('geo', 'r') as infile, open('final/geo', 'w') as outfile:
            molcharge_lines = []
            index = 0

            for line in infile:
                if line.startswith('HETATM'):
                    index += 1
                    atom_type = line.split()[2]
                    atom_index = index

                    if atom_type == 'O':
                        molcharge_lines.append(f"MOLCHARGE {atom_index} {atom_index}   " + str(o_avg))
                    elif atom_type == 'Zn':
                        molcharge_lines.append(f"MOLCHARGE {atom_index} {atom_index}   " + str(zn_avg))

            infile.seek(0)  # Reset the file pointer to the beginning

            for line in infile:
                if line.startswith('FORMAT ATOM'):
                    for molcharge_line in molcharge_lines:
                        outfile.write(molcharge_line + '\n')

                outfile.write(line)

def main():
    # Create the parser
    parser = argparse.ArgumentParser(description='''This script generates a geo file with fixed charges of Zn, Mg, O atoms.''')

    # Add the arguments
    parser.add_argument('atoms_per_layer', metavar='atoms_per_layer', type=int,
                        help='''the number of atoms per layer in the geo file.
                                Please make sure it is a valid integer''')

    # Execute the parse_args() method
    args = parser.parse_args()

    if not os.path.exists('final'):
        os.mkdir('final')

    index_dict = get_atom_index()
    result = get_avg_charges(args.atoms_per_layer, index_dict)
    generate_geo(result)


if __name__ == "__main__":
    main()