In [None]:
# ============================================================
# BATCHED EXTRACTION: Process subjects in batches
# Upload to OneDrive between batches to manage space
# ============================================================
# 
# CONFIGURATION:
# - Change BATCH_NUMBER before each run (1, 2, 3, ...)
# - Each batch processes ~10 subjects (~5,500 volumes)
# - 62 subjects / 10 = 7 batches total (for 4 classes)
# - For all 13 classes: ~16 batches
#
# BATCH PLAN (4 classes only):
# Batch 1: sub-01 to sub-10
# Batch 2: sub-11 to sub-20  
# Batch 3: sub-21 to sub-30
# Batch 4: sub-31 to sub-40
# Batch 5: sub-41 to sub-50
# Batch 6: sub-51 to sub-60
# Batch 7: sub-61 to sub-68 (remaining subjects)
# ============================================================

# ==================== CONFIGURATION ====================
# CHANGE THIS FOR EACH RUN!
BATCH_NUMBER = 7  # Change to 1, 2, 3, 4, 5, 6, 7

# Subjects per batch (adjust if needed)
SUBJECTS_PER_BATCH = 10

# Only extract these 4 classes (saves space!)
EXTRACT_4_CLASSES_ONLY = True
# ========================================================

# ==================== CELL 1: Setup ====================
!pip install -q boto3 nibabel tqdm

import os
import numpy as np
import nibabel as nib
from scipy.ndimage import zoom
from pathlib import Path
import pandas as pd
from tqdm.notebook import tqdm
import boto3
from botocore import UNSIGNED
from botocore.config import Config
import warnings
import shutil
warnings.filterwarnings('ignore')

# S3 client
s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
BUCKET = 'openneuro.org'
DATASET = 'ds004044'

# Paths
RAW_DIR = Path('/kaggle/working/raw_data')
OUTPUT_DIR = Path(f'/kaggle/working/batch_{BATCH_NUMBER:02d}')
RAW_DIR.mkdir(exist_ok=True)
OUTPUT_DIR.mkdir(exist_ok=True)

# Target shape
TARGET_SHAPE = (100, 100, 100)
TR = 2.0

# Trial type mapping
TRIAL_TYPE_MAP = {
    0: 'Rest',
    1: 'Toe movements',
    2: 'Ankle movements',
    3: 'Left leg movements',
    4: 'Right leg movements',
    5: 'Forearm movements',
    6: 'Upper arm movements',
    7: 'Wrist movements',
    8: 'Finger movements',
    9: 'Eye movements',
    10: 'Jaw movements',
    11: 'Lip movements',
    12: 'Tongue movements',
}

# Classes to extract
if EXTRACT_4_CLASSES_ONLY:
    CLASSES_TO_EXTRACT = {3, 4, 5, 6}
    print("Extracting 4 classes: Left leg, Right leg, Forearm, Upper arm")
else:
    CLASSES_TO_EXTRACT = None  # All classes
    print("Extracting ALL 13 classes")

# Create output directories
for code, class_name in TRIAL_TYPE_MAP.items():
    if CLASSES_TO_EXTRACT is None or code in CLASSES_TO_EXTRACT:
        (OUTPUT_DIR / class_name).mkdir(exist_ok=True)

print(f"\n{'='*60}")
print(f"BATCH {BATCH_NUMBER}")
print(f"{'='*60}")

# ==================== CELL 2: Get subject list for this batch ====================

def list_all_subjects():
    """List all subjects with denoised data"""
    paginator = s3.get_paginator('list_objects_v2')
    subjects = set()
    
    for page in paginator.paginate(Bucket=BUCKET, Prefix=f'{DATASET}/derivatives/fmriprep/', Delimiter='/'):
        for prefix in page.get('CommonPrefixes', []):
            # Extract subject ID from path like 'ds004044/derivatives/fmriprep/sub-01/'
            parts = prefix['Prefix'].rstrip('/').split('/')
            if len(parts) >= 4 and parts[3].startswith('sub-'):
                subjects.add(parts[3])
    
    return sorted(list(subjects))

print("Listing all subjects...")
all_subjects = list_all_subjects()
print(f"Total subjects in dataset: {len(all_subjects)}")

# Calculate batch range
start_idx = (BATCH_NUMBER - 1) * SUBJECTS_PER_BATCH
end_idx = min(start_idx + SUBJECTS_PER_BATCH, len(all_subjects))

batch_subjects = all_subjects[start_idx:end_idx]

print(f"\nBatch {BATCH_NUMBER}: Subjects {start_idx + 1} to {end_idx}")
print(f"Processing: {batch_subjects}")

if not batch_subjects:
    print("\n*** NO SUBJECTS IN THIS BATCH - ALL DONE! ***")
    raise SystemExit("Batch complete")

# ==================== CELL 3: Helper Functions ====================

def list_subject_files(subject_id):
    """List all denoised files for a subject"""
    paginator = s3.get_paginator('list_objects_v2')
    files = []
    
    for page in paginator.paginate(Bucket=BUCKET, Prefix=f'{DATASET}/derivatives/fmriprep/{subject_id}/'):
        for obj in page.get('Contents', []):
            if 'denoised' in obj['Key'] and obj['Key'].endswith('.nii.gz'):
                files.append(obj['Key'])
    
    return sorted(files)

def get_events_key(denoised_key):
    """Convert denoised file path to events file path"""
    parts = denoised_key.split('/')
    subject = parts[3]
    filename = parts[4]
    run_part = filename.split('_run-')[1].split('_')[0]
    run_num = int(run_part)
    events_key = f"{DATASET}/{subject}/ses-1/func/{subject}_ses-1_task-motor_run-{run_num:02d}_events.tsv"
    return events_key

def download_file(s3_key, local_path):
    """Download a file from S3"""
    local_path = Path(local_path)
    local_path.parent.mkdir(parents=True, exist_ok=True)
    if not local_path.exists():
        try:
            s3.download_file(BUCKET, s3_key, str(local_path))
            return True
        except Exception as e:
            print(f"    Error: {e}")
            return False
    return True

def preprocess_volume(volume_3d):
    """Resize and normalize"""
    zoom_factors = [t / s for t, s in zip(TARGET_SHAPE, volume_3d.shape)]
    resized = zoom(volume_3d, zoom_factors, order=1)
    mean, std = resized.mean(), resized.std()
    normalized = (resized - mean) / std if std > 0 else resized - mean
    return normalized.astype(np.float32)

def extract_volumes(denoised_path, events_path, subject_id, run_id):
    """Extract all volumes from 4D fMRI"""
    img = nib.load(denoised_path)
    data_4d = img.get_fdata()
    n_volumes = data_4d.shape[3]
    
    events = pd.read_csv(events_path, sep='\t')
    counts = {}
    
    for _, row in events.iterrows():
        onset = row['onset']
        duration = row['duration']
        trial_type = int(row['trial_type'])
        
        if CLASSES_TO_EXTRACT is not None and trial_type not in CLASSES_TO_EXTRACT:
            continue
        
        class_name = TRIAL_TYPE_MAP[trial_type]
        start_vol = int(onset / TR)
        end_vol = int((onset + duration) / TR)
        
        for vol_idx in range(start_vol, min(end_vol, n_volumes)):
            volume_3d = data_4d[:, :, :, vol_idx]
            processed = preprocess_volume(volume_3d)
            
            out_filename = f"{subject_id}_run-{run_id}_vol-{vol_idx:03d}.nii.gz"
            out_path = OUTPUT_DIR / class_name / out_filename
            
            nib.save(nib.Nifti1Image(processed, np.eye(4)), str(out_path))
            counts[class_name] = counts.get(class_name, 0) + 1
    
    return counts

# ==================== CELL 4: Process this batch ====================
print(f"\n{'='*60}")
print(f"PROCESSING BATCH {BATCH_NUMBER}")
print(f"{'='*60}")

results = []
total_counts = {}

for subject_id in tqdm(batch_subjects, desc=f"Batch {BATCH_NUMBER}"):
    subject_files = list_subject_files(subject_id)
    subject_counts = {}
    
    for denoised_key in subject_files:
        events_key = get_events_key(denoised_key)
        run_id = denoised_key.split('_run-')[1].split('_')[0]
        
        local_denoised = RAW_DIR / f"{subject_id}_run-{run_id}_denoised.nii.gz"
        local_events = RAW_DIR / f"{subject_id}_run-{run_id}_events.tsv"
        
        if not download_file(denoised_key, local_denoised):
            continue
        if not download_file(events_key, local_events):
            local_denoised.unlink() if local_denoised.exists() else None
            continue
        
        try:
            counts = extract_volumes(str(local_denoised), str(local_events), subject_id, run_id)
            for cls, cnt in counts.items():
                subject_counts[cls] = subject_counts.get(cls, 0) + cnt
                total_counts[cls] = total_counts.get(cls, 0) + cnt
        except Exception as e:
            print(f"  {subject_id} run-{run_id}: ERROR - {e}")
        
        # Clean up immediately
        local_denoised.unlink() if local_denoised.exists() else None
        local_events.unlink() if local_events.exists() else None
    
    results.append((subject_id, sum(subject_counts.values())))

# ==================== CELL 5: Summary ====================
print(f"\n{'='*60}")
print(f"BATCH {BATCH_NUMBER} COMPLETE!")
print(f"{'='*60}")

total_volumes = sum(r[1] for r in results)
print(f"\nSubjects processed: {len(results)}")
print(f"Total volumes extracted: {total_volumes:,}")

print("\nClass distribution:")
for class_name in sorted(total_counts.keys()):
    print(f"  {class_name}: {total_counts[class_name]:,}")

print(f"\nOutput saved to: {OUTPUT_DIR}")

# Show disk usage
import subprocess
result = subprocess.run(['du', '-sh', str(OUTPUT_DIR)], capture_output=True, text=True)
print(f"Disk usage: {result.stdout.strip()}")

# ==================== CELL 6: Instructions ====================
print(f"\n{'='*60}")
print("NEXT STEPS:")
print(f"{'='*60}")
print(f"""
1. Download this batch's output from:
   {OUTPUT_DIR}

2. Upload to OneDrive folder:
   /thesis_data/batch_{BATCH_NUMBER:02d}/

3. For next batch, create new notebook with:
   BATCH_NUMBER = {BATCH_NUMBER + 1}

4. Repeat until all batches complete.

BATCH STATUS:
""")

total_batches = (len(all_subjects) + SUBJECTS_PER_BATCH - 1) // SUBJECTS_PER_BATCH
for b in range(1, total_batches + 1):
    status = "✓ DONE" if b < BATCH_NUMBER else ("→ CURRENT" if b == BATCH_NUMBER else "  pending")
    start = (b - 1) * SUBJECTS_PER_BATCH
    end = min(start + SUBJECTS_PER_BATCH, len(all_subjects))
    subjects_range = f"sub-{all_subjects[start].split('-')[1]} to sub-{all_subjects[end-1].split('-')[1]}" if end > start else "none"
    print(f"  Batch {b}: {subjects_range} {status}")