<a href="https://colab.research.google.com/github/alex-tianhuang/idrfeatlib/blob/main/notebooks/FeatureMimic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Setup environment by downloading and installing the idrfeatlib repo.
#
# You only need to run this once.
!file idrfeatlib/ >/dev/null && rm -rf idrfeatlib
!git clone https://github.com/alex-tianhuang/idrfeatlib --quiet
%pip install idrfeatlib/

Processing ./idrfeatlib
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: idrfeatlib
  Building wheel for idrfeatlib (setup.py) ... [?25l[?25hdone
  Created wheel for idrfeatlib: filename=idrfeatlib-0.0.0-py3-none-any.whl size=31311 sha256=d818c5f20b8dc9f239f2daa474f5c1edca26975c4ec16c8b9b010e078e577e2b
  Stored in directory: /tmp/pip-ephem-wheel-cache-lilw9kpa/wheels/fd/73/2d/1c5cfad6d18b968112550482fad0b8617d5ba1997a480307db
Successfully built idrfeatlib
Installing collected packages: idrfeatlib
  Attempting uninstall: idrfeatlib
    Found existing installation: idrfeatlib 0.0.0
    Uninstalling idrfeatlib-0.0.0:
      Successfully uninstalled idrfeatlib-0.0.0
Successfully installed idrfeatlib-0.0.0


In [None]:
# Define/prepare design scripts.
#
# Functions `design_all` and `main` have been adapted from
# `idrfeatlib/scripts/feature-mimic.py`.

def design_all(tasks):
    import tqdm
    import sys
    import csv
    from idrfeatlib.featurizer import Featurizer
    acceptable_errors=(ArithmeticError, ValueError, KeyError)
    for query, target, protid, regionid, designid, seed, designer, colnames, args in tqdm.tqdm(tasks, desc="designing sequences"):
        featurizer = Featurizer(designer.featurizer)
        try:
            designer.metric.origin, _ = featurizer.featurize(target, acceptable_errors=())
        except acceptable_errors as e:
            print("cannot featurize target (protid=%s,regionid=%s): %s" % (protid, regionid, e), file=sys.stderr)
            return
        designer.rng.seed(seed)
        AMINOACIDS = list("ACDEFGHIKLMNPQRSTVWY")
        MAX_RETRIES = 15
        if query is None:
            for _ in range(MAX_RETRIES):
                try_query = "".join(designer.rng.choice(AMINOACIDS) for _ in range(len(target)))
                try:
                    featurizer.featurize(try_query, acceptable_errors=())
                except acceptable_errors:
                    continue
                query = try_query
                break
            else:
                print("cannot generate query with all features (protid=%s,regionid=%s,length=%d,seed=%d)" % (protid, regionid, len(target), seed), file=sys.stderr)
                return
        try:
            if args.keep_trajectory:
                save = []
                for progress in designer.design_loop(query, acceptable_errors=acceptable_errors):
                    save.append(progress)
            else:
                for progress in designer.design_loop(query, acceptable_errors=acceptable_errors):
                    progress.pop("Iteration")
                save = [progress]
        except acceptable_errors:
            print("query did not have all features (protid=%s,regionid=%s,seed=%d)" % (protid, regionid, seed), file=sys.stderr)
            return
        with open(args.output_file, "a") as file:
            writer = csv.DictWriter(file, colnames)
            for row in save:
                row["ProteinID"] = protid
                row["DesignID"] = designid
                if regionid is not None:
                    row["RegionID"] = regionid
                if args.save_seed:
                    row["Seed"] = seed
                writer.writerow(row)


def main(args):
    from idrfeatlib import FeatureVector
    from idrfeatlib.utils import read_nested_csv, iter_nested, read_fasta, read_regions_csv
    from idrfeatlib.featurizer import compile_featurizer
    from idrfeatlib.native import compile_native_featurizer
    from idrfeatlib.metric import Metric
    from idrfeatlib.designer import FeatureDesigner, GreedyFeatureDesigner
    import os
    import json
    import sys
    import random
    import csv
    for label, feature_vector in FeatureVector.load(args.feature_weights_file):
        if label == args.weights_feature_vector:
            metric = Metric(feature_vector, feature_vector)
            break
    else:
        raise RuntimeError("could not find feature vector `%s` in %s" % (args.weights_feature_vector, args.feature_weights_file))

    if args.feature_file:
        with open(args.feature_file, "r") as file:
            config = json.load(file)
        featurizer, errors = compile_featurizer(config)
    else:
        featurizer, errors = compile_native_featurizer()
    for featname, error in errors.items():
        print("error compiling `%s`: %s" % (featname, error), file=sys.stderr)
    if featurizer.keys() != metric.weights.as_dict.keys():
        raise RuntimeError("featurizer and metric feature vector `%s` have different features" % args.weights_feature_vector)
    LENGTH_THRESHOLD = 30
    SEED_COLNAME = "Seed"
    MAX_SEED = 2 ** 64
    CONVERGENCE_THRESHOLD = 1e-4
    GOOD_MOVES_THRESHOLD = 3
    DECENT_MOVES_THRESHOLD = 5
    QUERY_COLNAME = "Sequence"
    if args.greedy:
        designer = GreedyFeatureDesigner(featurizer, metric, convergence_threshold=CONVERGENCE_THRESHOLD)
        designer.rng = random.Random() # type: ignore
    else:
        designer = FeatureDesigner(featurizer, metric, covergence_threshold=CONVERGENCE_THRESHOLD, good_moves_threshold=GOOD_MOVES_THRESHOLD, decent_moves_threshold=DECENT_MOVES_THRESHOLD, rng=random.Random())

    fa = dict(read_fasta(args.input_sequences))
    tasks = []
    colnames = ["ProteinID"]
    featnames = featurizer.keys()

    if args.input_regions is None:
        colnames += ["DesignID", "Sequence", "Time"]
        if args.save_seed:
            colnames.append("Seed")
        if args.keep_trajectory:
            colnames.append("Iteration")
        colnames += featnames
        fa = {protid: seq for protid, seq in fa.items() if len(seq) >= LENGTH_THRESHOLD}
        if args.query_file is None:
            if args.seeds_file is None:
                n_random = args.n_random or 1
                rng = random.Random()
                seeds = {protid: [rng.randint(0, MAX_SEED) for _ in range(n_random)] for protid in fa.keys()}
            else:
                seeds = read_nested_csv(args.seeds_file, 1, group_multiple=True)
                seeds = {protid: [row[SEED_COLNAME] for row in rows] for protid, rows in seeds.items() if protid in fa}
            for protid, prot_seeds in seeds.items():
                if (entry := fa.get(protid)) is None:
                    continue
                target = entry
                assert isinstance(target, str)
                for counter, seed in enumerate(prot_seeds):
                    seed = int(seed)
                    design_id = args.design_id.format(counter=counter, seed=seed, proteinid=protid)
                    tasks.append(
                        (None, target, protid, None, design_id, seed, designer, colnames, args)
                    )
        else:
            qries = read_nested_csv(args.query_file, 2)
            for protid, designid, row in iter_nested(qries, 2):
                if (entry := fa.get(protid)) is None:
                    continue
                target = entry
                assert isinstance(target, str)
                query = row[QUERY_COLNAME]
                tasks.append(
                    (query, target, protid, None, designid, None, designer, colnames, args)
                )
    else:
        colnames += ["RegionID", "DesignID", "Sequence", "Time"]
        if args.save_seed:
            colnames.append("Seed")
        if args.keep_trajectory:
            colnames.append("Iteration")
        colnames += featnames
        regions = read_regions_csv(args.input_regions)
        regions = {protid: ret for protid, entry in regions.items() if (ret := {regionid: (start, stop) for regionid, (start, stop) in entry.items() if stop - start >= LENGTH_THRESHOLD})}
        if args.query_file is None:
            if args.seeds_file is None:
                n_random = args.n_random or 1
                rng = random.Random()
                seeds = {protid: {regionid: [rng.randint(0, MAX_SEED) for _ in range(n_random)] for regionid in entry.keys()} for protid, entry in regions.items() if protid in fa}

            else:
                seeds = read_nested_csv(args.seeds_file, 2, group_multiple=True)
                seeds = {protid: {regionid: [row[SEED_COLNAME] for row in rows] for regionid, rows in entry.items()} for protid, entry in seeds.items() if protid in fa}
            for protid, regionid, region_seeds in iter_nested(seeds, 2):
                if (entry := fa.get(protid)) is None:
                    continue
                target_whole = entry
                assert isinstance(target_whole, str)
                start, stop = regions[protid][regionid]
                target = target_whole[start:stop]
                if len(target) != stop - start:
                    print("invalid region `%s` for protein `%s` (start=%d,stop=%d,seqlen=%s)" % (regionid, protid, start, stop, len(target_whole)))
                    continue
                for counter, seed in enumerate(region_seeds):
                    seed = int(seed)
                    design_id = args.design_id.format(counter=counter, seed=seed, proteinid=protid, regionid=regionid, start=start, stop=stop)
                    tasks.append(
                        (None, target, protid, regionid, design_id, seed, designer, colnames, args)
                    )
        else:
            qries = read_nested_csv(args.query_file, 3)
            for protid, regionid, designid, row in iter_nested(qries, 3):
                if (entry := fa.get(protid)) is None:
                    continue
                target_whole = entry
                assert isinstance(target_whole, str)
                start, stop = regions[protid][regionid]
                target = target_whole[start:stop]
                if len(target) != stop - start:
                    print("invalid region `%s` for protein `%s` (start=%d,stop=%d,seqlen=%s)" % (regionid, protid, start, stop, len(target_whole)))
                    continue
                query = row[QUERY_COLNAME]
                tasks.append(
                    (query, target, protid, regionid, designid, None, designer, colnames, args)
                )
    if not os.path.exists(args.output_file):
        with open(args.output_file, "w") as file:
            csv.DictWriter(file, colnames).writeheader()
        design_all(tasks)
        return
    with open(args.output_file, "r") as file:
        reader = csv.DictReader(file)
        if reader.fieldnames is None:
            with open(args.output_file, "w") as file:
                csv.DictWriter(file, colnames).writeheader()
            design_all(tasks)
            return
        if reader.fieldnames != colnames:
            BOLD_RED = "\033[1;31m"
            NORMAL = "\033[0m"
            print(BOLD_RED + "cannot overwrite file `%s` with different column names (shown below):" % args.output_file + NORMAL, file=sys.stderr)
            print(",".join(colnames), file=sys.stderr)
            sys.exit(1)
        if args.keep_trajectory:
            checkpoint = [
                row for row in reader if row.pop("Iteration") == "END"
            ]
        else:
            checkpoint = list(reader)
    checkpoint_keys = ["ProteinID", "DesignID"]
    keys = [2, 4]
    if args.input_regions is not None:
        checkpoint_keys.insert(1, "RegionID")
        keys.insert(1, 3)
    checkpoint = {
        tuple(row.pop(key) for key in checkpoint_keys): row for row in checkpoint
    }
    tasks_not_done = []
    for task in tasks:
        checkpoint_key = tuple(task[key] for key in keys)
        if checkpoint_key not in checkpoint:
            tasks_not_done.append(task)
    design_all(tasks_not_done)


def display_csv(output_name):
    """
    Show the table in the notebook.

    I assume colab will forever keep pandas as available by default.
    """
    from IPython.display import display
    import pandas as pd

    df = pd.read_csv(output_name)

    print()
    print("Showing output below")
    print("--------------------")
    print()
    display(df)
    print()

def run_colab_wrapper(output_name):
    import argparse
    import os
    from google.colab import files

    args = argparse.Namespace()

    args.input_sequences = 'input_sequences.fasta'
    goto_upload = True
    if os.path.exists(args.input_sequences):
        choice = input(f"The file {args.input_sequences} already exists. Would you like to overwrite it? (y/n)")
        if choice.lower() != 'y':
            goto_upload = False
    if goto_upload:
        files.upload_file(args.input_sequences)

    choice = input("Would you like to upload a file containing region boundaries? (y/n)")
    if choice.lower() == 'y':
        args.input_regions = 'input_regions.csv'
        files.upload_file(args.input_regions)
    else:
        args.input_regions = None

    args.feature_file = 'feature_config.json'
    goto_upload = True
    if os.path.exists(args.feature_file):
        choice = input(f"The file {args.feature_file} already exists. Would you like to overwrite it? (y/n)")
        if choice.lower() != 'y':
            goto_upload = False
            print(f"Ignoring {args.feature_file}")
    else:
        choice = input("Would you like to upload a file containing feature configuration? (y/n)")
        if choice.lower() != 'y':
            goto_upload = False
    if goto_upload:
        files.upload_file(args.feature_file)
    else:
        args.feature_file = None

    args.feature_weights_file = 'feature_weights.csv'
    goto_upload = True
    if os.path.exists(args.feature_weights_file):
        choice = input(f"The file {args.feature_weights_file} already exists. Would you like to overwrite it? (y/n)")
        if choice.lower() != 'y':
            goto_upload = False
    if goto_upload:
        files.upload_file(args.feature_weights_file)

    # Seed arguments (simplified for Colab)
    print("Choose how to generate query sequences:")
    print("1. Randomly sample sequences (provide number)")
    print("2. Use seeds from a file (provide file)")
    print("3. Provide query sequences in a file")
    choice = input("Enter choice (1, 2, or 3): ")

    args.n_random = None
    args.seeds_file = None
    args.query_file = None

    if choice == '1':
        try:
            args.n_random = int(
                input("Enter the number of random sequences per region/protein: ") \
                    .strip()
            )
        except ValueError:
            print("Invalid number. Defaulting to 1 random sequence.")
            args.n_random = 1
    elif choice == '2':
        args.seeds_file = 'seeds.csv'
        print(f"Please upload the seeds file ({args.seeds_file}):")
        files.upload_file(args.seeds_file)
    elif choice == '3':
        args.query_file = 'query_sequences.csv'
        print(f"Please upload the query sequences file ({args.query_file}):")
        files.upload_file(args.query_file)
    else:
        print("Invalid choice. Defaulting to 1 random sequence per region/protein.")
        args.n_random = 1

    choice = input("Would you like to adjust other parameters? Press `n` for defaults (y/n)")
    if choice.lower() == 'y':
        args.weights_feature_vector = input(f"Enter the label for the weights feature vector (default: 'weights'): ") or "weights"
        args.keep_trajectory = input("Keep full trajectory? (y/n): ").lower() == 'y'
        args.save_seed = input("Save seed in output? (y/n): ").lower() == 'y'
        args.design_id = input("Enter design ID format string (default: '{counter}'): ") or "{counter}"
        args.greedy = input("Use greedy optimization? (y/n): ").lower() == 'y'
    else:
        args.weights_feature_vector = "weights"
        args.keep_trajectory = False
        args.save_seed = False
        args.design_id = "{counter}"
        args.greedy = False

    args.output_file = output_name
    if os.path.exists(args.output_file):
        overwrite = input(f"Output file '{args.output_file}' already exists. Overwrite (w) or append (a)? (w/a): ").lower()
        if overwrite == 'w':
            os.remove(args.output_file)
            print(f"Output file '{args.output_file}' overwritten.")
        else:
            print("Output file exists and not overwriting. Will append to it.")


    print("Starting design task...")
    main(args)
    print("Design task finished.")
    display_csv(args.output_file)
    print(f"Downloading output file to {args.output_file}")
    files.download(args.output_file)


In [None]:
# This cell will:
#
# 1. Ask for lots of input parameters:
#    - (required) a FASTA file of sequences
#       - (optional) a CSV file of region boundaries, see extra section below
#       - (optional) a JSON file for defining custom features, see extra section below
#    - (required) a CSV file of feature weights
#                 either for the default features or the user-entered features
#    - (optional) five other parameters ...
# 2. Design sequences and output them iteratively to `output_features.csv`
# 3. Ask to download the output file, called `output_features.csv`
#
# Run this cell after running the above cells as many times as you would like.

run_colab_wrapper("output_features.csv")
print("Done!")


The file input_sequences.fasta already exists. Would you like to overwrite it? (y/n)n
Would you like to upload a file containing region boundaries? (y/n)n
Would you like to upload a file containing feature configuration? (y/n)n
The file feature_weights.csv already exists. Would you like to overwrite it? (y/n)n
Choose how to generate query sequences:
1. Randomly sample sequences (provide number)
2. Use seeds from a file (provide file)
3. Provide query sequences in a file
Enter choice (1, 2, or 3): 1
Enter the number of random sequences per region/protein: 10
Would you like to adjust other parameters? Press `n` for defaults (y/n)n
Starting design task...


ValueError: could not convert string to float: 'PAK1'