In [1]:
import os
import re
from collections import OrderedDict
import math
from ctgomartini.api import MartiniTopFile



In [2]:
def Extract_contacts_from_top(top, molecule_name):
    atoms = top.moleculeTypes[molecule_name]._topology['atoms']
    nonbond_params = top.forcefield._parameters['nonbond_params']

    pattern=f'{molecule_name}_\d+'
    contact_atoms = {}  # atomtype: atomid # hard restraint
    for item in atoms:
        atomtype = item[1]
        if re.fullmatch(pattern, atomtype):
            atomid=int(item[0])
            contact_atoms[atomtype]=atomid
    
    contacts = []
    for fields in nonbond_params:
        atomtype1 = fields[0]
        atomtype2 = fields[1]
        judge1=bool(re.fullmatch(pattern, atomtype1))
        judge2=bool(re.fullmatch(pattern, atomtype2))
        if judge1 ^ judge2 :
            raise Exception(f"Error: Unsupport the contact bewteen {key}!")
        elif judge1 & judge2:
            try:
                atomid1=contact_atoms[atomtype1]
                atomid2=contact_atoms[atomtype2]
            except:
                raise Exception(f"Error: not contact_atoms! {fields}")
            
            atomid1, atomid2 = sorted([atomid1, atomid2])
            assert fields[2] == '1', f"Error: only support functype 1: {fields}"
            newfields=[str(atomid1), str(atomid2)]+fields[2:]
            contacts.append(newfields)
        else:
            continue    
        
    contacts = sorted(contacts, key=lambda fields: (int(fields[0]), int(fields[1])))
    return contacts

def GetAtomNames(atomidlist, atoms):
    atoms_dict = {int(fields[0]): fields for fields in atoms}
    
    atomnamelist = []
    for atomid in atomidlist:
        fields = atoms_dict[int(atomid)]
        atomname = fields[4]
        atomnamelist.append(atomname)
    return atomnamelist    

def GetAngleDiehdralType(atomnamelist):
    typestr = ''
    for atomname in atomnamelist:
        if atomname == 'BB':
            typestr += 'B'
        elif atomname.startswith('SC'):
            typestr += 'S'
        else:
            raise ValueError(f"Error: Unsupport atomnames other than BB and SC*: {atomnamelist}")
    return typestr

def DifferentiateAngles(angles, atoms):
    """
    Get the BBB angles and notBBB angles

    Parameters
    #########
    angles: list
        list of angle fields
    atoms: list
        list of atom fields

    Return
    ######
    BBB_angles, notBBB_angles
    """
    BBB_angles = []
    notBBB_angles = []

    for fields in angles:
        atomidlist = fields[:3]
        atomnamelist = GetAtomNames(atomidlist, atoms)
        if GetAngleDiehdralType(atomnamelist) == 'BBB':
            BBB_angles.append(fields)
        else:
            notBBB_angles.append(fields)
    return BBB_angles, notBBB_angles

def DifferentiateDihedrals(dihedrals, atoms):
    """
    Get the BBBB dihedrals, SSSS_dihedrals and SBBS dihedrals

    Parameters
    #########
    angles: list
        list of dihedrals fields
    atoms: list
        list of atom fields

    Return
    ######
    BBBB_dihedrals, SSSS_dihedrals, SBBS_dihedrals
    """
    BBBB_dihedrals = []
    SSSS_dihedrals = []
    SBBS_dihedrals = []
    other_dihedrals = []

    for fields in dihedrals:
        atomidlist = fields[:4]
        atomnamelist = GetAtomNames(atomidlist, atoms)
        if GetAngleDiehdralType(atomnamelist) == 'BBBB':
            BBBB_dihedrals.append(fields)
        elif GetAngleDiehdralType(atomnamelist) == 'SBBS':
            SBBS_dihedrals.append(fields)
        elif GetAngleDiehdralType(atomnamelist) == 'SSSS':
            SSSS_dihedrals.append(fields)            
        else:
            other_dihedrals.append(fields)
    assert other_dihedrals == [], f'Error: not supported dihedral type: {other_dihedrals}'
    return BBBB_dihedrals, SSSS_dihedrals, SBBS_dihedrals


In [7]:
molA._topology['bonds']

[['1', '3', '1', '0.350', '4000'],
 ['3', '5', '1', '0.350', '4000'],
 ['5', '7', '1', '0.350', '4000'],
 ['7', '9', '1', '0.350', '4000'],
 ['9', '11', '1', '0.350', '4000'],
 ['11', '13', '1', '0.350', '4000'],
 ['13', '15', '1', '0.350', '4000'],
 ['15', '17', '1', '0.350', '4000'],
 ['17', '21', '1', '0.350', '4000'],
 ['21', '23', '1', '0.350', '4000'],
 ['23', '25', '1', '0.350', '4000'],
 ['25', '29', '1', '0.350', '4000'],
 ['29', '31', '1', '0.350', '4000'],
 ['31', '35', '1', '0.350', '4000'],
 ['35', '38', '1', '0.350', '4000'],
 ['38', '40', '1', '0.350', '4000'],
 ['40', '41', '1', '0.350', '4000'],
 ['41', '43', '1', '0.350', '4000'],
 ['43', '45', '1', '0.350', '4000'],
 ['45', '50', '1', '0.350', '4000'],
 ['50', '52', '1', '0.350', '4000'],
 ['84', '86', '1', '0.350', '4000'],
 ['86', '89', '1', '0.350', '4000'],
 ['89', '91', '1', '0.350', '4000'],
 ['91', '93', '1', '0.350', '4000'],
 ['93', '98', '1', '0.350', '4000'],
 ['98', '100', '1', '0.350', '4000'],
 ['100', 

In [8]:
mol = molA


In [9]:
mol._topology['atoms']

[['1', 'Q5', '5', 'LEU', 'BB', '1', '1', '72.0'],
 ['2', 'SC2', '5', 'LEU', 'SC1', '2', '0.0', '54.0'],
 ['3', 'SP2', '6', 'VAL', 'BB', '3', '0.0', '54.0'],
 ['4', 'SC3', '6', 'VAL', 'SC1', '4', '0.0', '54.0'],
 ['5', 'SP2', '7', 'VAL', 'BB', '5', '0.0', '54.0'],
 ['6', 'SC3', '7', 'VAL', 'SC1', '6', '0.0', '54.0'],
 ['7', 'SP2', '8', 'ALA', 'BB', '7', '0.0', '54.0'],
 ['8', 'TC3', '8', 'ALA', 'SC1', '8', '0.0', '36.0'],
 ['9', 'P2', '9', 'THR', 'BB', '9', '0.0', '72.0'],
 ['10', 'SP1', '9', 'THR', 'SC1', '10', '0.0', '54.0'],
 ['11', 'P2', '10', 'ASP', 'BB', '11', '0.0', '72.0'],
 ['12', 'SQ5n', '10', 'ASP', 'SC1', '12', '-1.0', '54.0'],
 ['13', 'P2', '11', 'THR', 'BB', '13', '0.0', '72.0'],
 ['14', 'SP1', '11', 'THR', 'SC1', '14', '0.0', '54.0'],
 ['15', 'SP2', '12', 'ALA', 'BB', '15', '0.0', '54.0'],
 ['16', 'TC3', '12', 'ALA', 'SC1', '16', '0.0', '36.0'],
 ['17', 'P2', '13', 'PHE', 'BB', '17', '0.0', '72.0'],
 ['18', 'SC4', '13', 'PHE', 'SC1', '18', '0.0', '54.0'],
 ['19', 'TC5', '

In [None]:
def GetLongElasticBonds(mol):
    BB_atoms = mol._top


In [4]:
working_path = "/home/ys/SongYang/GoMartini3/gbp_sw/Martini3_300ns/repeat-copy/"
os.chdir(working_path)

In [6]:
topfileA = 'system_open.top'
mol_nameA = 'gbp_open'
top = MartiniTopFile(topfileA)
contacts = Extract_contacts_from_top(top, mol_nameA)
molA =  top.moleculeTypes[mol_nameA]
molA._topology['contacts'] = contacts

topfileB = 'system_close.top'
mol_nameB = 'gbp_close'
top = MartiniTopFile(topfileB)
contacts = Extract_contacts_from_top(top, mol_nameB)
molB =  top.moleculeTypes[mol_nameB]
molB._topology['contacts'] = contacts


In [None]:
# angles:
BBB_angles, BBS_regular_angles, SBB_regular_angles, BSS_angles, SBB_scFix_angles, BBS_scFix_angles, 
BBB_angles, not_BBB_angles
# dihedrals
BBBB_dihedrals, SSSS_dihedrals, SBBS_scFix_dihedrals

In [30]:
# BBB_angles, notBBB_angles = DifferentiateAngles(molA._topology['angles'], molA._topology['atoms'])
BBBB_dihedrals, SSSS_dihedrals, SBBS_dihedrals = DifferentiateDihedrals(molA._topology['dihedrals'], molA._topology['atoms'])

In [98]:
molA._topology['moleculetype'][0][0]

'gbp_open'

In [79]:
mbatoms = CombineMols.combine_atoms('gbp', 
                          [
                            [mol_nameA, molA._topology['atoms']], 
                            [mol_nameB, molB._topology['atoms']],
                            [mol_nameB, molB._topology['atoms']],

                          ]
                          )

mbbonds, mbconstraints = CombineMols.combine_bonds_constraints(3, 
    [molA._topology['bonds'],
     molB._topology['bonds'],
     molA._topology['bonds'],
    ],
    [molA._topology['constraints'],
     molB._topology['constraints'],
     molA._topology['constraints'],
    ]
)

mbexclusions = CombineMols.combine_exclusions(   
    [molA._topology['exclusions'],
     molB._topology['exclusions'],
     molA._topology['exclusions'],
    ])

mbcontacts, mbmulti_contacts = CombineMols.combine_contacts(3, 
    [molA._topology['contacts'],
     molB._topology['contacts'],
     molA._topology['contacts'],
    ], 0.06)

molA_BBB_angles, molA_notBBB_angles = DifferentiateAngles(molA._topology['angles'], molA._topology['atoms'])
molB_BBB_angles, molB_notBBB_angles = DifferentiateAngles(molB._topology['angles'], molB._topology['atoms'])
mbangles, mbmulti_angles = CombineMols.combine_angles(3, 
    [molA_BBB_angles,
     molB_BBB_angles,
     molA_BBB_angles,
    ], 15)


molA_BBBB_dihedrals, molA_SSSS_dihedrals, molA_SBBS_dihedrals = DifferentiateDihedrals(molA._topology['dihedrals'], molA._topology['atoms'])
molB_BBBB_dihedrals, molB_SSSS_dihedrals, molB_SBBS_dihedrals = DifferentiateDihedrals(molB._topology['dihedrals'], molB._topology['atoms'])
mbdihdedrals, mbmulti_dihedrals = CombineMols.combine_dihedrals(3, 
    [molA_BBBB_dihedrals,
     molB_BBBB_dihedrals,
     molA_BBBB_dihedrals,
    ], 30)

In [96]:
SameListList([molA_SSSS_dihedrals, molB_SSSS_dihedrals], sort=True)

True

In [90]:
SameListList([molA_SBBS_dihedrals, molB_SBBS_dihedrals])

True

In [94]:
SameListList([molA_notBBB_angles, molB_notBBB_angles])

True

In [88]:
ForceListFloat(molA_SSSS_dihedrals)

[[49.0, 47.0, 48.0, 46.0, 2.0, 180.0, 50.0],
 [70.0, 69.0, 67.0, 66.0, 2.0, 180.0, 100.0],
 [97.0, 95.0, 96.0, 94.0, 2.0, 180.0, 50.0],
 [190.0, 188.0, 189.0, 187.0, 2.0, 180.0, 50.0],
 [195.0, 193.0, 194.0, 192.0, 2.0, 180.0, 50.0],
 [274.0, 272.0, 273.0, 271.0, 2.0, 180.0, 50.0],
 [323.0, 321.0, 322.0, 320.0, 2.0, 180.0, 50.0],
 [368.0, 366.0, 367.0, 365.0, 2.0, 180.0, 50.0],
 [418.0, 416.0, 417.0, 415.0, 2.0, 180.0, 50.0],
 [480.0, 478.0, 479.0, 477.0, 2.0, 180.0, 50.0],
 [491.0, 489.0, 490.0, 488.0, 2.0, 180.0, 50.0],
 [503.0, 502.0, 500.0, 499.0, 2.0, 180.0, 100.0]]

In [86]:
ForceListFloat(molB_SSSS_dihedrals)

[[49.0, 47.0, 48.0, 46.0, 2.0, 180.0, 50.0],
 [70.0, 69.0, 67.0, 66.0, 2.0, 180.0, 100.0],
 [97.0, 95.0, 96.0, 94.0, 2.0, 180.0, 50.0],
 [190.0, 188.0, 189.0, 187.0, 2.0, 180.0, 50.0],
 [195.0, 193.0, 194.0, 192.0, 2.0, 180.0, 50.0],
 [274.0, 272.0, 273.0, 271.0, 2.0, 180.0, 50.0],
 [323.0, 321.0, 322.0, 320.0, 2.0, 180.0, 50.0],
 [368.0, 366.0, 367.0, 365.0, 2.0, 180.0, 50.0],
 [418.0, 416.0, 417.0, 415.0, 2.0, 180.0, 50.0],
 [480.0, 478.0, 479.0, 477.0, 2.0, 180.0, 50.0],
 [491.0, 489.0, 490.0, 488.0, 2.0, 180.0, 50.0],
 [503.0, 502.0, 500.0, 499.0, 2.0, 180.0, 100.0]]

In [80]:
mbmulti_dihedrals

[['77', '79', '82', '84', '3', '2', '1', '-120', '400', '1'],
 ['369', '373', '375', '378', '3', '1', '1', '-120', '400', '1'],
 ['369', '373', '375', '378', '3', '3', '1', '-120', '400', '1']]

In [81]:
mbdihdedrals

[['53', '57', '59', '61', '1', '-120.0', '400.0', '1'],
 ['57', '59', '61', '63', '1', '-120.0', '400.0', '1'],
 ['59', '61', '63', '65', '1', '-120.0', '400.0', '1'],
 ['61', '63', '65', '71', '1', '-120.0', '400.0', '1'],
 ['63', '65', '71', '73', '1', '-120.0', '400.0', '1'],
 ['65', '71', '73', '75', '1', '-120.0', '400.0', '1'],
 ['71', '73', '75', '77', '1', '-120.0', '400.0', '1'],
 ['73', '75', '77', '79', '1', '-120.0', '400.0', '1'],
 ['75', '77', '79', '82', '1', '-120.0', '400.0', '1'],
 ['111', '115', '117', '118', '1', '-120.0', '400.0', '1'],
 ['115', '117', '118', '120', '1', '-120.0', '400.0', '1'],
 ['117', '118', '120', '122', '1', '-120.0', '400.0', '1'],
 ['118', '120', '122', '124', '1', '-120.0', '400.0', '1'],
 ['120', '122', '124', '126', '1', '-120.0', '400.0', '1'],
 ['122', '124', '126', '128', '1', '-120.0', '400.0', '1'],
 ['158', '160', '162', '165', '1', '-120.0', '400.0', '1'],
 ['263', '264', '266', '268', '1', '-120.0', '400.0', '1'],
 ['264', '266', 

In [59]:
molA._topology['angles']

[['50', '52', '53', '10', '130', '20'],
 ['52', '53', '57', '10', '130', '20'],
 ['53', '57', '59', '2', '96', '700'],
 ['57', '59', '61', '2', '96', '700'],
 ['59', '61', '63', '2', '96', '700'],
 ['61', '63', '65', '2', '96', '700'],
 ['63', '65', '71', '2', '96', '700'],
 ['65', '71', '73', '2', '96', '700'],
 ['71', '73', '75', '2', '96', '700'],
 ['73', '75', '77', '2', '96', '700'],
 ['75', '77', '79', '2', '96', '700'],
 ['77', '79', '82', '2', '96', '700'],
 ['79', '82', '84', '10', '100', '20'],
 ['82', '84', '86', '10', '100', '20'],
 ['107', '109', '111', '10', '127', '20'],
 ['109', '111', '115', '10', '127', '20'],
 ['111', '115', '117', '2', '96', '700'],
 ['115', '117', '118', '2', '96', '700'],
 ['117', '118', '120', '2', '96', '700'],
 ['118', '120', '122', '2', '98', '100'],
 ['120', '122', '124', '10', '98', '100'],
 ['122', '124', '126', '2', '98', '100'],
 ['124', '126', '128', '2', '96', '700'],
 ['126', '128', '130', '10', '100', '20'],
 ['128', '130', '132', '10

In [95]:
def CombineDict(dict_list):
    n_dict = len(dict_list)
    key_combined_list = []
    for i in range(n_dict):
        key_combined_list += list(dict_list[i].keys())
    key_combined_list = list(set(key_combined_list))
    key_combined_list = sorted(key_combined_list)
    
    dict_combined = OrderedDict()
    for key in key_combined_list:
        dict_combined[key] = []
        for i in range(n_dict):
            if key in dict_list[i].keys():
                dict_combined[key].append(dict_list[i][key])
    return dict_combined

def ForceItemFloat(item):
    try:
        item = float(item)
    except:
        pass
    return item

def ForceListFloat(itemlist):
    newlist = []
    for item in itemlist:
        if type(item) is not list:
            newlist.append(ForceItemFloat(item))
        else:
            newlist.append(ForceListFloat(item))
    return newlist

def SameListList(listlist, typeforce=True, sort=False):
    """
    Judge whether the lists in the list are the same
    
    Parameters
    #####
    listlist: list(list)
        list of some lists
    typefoce: bool, True

    Return
    ######
    True or False
    """
    issame = True
    ref_list = ForceListFloat(listlist[0]) if typeforce else listlist[0]
    ref_list = sorted(ref_list) if sort else ref_list
    for item in listlist[1:]:
        item = ForceListFloat(item) if typeforce else item
        item = sorted(item) if sort else item
        if item != ref_list:
            issame = False
    return issame

def SameList(alist, typeforce=True):
    """
    The items in the list are the same.
    """
    issame = True
    alistlist = [[item] for item in alist]
    issame = SameListList(alistlist, typeforce)
    return issame


    

In [60]:
class CombineMols:
    @staticmethod
    def combine_atoms(mbmolname: str, mols_atoms_pairs: list):
        """
        Combine atoms from different states of molecules

        Parameters
        ##########
        mbmolname: str,
            New prefix of atomnames of virtual sites

        mols_atoms_pairs: list
            Atoms from different states of molecules
            [(molnameA, atomtopA), (molnameB, atomtopB), ...]

        Return
        ######
        atomtop: list
            atomtop
        """
        n_mols = len(mols_atoms_pairs)
        n_atoms = len(mols_atoms_pairs[0][1])
        
        # Assert that there are not less than 2 mols
        assert n_mols >= 2, "Error: The number of mols must more than or equal 2"

        # Assert that differet molecules have same number of atoms
        for pair in mols_atoms_pairs[1:]:
            assert len(pair[1]) == len(mols_atoms_pairs[0][1])
        
        mols_atoms_dict_list = []
        for pair in mols_atoms_pairs:
            atoms = pair[1]
            mols_atoms_dict_list.append({(int(atoms[i][0]),): atoms[i] for i in range(n_atoms)})
        mols_atoms_dict_combined = CombineDict(mols_atoms_dict_list)

        mbmol_atomtop = []
        Extract=lambda atomtype: re.findall('^(\w+)_(\d+)$',atomtype)[0]
        for key, value in mols_atoms_dict_combined.items():
            assert len(value) == n_mols, f"Error: The number of molecules with the same atomid is not equal to the number of molecules. {key}"
            if SameListList(value):
                mbmol_atomtop.append(value[0].copy())
            else:
                try:
                    mol_resid_extract_list = []
                    for i, atom in enumerate(value):
                        atomtype = atom[1]
                        mol_name_extract, mol_resid_extract = Extract(atomtype)
                        assert mol_name_extract == mols_atoms_pairs[i][0]
                        mol_resid_extract_list.append(mol_resid_extract)

                    if SameList(mol_resid_extract_list):
                        newatomtype = f"{mbmolname}_{mol_resid_extract}"
                        newatom = value[0].copy()
                        newatom[1] = newatomtype
                        mbmol_atomtop.append(newatom)
                    else:
                        raise ValueError

                except:
                    raise ValueError("Error: atoms from different states of one molecule cannot meet the combination rule!", mols_atom_list)
        
        assert len(mbmol_atomtop) == n_atoms
        return mbmol_atomtop    

    @staticmethod
    def combine_bonds_constraints(n_mols, mols_bonds_list, mols_constraints_list):
        """
        """

        mols_connections_dict_list = []
        for i, bonds in enumerate(mols_bonds_list):
            connection_dict = {}
            state = str(i+1)
            for bond in bonds:
                assert bond[2] == "1", f"Error: bond type must be 1: {bond}"
                if int(bond[0]) > int(bond[1]):
                    bond[:2] = [bond[1], bond[0]]
                key = tuple([int(bond[0]), int(bond[1])])
                connection_dict[key] = [state] + bond
            mols_connections_dict_list.append(connection_dict)
        
        for i, constraints in enumerate(mols_constraints_list):
            connection_dict = {}
            state = str(i+1)
            for constraint in constraints:
                assert constraint[2] == "1", f"Error: constraint type must be 1: {constraint}"
                if int(constraint[0]) > int(constraint[1]):
                    constraint[:2] = [constraint[1], constraint[0]]
                key = tuple([int(constraint[0]), int(constraint[1])])
                connection_dict[key] = [state] + constraint[:4] + [None]
            mols_connections_dict_list.append(connection_dict)
        
        mols_connections_dict_combined = CombineDict(mols_connections_dict_list)

        mbconnections = []
        for key, value in mols_connections_dict_combined.items():
            n_states_in_value = len(set([fields[0] for fields in value]))
            assert n_states_in_value == len(value), f"Error: value repeats! {value}"
            assert n_states_in_value == n_mols, f"Error: {key} does not have {n_mols} values"
            dist_list = [float(fields[1:][3]) for fields in value]
            k_list = [float(fields[1:][4]) for fields in value if fields[1:][4] is not None]
            dist_mean = sum(dist_list) / len(dist_list)
            dist_mean = str(round(dist_mean, 3))
            if k_list != []:
                k_mean = sum(k_list) / len(k_list)
                k_mean = str(round(k_mean, 3))
            else:
                k_mean = None
            
            mbconnections.append(value[0][1:][:3] + [dist_mean, k_mean])
        
        mbbonds = []
        mbconstraints = []
        for item in mbconnections:
            if item[4] is not None:
                mbbonds.append(item)
            else:
                mbconstraints.append(item[:4])

        return mbbonds, mbconstraints



    @staticmethod
    def combine_exclusions(mols_exclusions_list):
        """
        """

        exclusion_pair_list = []
        for exclusions in mols_exclusions_list:
            for fields in exclusions:
                item0 = fields[0]
                for item in fields[1:]:
                    exclusion_pair_list.append(tuple(sorted([int(item0), int(item)])))
        exclusion_pair_list = sorted(list(set(exclusion_pair_list)))
        
        mbexclusion_dict = OrderedDict()
        for item in exclusion_pair_list:
            key=(item[0],)
            if key not in mbexclusion_dict: 
                mbexclusion_dict[key]=[str(item[0])]
            if item[1] not in mbexclusion_dict[key]:
                mbexclusion_dict[key].append(str(item[1]))

        mbexclusions = list(mbexclusion_dict.values())
        return mbexclusions

    @staticmethod
    def combine_contacts(n_mols, mols_contacts_list, cutoff):
        """
        """
        assert len(mols_contacts_list) == n_mols, f"Error: The number of contacts is not equal to the number of molecules."
        mols_contacts_dict_list = []
        for i, contacts in enumerate(mols_contacts_list):
            mols_contacts_dict = {}
            state = str(i+1)
            for fields in contacts:
                assert fields[2] == '1', f"Error: contact type is not 1. {fields}"
                if int(fields[0]) > int(fields[1]):
                    fields[:2] = [fields[1], fields[0]]
                key = tuple([int(fields[0]), int(fields[1])])
                if key not in mols_contacts_dict:
                    mols_contacts_dict[key] = [state] + fields
            mols_contacts_dict_list.append(mols_contacts_dict)
        mols_contacts_dict_combined = CombineDict(mols_contacts_dict_list)

        mbcontacts = []
        mbmulti_contacts = []
        for key, value in mols_contacts_dict_combined.items():
            n_states_in_value = len(set([fields[0] for fields in value]))
            assert n_states_in_value == len(value), f"Error: contact value repeats! {value}"
            if n_states_in_value != n_mols:
                for fields in value:
                    state = fields[0]
                    fields = fields[1:]
                    newfields = fields[:2] + [str(n_mols), state] + fields[2:]
                    mbmulti_contacts.append(newfields)
            else:
                sigma_list = [float(fields[1:][3]) for fields in value]
                epsilon_list = [float(fields[1:][4]) for fields in value]
                diff_sigma = abs(max(sigma_list) - min(sigma_list))
                if diff_sigma <= cutoff:
                    mean_sigma = round(sum(sigma_list) / len(sigma_list), 10)
                    mean_epsilon = round(sum(epsilon_list) / len(epsilon_list), 10)
                    mbcontacts.append([str(key[0]), str(key[1]), '1', str(mean_sigma), str(mean_epsilon)])

                else:
                    for fields in value:
                        state = fields[0]
                        fields = fields[1:]
                        newfields = fields[:2] + [str(n_mols), state] + fields[2:]
                        mbmulti_contacts.append(newfields)

        return mbcontacts, mbmulti_contacts
        

    @staticmethod
    def combine_angles(n_mols, mols_angles_list, cutoff):
        """
        Combine angles from different states of molecules
        Convert the angle type 2 (g96 angles) to type 10 (restricted angles) if the angles of the same atoms from different states have different types
        
        Parameters
        ##########
        """

        mols_angles_dict_list = []
        assert len(mols_angles_list) == n_mols, 'The number of molecules is not equal to the number of angles'
        for i, angles in enumerate(mols_angles_list):
            mols_angles_dict = {}
            state = str(i+1)
            for fields in angles:
                assert fields[3] in ['2', '10'], f"Error: angle type is not 2 or 10. {fields}"
                assert float(fields[4]) >=0 and float(fields[4]) <=180, f"Error: angles should be in 0-180. {fields}"
                if int(fields[0]) > int(fields[2]):
                    fields[:3] = [fields[2], fields[1], fields[0]]
                key = tuple([int(fields[0]), int(fields[1]), int(fields[2])])
                if key not in mols_angles_dict:
                    mols_angles_dict[key] = [state] + fields
            mols_angles_dict_list.append(mols_angles_dict)
        mols_angles_dict_combined = CombineDict(mols_angles_dict_list)

        # g96_angles only for those that the angle difference is smaller than cutoff and the inital types are g96 angles (2).
        # Others should be transformed into restricted angles.
        # type_g96_angles='2'
        # type_restricted_angles='10'
        def g96Torestricted(fields):
            if fields[3] == '2':
                newfields=fields[:3]+['10',fields[4],25.0]
            elif fields[3] == '10':
                newfields=fields.copy()
            else:
                raise ValueError(f"Error: Unsupport angle type. {fields}")
            return newfields

            
        mbangles = []
        mbmulti_angles = []
        for key, value in mols_angles_dict_combined.items():
            n_states_in_value = len(set([fields[0] for fields in value]))
            assert n_states_in_value == len(value), f"Error: angle value repeats! {value}"
            if  n_states_in_value != n_mols:
                for fields in value:
                    state = fields[0]
                    fields = g96Torestricted(fields[1:])
                    newfields = fields[:3] + [str(n_mols), state] + fields[3:]
                    mbmulti_angles.append(newfields)
            else:
                type_list = [fields[1:][3] for fields in value]
                angle_list = [float(fields[1:][4]) for fields in value]
                diff_angle = abs(max(angle_list) - min(angle_list))
                if not SameList(type_list) or diff_angle > cutoff:
                    value = [[fields[0]] + g96Torestricted(fields[1:]) for fields in value]
                    type_list = [fields[1:][3] for fields in value]
                    k_list = [float(fields[1:][5]) for fields in value]
                else:
                    k_list = [float(fields[1:][5]) for fields in value]
                
                if diff_angle <= cutoff:
                    mean_angle = round(sum(angle_list) / len(angle_list), 2)
                    mean_k = round(sum(k_list) / len(k_list), 2)
                    mbangles.append([str(key[0]), str(key[1]), str(key[2]), type_list[0], str(mean_angle), str(mean_k)])

                else:
                    for fields in value:
                        state = fields[0]
                        fields = g96Torestricted(fields[1:])
                        newfields = fields[:3] + [str(n_mols), state] + fields[3:]
                        mbmulti_angles.append(newfields)
        return mbangles, mbmulti_angles
        

    @staticmethod
    def combine_dihedrals(n_mols, mols_dihedrals_list, cutoff):
        """
        Combine dihedrals from different states of molecules
        assert periodic dihedrals

        Parameters
        ##########
        """

        mols_dihedrals_dict_list = []
        assert len(mols_dihedrals_list) == n_mols, 'The number of molecules is not equal to the number of angles'
        for i, dihedrals in enumerate(mols_dihedrals_list):
            mols_dihedrals_dict = {}
            state = str(i+1)
            for fields in dihedrals:
                assert fields[4] in ['1'], f"Error: dihedral type is not 1. {fields}"
                assert fields[7] == '1', f"Error: dihedral n is not 1. {fields}"
                assert float(fields[5]) > -180 and float(fields[5]) <=180, f"Error: dihedrals should be in -180 -- +180. {fields}"
                if int(fields[0]) > int(fields[3]):
                    fields[:4] = [fields[3], fields[2], fields[1], fields[0]]
                key = tuple([int(fields[0]), int(fields[1]), int(fields[2]), int(fields[3])])
                if key not in mols_dihedrals_dict:
                    mols_dihedrals_dict[key] = [state] + fields
            mols_dihedrals_dict_list.append(mols_dihedrals_dict)
        mols_dihedrals_dict_combined = CombineDict(mols_dihedrals_dict_list)

        mbdihdedrals = []
        mbmulti_dihedrals = []
        for key, value in mols_dihedrals_dict_combined.items():
            n_states_in_value = len(set([fields[0] for fields in value]))
            assert n_states_in_value == len(value), f"Error: one state has more than one dihedral for same atoms! {value}"
            if  n_states_in_value != n_mols:
                for fields in value:
                    state = fields[0]
                    fields = fields[1:]
                    newfields = fields[:4] + [str(n_mols), state] + fields[4:]
                    mbmulti_dihedrals.append(newfields)
            else:
                dihedral_list = [float(fields[1:][5]) for fields in value]
                diff_dihedral, mean_dihedral = Calculate_DiffDihedral(dihedral_list)
                k_list = [float(fields[1:][6]) for fields in value]
                
                if diff_dihedral <= cutoff:
                    mean_dihedral = round(mean_dihedral, 2)
                    mean_k = round(sum(k_list) / len(k_list), 2)
                    mbdihdedrals.append([str(key[0]), str(key[1]), str(key[2]), str(key[3]), '1', str(mean_dihedral), str(mean_k), '1'])
                else:
                    for fields in value:
                        state = fields[0]
                        fields = fields[1:]
                        newfields = fields[:4] + [str(n_mols), state] + fields[4:]
                        mbmulti_dihedrals.append(newfields)
        return mbdihdedrals, mbmulti_dihedrals



In [69]:
def Calculate_DiffDihedral(dihedral_list):
    """
    when follow the clocklike direction or anticlocklike direction, the max difference between dihedrals shold be less than 180
    """
    
    dihedral_list = sorted(dihedral_list)
    # print(dihedral_list)
    anticlock_dihedral_list = []
    for i, dihedral in enumerate(dihedral_list):
        if dihedral < 0:
            dihedral += 360
        anticlock_dihedral_list.append(dihedral)
    anticlock_dihedral_list = sorted(anticlock_dihedral_list)
    anticlock_diff_max = abs(anticlock_dihedral_list[-1] - anticlock_dihedral_list[0])
    # print(anticlock_dihedral_list)

    clock_dihedral_list = dihedral_list.copy()
    clock_dihedral_list = sorted(clock_dihedral_list)
    clock_diff_max = abs(clock_dihedral_list[-1] - clock_dihedral_list[0])
    # print(clock_dihedral_list)

    if anticlock_diff_max >= 180 and clock_diff_max < 180:
        dihedral_list = clock_dihedral_list
    elif anticlock_diff_max < 180 and clock_diff_max >= 180:
        dihedral_list = anticlock_dihedral_list
    elif anticlock_diff_max == 0 and clock_diff_max == 0:
        dihedral_list = anticlock_dihedral_list
    else:
        print(anticlock_diff_max, clock_diff_max)
        raise ValueError(f'Error: something wrong with the dihedrals {dihedral_list}')
    
    dihedral_list = sorted(dihedral_list)
    diff_max = abs(dihedral_list[-1] - dihedral_list[0])
    mean_dihedral = sum(dihedral_list)/len(dihedral_list)
    if mean_dihedral > 180: mean_dihedral -= 360
    if mean_dihedral <= -180: mean_dihedral += 360
    return diff_max, mean_dihedral

In [100]:
Calculate_DiffDihedral([90, 45, -45])

(135, 30.0)