diff --git a/proteinflow/__init__.py b/proteinflow/__init__.py index 68c48b7..2594065 100644 --- a/proteinflow/__init__.py +++ b/proteinflow/__init__.py @@ -339,8 +339,8 @@ def generate_data( the sequence similarity threshold for excluding chains exclude_clusters : bool, default False if `True`, exclude clusters that contain chains similar to chains in the `exclude_chains` list - exclude_based_on_cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional - if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters + exclude_based_on_cdr : list, optional + if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters (choose from "H1", "H2", "H3", "L1", "L2", "L3") load_ligands : bool, default False if `True`, load ligands from the PDB files exclude_chains_without_ligands : bool, default False @@ -506,8 +506,8 @@ def split_data( the sequence similarity threshold for excluding chains exclude_clusters : bool, default False if `True`, exclude clusters that contain chains similar to chains in the `exclude_chains` list - exclude_based_on_cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional - if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters + exclude_based_on_cdr : list, optional + if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters (choose from "H1", "H2", "H3", "L1", "L2", "L3") random_seed : int, default 42 random seed for reproducibility (set to `None` to use a random seed) exclude_chains_without_ligands : bool, default False diff --git a/proteinflow/cli.py b/proteinflow/cli.py index f78daae..425d5f1 100644 --- a/proteinflow/cli.py +++ b/proteinflow/cli.py @@ -193,6 +193,7 @@ def download(**kwargs): @click.option( "--exclude_based_on_cdr", type=click.Choice(["L1", "L2", "L3", "H1", "H2", "H3"]), + multiple=True, help="if given and exclude_clusters is true + the dataset is SAbDab, exclude files based on only the given CDR clusters", ) @click.option( @@ -302,6 +303,7 @@ def generate(**kwargs): @click.option( "--exclude_based_on_cdr", type=click.Choice(["L1", "L2", "L3", "H1", "H2", "H3"]), + multiple=True, help="if given and exclude_clusters is true + the dataset is SAbDab, exclude files based on only the given CDR clusters", ) @click.option( diff --git a/proteinflow/split/__init__.py b/proteinflow/split/__init__.py index a8545b0..8892066 100644 --- a/proteinflow/split/__init__.py +++ b/proteinflow/split/__init__.py @@ -1395,11 +1395,7 @@ def _exclude_biounits( Since `proteinflow` assumes splitting at the level of biounits, when using `exclude_clusters` the dictionaries are adjusted to exclude full biounits that the newly excluded chains / CDRs are part of. This is done by moving the full biounits to the excluded set - and removing the rest of the clusters they belong to from training / test / validation. - - For example, if for antibody Ab CDR H1 is in cluster A, CDR H2 is in cluster B and CDR H3 is in cluster C, and cluster C is in - the excluded set, then clusters A and B are removed from the training / test / validation sets and added to the excluded set with only the Ab CDRs. - The files for the other biounits that are part of the excluded clusters are also moved to the excluded set but not added to the split dictionary. + and removing the corresponding entries from all training / test / validation clusters. """ set_to_exclude = set(excluded_biounits) @@ -1416,8 +1412,9 @@ def _exclude_biounits( for i, chain in enumerate(clusters_dict[cluster]): if chain[0] in set_to_exclude: if exclude_clusters: - if exclude_based_on_cdr is not None and cluster.endswith( - exclude_based_on_cdr + if ( + exclude_based_on_cdr is not None + and cluster.split("__")[-1] in exclude_based_on_cdr ): exclude_whole_cluster = True elif exclude_based_on_cdr is None: @@ -1444,19 +1441,15 @@ def _exclude_biounits( test_clusters_dict, ]: for cluster in list(clusters_dict.keys()): - excluded_biounit_in_cluster = False - if cluster in excluded_clusters_dict: - excluded_biounit_in_cluster = True + to_exclude = [] for i, (file, chain) in enumerate(clusters_dict[cluster]): if file in excluded_biounits: - excluded_biounit_in_cluster = True - excluded_clusters_dict[cluster].append((file, chain)) - excluded_biounits.add(file) - # remove cluster from training / validation / test set if at least one biounit in the cluster is excluded - if exclude_clusters and excluded_biounit_in_cluster: - chains = clusters_dict.pop(cluster) - excluded_biounits.update([x[0] for x in chains]) - excluded_clusters_dict = {k: list(v) for k, v in excluded_clusters_dict.items()} + to_exclude.append(i) + clusters_dict[cluster] = [ + x for i, x in enumerate(clusters_dict[cluster]) if i not in to_exclude + ] + if len(clusters_dict[cluster]) == 0: + clusters_dict.pop(cluster) return ( train_clusters_dict, valid_clusters_dict, @@ -1482,7 +1475,7 @@ def _split_data( A list of files to exclude from the dataset exclude_clusters : bool, default False If True, exclude all files in a cluster if at least one file in the cluster is in `excluded_files` - exclude_based_on_cdr : str, optional + exclude_based_on_cdr : list, optional If not `None`, exclude all files in a cluster if the cluster name does not end with `exclude_based_on_cdr` """