In [None]:
import os
import io
import numpy as np
import albumentations as A
from google.cloud import storage
from sklearn.utils import shuffle
import gc

# GCS bucket details
SOURCE_BUCKET = "autism-processed-images"
DEST_BUCKET = "mri-final-dataset-test"
AUTISTIC_PREFIX = "autistic_processed/"
CONTROL_PREFIX = "control_processed/"

# Augmentation
BATCH_SIZE = 10
AUGMENT_RATIO = 3

# GCS client
client = storage.Client()
source_bucket = client.bucket(SOURCE_BUCKET)
dest_bucket = client.bucket(DEST_BUCKET)

# Albumentations pipeline
augment = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=20, p=0.3),
])

def list_npy_files(prefix):
    blobs = list(client.list_blobs(SOURCE_BUCKET, prefix=prefix))
    return [blob.name for blob in blobs if blob.name.endswith(".npy")]

def download_npy(blob_name):
    blob = source_bucket.blob(blob_name)
    return np.load(io.BytesIO(blob.download_as_bytes()))

def upload_npy(array, folder, filename):
    blob = dest_bucket.blob(f"{folder}/{filename}")
    buf = io.BytesIO()
    np.save(buf, array)
    buf.seek(0)
    blob.upload_from_file(buf, rewind=True)
    print(f"✅ Uploaded: {folder}/{filename}")

# List and pair files
autistic_files = list_npy_files(AUTISTIC_PREFIX)
control_files = list_npy_files(CONTROL_PREFIX)

# Balance the dataset
min_len = min(len(autistic_files), len(control_files))
autistic_files = autistic_files[:min_len]
control_files = control_files[:min_len]

# Pair autistic and control files and shuffle
paired_files = list(zip(autistic_files, control_files))
paired_files = shuffle(paired_files, random_state=42)

for batch_id in range(0, min_len, BATCH_SIZE):
    batch_pairs = paired_files[batch_id:batch_id + BATCH_SIZE]

    X_batch, y_batch = [], []

    for a_file, c_file in batch_pairs:
        X_batch.append(download_npy(a_file))
        y_batch.append(1)
        X_batch.append(download_npy(c_file))
        y_batch.append(0)
    
      # Shuffle the current batch
    X_batch, y_batch = shuffle(np.array(X_batch), np.array(y_batch), random_state=batch_id)

    # 80-20 train-test split
    split_idx = int(0.8 * len(X_batch))
    X_train, X_test = X_batch[:split_idx], X_batch[split_idx:]
    y_train, y_test = y_batch[:split_idx], y_batch[split_idx:]

    # Augment only training data
    X_aug, y_aug = [], []
    for x, y in zip(X_train, y_train):
        for _ in range(AUGMENT_RATIO):
            X_aug.append(augment(image=x)["image"])
            y_aug.append(y)

    # Combine original + augmented training data
    X_train_final = np.concatenate([X_train, X_aug])
    y_train_final = np.concatenate([y_train, y_aug])

    # Upload to GCS
    upload_npy(X_train_final, "train", f"images_batch_{batch_id // BATCH_SIZE}.npy")
    upload_npy(y_train_final, "train", f"labels_batch_{batch_id // BATCH_SIZE}.npy")
    upload_npy(X_test, "test", f"images_batch_{batch_id // BATCH_SIZE}.npy")
    upload_npy(y_test, "test", f"labels_batch_{batch_id // BATCH_SIZE}.npy")

    # Free memory
    del X_batch, y_batch, X_train, X_test, y_train, y_test, X_aug, y_aug, X_train_final, y_train_final
    gc.collect()

print("🎉 All batches processed and uploaded with unique and balanced train-test splits.")

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                  ┃ Output Shape              ┃         Param # ┃ Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer (InputLayer)      │ (None, 80, 128, 128, 1)   │               0 │ -                          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ conv3d (Conv3D)               │ (None, 80, 128, 128, 16)  │             448 │ input_layer[0][0]          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ batch_normalization           │ (None, 80, 128, 128, 16)  │              64 │ conv3d[0][0]               │
│ (BatchNormalization)          │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ activation (Activation)       │ (None, 80, 128, 128, 16)  │               0 │ batch_normalization[0][0]  │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ conv3d_1 (Conv3D)             │ (None, 80, 128, 128, 16)  │           6,928 │ activation[0][0]           │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ batch_normalization_1         │ (None, 80, 128, 128, 16)  │              64 │ conv3d_1[0][0]             │
│ (BatchNormalization)          │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ activation_1 (Activation)     │ (None, 80, 128, 128, 16)  │               0 │ batch_normalization_1[0][… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ conv3d_2 (Conv3D)             │ (None, 80, 128, 128, 16)  │           6,928 │ activation_1[0][0]         │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ batch_normalization_2         │ (None, 80, 128, 128, 16)  │              64 │ conv3d_2[0][0]             │
│ (BatchNormalization)          │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ add (Add)                     │ (None, 80, 128, 128, 16)  │               0 │ activation[0][0],          │
│                               │                           │                 │ batch_normalization_2[0][… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ activation_2 (Activation)     │ (None, 80, 128, 128, 16)  │               0 │ add[0][0]                  │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ max_pooling3d (MaxPooling3D)  │ (None, 40, 64, 64, 16)    │               0 │ activation_2[0][0]         │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ conv3d_3 (Conv3D)             │ (None, 40, 64, 64, 32)    │          13,856 │ max_pooling3d[0][0]        │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ batch_normalization_3         │ (None, 40, 64, 64, 32)    │             128 │ conv3d_3[0][0]             │
│ (BatchNormalization)          │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ activation_3 (Activation)     │ (None, 40, 64, 64, 32)    │               0 │ batch_normalization_3[0][… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ conv3d_5 (Conv3D)             │ (None, 40, 64, 64, 32)    │             544 │ max_pooling3d[0][0]        │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ conv3d_4 (Conv3D)             │ (None, 40, 64, 64, 32)    │          27,680 │ activation_3[0][0]         │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ batch_normalization_5         │ (None, 40, 64, 64, 32)    │             128 │ conv3d_5[0][0]             │
│ (BatchNormalization)          │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ batch_normalization_4         │ (None, 40, 64, 64, 32)    │             128 │ conv3d_4[0][0]             │
│ (BatchNormalization)          │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ add_1 (Add)                   │ (None, 40, 64, 64, 32)    │               0 │ batch_normalization_5[0][… │
│                               │                           │                 │ batch_normalization_4[0][… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ activation_4 (Activation)     │ (None, 40, 64, 64, 32)    │               0 │ add_1[0][0]                │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ max_pooling3d_1               │ (None, 20, 32, 32, 32)    │               0 │ activation_4[0][0]         │
│ (MaxPooling3D)                │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ conv3d_6 (Conv3D)             │ (None, 20, 32, 32, 64)    │          55,360 │ max_pooling3d_1[0][0]      │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ batch_normalization_6         │ (None, 20, 32, 32, 64)    │             256 │ conv3d_6[0][0]             │
│ (BatchNormalization)          │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ activation_5 (Activation)     │ (None, 20, 32, 32, 64)    │               0 │ batch_normalization_6[0][… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ conv3d_8 (Conv3D)             │ (None, 20, 32, 32, 64)    │           2,112 │ max_pooling3d_1[0][0]      │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ conv3d_7 (Conv3D)             │ (None, 20, 32, 32, 64)    │         110,656 │ activation_5[0][0]         │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ batch_normalization_8         │ (None, 20, 32, 32, 64)    │             256 │ conv3d_8[0][0]             │
│ (BatchNormalization)          │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ batch_normalization_7         │ (None, 20, 32, 32, 64)    │             256 │ conv3d_7[0][0]             │
│ (BatchNormalization)          │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ add_2 (Add)                   │ (None, 20, 32, 32, 64)    │               0 │ batch_normalization_8[0][… │
│                               │                           │                 │ batch_normalization_7[0][… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ activation_6 (Activation)     │ (None, 20, 32, 32, 64)    │               0 │ add_2[0][0]                │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ max_pooling3d_2               │ (None, 10, 16, 16, 64)    │               0 │ activation_6[0][0]         │
│ (MaxPooling3D)                │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ global_average_pooling3d      │ (None, 64)                │               0 │ max_pooling3d_2[0][0]      │
│ (GlobalAveragePooling3D)      │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ dense (Dense)                 │ (None, 64)                │           4,160 │ global_average_pooling3d[… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ dropout (Dropout)             │ (None, 64)                │               0 │ dense[0][0]                │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ dense_1 (Dense)               │ (None, 1)                 │              65 │ dropout[0][0]              │
└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
 Total params: 688,901 (2.63 MB)
 Trainable params: 229,409 (896.13 KB)
 Non-trainable params: 672 (2.62 KB)
 Optimizer params: 458,820 (1.75 MB)