In [7]:
import os
from botocore.exceptions import ClientError
from tqdm import tqdm

from deep_ccf_registration.datasets.slice_dataset import SubjectMetadata, AcquisitionAxis

In [9]:
os.environ['AWS_PROFILE'] = '533267145346_ProductionRW'

In [13]:
from concurrent.futures import as_completed, ThreadPoolExecutor


def list_smartspim_stitched_prefixes(bucket_name, delimiter='/', aws_profile=None):
    """
    List S3 prefixes that start with 'SmartSPIM_' and contain 'stitched' in the name
    Example: SmartSPIM_000393_2023-01-06_13-35-10_stitched_2023-02-02_22-28-35/

    Args:
        bucket_name (str): Name of the S3 bucket
        delimiter (str): Delimiter to use for prefix grouping (default: '/')
        aws_profile (str, optional): AWS profile name to use

    Returns:
        list: List of matching prefixes
    """

    # Initialize S3 client
    if aws_profile:
        session = boto3.Session(profile_name=aws_profile)
        s3_client = session.client('s3')
    else:
        s3_client = boto3.client('s3')

    matching_prefixes = []

    try:
        # Use paginator to handle large numbers of prefixes
        paginator = s3_client.get_paginator('list_objects_v2')

        # List with prefix "SmartSPIM_" and delimiter to get common prefixes
        page_iterator = paginator.paginate(
            Bucket=bucket_name,
            Prefix='SmartSPIM_',
            Delimiter=delimiter
        )

        for page in page_iterator:
            # Get common prefixes (directory-like structures)
            if 'CommonPrefixes' in page:
                for prefix_info in page['CommonPrefixes']:
                    prefix = prefix_info['Prefix']
                    if 'stitched' in prefix:
                        matching_prefixes.append(prefix)

    except ClientError as e:
        print(f"Error accessing S3 bucket: {e}")
        return []

    return sorted(matching_prefixes)

def list_smartspim_atlas_alignments_prefixes(bucket_name, delimiter='/', aws_profile=None):
    """
    Find prefixes following the structure:
    SmartSPIM_*stitched*/image_atlas_alignments/

    Args:
        bucket_name (str): Name of the S3 bucket
        delimiter (str): Delimiter to use for prefix grouping (default: '/')
        aws_profile (str, optional): AWS profile name to use

    Returns:
        list: List of image_atlas_alignments prefixes under SmartSPIM stitched directories
    """

    # Initialize S3 client
    if aws_profile:
        session = boto3.Session(profile_name=aws_profile)
        s3_client = session.client('s3')
    else:
        s3_client = boto3.client('s3')

    atlas_prefixes = []

    try:
        # Step 1: Get all SmartSPIM prefixes that contain 'stitched'
        stitched_prefixes = list_smartspim_stitched_prefixes(bucket_name, delimiter, aws_profile)

        # Step 2: For each stitched prefix, look for image_atlas_alignments subdirectory
        paginator = s3_client.get_paginator('list_objects_v2')

        for stitched_prefix in tqdm(stitched_prefixes):
            page_iterator = paginator.paginate(
                Bucket=bucket_name,
                Prefix=stitched_prefix,
                Delimiter=delimiter
            )

            for page in page_iterator:
                if 'CommonPrefixes' in page:
                    for prefix_info in page['CommonPrefixes']:
                        prefix = prefix_info['Prefix']
                        # Check if this subdirectory is image_atlas_alignments
                        if prefix.endswith(f'image_atlas_alignment{delimiter}'):
                            atlas_prefixes.append(prefix)

    except ClientError as e:
        print(f"Error accessing S3 bucket: {e}")
        return []

    return sorted(atlas_prefixes)


def check_single_prefix(args):
    """
    Helper function to check a single prefix for a file
    Designed to be used with ThreadPoolExecutor
    """
    bucket_name, prefix, filename, aws_profile = args

    # Initialize S3 client for this thread
    if aws_profile:
        session = boto3.Session(profile_name=aws_profile)
        s3_client = session.client('s3')
    else:
        s3_client = boto3.client('s3')

    try:
        # Search for the file anywhere under this prefix
        paginator = s3_client.get_paginator('list_objects_v2')
        page_iterator = paginator.paginate(
            Bucket=bucket_name,
            Prefix=prefix
        )

        file_locations = []

        for page in page_iterator:
            if 'Contents' in page:
                for obj in page['Contents']:
                    key = obj['Key']
                    # Check if this object ends with our target filename
                    if key.endswith(filename):
                        file_locations.append(key)

        return prefix, file_locations

    except Exception as e:
        return prefix, f"Error: {e}"

def check_file_in_prefixes(bucket_name, prefixes, filename, aws_profile=None, max_workers=20):
    """
    Check which prefixes contain a specific file anywhere under the prefix
    Uses concurrent processing for much faster execution

    Args:
        bucket_name (str): Name of the S3 bucket
        prefixes (list): List of S3 prefixes to check
        filename (str): Name of the file to look for
        aws_profile (str, optional): AWS profile name to use
        max_workers (int): Maximum number of concurrent threads (default: 20)

    Returns:
        list: List of prefixes that contain the specified file anywhere under them
    """

    prefixes_with_file = []
    print(f"Checking {len(prefixes)} prefixes using {max_workers} concurrent workers...")

    # Prepare arguments for each prefix check
    check_args = [(bucket_name, prefix, filename, aws_profile) for prefix in prefixes]

    # Use tqdm progress bar with known total
    with tqdm(total=len(prefixes), desc="Checking prefixes", unit="prefix") as pbar:
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            # Submit all tasks
            future_to_prefix = {executor.submit(check_single_prefix, args): args[1]
                               for args in check_args}

            for future in as_completed(future_to_prefix):
                prefix, result = future.result()

                if isinstance(result, list) and result:  # Found files
                    prefixes_with_file.append(prefix)
                elif isinstance(result, list):  # No files found
                    pass
                else:  # Error occurred
                    tqdm.write(f"❌ Error in {prefix}: {result}")

                pbar.update(1)

    return prefixes_with_file


In [4]:
all_experiments = list_smartspim_stitched_prefixes(bucket_name='aind-open-data', aws_profile='533267145346_ProductionRW')

In [5]:
experiments = list_smartspim_atlas_alignments_prefixes(bucket_name='aind-open-data', aws_profile='533267145346_ProductionRW')

100%|██████████| 3022/3022 [05:11<00:00,  9.71it/s]


In [14]:
experiments_w_ls_inverse = check_file_in_prefixes(bucket_name='aind-open-data', aws_profile='533267145346_ProductionRW', filename='ls_to_template_SyN_1InverseWarp.nii.gz', prefixes=experiments)

Checking 2552 prefixes using 20 concurrent workers...


Checking prefixes: 100%|██████████| 2552/2552 [53:43<00:00,  1.26s/prefix]  


In [26]:
len(set([x.split('_')[1] for x in experiments_w_ls_inverse])) / len(set([x.split('_')[1] for x in all_experiments]))

0.6231536926147705

In [25]:
'000393' in set([x.split('_')[1] for x in all_experiments])

True

In [20]:
import json
with open('/Users/adam.amster/smartspim-registration/prefixes_with_inverse_warp.json', 'w') as f:
    f.write(json.dumps(experiments_w_ls_inverse, indent=2))

In [10]:
import json
import pandas as pd

with open('/Users/adam.amster/smartspim-registration/prefixes_with_inverse_warp.json') as f:
    experiments_w_ls_inverse = json.load(f)

base_paths = [x.split('/')[0] for x in experiments_w_ls_inverse]
experiments = pd.DataFrame({'base_path': base_paths})
experiments['experiment_id'] = experiments['base_path'].str.split('_').apply(lambda x: x[1]).astype(str)
experiments['stitched_datetime'] = experiments['base_path'].str.split('_').apply(lambda x: x[-2] + '_' + x[-1]).astype(str)
experiments['stitched_datetime'] = pd.to_datetime(experiments['stitched_datetime'], format='%Y-%m-%d_%H-%M-%S')
experiments = experiments.sort_values(by='stitched_datetime', ascending=True).drop_duplicates(subset='experiment_id')

In [11]:
def get_alignment_artifacts(experiments: pd.DataFrame):
    session = boto3.Session(profile_name='533267145346_ProductionRW', region_name='us-west-2')
    s3 = session.client('s3')

    registration_artifacts_prefixes = []
    errors = []

    for row in tqdm(experiments.itertuples(), total=experiments.shape[0]):
        base_path = row.base_path
        # remove the stitched part to get prefix
        base_path = '_'.join(base_path.split('_')[:-3])

        # get processing manifest
        try:
            s3.head_object(
                Bucket='aind-open-data',
                Key=f'{base_path}/SPIM/derivatives/processing_manifest.json',
            )
            processing_manifest_path = f'{base_path}/SPIM/derivatives/processing_manifest.json'
        except ClientError:
            try:
                s3.head_object(
                    Bucket='aind-open-data',
                    Key=f'{base_path}/derivatives/processing_manifest.json',
                )
                processing_manifest_path = f'{base_path}/derivatives/processing_manifest.json'
            except ClientError:
                raise RuntimeError(f'could not find processing_manifest for {base_path}')
        obj = s3.get_object(Bucket='aind-open-data', Key=processing_manifest_path)
        processing_manifest = json.loads(obj['Body'].read())

        registration_channels = processing_manifest['pipeline_processing']['registration']['channels']

        if len(registration_channels) == 1:
            registration_channels = registration_channels[0]
        elif len(registration_channels) == 0:
            raise ValueError(f'expected 1 channel {base_path}')
        elif len(registration_channels) > 1:
            has_alignment_artifacts = []
            for channel in registration_channels:
                # check which channel has registration artifacts
                try:
                    s3.head_object(
                        Bucket='aind-open-data',
                        Key=f'{row.base_path}/image_atlas_alignment/{channel}'
                    )
                    has_alignment_artifacts.append(True)
                except ClientError:
                    pass
            if sum(has_alignment_artifacts) != 1:
                errors.append(row.base_path)
                print(f'expected 1 channel to have alignment but found {sum(has_alignment_artifacts)}')
                registration_artifacts_prefixes.append(None)
                continue

            registration_channels = registration_channels[has_alignment_artifacts.index(True)]

        registration_artifacts_prefix = f'{row.base_path}/image_atlas_alignment/{registration_channels}'

        # check that it exists
        try:
            s3.head_object(Bucket='aind-open-data', Key=f'{registration_artifacts_prefix}/ls_to_template_SyN_1InverseWarp.nii.gz')
            s3.head_object(Bucket='aind-open-data', Key=f'{registration_artifacts_prefix}/ls_to_template_rigid_0GenericAffine.mat')
        except ClientError:
            errors.append(base_path)
            print(f'Could not find registration artifacts {registration_artifacts_prefix}')
            registration_artifacts_prefix = None
        registration_artifacts_prefixes.append(registration_artifacts_prefix)
    return registration_artifacts_prefixes, errors

In [12]:
image_atlas_alignment_paths, errors = get_alignment_artifacts(experiments=experiments)

 11%|█         | 170/1561 [01:25<11:03,  2.10it/s]

Could not find registration artifacts SmartSPIM_719360_2024-06-28_14-08-20_stitched_2024-07-10_01-02-48/image_atlas_alignment/Ex_639_Em_667


 25%|██▍       | 389/1561 [03:15<09:41,  2.02it/s]

expected 1 channel to have alignment but found 0


100%|██████████| 1561/1561 [11:33<00:00,  2.25it/s]


In [13]:
experiments['image_atlas_alignment_path'] = image_atlas_alignment_paths
experiments = experiments[~experiments['image_atlas_alignment_path'].isnull()]

In [14]:
experiments = experiments.rename(columns={'base_path': 'stitched_path'})
#experiments['stitched_path'] = experiments['stitched_path'].apply(lambda x: f'/data/{x}')
#experiments['image_atlas_alignment_path'] = experiments['image_atlas_alignment_path'].apply(lambda x: f'/data/{x}')

In [17]:
def get_processing_manifest(s3, base_path):
    # get processing manifest
    try:
        s3.head_object(
            Bucket='aind-open-data',
            Key=f'{base_path}/SPIM/derivatives/processing_manifest.json',
        )
        processing_manifest_path = f'{base_path}/SPIM/derivatives/processing_manifest.json'
    except ClientError:
        try:
            s3.head_object(
                Bucket='aind-open-data',
                Key=f'{base_path}/derivatives/processing_manifest.json',
            )
            processing_manifest_path = f'{base_path}/derivatives/processing_manifest.json'
        except ClientError:
            raise RuntimeError(f'could not find processing_manifest for {base_path}')
    return processing_manifest_path

In [15]:
from zarr.errors import GroupNotFoundError
from pathlib import Path
import zarr
import boto3

def get_volume_shapes(experiments: pd.DataFrame):
    shapes = []
    volume_paths = []
    registration_downsamples = []
    errors = []
    alignment_resolutions = []

    s3_client = boto3.client('s3')

    for row in tqdm(experiments.itertuples(), total=experiments.shape[0]):
        stitched_path = row.stitched_path
        registration_channel = Path(row.image_atlas_alignment_path).name

        # get image_atlas_alignment processing.json
        registration_processing_path = Path(row.image_atlas_alignment_path) / 'metadata' / 'processing.json'
        response = s3_client.get_object(Bucket='aind-open-data', Key=str(registration_processing_path))
        registration_processing = json.loads(response['Body'].read())

        alignment_step = [x for x in registration_processing['processing_pipeline']['data_processes'] if x['name'] == 'Image atlas alignment'][0]
        aligment_spacing = alignment_step['parameters']['spacing']
        if alignment_step['parameters']['unit'] == 'millimetre':
            aligment_spacing = [x * 1e3 for x in aligment_spacing]
        else:
            raise ValueError(f'unexpected unit {alignment_step['parameters']["unit"]}')
        alignment_downsample_factor = int(Path(alignment_step['input_location']).name)

        volume_path = Path(stitched_path) / 'image_tile_fusing' / 'OMEZarr' / f'{registration_channel}.zarr'

        try:
            volume = zarr.open(f's3://aind-open-data/{volume_path}', storage_options={'anon': True}, mode='r')
        except GroupNotFoundError:
            errors.append(volume_path)
            print(f'cannot load zarr group {volume_path}')
            shapes.append(None)
            volume_paths.append(None)
            registration_downsamples.append(None)
            alignment_resolutions.append(None)
            continue
        shapes.append(volume[str(alignment_downsample_factor)].shape[2:])
        volume_paths.append(volume_path)
        registration_downsamples.append(alignment_downsample_factor)
        alignment_resolutions.append(aligment_spacing)
    return shapes, volume_paths, registration_downsamples, alignment_resolutions

In [16]:
shapes, volume_paths, registration_downsamples, registration_resolutions = get_volume_shapes(experiments=experiments)

  0%|          | 5/1559 [00:04<15:23,  1.68it/s]  

cannot load zarr group SmartSPIM_AK031_2023-07-11_16-27-40_stitched_2023-07-29_03-22-11/image_tile_fusing/OMEZarr/Ex_639_Em_660.zarr


  2%|▏         | 28/1559 [00:15<10:47,  2.36it/s]

cannot load zarr group SmartSPIM_AK030_2023-07-10_23-22-02_stitched_2023-09-27_01-28-29/image_tile_fusing/OMEZarr/Ex_639_Em_660.zarr


  3%|▎         | 54/1559 [00:27<10:24,  2.41it/s]

cannot load zarr group SmartSPIM_AK029_2023-11-01_12-32-06_stitched_2023-11-22_03-30-32/image_tile_fusing/OMEZarr/Ex_639_Em_667.zarr


  4%|▎         | 55/1559 [00:28<09:26,  2.65it/s]

cannot load zarr group SmartSPIM_692913_2023-10-23_21-38-08_stitched_2023-11-22_03-42-37/image_tile_fusing/OMEZarr/Ex_639_Em_667.zarr


 47%|████▋     | 726/1559 [06:00<06:25,  2.16it/s]

cannot load zarr group SmartSPIM_660851_2023-04-03_16-25-48_stitched_2025-01-17_00-58-31/image_tile_fusing/OMEZarr/Ex_639_Em_667.zarr


100%|██████████| 1559/1559 [12:59<00:00,  2.00it/s]


In [None]:
experiments['registration_shape'] = shapes
experiments['volume_path'] = volume_paths
experiments['registration_downsample'] = registration_downsamples
experiments['registration_resolution'] = registration_resolutions

experiments = experiments[~experiments['volume_path'].isnull()]

In [20]:
pd.Series(registration_downsamples).unique()

array([ 3., nan])

In [26]:
experiments['registration_resolution'].apply(lambda x: tuple(x)).unique()

array([(14.4, 14.4, 16.0), (16.0, 14.4, 14.4)], dtype=object)

In [117]:
experiments['shape'].unique()

array([(480, 1279, 927), (502, 1280, 930), (482, 1279, 928),
       (454, 1100, 927), (467, 1281, 928), (502, 1278, 929),
       (573, 1101, 927), (471, 1281, 929), (501, 1281, 927),
       (468, 1279, 928), (513, 1278, 927), (564, 1279, 928),
       (508, 1279, 928), (538, 1280, 928), (538, 1282, 930),
       (497, 1280, 929), (593, 1280, 929), (514, 1099, 928),
       (463, 1279, 928), (486, 1279, 928), (564, 1100, 928),
       (496, 1280, 929), (518, 1280, 929), (540, 1279, 929),
       (527, 1282, 929), (617, 1282, 930), (475, 1279, 929),
       (546, 1280, 930), (463, 1282, 929), (440, 1281, 928),
       (517, 1282, 930), (539, 1281, 928), (443, 1282, 929),
       (446, 1282, 929), (509, 1282, 930), (472, 1282, 930),
       (516, 1280, 931), (536, 1281, 929), (495, 1282, 930),
       (534, 1281, 930), (426, 1108, 934), (450, 1108, 933),
       (478, 1108, 934), (491, 1284, 928), (499, 1282, 929),
       (516, 1108, 933), (461, 1108, 933), (466, 1101, 927),
       (403, 1100, 927),

In [18]:
def get_axes(s3_client, base_path: str):
    acquisition_meta_path = Path(base_path) / 'acquisition.json'
    response = s3_client.get_object(Bucket='aind-open-data', Key=str(acquisition_meta_path))
    acquisition_meta = json.loads(response['Body'].read())

    axes = acquisition_meta['axes']
    axes = [AcquisitionAxis(**x) for x in axes]
    return axes

In [25]:
import boto3
s3_client = boto3.client('s3')

dataset_meta = [
    SubjectMetadata(
        subject_id=row.subject_id,
        stitched_volume_path=row.volume_path,
        axes=get_axes(s3_client=s3_client, base_path=row.stitched_path),
        registered_shape=row.registration_shape,
        registered_resolution=row.registration_resolution,
        ls_to_template_inverse_warp_path=Path(row.image_atlas_alignment_path) / 'ls_to_template_SyN_1InverseWarp.nii.gz',
        ls_to_template_affine_matrix_path=Path(row.image_atlas_alignment_path) / 'ls_to_template_rigid_0GenericAffine.mat'
    )
for row in experiments.itertuples()]

In [32]:
with open('/Users/adam.amster/smartspim-registration/dataset_meta.json', 'w') as f:
    f.write(json.dumps([json.loads(x.model_dump_json()) for x in dataset_meta], indent=2))