# Feature Filtering

In [7]:
import os
from collections import defaultdict
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.SeqFeature import SeqFeature, FeatureLocation


In [8]:
# ---------------------------------------------------
# 1. Reference Parsing
# ---------------------------------------------------
from collections import defaultdict
from Bio import SeqIO

def parse_reference_fasta(ref_fasta):
    """
    Read reference FASTA, parse known feature segments, return dict of feature lengths.
    Assumes the reference or an accompanying table has boundaries or lengths of each feature.
    """
    feature_lengths = {}
  
    for record in SeqIO.parse(ref_fasta, "fasta"):
        feature_lengths[record.id] = len(record.seq)
  
    return feature_lengths

ref_fasta_path = "00-Data/references/ref_28s_features.fasta"
ref_lengths = parse_reference_fasta(ref_fasta_path)

# ---------------------------------------------------
# 2. Parse Read Reference Positions
# ---------------------------------------------------
# Store as (read_id, type) -> [(start, end, strand), ...]
read_ref_positions = defaultdict(list)
alignment_file = "00-Data/repeatmaskered/sample-cutadapt_sup-filtered.fasta.out.xm"

with open(alignment_file) as infile:
    for line_num, line in enumerate(infile, 1):
        # Split the line on whitespace
        parts = line.strip().split()
        if len(parts) < 10:
            # Skip lines that don't have enough columns
            print(f"Skipping presumed header line {line_num}: {line}")
            continue  

        # Extract the relevant information
        read_id = parts[4]
        strand = parts[8]
        feat_type = parts[9].split("#")[0]

        # Extract start and end positions based on strand
        try:
            if strand == "+":
                start = int(parts[10])
                end = int(parts[11])
            elif strand == "C":
                start = int(parts[12])
                end = int(parts[11])
            else:
                # Skip if strand information is invalid
                print(f"Skipping line {line_num}: invalid strand {strand}")
                continue  
        except (IndexError, ValueError):
            # Skip lines with invalid position data
            print(f"Skipping line {line_num}: invalid positions")
            continue  

        read_ref_positions[(read_id, feat_type)].append((start, end, strand))

# ---------------------------------------------------
# 3. Helper Function to Get Best Flanking Pair with Combined Length Checks
# ---------------------------------------------------
def get_best_flanking_pair(record, feat_3_names, feat_5_names):
    """
    Return (best_3_feat, best_5_feat) or None if no valid pair is found.
    'Best' means same-strand pair with the longest distance between them.
    """
    three_features = []
    five_features = []

    # Gather 3end & 5end features
    for feat in record.features:
        if feat.type in feat_3_names:
            three_features.append(feat)
        elif feat.type in feat_5_names:
            five_features.append(feat)
    
    # No 3 or 5 features -> no valid pair
    if not three_features or not five_features:
        return None

    # --------------------------------------------
    # **New Addition**: Calculate sum of lengths per direction
    # --------------------------------------------
    # Define directions as (feature type, strand)
    direction_sums = defaultdict(int)
    for feat in three_features + five_features:
        direction = (feat.type, feat.strand)
        direction_sums[direction] += len(feat.location)

    # Identify valid directions with sum >= 100 nt
    valid_directions = set()
    for direction, total_length in direction_sums.items():
        if total_length >= 100:
            valid_directions.add(direction)

    # Filter features to keep only those in valid directions
    filtered_three_features = [feat for feat in three_features if (feat.type, feat.strand) in valid_directions]
    filtered_five_features = [feat for feat in five_features if (feat.type, feat.strand) in valid_directions]

    # If no features remain after filtering, return None
    if not filtered_three_features or not filtered_five_features:
        return None

    # Proceed to find the best pair from filtered features
    best_pair = None
    max_distance = -1

    # Find pair with same strand & the greatest distance
    for tfeat in filtered_three_features:
        for ffeat in filtered_five_features:
            if tfeat.strand == ffeat.strand and tfeat.strand in (1, -1):
                left_coord = min(tfeat.location.start, ffeat.location.start)
                right_coord = max(tfeat.location.end, ffeat.location.end)
                distance = right_coord - left_coord
                if distance > max_distance:
                    max_distance = distance
                    best_pair = (tfeat, ffeat)
    
    return best_pair  # could be None if no same-strand pairs

# ---------------------------------------------------
# 4. Define Input Files and Feature Names
# ---------------------------------------------------
input_file = "00-Data/feature_annotated/Repeatmasker_transgene.gb"
feat_3_names = ["3end_200nt", "3end_400nt"]  # adjust if needed
feat_5_names = ["5end_200nt", "5end_400nt"]  # adjust if needed

# ---------------------------------------------------
# 5. Apply Filter 1 and Classify Records with Detailed Failure Reasons
# ---------------------------------------------------
records_filter1_pass = []
records_filter1_fail = []

# Initialize counters (optional, for tracking purposes)
counters = defaultdict(int)
# Possible counters:
# - 'discarded_missing_features': missing 3end or 5end features
# - 'discarded_due_to_combined_length': failing combined length requirements
# - 'discarded_due_to_flanking_pair': failing to find best flanking pair

# Initialize lists to store read IDs that passed and failed
passed_reads = []  # Not used
failed_reads = []  # Not used

# Initialize list for reads that failed due to combined length
discarded_due_to_combined_length_reads = []

# Read the input file and classify records based on Filter 1
for record in SeqIO.parse(input_file, "genbank"):
    best_pair = get_best_flanking_pair(record, feat_3_names, feat_5_names)
    if best_pair is not None:
        # Since combined lengths are already checked in get_best_flanking_pair,
        # we no longer need to check individual feature lengths here.
        records_filter1_pass.append(record)
        passed_reads.append(record.id)
    else:
        # Determine the reason for failure
        # Check if the failure was due to combined length or no valid pair
        # To do this, we need to recalculate the combined lengths

        # Gather all 3end & 5end features
        three_features = [feat for feat in record.features if feat.type in feat_3_names]
        five_features = [feat for feat in record.features if feat.type in feat_5_names]

        # Calculate sum of lengths per direction
        direction_sums = defaultdict(int)
        for feat in three_features + five_features:
            direction = (feat.type, feat.strand)
            direction_sums[direction] += len(feat.location)

        # Check if any direction meets the combined length requirement
        any_combined_length_met = any(total >= 100 for total in direction_sums.values())

        if not any_combined_length_met:
            # Failed due to combined length requirements
            counters['discarded_due_to_combined_length'] += 1
            discarded_due_to_combined_length_reads.append(record.id)
        else:
            # Failed to find a valid flanking pair despite meeting combined length
            counters['discarded_due_to_flanking_pair'] += 1

        # Add to failed records
        records_filter1_fail.append(record)
        failed_reads.append(record.id)

# Print out how many passed/failed Filter 1
total_records = len(records_filter1_pass) + len(records_filter1_fail)
print(f"Step 1 (Filter 1) - Total records: {total_records}")
print(f"Passed Filter 1 (have flanking regions with combined directions >= 100bp): {len(records_filter1_pass)}")
print(f"Failed Filter 1 (no valid flanking regions or combined directions < 100bp): {len(records_filter1_fail)}")

# Print the number of reads discarded due to specific reasons
print("\nReasons for Discarding Reads:")
print(f" - Discarded due to failing combined length requirements: {counters['discarded_due_to_combined_length']}")
print(f" - Discarded due to failing to find a best flanking pair: {counters['discarded_due_to_flanking_pair']}")


Step 1 (Filter 1) - Total records: 26386
Passed Filter 1 (have flanking regions with combined directions >= 100bp): 10175
Failed Filter 1 (no valid flanking regions or combined directions < 100bp): 16211

Reasons for Discarding Reads:
 - Discarded due to failing combined length requirements: 15345
 - Discarded due to failing to find a best flanking pair: 866


In [None]:

######################################################
#        FILTER 2 - From Filter 1, keep those with an
#           insert feature (adjust feature name).
######################################################

def has_required_feature(record, feature_names=["Transgene", "3UTR"]):
    """
    Check if the record contains any feature with type in feature_names.
    
    Args:
        record (SeqRecord): The sequence record to check.
        feature_names (list): List of feature types to look for.
    
    Returns:
        bool: True if any feature matches, False otherwise.
    """
    return any(feat.type in feature_names for feat in record.features)

# Initialize lists to store passed and failed records for Filter 2
records_filter2_pass = []
records_filter2_fail = []  # Records that have flanking region but neither Transgene nor 3UTR

# Iterate over records that passed Filter 1
for record in records_filter1_pass:
    if has_required_feature(record, feature_names=["Transgene", "3UTR"]):
        records_filter2_pass.append(record)
    else:
        records_filter2_fail.append(record)
        counters['discarded_due_to_missing_transgene_or_3UTR'] += 1

# Output the results
print(f"\nStep 2 (Filter 2) - Starting with {len(records_filter1_pass)} Filter 1-passed records.")
print(f"Passed Filter 2 (contain Transgene or 3UTR): {len(records_filter2_pass)}")
print(f"Failed Filter 2 (no Transgene or 3UTR): {len(records_filter2_fail)}")

# Print the number of reads discarded due to missing Transgene or 3UTR
print(f" - Discarded due to missing Transgene or 3UTR: {counters['discarded_due_to_missing_transgene_or_3UTR']}")

###################################################################################
# # CELL 3: FILTER 3 - Among those that failed Filter 2 (but passed Filter 1),
# #         check how many have ≥50 unannotated bases between the 3end and 5end.
###################################################################################

def unannotated_region_in_flanking_interval(record, feat_3_names, feat_5_names, min_gap=50):
    """
    Returns True if there is a region of >= min_gap bases between
    the flanking features that is NOT overlapped by any other feature.
    We use the 'best' 3end-5end pair from get_best_flanking_pair().
    """
    best_pair = get_best_flanking_pair(record, feat_3_names, feat_5_names)
    if not best_pair:
        return False  # no flanking pair

    three_feat, five_feat = best_pair
    left = min(three_feat.location.end, five_feat.location.end)
    right = max(three_feat.location.start, five_feat.location.start)
    # The 'internal region' is left..right, if left < right.
    # If left >= right, it means the features overlap or are adjacent.
    
    if left >= right:
        return False
    
    gap_length = right - left
    if gap_length < min_gap:
        return False
    
    # Now we check if ANY feature overlaps that region
    for feat in record.features:
        f_start = feat.location.start
        f_end = feat.location.end
        # If a feature overlaps the gap region, it's annotated
        if f_start < right and f_end > left:
            return False  # Gap region is at least partially annotated
    
    # If we reached here, there's no feature that covers that gap region,
    # so we have at least 50 unannotated bases.
    return True

count_filter3 = 0
for record in records_filter2_fail:  # those that had flanking region but no transgene
    if unannotated_region_in_flanking_interval(
        record,
        feat_3_names,
        feat_5_names,
        min_gap=50
    ):
        count_filter3 += 1

print(f"\nStep 3 (Filter 3) - Among {len(records_filter2_fail)} records that have flanking regions but no Transgene or 3UTR:")
print(f"Number with ≥50 unannotated bases between flanking regions: {count_filter3}")

##################################################################################
# # CELL 4: GENBANK TRIMMING (passed Filter 1 & 2) + STORE POSITIONS
# #         Don't alter the record ID (no "_trimmed" suffix).
##################################################################################
# from Bio import SeqIO
# from Bio.SeqRecord import SeqRecord
# from Bio.SeqFeature import SeqFeature, FeatureLocation

# Adjust these paths if needed:
output_file_gb = "00-Data/feature_annotated/Repeatmasker_transgene_filtered.gb"
removed_file_gb= "00-Data/feature_annotated/Repeatmasker_transgene_removed.gb"

trim_positions = {}   # { record.id : (slice_left, slice_right, strand) }
trimmed_records = []
removed_records = []

for record in records_filter2_pass:
    best_pair = get_best_flanking_pair(record, feat_3_names, feat_5_names)
    if not best_pair:
        # Theoretically shouldn't happen if it passed Filter 1,
        # but we'll handle it just in case:
        removed_records.append(record)
        continue

    three_feat, five_feat = best_pair
    slice_left = min(three_feat.location.start, five_feat.location.start)
    slice_right = max(three_feat.location.end, five_feat.location.end)
    slice_strand = three_feat.strand  # same as five_feat.strand

    # Store the coords for FASTQ trimming (Cell 5, presumably)
    trim_positions[record.id] = (slice_left, slice_right, slice_strand)

    # Create trimmed GenBank record (optional):
    sub_seq = record.seq[slice_left : slice_right]

    # Reverse-complement if negative strand
    if slice_strand == -1:
        sub_seq = sub_seq.reverse_complement()

    # Keep the same ID:
    new_record = SeqRecord(
        sub_seq,
        id=record.id,
        name=record.name,
        description=(
            f"Region {slice_left}-{slice_right} from original {record.id}; "
            f"strand={slice_strand}"
        ),
    )

    # Adjust features
    adjusted_features = []
    length_sub = slice_right - slice_left

    for feat in record.features:
        f_start = feat.location.start
        f_end   = feat.location.end

        # Keep only if fully within the slice
        if f_start >= slice_left and f_end <= slice_right:
            new_start = f_start - slice_left
            new_end   = f_end   - slice_left
            new_strand = feat.strand

            if slice_strand == -1:
                # Flip for reverse complement
                tmp_start = length_sub - new_end
                tmp_end   = length_sub - new_start
                new_start, new_end = tmp_start, tmp_end
                # Flip strand if it's 1 or -1
                if new_strand in (1, -1):
                    new_strand = -new_strand

            new_location = FeatureLocation(new_start, new_end, strand=new_strand)
            new_feat = SeqFeature(new_location, type=feat.type, qualifiers=feat.qualifiers)
            adjusted_features.append(new_feat)

    new_record.features = adjusted_features
    # Copy certain annotations if desired
    new_record.annotations["molecule_type"] = record.annotations.get("molecule_type","DNA")

    trimmed_records.append(new_record)

#
# IMPORTANT: Write out the trimmed and removed records AFTER this loop.
#

# Write trimmed records
if trimmed_records:
    with open(output_file_gb, "w") as handle:
        SeqIO.write(trimmed_records, handle, "genbank")

# Write removed records
if removed_records:
    with open(removed_file_gb, "w") as handle:
        SeqIO.write(removed_records, handle, "genbank")

print("Trimming done.")
print(f"Trimmed records: {len(trimmed_records)}")
print(f"Removed records: {len(removed_records)}")



Step 2 (Filter 2) - Starting with 10175 Filter 1-passed records.
Passed Filter 2 (contain Transgene or 3UTR): 846
Failed Filter 2 (no Transgene or 3UTR): 9329
 - Discarded due to missing Transgene or 3UTR: 9329

Step 3 (Filter 3) - Among 9329 records that have flanking regions but no Transgene or 3UTR:
Number with ≥50 unannotated bases between flanking regions: 0
Trimming done.
Trimmed records: 846
Removed records: 0


## Trimming of the FastQ File 

Extract the length of the Feature Annotation and trimm the FastQ File accordingly 

In [10]:
#######################################################################
# CELL 5: FASTQ TRIMMING BASED ON STORED (ID, START, END, STRAND), 
#         KEEPING THE SAME READ IDs.
#######################################################################

from Bio.Seq import Seq

input_file_fastq  = "00-Data/samples_filtered/sample-cutadapt_sup-filtered.fastq"
output_file_fastq = "00-Data/samples_filtered/sample-cutadapt_sup-filtered-trimmed.fastq"

trimmed_fastq_records = []

with open(input_file_fastq, "r") as in_handle:
    for fastq_record in SeqIO.parse(in_handle, "fastq"):
        # If we have trimming info for this read
        if fastq_record.id in trim_positions:
            slice_left, slice_right, slice_strand = trim_positions[fastq_record.id]

            # Ensure slice_right does not exceed read length
            read_len = len(fastq_record.seq)
            if slice_right > read_len:
                slice_right = read_len  # or skip if out-of-bounds

            # Slice the sequence and quality
            trimmed_seq = fastq_record.seq[slice_left : slice_right]
            trimmed_qual = fastq_record.letter_annotations["phred_quality"][slice_left : slice_right]

            # Reverse complement if strand is -1
            if slice_strand == -1:
                trimmed_seq = trimmed_seq.reverse_complement()
                trimmed_qual = trimmed_qual[::-1]

            # Create a new FASTQ record with the same ID
            new_fastq_record = SeqRecord(
                trimmed_seq,
                id=fastq_record.id,   # same ID
                name=fastq_record.name,
                description=fastq_record.description + f" [trimmed {slice_left}-{slice_right}]"
            )
            new_fastq_record.letter_annotations["phred_quality"] = trimmed_qual

            trimmed_fastq_records.append(new_fastq_record)

# Write the trimmed FASTQ
with open(output_file_fastq, "w") as out_handle:
    SeqIO.write(trimmed_fastq_records, out_handle, "fastq")

print(f"Step 5 (FASTQ trimming) - Trimmed reads: {len(trimmed_fastq_records)}")
print(f"Output FASTQ file: {output_file_fastq}")


Step 5 (FASTQ trimming) - Trimmed reads: 846
Output FASTQ file: 00-Data/samples_filtered/sample-cutadapt_sup-filtered-trimmed.fastq


## New Minimap Alignment with trimmed Fastq 

In [11]:
# repeat alignment with trimmed reads, reduce minimal chain score requirement 
# -m 15

!minimap2 -a --splice --MD --eqx 00-Data/references/ref_400nt_flanking_Transgene.fasta 00-Data/samples_filtered/sample-cutadapt_sup-filtered-trimmed.fastq | samtools view -h -F 4 -o 00-Data/filtered_aligned/trimmed.sam

[M::mm_idx_gen::0.000*3.97] collected minimizers
[M::mm_idx_gen::0.001*3.34] sorted minimizers
[M::main::0.001*3.08] loaded/built the index for 1 target sequence(s)
[M::mm_mapopt_update::0.001*2.98] mid_occ = 10
[M::mm_idx_stat] kmer size: 15; skip: 10; is_hpc: 0; #seq: 1
[M::mm_idx_stat::0.001*2.89] distinct minimizers: 435 (99.31% are singletons); average occurrences: 1.011; average spacing: 5.232; total length: 2302
[M::worker_pipeline::0.134*2.84] mapped 846 sequences
[M::main] Version: 2.28-r1209
[M::main] CMD: minimap2 -a --splice --MD --eqx 00-Data/references/ref_400nt_flanking_Transgene.fasta 00-Data/samples_filtered/sample-cutadapt_sup-filtered-trimmed.fastq
[M::main] Real time: 0.134 sec; CPU: 0.380 sec; Peak RSS: 0.017 GB


### Minimap Alignment with full length FASTQ and greater Endbonus 

In [12]:
!minimap2 -a --splice --end-bonus 1000 --MD --eqx 00-Data/references/ref_crRNA_Transgene.fasta 00-Data/samples_filtered/sample-cutadapt_sup-filtered.fastq  | samtools view -h -F 4 -o 00-Data/filtered_aligned/crRNA.sam

[M::mm_idx_gen::0.001*2.19] collected minimizers
[M::mm_idx_gen::0.001*2.42] sorted minimizers
[M::main::0.001*2.41] loaded/built the index for 1 target sequence(s)
[M::mm_mapopt_update::0.001*2.33] mid_occ = 30
[M::mm_idx_stat] kmer size: 15; skip: 10; is_hpc: 0; #seq: 1
[M::mm_idx_stat::0.001*2.26] distinct minimizers: 1844 (97.07% are singletons); average occurrences: 1.075; average spacing: 5.331; total length: 10572
[M::worker_pipeline::77.346*2.97] mapped 26386 sequences
[M::main] Version: 2.28-r1209
[M::main] CMD: minimap2 -a --splice --end-bonus 1000 --MD --eqx 00-Data/references/ref_crRNA_Transgene.fasta 00-Data/samples_filtered/sample-cutadapt_sup-filtered.fastq
[M::main] Real time: 77.347 sec; CPU: 229.968 sec; Peak RSS: 0.988 GB
