STREME utilizes the ZOOPS (zero or one occurrence per sequence) model -> for sensitivity experiments, this enables
STREME to work with diluted data sets where only a fraction of input sequences actually contains the desired motif(s).

Our loss quickly begins to struggle with diluted input data. Try to find a loss function that does not penalize input
sequences without any good motif.

In [1]:
import logging
from pathlib import Path
import sys

from modules import ModelDataSet, plotting, SequenceRepresentation

logger = logging.getLogger()
logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s", 
                    encoding='utf-8', level=logging.DEBUG)

2025-02-21 13:53:26.511307: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-21 13:53:26.536872: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-21 13:53:26.783493: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-21 13:53:27.731350: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# set some args
SEED = 42
MAXSEQS = 100
mode = 'DNA'
phylo_t = 0

In [3]:
# load data

fasta = Path('../data/STREME_benchmark_data/diluted_ds_primary/1.00/wgEncodeAwgTfbsBroadK562CtcfUniPk.narrowPeak.centered100bp.1.00pure.fasta')
assert fasta.is_file(), f"[ERROR] >>> Input file '{fasta}' not found"


if mode == 'DNA':
    datamode = ModelDataSet.DataMode.DNA
else:
    if phylo_t == 0:
        datamode = ModelDataSet.DataMode.Translated
    else:
        logging.warning("[main] Phylo_t is not 0.0 and data mode is set to 'Translated'. Setting data mode to " \
                        + "'Translated_noStop', using only 20-letter aa alphabet without stop codon.")
        datamode = ModelDataSet.DataMode.Translated_noStop
        
logging.info(f"[main] Data mode: {datamode}")

# === LOAD DATA ===

logging.info("[main] Loading sequences")
sequences = SequenceRepresentation.loadFasta_agnostic(fasta)
if MAXSEQS is not None and MAXSEQS < len(sequences):
    logging.info(f"[main] Limiting data to {MAXSEQS}/{len(sequences)} sequences from the input fasta")
    sequences = sequences[:MAXSEQS]

genomes = [SequenceRepresentation.Genome([s]) for s in sequences]

2025-02-21 13:53:41,524 INFO: [main] Data mode: DataMode.DNA
2025-02-21 13:53:41,524 INFO: [main] Loading sequences
2025-02-21 13:53:42,305 INFO: [main] Limiting data to 100/51992 sequences from the input fasta


In [4]:
genomes[:5]

[Genome from chr22:22901677-22901578 with 1 chromosomes: 0 (1 sequence[s]),
 Genome from chr22:20918741-20918642 with 1 chromosomes: 0 (1 sequence[s]),
 Genome from chr22:23285067-23284968 with 1 chromosomes: 0 (1 sequence[s]),
 Genome from chr22:21239921-21239822 with 1 chromosomes: 0 (1 sequence[s]),
 Genome from chr22:23624139-23624040 with 1 chromosomes: 0 (1 sequence[s])]

In [None]:
data = ModelDataSet.ModelDataSet(genomes, datamode,
                                 tile_size=args.tile_size, tiles_per_X=args.tiles_per_X,
                                 batch_size=args.batch_size, prefetch=args.prefetch)
trainsetup = ProfileFindingSetup.ProfileFindingTrainingSetup(data,
                                                             U = args.U, k = args.k, 
                                                             midK = args.midK, s = args.s, 
                                                             epochs = 350, gamma = args.gamma, l2 = args.l2,
                                                             match_score_factor = args.match_score_factor,
                                                             learning_rate = args.learning_rate,
                                                             lr_patience = args.lr_patience,
                                                             lr_factor = args.lr_factor,
                                                             rho = args.rho, sigma = args.sigma,
                                                             profile_plateau = args.profile_plateau,
                                                             profile_plateau_dev = args.profile_plateau_dev,
                                                             n_best_profiles = args.n_best_profiles,
                                                             phylo_t = args.phylo_t)
                                                             
trainsetup.initializeProfiles_kmers(enforceU=args.enforceU, 
                                    minU=args.minU, minOcc=args.minOcc,
                                    overlapTilesize=args.overlapTilesize,
                                    plot=False)

logging.info(f"[main] Start training and evaluation")
training.trainAndEvaluate(fasta.name, trainsetup, evaluator, 
                          outdir,  # type: ignore
                          do_not_train=args.do_not_train,
                          rand_seed=SEED) # type: ignore