CSV-Filter Image Generation Pipeline

This notebook implements the CSV-Filter method for generating 4-channel CIGAR images from structural variants.

Sourced from Xia, Z., Xiang, W., Wang, Q., Li, X., Li, Y., Gao, J., ... & Cui, Y. (2024). CSV-Filter: a deep learning-based comprehensive structural variant filtering method for both short and long reads. Bioinformatics, 40(9), btae539.


In [None]:
import os
import sys
import numpy as np
import torch
import pysam
import torchvision
import gzip
import re
import json
from multiprocessing import Pool
from pathlib import Path
from tqdm import tqdm

print(f"🔧 PyTorch version: {torch.__version__}")
print(f"🔧 Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

In [None]:
# These functions are direct implementations from Xia et al. (2024) CSV-Filter utilities.py

# CSV-Filter constants
HEIGHT = 224
resize = torchvision.transforms.Resize([HEIGHT, HEIGHT])

def cigar_new_img_single_memory(bam_path, chromosome, begin, end):
    """
    CSV-Filter function: Generate 4-channel CIGAR image [Match, Deletion, Insertion, Soft Clip]

    Direct implementation from Xia et al. (2024) utilities.py

    Args:
        bam_path: Path to BAM file
        chromosome: Chromosome name (e.g., "1", "chr1")
        begin: Start position
        end: End position

    Returns:
        torch.Tensor: 4-channel image (4, HEIGHT, HEIGHT)
    """
    r_start = []
    r_end = []
    sam_file = pysam.AlignmentFile(bam_path, "rb")

    # Get read boundaries
    for read in sam_file.fetch(chromosome, begin, end):
        r_start.append(read.reference_start)
        r_end.append(read.reference_end)

    if r_start:
        ref_min = np.min(r_start)
        ref_max = np.max(r_end)

        # Process Match operations (Channel 0)
        cigars_img = torch.zeros([1, len(r_start), ref_max - ref_min])
        for i, read in enumerate(sam_file.fetch(chromosome, begin, end)):
            max_terminal = read.reference_start - ref_min
            for operation, length in read.cigar:
                if operation == 0:  # Match
                    cigars_img[0, i, max_terminal:max_terminal+length] = 255
                    max_terminal += length
                elif operation == 2:  # Deletion
                    max_terminal += length
                elif operation == 3 or operation == 7 or operation == 8:  # N, =, X
                    max_terminal += length
        cigars_img1 = resize(cigars_img)

        # Process Deletion operations (Channel 1)
        cigars_img[:, :, :] = 0
        for i, read in enumerate(sam_file.fetch(chromosome, begin, end)):
            max_terminal = read.reference_start - ref_min
            for operation, length in read.cigar:
                if operation == 0:  # Match
                    max_terminal += length
                elif operation == 2:  # Deletion
                    cigars_img[0, i, max_terminal:max_terminal+length] = 255
                    max_terminal += length
                elif operation == 3 or operation == 7 or operation == 8:  # N, =, X
                    max_terminal += length
        cigars_img2 = resize(cigars_img)

        # Process Insertion operations (Channel 2)
        cigars_img[:, :, :] = 0
        for i, read in enumerate(sam_file.fetch(chromosome, begin, end)):
            max_terminal = read.reference_start - ref_min
            for operation, length in read.cigar:
                if operation == 0:  # Match
                    max_terminal += length
                elif operation == 2:  # Deletion
                    max_terminal += length
                elif operation == 1:  # Insertion
                    cigars_img[0, i, max_terminal - int(length / 2):max_terminal + int(length / 2)] = 255
                elif operation == 3 or operation == 7 or operation == 8:  # N, =, X
                    max_terminal += length
        cigars_img3 = resize(cigars_img)

        # Process Soft clip operations (Channel 3)
        cigars_img[:, :, :] = 0
        for i, read in enumerate(sam_file.fetch(chromosome, begin, end)):
            max_terminal = read.reference_start - ref_min
            for operation, length in read.cigar:
                if operation == 0:  # Match
                    max_terminal += length
                elif operation == 2:  # Deletion
                    max_terminal += length
                elif operation == 4:  # Soft clip
                    cigars_img[0, i, max_terminal - int(length / 2):max_terminal + int(length / 2)] = 255
                elif operation == 3 or operation == 7 or operation == 8:  # N, =, X
                    max_terminal += length
        cigars_img4 = resize(cigars_img)

        # Combine all 4 channels
        cigars_img = torch.empty([4, HEIGHT, HEIGHT])
        cigars_img[0] = cigars_img1  # Match
        cigars_img[1] = cigars_img2  # Deletion
        cigars_img[2] = cigars_img3  # Insertion
        cigars_img[3] = cigars_img4  # Soft clip

    else:
        cigars_img = torch.zeros([4, HEIGHT, HEIGHT])

    sam_file.close()
    return cigars_img

def extract_variants_from_vcf(vcf_path, max_variants=None):
    """Extract structural variants from VCF file"""
    variants = []
    opener = gzip.open if vcf_path.endswith('.gz') else open

    with opener(vcf_path, 'rt') as f:
        for line_num, line in enumerate(f):
            if line.startswith('#'):
                continue

            if max_variants and len(variants) >= max_variants:
                break

            try:
                fields = line.strip().split('\t')
                if len(fields) < 8:
                    continue

                chrom = fields[0]
                pos = int(fields[1])
                ref = fields[3]
                alt = fields[4]
                info = fields[7]

                # Parse SV information
                sv_type = None
                svlen = None
                end_pos = None

                # Extract SVTYPE
                svtype_match = re.search(r'SVTYPE=([^;]+)', info)
                if svtype_match:
                    sv_type = svtype_match.group(1)

                # Extract SVLEN
                svlen_match = re.search(r'SVLEN=([^;]+)', info)
                if svlen_match:
                    svlen = abs(int(svlen_match.group(1)))

                # Extract END
                end_match = re.search(r'END=([^;]+)', info)
                if end_match:
                    end_pos = int(end_match.group(1))

                # Calculate end position if not provided
                if end_pos is None:
                    if svlen:
                        end_pos = pos + svlen
                    else:
                        end_pos = pos + max(len(ref), len(alt))

                # Calculate SVLEN if not provided
                if svlen is None:
                    svlen = end_pos - pos

                # Filter for structural variants (>=50bp)
                if svlen >= 50 and sv_type in ['DEL', 'INS', 'DUP', 'INV']:
                    variants.append({
                        'chrom': chrom,
                        'pos': pos,
                        'end': end_pos,
                        'svtype': sv_type,
                        'svlen': svlen,
                        'ref': ref,
                        'alt': alt
                    })

            except Exception as e:
                continue

    return variants

print("CSV-Filter core functions loaded")

In [None]:
# Define datasets with appropriate reference genomes
DATASETS = {
    'HG002_GRCh37': {
        'tp_comp_vcf': '../data/raw/bench-HG002_GRCh37_noqc/tp-comp.vcf.gz',
        'fp_vcf': '../data/raw/bench-HG002_GRCh37_noqc/fp.vcf.gz',
        'ref': '../data/raw/GRCh37.fna'
    },
    'HG002_GRCh38': {
        'tp_comp_vcf': '../data/raw/bench-HG002_GRCh38-GIABv3_noqc/tp-comp.vcf.gz',
        'fp_vcf': '../data/raw/bench-HG002_GRCh38-GIABv3_noqc/fp.vcf.gz',
        'ref': '../data/raw/GRCh38.fa'
    },
    'HG005_GRCh38': {
        'tp_comp_vcf': '../data/raw/bench-HG005_GRCh38_noqc/tp-comp.vcf.gz',
        'fp_vcf': '../data/raw/bench-HG005_GRCh38_noqc/fp.vcf.gz',
        'ref': '../data/raw/GRCh38.fa'
    }
}


CONFIG = {
    'cache_dir': '/content/drive/MyDrive/SV_Diffusion/all_datasets_images',
    'max_coverage': 1000,  # Skip variants with >1000 reads
    'chunk_size': 50,      # Variants per worker
    'num_workers': 4,      # Parallel workers
    'sync_every': 1000     # Progress sync frequency
}

print(f"Configured {len(DATASETS)} datasets")
print(f"Output directory: {CONFIG['cache_dir']}")

In [None]:
# Parallel Processing Functions

def process_variant_chunk_parallel(args):
    """
    Parallel worker function for processing variant chunks

    Args:
        args: (variants_chunk, bam_path, dataset_output_dir, dataset_name)

    Returns:
        dict: Processing statistics for this chunk
    """
    variants_chunk, bam_path, dataset_output_dir, dataset_name = args

    # Local results for this worker
    worker_results = {
        'processed': 0,
        'skipped_existing': 0,
        'skipped_coverage': 0,
        'failed': 0,
        'skipped_high_coverage': [],
        'failed_variants': []
    }

    for variant, filename in variants_chunk:
        try:
            # Skip if file already exists
            output_path = os.path.join(dataset_output_dir, filename)
            if os.path.exists(output_path):
                worker_results['skipped_existing'] += 1
                continue

            # Memory safety: check read count
            sam_file = pysam.AlignmentFile(bam_path, "rb")
            read_count = 0
            for read in sam_file.fetch(variant['chrom'], variant['pos'], variant['end']):
                read_count += 1
                if read_count > CONFIG['max_coverage']:
                    break
            sam_file.close()

            if read_count > CONFIG['max_coverage']:
                worker_results['skipped_coverage'] += 1
                worker_results['skipped_high_coverage'].append({
                    'dataset': dataset_name,
                    'variant': variant,
                    'reason': f'High coverage: {read_count}+ reads',
                    'filename': filename,
                    'read_count': read_count
                })
                continue

            # Generate 4-channel CIGAR image using CSV-Filter
            image = cigar_new_img_single_memory(
                bam_path=bam_path,
                chromosome=variant['chrom'],
                begin=variant['pos'],
                end=variant['end']
            )

            # Verify 4-channel format
            assert image.shape[0] == 4, f"Expected 4 channels, got {image.shape[0]}"

            # Save image with metadata
            torch.save({
                'image': image,
                'metadata': variant,
                'shape': image.shape,
                'format': '4ch_cigar',
                'method': 'CSV-Filter (Xia et al. 2024)'
            }, output_path)

            worker_results['processed'] += 1

            # Memory cleanup
            del image
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Garbage collect periodically
            if worker_results['processed'] % 50 == 0:
                import gc
                gc.collect()

        except Exception as e:
            worker_results['failed'] += 1
            worker_results['failed_variants'].append({
                'dataset': dataset_name,
                'variant': variant,
                'error': str(e)
            })

            # Memory cleanup on failure
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            import gc
            gc.collect()

    return worker_results

In [None]:
# Main Processing Pipeline

def process_all_datasets():
    """
    Main processing pipeline: Generate CSV-Filter images for all datasets
    """

    print("CSV-Filter Image Generation Pipeline")

    # Create output directory
    os.makedirs(CONFIG['cache_dir'], exist_ok=True)

    # Check existing files
    print(f"\n Checking existing files...")
    total_existing = 0
    for dataset_name in DATASETS.keys():
        dataset_dir = os.path.join(CONFIG['cache_dir'], dataset_name)
        if os.path.exists(dataset_dir):
            existing_count = len([f for f in os.listdir(dataset_dir) if f.endswith('.pt')])
            total_existing += existing_count
            print(f"   {dataset_name}: {existing_count} existing files")
        else:
            print(f"   {dataset_name}: 0 existing files")

    print(f"   Total existing files: {total_existing}")

    # Estimate processing time
    estimated_hours = max(1, (20000 - total_existing) // (CONFIG['num_workers'] * 1000))

    proceed = input(f"\n Ready to process all datasets!\n"
                   f"Existing files: {total_existing} (will be skipped)\n"
                   f"Estimated time: ~{estimated_hours} hours\n"
                   f"Cache: {CONFIG['cache_dir']}\n"
                   f"Resume-safe: Never overwrites existing files\n"
                   f"Proceed? (y/n): ")

    if proceed.lower() != 'y':
        print("⏸Processing cancelled by user")
        return

    print(f"Starting resume-safe parallel processing...")

    # Global tracking
    total_images_generated = 0
    total_skipped = 0
    failed_variants = []
    skipped_high_coverage = []

    # Process each dataset
    for dataset_name, config in DATASETS.items():
        print(f"\n Processing dataset: {dataset_name}")

        # Verify files exist
        missing_files = [path for path in config.values() if not os.path.exists(path)]
        if missing_files:
            print(f"    Missing files: {missing_files}")
            continue

        # Extract variants from VCFs
        print(f"    Extracting variants...")
        tp_variants = extract_variants_from_vcf(config['tp_comp_vcf'])
        fp_variants = extract_variants_from_vcf(config['fp_vcf'])

        print(f"    Found {len(tp_variants)} TP and {len(fp_variants)} FP variants")

        # Combine variants with labels
        all_variants = []
        for variant in tp_variants:
            variant.update({'label': 'TP', 'dataset': dataset_name})
            all_variants.append(variant)
        for variant in fp_variants:
            variant.update({'label': 'FP', 'dataset': dataset_name})
            all_variants.append(variant)

        # Create output directory
        dataset_output_dir = os.path.join(CONFIG['cache_dir'], dataset_name)
        os.makedirs(dataset_output_dir, exist_ok=True)

        # Filter variants that need processing
        existing_files = set(f for f in os.listdir(dataset_output_dir) if f.endswith('.pt'))
        variants_to_process = []
        skipped_count = 0

        for variant in all_variants:
            filename = f"{variant['dataset']}_{variant['label']}_{variant['chrom']}_{variant['pos']}_{variant['end']}_{variant['svtype']}_{variant['svlen']}bp.pt"

            if filename in existing_files:
                skipped_count += 1
                total_skipped += 1
            else:
                variants_to_process.append((variant, filename))

        print(f"    Skipping {skipped_count} existing files")
        print(f"    Processing {len(variants_to_process)} new variants")

        if len(variants_to_process) == 0:
            print(f"    All variants already processed!")
            continue

        # Split into chunks for parallel processing
        variant_chunks = []
        for i in range(0, len(variants_to_process), CONFIG['chunk_size']):
            chunk = variants_to_process[i:i + CONFIG['chunk_size']]
            variant_chunks.append((chunk, config['bam'], dataset_output_dir, dataset_name))

        print(f"    Split into {len(variant_chunks)} chunks")

        # Process in parallel
        dataset_stats = {'processed': 0, 'skipped_coverage': 0, 'failed': 0}

        with Pool(processes=CONFIG['num_workers']) as pool:
            chunk_results = [pool.apply_async(process_variant_chunk_parallel, (chunk_args,))
                           for chunk_args in variant_chunks]

            # Collect results with progress tracking
            with tqdm(total=len(variants_to_process), desc=f"{dataset_name}", unit="variants") as pbar:
                for result in chunk_results:
                    chunk_result = result.get()  # Wait for completion

                    # Update statistics
                    dataset_stats['processed'] += chunk_result['processed']
                    dataset_stats['skipped_coverage'] += chunk_result['skipped_coverage']
                    dataset_stats['failed'] += chunk_result['failed']

                    # Collect failed/skipped items
                    skipped_high_coverage.extend(chunk_result['skipped_high_coverage'])
                    failed_variants.extend(chunk_result['failed_variants'])

                    total_images_generated += chunk_result['processed']

                    # Update progress
                    completed = (chunk_result['processed'] +
                               chunk_result['skipped_coverage'] +
                               chunk_result['failed'])
                    pbar.update(completed)
                    pbar.set_postfix({
                        'new': dataset_stats['processed'],
                        'total': total_images_generated,
                        'skipped': total_skipped,
                        'failed': len(failed_variants)
                    })

        print(f"    Completed: {dataset_stats['processed']} new images generated")

    # Final summary
    print(f"\n PROCESSING COMPLETE!")
    print("=" * 60)
    print(f"  New images generated: {total_images_generated}")
    print(f"  Existing images skipped: {total_skipped}")
    print(f"  Failed variants: {len(failed_variants)}")
    print(f"  High-coverage variants skipped: {len(skipped_high_coverage)}")
    print(f". Output directory: {CONFIG['cache_dir']}")
    print(f"  Format: 4-channel CSV-Filter CIGAR")

    # Save skipped variants for later processing
    if skipped_high_coverage:
        skipped_dir = os.path.join(CONFIG['cache_dir'], 'skipped_variants')
        os.makedirs(skipped_dir, exist_ok=True)

        with open(os.path.join(skipped_dir, 'high_coverage_variants.json'), 'w') as f:
            json.dump(skipped_high_coverage, f, indent=2)
        print(f"\n Saved {len(skipped_high_coverage)} high-coverage variants for later processing")

    # Show final counts
    print(f"\n Final file counts:")
    for dataset_name in DATASETS.keys():
        dataset_dir = os.path.join(CONFIG['cache_dir'], dataset_name)
        if os.path.exists(dataset_dir):
            count = len([f for f in os.listdir(dataset_dir) if f.endswith('.pt')])
            print(f"    {dataset_name}: {count} total files")

    return {
        'total_generated': total_images_generated,
        'total_skipped': total_skipped,
        'failed_variants': failed_variants,
        'skipped_high_coverage': skipped_high_coverage,
        'output_dir': CONFIG['cache_dir']
    }

In [None]:
# Test CSV-Filter Functions

def test_csv_filter():
    """Test CSV-Filter functions on a sample variant"""
    print("Testing CSV-Filter implementation...")

    # Test parameters
    test_bam = DATASETS['HG002_GRCh37']['bam']
    test_chrom = '1'
    test_pos = 900035
    test_end = test_pos + 79

    if not os.path.exists(test_bam):
        print(f" Test BAM file not found: {test_bam}")
        return False

    try:
        # Generate test image
        image = cigar_new_img_single_memory(test_bam, test_chrom, test_pos, test_end)

        print(f" CSV-Filter test successful!")
        print(f"    Image shape: {image.shape}")
        print(f"    Data type: {image.dtype}")
        print(f"    Value range: {image.min():.3f} - {image.max():.3f}")

        # Display sample
        import matplotlib.pyplot as plt

        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        axes = axes.flatten()
        channel_names = ['Match (M)', 'Deletion (D)', 'Insertion (I)', 'Soft Clip (S)']

        for i, (ax, name) in enumerate(zip(axes, channel_names)):
            img_data = image[i].numpy()
            im = ax.imshow(img_data, cmap='hot', aspect='auto')
            ax.set_title(f'Channel {i}: {name}')
            plt.colorbar(im, ax=ax, shrink=0.8)

        plt.suptitle(f'CSV-Filter 4-Channel Test\n{test_chrom}:{test_pos}-{test_end}', fontsize=14)
        plt.tight_layout()
        plt.show()

        return True

    except Exception as e:
        print(f" CSV-Filter test failed: {e}")
        return False

# Run test
test_success = test_csv_filter()

if test_success:
    print(f"\n Ready to process datasets!")
    print(f" Run: results = process_all_datasets()")
else:
    print(f"\n Fix test errors before proceeding")

In [None]:
# Run Processing
results = process_all_datasets()