In [1]:
import pandas as pd
from tqdm import tqdm
import os
from core import read_dom_table
import glob
import xml.etree.ElementTree as ET
import networkx as nx
from sklearn.model_selection import train_test_split

In [2]:
def get_xml_file_paths(top_directory):
    pattern = f"{top_directory}/**/*.xml"
    file_paths = glob.glob(pattern, recursive=True)    
    return file_paths

In [3]:
def get_exhchangable_profiles(file_path):
    model = os.path.basename(file_path).split('.')[0]
    tree = ET.parse(os.path.expanduser(file_path))
    root = tree.getroot()
    data = []
    exchangeable_names = []
    parent_names = []
    for gene in root.findall('.//gene'):
        gene_name = gene.attrib.get('name', '')
        exchangeables = gene.findall('.//exchangeables')
        if exchangeables:
            parent_names.append(gene_name)
            for exch in exchangeables:
                exch_genes = exch.findall('.//gene')
                for sub_gene in exch_genes:
                    exch_name = sub_gene.attrib.get('name', '')
                    data.append({'model': model, 'parent_profile': gene_name, 'profile': exch_name})
                    exchangeable_names.append(exch_name)
        data.append({'model': model, 'parent_profile': gene_name, 'profile': gene_name})
    df = pd.DataFrame(data)
    rm_exchangable = [x for x in exchangeable_names if x not in parent_names]
    df = df[~df['parent_profile'].isin(rm_exchangable)].reset_index(drop=True)
    return df


In [4]:
defense_finder_genes_df = pd.read_parquet('../data/interim/defense_finder_genes_genomes.pq')
defense_homolog_df = pd.read_parquet('../data/interim/defense_finder_homologs_profile_names.pq')

In [5]:
true_defense_seq_ids = defense_finder_genes_df['seq_id'].drop_duplicates()
print(len(true_defense_seq_ids))
defense_homolog_seq_ids = defense_homolog_df['seq_id'].drop_duplicates()
print(len(defense_homolog_seq_ids))

241596
2272123


In [6]:
%%time
seq_id_seq_df = pd.read_parquet('../data/interim/refseq_seq_ids.pq')
defense_seq_id_seq_df = seq_id_seq_df[seq_id_seq_df['seq_id'].isin(true_defense_seq_ids)]
defense_homolog_seq_id_seq_df = seq_id_seq_df[seq_id_seq_df['seq_id'].isin(defense_homolog_seq_ids)]
del seq_id_seq_df

CPU times: user 1min 31s, sys: 45.7 s, total: 2min 16s
Wall time: 2min 42s


In [7]:
%%time
all_clusters = pd.read_table('../data/interim/refseqs_clusters_mode1.tsv',
                             names=['cluster_id', 'seq_id'])
defense_clusters = (all_clusters[all_clusters['seq_id'].isin(true_defense_seq_ids)]
                    .reset_index(drop=True))
defense_homolog_clusters = (all_clusters[all_clusters['seq_id'].isin(defense_homolog_seq_ids)])
del all_clusters

CPU times: user 1min 41s, sys: 5.36 s, total: 1min 46s
Wall time: 1min 46s


## Get exchangable genes

In [8]:
definition_file_paths = get_xml_file_paths(os.path.expanduser('~/defense-finder-models/definitions/'))

In [9]:
len(definition_file_paths)

229

In [10]:
gene_group_list = []
for f in tqdm(definition_file_paths):
    gene_group_list.append(get_exhchangable_profiles(f))
gene_group_df = pd.concat(gene_group_list)

100%|██████████| 229/229 [00:02<00:00, 84.35it/s] 


In [11]:
all_genes = gene_group_df['profile'].drop_duplicates().to_list()
len(all_genes)

1018

In [12]:
gene_group_edges = (gene_group_df.merge(gene_group_df, how='inner', on='parent_profile')
                    .query('profile_x != profile_y')
                    [['profile_x', 'profile_y']]
                    .drop_duplicates()
                    .rename(columns={'profile_x': 'gene_name_x', 
                                     'profile_y': 'gene_name_y'}))

## Get overlapping gene clusters

In [13]:
seq_id_gene_df = (defense_finder_genes_df[['seq_id', 'gene_name']].drop_duplicates()
                  .merge(defense_clusters, how='inner', on='seq_id'))

In [14]:
gene_clust_edges = (seq_id_gene_df.merge(seq_id_gene_df, how='inner', on='cluster_id')
                    .query('gene_name_x != gene_name_y')
                    [['gene_name_x', 'gene_name_y']]
                    .drop_duplicates())

In [15]:
gene_clust_network = nx.from_pandas_edgelist(gene_clust_edges, 'gene_name_x', 'gene_name_y')

In [16]:
G = nx.from_pandas_edgelist(gene_clust_edges, 'gene_name_x', 'gene_name_y')
nodes = list(G.nodes)
for gene_name in all_genes:
    if gene_name not in nodes:
        G.add_node(gene_name)
print('# Edges:', G.number_of_edges())
print('# Nodes:', G.number_of_nodes())
components = nx.connected_components(G)
component_list = []
for i, comp in enumerate(components):
    subgraph = G.subgraph(comp)
    max_degree_node = max(subgraph, key=lambda node: subgraph.degree(node))
    for name in comp:
        component_list.append({'name': name, 
                               'component': i, 
                               'component_name': max_degree_node})
print('# Components:', i)
gene_component_df = pd.DataFrame(component_list)
gene_component_df['component_name'].value_counts().head(20)

# Edges: 444
# Nodes: 1018
# Components: 703


component_name
Lamassu-Fam__LmuB_SMC_Cap4_nuclease_II         99
Cas__csm3gr7_III_IV                            16
Cas__cas5_I_9                                  13
Cas__cas1_I_II_III_IV_V_VI_1                   13
RosmerTA__RmrA_2662548665                      12
Lamassu-Fam__LmuA_effector_Cap4_nuclease_II     9
Cas__cas6_I_II_III_IV_V_VI_7                    8
Cas__cas2_I_II_III_IV_V_VI_3                    6
Cas__cas7_I_15                                  5
Cas__cas8b2_I-B_2                               5
Cas__csm6_III_3                                 5
Cas__cse2gr11_I-E_8                             5
Cas__cas10_III_1                                5
Septu__PtuB                                     4
Cas__cmr5gr11_III-B_18                          4
Cas__csm2gr11_III-A_15                          4
Cas__cas2_I-E_1                                 4
Thoeris_I__ThsA_new_grand                       4
Cas__cas5u_I-G_1                                4
CBASS__Cyclase_II                  

## Cluster genes

In [18]:
cat_edges = (pd.concat([gene_group_edges, gene_clust_edges])
             .drop_duplicates())
G = nx.from_pandas_edgelist(cat_edges, 'gene_name_x', 'gene_name_y')
nodes = list(G.nodes)
for gene_name in all_genes:
    if gene_name not in nodes:
        G.add_node(gene_name)
print('# Edges:', G.number_of_edges())
print('# Nodes:', G.number_of_nodes())
components = nx.connected_components(G)
component_list = []
for i, comp in enumerate(components):
    subgraph = G.subgraph(comp)
    max_degree_node = max(subgraph, key=lambda node: subgraph.degree(node))
    for name in comp:
        component_list.append({'name': name, 
                               'component': i, 
                               'component_name': max_degree_node})
print('# Components:', i)
gene_component_df = pd.DataFrame(component_list)
gene_component_df['component_name'].value_counts().head(20)

# Edges: 8524
# Nodes: 1018
# Components: 278


component_name
Lamassu-Fam__LmuB_SMC_Cap4_nuclease_II    168
Cas__cas8b2_I-B_2                          99
Cas__csm3gr7_III-D_3                       50
Menshen__NsnC_2525400098                   45
Cas__csa5gr11_III-B_1                      37
Cas__cas5_I-A_4                            36
RosmerTA__RmrT_2600853143                  33
Cas__csx1_III_17                           27
Cas__cas6_I_II_III_IV_V_VI_10              27
Cas__cas12f1_V-F_2                         26
Cas__cas7_I-B_1                            22
Cas__csx22_III-A_1                         19
Cas__csx19_III-D_18                        19
RosmerTA__RmrA_2603008502                  16
Cas__cas10_III-C_3                         14
Cas__cas2_I-E_1                            13
Cas__csn2_II-A_5                           13
Cas__cas10d_I-D_1                           8
Lamassu-Fam__LmuC_acc_Sir2                  8
Cas__cas9_II-B_2                            6
Name: count, dtype: int64

## Split clusters

In [19]:
test_size = 0.15
val_size = 0.05
unique_components = list(gene_component_df['component_name'].unique())
train_val_components, test_components = train_test_split(unique_components, test_size=test_size, random_state=7)
train_components, val_components = train_test_split(train_val_components, test_size=val_size/(1-test_size), random_state=7)

In [20]:
len(val_components)/len(unique_components)

0.05017921146953405

In [21]:
def assign_profile(component, test_components=test_components, train_components=train_components, val_components=val_components):
    if component in val_components:
        return 'val'
    elif component in test_components:
        return 'test'
    else:
        return 'train'

In [22]:
gene_component_df['split'] = gene_component_df['component_name'].apply(assign_profile)

In [23]:
gene_component_df['split'].value_counts()

split
train    882
test      99
val       37
Name: count, dtype: int64

In [24]:
gene_component_df.groupby('split')['component'].nunique()

split
test      42
train    223
val       14
Name: component, dtype: int64

In [25]:
pd.set_option('display.max_rows', 200)
gene_component_df[gene_component_df['split'] == 'test']

Unnamed: 0,name,component,component_name,split
213,Wadjet__JetA_II,2,Wadjet__JetA_II,test
214,Wadjet__JetA_I,2,Wadjet__JetA_II,test
215,Wadjet__JetA_III,2,Wadjet__JetA_II,test
216,Wadjet__JetB_III,3,Wadjet__JetB_III,test
217,Wadjet__JetB_I,3,Wadjet__JetB_III,test
218,Wadjet__JetB_II,3,Wadjet__JetB_III,test
371,Cas__csx10gr5_III-D_5,22,Cas__csm3gr7_III-D_3,test
372,Cas__csm3gr7_III_IV,22,Cas__csm3gr7_III-D_3,test
373,Cas__cmr6gr7_III-B_2,22,Cas__csm3gr7_III-D_3,test
374,Cas__csm3gr7_III-D_5,22,Cas__csm3gr7_III-D_3,test


## Split seq ids into train/test/val

In [73]:
max_seqs_per_cluster = 5
filtered_defense_finder_genes = (defense_finder_genes_df
                                 .merge(defense_clusters, how='inner', 
                                        on='seq_id')
                                 .merge(gene_component_df
                                        .rename(columns={'name': 'gene_name'}), 
                                        on='gene_name')
                                 .groupby('seq_id')
                                 .sample(n=1, random_state=7)
                                 .sample(frac=1, random_state=7)  # shuffle
                                 .groupby('cluster_id')
                                 .head(max_seqs_per_cluster)
                                 .merge(defense_seq_id_seq_df, how='inner', 
                                        on='seq_id'))

In [60]:
assert filtered_defense_finder_genes['seq_id'].value_counts().max() == 1
assert filtered_defense_finder_genes.groupby('cluster_id')['split'].nunique().max() == 1

In [63]:
from IPython.display import display

In [65]:
for split, split_df in filtered_defense_finder_genes.groupby('split'):
    print(split)
    display(split_df['gene_name'].value_counts().head(20))

test


gene_name
Wadjet__JetB_I                720
Wadjet__JetA_I                594
ShosTA__ShosT                 400
Rst_PARIS__DUF4435            376
Cas__cmr1gr7_III-B_1          318
Cas__csm3gr7_III-D_2          318
ShosTA__ShosA                 302
Wadjet__JetB_II               301
Cas__cmr3gr5_III-B_III-C_5    293
DRT_2__drt2                   279
Cas__cmr6gr7_III-B_3          252
Cas__csm3gr7_III-D_3          241
Wadjet__JetA_II               238
PD-Lambda-5__PD-Lambda-5_A    230
Cas__csm5gr7_III-A_2          224
PsyrTA__PsyrA                 220
Cas__cmr6gr7_III-B_2          220
Wadjet__JetA_III              217
SpbK__SpbK                    199
Wadjet__JetB_III              196
Name: count, dtype: int64

train


gene_name
RM_Type_IV__Type_IV_REases                7006
RM_Type_II__Type_II_REases                4437
RM__Type_I_S                              4199
RM_Type_II__Type_II_MTases                3110
RM_Type_IIG__Type_IIG                     2800
AbiD__AbiD                                1318
RM__Type_I_REases                         1280
Septu__PtuA                               1252
RloC__RloC                                1162
Septu__PtuB                               1155
RM__Type_I_MTases                         1103
Shedu__SduA                                940
Cas__cas3_I_5                              936
Gabija__GajA                               898
CBASS__Cyclase_II                          870
Mokosh_TypeI__MkoA                         857
Cas__cas2_I_II_III_IV_V_VI_3               815
Gabija__GajB_2                             807
Lamassu-Fam__LmuB_SMC_Cap4_nuclease_II     784
RM_Type_III__Type_III_REases               734
Name: count, dtype: int64

val


gene_name
AbiEii__AbiEii         809
SoFic__SoFic           573
Cas__cas10_III-D_3     381
Cas__cas10_III_1       343
Kiwa__KwaA             313
SanaTA__SanaT          306
Cas__cas10_III-A_1     303
SspBCDE__SspB          204
SspBCDE__SspC          200
PD-T4-3__PD-T4-3       137
Druantia_III__DruH     132
BREX__brxF             110
Cas__cas10_III-B_12     74
Cas__cas10_III-C_3      58
SspBCDE__SspD           57
Cas__cas10_III_6        56
BREX__brxP              54
Bunzi__BnzB             52
Cas__csb3_I-G_1         44
Cas__cas10d_I-D_3       42
Name: count, dtype: int64

In [2]:
import pandas as pd
filtered_defense_finder_genes = pd.read_parquet('../data/interim/defense_finder_model_seqs.pq')

In [3]:
filtered_defense_finder_genes['split'].value_counts()

split
train    99642
test     10128
val       4434
Name: count, dtype: int64

## Outputs

In [26]:
gene_component_df.to_csv('../data/interim/defense_finder_gene_split_df.csv', index=False)

In [66]:
filtered_defense_finder_genes.to_parquet('../data/interim/defense_finder_model_seqs.pq', index=False)
