In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install torchio

In [None]:
import os
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import random
from scipy.ndimage import rotate, zoom, gaussian_filter

def normalize_intensity(image):
    mask = image != 0
    if np.sum(mask) == 0:
        return image
    mean = np.mean(image[mask])
    std = np.std(image[mask])
    std = std if std > 0 else 1e-8
    return (image - mean) / std

def load_images_with_lesions_from_root(root_dir):
    data_with_lesions, empty_masks, valid_cases = [], 0, 0
    for folder in os.scandir(root_dir):
        if not folder.is_dir():
            continue
        image_path = os.path.join(folder.path, 'flair_time02_registered_to_time01.nii.gz')
        mask_path = os.path.join(folder.path, 'mask_time02_registered_to_time01.nii.gz')
        if not os.path.exists(image_path) or not os.path.exists(mask_path):
            print(f"Sar peste {folder.name}, lipsesc fișierele necesare.")
            continue
        img = normalize_intensity(nib.load(image_path).get_fdata())
        mask = nib.load(mask_path).get_fdata()
        if np.sum(mask) > 0:
            data_with_lesions.append((img, mask, folder.name))
            valid_cases += 1
        else:
            empty_masks += 1
    print(f"✅ Cazuri CU leziuni: {valid_cases}")
    print(f"❌ Cazuri FĂRĂ leziuni: {empty_masks}")
    return data_with_lesions

def show_middle_slice(image, mask, axis=2):
    mid = image.shape[axis] // 2
    if axis == 0:
        img_slice, mask_slice = image[mid,:,:], mask[mid,:,:]
    elif axis == 1:
        img_slice, mask_slice = image[:,mid,:], mask[:,mid,:]
    else:
        img_slice, mask_slice = image[:,:,mid], mask[:,:,mid]
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.title("Axial - Image")
    plt.imshow(img_slice.T, cmap="gray", origin="lower")
    plt.subplot(1, 2, 2)
    plt.title("Axial - Image+Mask")
    plt.imshow(img_slice.T, cmap="gray", origin="lower")
    plt.imshow(mask_slice.T, cmap="Reds", alpha=0.4, origin="lower")
    plt.show()

def augment_patch(img_patch, mask_patch):
    if np.std(img_patch) < 0.01:
        return img_patch, mask_patch
    axes = [(0,1), (1,2), (0,2)]
    for ax in axes:
        if random.random() < 0.5:
            angle = random.choice([90, 180, 270])
            img_patch = rotate(img_patch, angle, axes=ax, reshape=False, order=1)
            mask_patch = rotate(mask_patch, angle, axes=ax, reshape=False, order=0)
    for axis in range(3):
        if random.random() < 0.5:
            img_patch = np.flip(img_patch, axis=axis).copy()
            mask_patch = np.flip(mask_patch, axis=axis).copy()
    if random.random() < 0.5:
        img_patch += np.random.normal(0, 0.05, img_patch.shape)
    if random.random() < 0.5:
        img_patch *= random.uniform(0.9, 1.1)
    if random.random() < 0.3:
        shift = np.random.randint(-2, 3, size=3)
        img_patch = np.roll(img_patch, shift=tuple(shift), axis=(0,1,2))
        mask_patch = np.roll(mask_patch, shift=tuple(shift), axis=(0,1,2))
    img_patch = normalize_intensity(img_patch)
    img_patch = np.clip(img_patch, -3, 3)
    return img_patch, mask_patch

def axial_subsampling(image, factor=2):
    d = zoom(image, (1,1,1/factor), order=1)
    return zoom(d, np.array(image.shape) / np.array(d.shape), order=1)

def gin_augmentation(image, alpha=0.8, sigma=4):
    field = gaussian_filter(np.random.randn(*image.shape), sigma=sigma)
    return normalize_intensity(image + alpha * field)

def extract_lesions(mask):
    return np.argwhere(mask > 0)

def apply_carvemix(src_img, src_mask, tgt_img, tgt_mask, num_lesions=3):
    voxels = extract_lesions(src_mask)
    if len(voxels) == 0:
        return tgt_img, tgt_mask
    for _ in range(min(num_lesions, len(voxels))):
        c = voxels[np.random.randint(len(voxels))]
        size = np.array([16,16,16])
        half = size // 2
        s, e = c - half, c + half
        if np.any(s < 0) or np.any(e >= src_img.shape):
            continue
        src_p, src_m = src_img[s[0]:e[0], s[1]:e[1], s[2]:e[2]], src_mask[s[0]:e[0], s[1]:e[1], s[2]:e[2]]
        tc = np.array([np.random.randint(half[i], tgt_img.shape[i] - half[i]) for i in range(3)])
        ts, te = tc - half, tc + half
        if np.any(ts < 0) or np.any(te >= tgt_img.shape):
            continue
        tr = tgt_img[ts[0]:te[0], ts[1]:te[1], ts[2]:te[2]]
        if np.abs(np.mean(tr) - np.mean(src_p)) > 1.5:
            continue
        w = 0.5
        tgt_img[ts[0]:te[0], ts[1]:te[1], ts[2]:te[2]] = tr * (1 - w) + src_p * w
        tgt_mask[ts[0]:te[0], ts[1]:te[1], ts[2]:te[2]] = np.maximum(tgt_mask[ts[0]:te[0], ts[1]:te[1], ts[2]:te[2]], src_m)
    return tgt_img, tgt_mask

def augment_patch_advanced(img_patch, mask_patch, all_data=None):
    # ✅ Evităm patch-uri complet plate de la început
    if np.std(img_patch) < 0.008 or (np.max(img_patch) - np.min(img_patch)) < 0.04:
        return img_patch, mask_patch

    # Aplicăm augmentările de bază
    img_patch, mask_patch = augment_patch(img_patch, mask_patch)

    # Aplicăm fie subsampling, fie GIN (dar nu ambele simultan)
    if random.random() < 0.5:
        if random.random() < 0.5:
            img_patch = axial_subsampling(img_patch)
        else:
            img_patch = gin_augmentation(img_patch)

    # Aplicăm CarveMix DOAR dacă patch-ul este negativ și există surse disponibile
    if all_data is not None and np.sum(mask_patch) == 0 and random.random() < 0.5:
        other_case = random.choice(all_data)
        img_patch, mask_patch = apply_carvemix(other_case[0], other_case[1], img_patch, mask_patch)

    # ✅ Validare finală — doar patch-uri cu conținut real
    if np.std(img_patch) < 0.008 or (np.max(img_patch) - np.min(img_patch)) < 0.04:
        return img_patch, mask_patch  # patch slab, nu-l includem

    return img_patch, mask_patch

def extract_patches_from_case_balanced(image, mask, patch_size=(32,32,32), max_patches=20):
    patches = []
    half = np.array(patch_size) // 2
    shape = np.array(image.shape)

    # 90% CU leziuni, 10% FĂRĂ leziuni
    num_pos = int(0.9 * max_patches)
    num_neg = max_patches - num_pos

    # === Patch-uri CU leziuni ===
    pos = np.argwhere(mask > 0)
    np.random.shuffle(pos)
    pc = 0
    for c in pos:
        s, e = c - half, c + half
        if np.any(s < 0) or np.any(e >= shape):
            continue
        ip = image[s[0]:e[0], s[1]:e[1], s[2]:e[2]]
        mp = mask[s[0]:e[0], s[1]:e[1], s[2]:e[2]]
        patches.append((ip, mp))
        pc += 1
        if pc >= num_pos:
            break

    # === Patch-uri FĂRĂ leziuni ===
    attempts, nc = 0, 0
    max_attempts = num_neg * 10
    while nc < num_neg and attempts < max_attempts:
        rc = np.array([np.random.randint(half[i], shape[i] - half[i]) for i in range(3)])
        s, e = rc - half, rc + half
        mp = mask[s[0]:e[0], s[1]:e[1], s[2]:e[2]]
        if np.sum(mp) == 0:
            ip = image[s[0]:e[0], s[1]:e[1], s[2]:e[2]]
            if np.std(ip) > 0.01:
                patches.append((ip, mp))
                nc += 1
        attempts += 1

    return patches

root_dir = '/content/drive/MyDrive/training_small/'
data_with_lesions = load_images_with_lesions_from_root(root_dir)
all_patches, lesion_voxel_counts = [], []
augmentations_per_patch = 30

all_lesions_pool = [
    (ip.copy(), mp.copy())
    for img, mask, _ in data_with_lesions
    for ip, mp in extract_patches_from_case_balanced(img, mask, max_patches=30)
    if np.sum(mp) > 0
]

for img, mask, name in data_with_lesions:
    print(f"{name} — Media: {np.mean(img):.4f}, STD: {np.std(img):.4f}")
    patches = extract_patches_from_case_balanced(img, mask, max_patches=40)
    for img_patch, mask_patch in patches:
        all_patches.append((img_patch, mask_patch))
        lesion_voxel_counts.append(np.sum(mask_patch))

        for _ in range(augmentations_per_patch):
            aug_img, aug_mask = augment_patch_advanced(
                img_patch.copy(),
                mask_patch.copy(),
                all_lesions_pool
            )

            # 🔒 FILTRU PENTRU ZGOMOT HAOTIC FĂRĂ STRUCTURĂ
            if np.std(aug_img) < 0.008 or (np.max(aug_img) - np.min(aug_img)) < 0.04:
                continue  # doar cele mai extreme cazuri complet plate

            all_patches.append((aug_img, aug_mask))
            lesion_voxel_counts.append(np.sum(aug_mask))

print(f"Total patch-uri extrase + augmentate: {len(all_patches)}")
print("\U0001f522 Voxeli pozitivi în mască:")
print(f"   Minim: {np.min(lesion_voxel_counts)}")
print(f"   Maxim: {np.max(lesion_voxel_counts)}")
print(f"   Medie: {np.mean(lesion_voxel_counts):.2f}")