In [None]:
import boto3
import matplotlib.pyplot as plt
import os
import pandas as pd
from qa_mods import *
import multiprocessing
import h5py
import fsspec
import numpy as np


s3client = boto3.client('s3')
paginator = s3client.get_paginator('list_objects')

In [None]:
# Choose data source: 'manifest' or 's3'
data_source = ''

# === Common parameters (needed for both modes) ===
order = ''      # Needed for output file naming

# === Parameters for 's3' mode only ===
provider = ''  # psomagen, novogene
proj = ''

In [None]:
def find_raw_filtered(bucket_name, prefix):
    """
    Given a bucket and prefix, go through s3 directory and find all files ending with suffix
    If there are no uris that have "per_sample_outs", can assume that this was not run with "multi" and can use all filtered_feature_bc_matrix.h5
    """
    suffix = 'filtered_feature_bc_matrix.h5'
    
    # Use a paginator to handle cases with more than 1000 objects
    paginator = s3client.get_paginator('list_objects_v2')
    
    # List objects with the specified prefix
    pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
    
    matching_files = []
    all_keys = []
    for page in pages:
        if 'Contents' in page:
            for obj in page['Contents']:
                # Filter client-side for the specific suffix
                if obj['Key'].endswith(suffix):
                    all_keys.append(obj['Key'])
                
    if len([i for i in all_keys if re.search('per_sample_outs',i)])>0:
        matching_files = [i for i in all_keys if re.search('per_sample_outs',i)]
    else:
        matching_files = all_keys
    return matching_files

In [None]:
# Gather raw filtered h5 cellranger output files

bucket = f'czi-{provider}'
order_dir = f'{proj}/{order}/'
r_order = s3client.list_objects(Bucket=bucket, Prefix=order_dir, Delimiter='/')
all_raw_h5 = []

if 'CommonPrefixes' in r_order:
    groups = [e['Prefix'] for e in r_order['CommonPrefixes']]
    for g in groups:
        matching_files = find_raw_filtered(bucket, g)
        all_raw_h5.extend(matching_files)

In [None]:
# Do a sanity check and review raw filtered h5 files that will be counted

all_raw_h5

In [None]:
def process_raw_filtered(s3_uri):
    """
    Opens an H5 file directly from S3 using fsspec and h5py, 
    reading only the necessary metadata bytes to get the shape.
    """
    try:
        print(f'Accessing metadata for {s3_uri}')

        cell_count = 0
        with fsspec.open(s3_uri, 'rb') as f:
            with h5py.File(f, 'r') as h5:
                cell_count = h5['matrix']['barcodes'].shape[0]

        return {
            'uri': clean_uri(s3_uri),
            'observations_cells': cell_count,
            'status': 'Success'
        }

    except Exception as e:
        print(f'[{os.getpid()}] Error processing {s3_uri}: {e}')
        return {
            'uri': clean_uri(s3_uri),
            'observations_cells': 0,
            'status': f'Error: {str(e)}'
        }


def clean_uri(s3_uri):
    """
    Extract GroupID and subsample if present from URI, where non-multi will just have GroupID,
    v9 cellranger will have 'count' subdirectory, and v10 will not have 'count' subdirectory
    """
    summary_id = ""
    if not re.search('per_sample_outs', s3_uri):
        summary_id = s3_uri.split('/')[4]
    else:
        if s3_uri.split('/')[-2] == 'count':
            summary_id = f'{s3_uri.split("/")[5]}/{s3_uri.split("/")[-3]}'
        else:
            summary_id = f'{s3_uri.split("/")[5]}/{s3_uri.split("/")[-2]}'
            
    return summary_id

In [None]:
# Go through list of raw filtered h5 in parallel, adjust number of processes if many h5 files that need to be analyzed

NUM_PROCESSES = 5

with multiprocessing.Pool(processes=NUM_PROCESSES) as pool:
    # pool.map distributes the s3_h5_files list to the processing function
    # and returns a list of dictionaries (the return values from the function)
    all_file_results = pool.map(process_raw_filtered, [f's3://{bucket}/{i}' for i in all_raw_h5])

In [None]:
# Print a summary report per group and final total

print('\n--- Summary Report ---')
total_observations = 0
print(pd.DataFrame.from_dict(all_file_results))
for result in all_file_results:
    total_observations += result['observations_cells']

print(f'\nTotal observations across all files: {total_observations}\n')