In [None]:
!export PATH="${HOME}/.local/bin:${PATH}" && uv pip uninstall --system --quiet jax tensorflow tensorflow-tpu

In [None]:
!export PATH="${HOME}/.local/bin:${PATH}" && uv pip install --system --quiet tensorflow=="2.18.0"
!export PATH="${HOME}/.local/bin:${PATH}" && uv pip install --system --quiet imagehash
!export PATH="${HOME}/.local/bin:${PATH}" && uv pip install --system --quiet tensorflow-tpu=="2.18.0" --find-links https://storage.googleapis.com/libtpu-tf-releases/index.html

In [None]:
import os
import cv2
import gc
import glob
import json
import shutil
import imagehash
import numpy as np
import pandas as pd
import tensorflow as tf
import concurrent
import albumentations as A 
import xml.etree.ElementTree as ET
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import precision_recall_curve
from concurrent.futures import ProcessPoolExecutor
from tensorflow.keras import layers, models, backend as K
from sklearn.model_selection import StratifiedGroupKFold
# ==========================================
# 0. DỌN DẸP CACHE
# ==========================================
for f in glob.glob('/kaggle/working/*.lockfile') + glob.glob('/kaggle/working/*cache*'):
    try:
        if os.path.isdir(f): shutil.rmtree(f)
        else: os.remove(f)
    except: pass
# ==========================================
# 1. CẤU HÌNH TPU/GPU
# ==========================================
def get_strategy():
    try:
        if 'TPU_NAME' not in os.environ and 'TPU_ACCELERATOR_TYPE' not in os.environ:
             raise RuntimeError("No TPU environment variables found.")
        
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.TPUStrategy(tpu)
        print("Running on TPU")
        return strategy
        
    except Exception as e:
        print(f"TPU init failed or not detected: {type(e).__name__}. Falling back to GPU/CPU...")
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)

            if len(gpus) > 1:
                strategy = tf.distribute.MirroredStrategy()
                print(f"Running on {len(gpus)} GPUs (MirroredStrategy)")
            else:
                strategy = tf.distribute.OneDeviceStrategy(device="/GPU:0")
                print("Running on single GPU")

            return strategy

        except RuntimeError as e:
            print("GPU init failed:", e)
    print("Running on CPU")
    return tf.distribute.get_strategy()

strategy = get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)

In [None]:
# =================================
# 2. DATA CLEANING & PRE-PROCESSING
# =================================
USNS_DATA_DIR = '/kaggle/input/ultrasound-nerve-segmentation/train'
BUSI_DATA_DIR = '/kaggle/input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT'
CAMUS_DATA_DIR = '/kaggle/input/camus-echocardiography-image-dataset'
DDTI_DATA_DIR = '/kaggle/input/ddti-thyroid-ultrasound-images'
THYROID_RAW_DIR = '/kaggle/input/ultrasounddataset-andmat/Thyroid Dataset'

WORKING_DIR = '/kaggle/working/data_png'
GEN_MASKS_DIR = '/kaggle/working/generated_masks'
PROCESSED_USNS_DIR = '/kaggle/working/processed_usns'
PROCESSED_THYROID_DIR = '/kaggle/working/processed_thyroid'
PROCESSED_OTHERS_DIR = '/kaggle/working/processed_others'

TARGET_HEIGHT = 512
TARGET_WIDTH = 512

for d in [WORKING_DIR, GEN_MASKS_DIR, PROCESSED_USNS_DIR, PROCESSED_THYROID_DIR, PROCESSED_OTHERS_DIR]:
    if not os.path.exists(d): os.makedirs(d)

def get_patient_id(file_path, source):
    """
    Trích xuất Patient ID từ đường dẫn file dựa trên quy tắc từng dataset.
    """
    filename = os.path.basename(file_path)
    
    if source == 'USNS':
        return f"usns_{filename.split('_')[0]}"
    
    elif source == 'DDTI':
        try:
            pid = filename.split('_')[1].split('.')[0]
            return f"ddti_{pid}"
        except:
            return f"ddti_{filename}"
            
    elif source == 'CAMUS':
        parent_dir = os.path.basename(os.path.dirname(file_path))
        return f"camus_{parent_dir}"
        
    elif source == 'BUSI':
        return f"busi_{filename.replace('.png', '')}"
    
    return "unknown"

def create_meta_dataframe(img_paths, mask_paths):
    data = []
    
    print("Đang tạo bảng Meta-data...")
    for img, msk in zip(img_paths, mask_paths):
        if 'BUSI' in img: source = 'BUSI'
        elif 'DDTI' in img: source = 'DDTI'
        elif 'CAMUS' in img: source = 'CAMUS'
        else: source = 'USNS'
        
        pid = get_patient_id(img, source)
        
        data.append({
            'image_path': img,
            'mask_path': msk,
            'source': source,
            'patient_id': pid,
            'hash': None,
            'has_nerve': -1
        })
        
    df = pd.DataFrame(data)
    return df

def process_metadata(df):
    """
    Tính pHash để xóa trùng và Check mask để dán nhãn phân lớp.
    """
    hashes = []
    labels = []
    
    print(f"Hashing & Labeling...")
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        try:
            img = Image.open(row['image_path'])
            h = imagehash.phash(img)
            hashes.append(str(h))
        except:
            hashes.append("error")
        try:
            msk = cv2.imread(row['mask_path'], 0)
            if msk is None:
                labels.append(0)
            else:
                labels.append(1 if np.max(msk) > 0 else 0)
        except:
            labels.append(0)
    
    df['hash'] = hashes
    df['has_nerve'] = labels
    return df

def remove_duplicates(df, threshold=5):
    print(f"Before: {len(df)} samples")

    df_clean = df.drop_duplicates(subset=['hash'], keep='first')

    df_clean = df_clean[df_clean['hash'] != "error"]
    
    print(f"After: {len(df_clean)} samples")
    return df_clean

def split_data(df, n_splits=5):
    df = df.reset_index(drop=True)
    
    sgkf = StratifiedGroupKFold(n_splits=n_splits)

    df['fold'] = -1

    X = df['image_path']
    y = df['has_nerve']
    groups = df['patient_id']
    
    print("Fold Spliting...")
    for fold_idx, (train_idx, val_idx) in enumerate(sgkf.split(X, y, groups)):
        df.loc[val_idx, 'fold'] = fold_idx
        train_y = y.iloc[train_idx]
        val_y = y.iloc[val_idx]
        print(f"Fold {fold_idx}:")
        print(f"  Train: {len(train_idx)} (Pos: {train_y.sum()}, Neg: {len(train_y)-train_y.sum()})")
        print(f"  Val  : {len(val_idx)} (Pos: {val_y.sum()}, Neg: {len(val_y)-val_y.sum()})")
        
    return df

offline_aug = A.Compose([
    A.OneOf([
        A.ElasticTransform(alpha=70, sigma=120 * 0.05, p=1.0),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=1.0),
        A.OpticalDistortion(distort_limit=0.3, p=1.0),
        A.Affine(scale=(0.9, 1.1), translate_percent=(-0.05, 0.05), rotate=(-10, 10), p=1.0),
    ], p=0.8),
    A.OneOf([
        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=1.0),
        A.MultiplicativeNoise(multiplier=(0.9, 1.1), p=0.5, elementwise=True),                  
    ], p=0.4)
], p=1.0)

def generate_ddti_masks(ddti_root_dir, output_mask_dir):
    if not os.path.exists(output_mask_dir):
        os.makedirs(output_mask_dir)
    xml_files = glob.glob(os.path.join(ddti_root_dir, "*.xml"))
    print(f"{len(xml_files)} file XML. Preprocessing...")

    success_count = 0

    for xml_path in tqdm(xml_files):
        try:
            tree = ET.parse(xml_path)
            root = tree.getroot()
            number_tag = root.find('number')
            if number_tag is None:
                continue

            patient_id = number_tag.text
            image_annotations = {}

            for mark in root.findall('mark'):
                img_idx = mark.find('image').text
                svg_tag = mark.find('svg')

                if img_idx and svg_tag is not None and svg_tag.text:
                    try:
                        shapes = json.loads(svg_tag.text)
                        if img_idx not in image_annotations:
                            image_annotations[img_idx] = []

                        for shape in shapes:
                            if "points" in shape:
                                pts = np.array([[p['x'], p['y']] for p in shape["points"]], dtype=np.int32)
                                pts = pts.reshape((-1, 1, 2))
                                image_annotations[img_idx].append(pts)
                    except: pass
            potential_images = glob.glob(os.path.join(ddti_root_dir, f"{patient_id}_*.*"))

            for img_path in potential_images:
                filename = os.path.basename(img_path)
                if filename.endswith('.xml'): continue
                name_no_ext = os.path.splitext(filename)[0]
                parts = name_no_ext.split('_')
                if len(parts) < 2: continue

                current_img_idx = parts[-1]
                img = cv2.imread(img_path)
                if img is None: continue
                h, w = img.shape[:2]

                mask = np.zeros((h, w), dtype=np.uint8)
                if current_img_idx in image_annotations:
                    for pts in image_annotations[current_img_idx]:
                        cv2.fillPoly(mask, [pts], color=255)
                mask_filename = filename.rsplit('.', 1)[0] + "_mask.png"
                save_path = os.path.join(output_mask_dir, mask_filename)
                cv2.imwrite(save_path, mask)
                success_count += 1

        except Exception as e:
            print(f"Lỗi xử lý {xml_path}: {e}")

def offline_aug_fn(image, mask):
    return offline_aug(image=image, mask=mask)

def get_patient_id(file_path, source):
    filename = os.path.basename(file_path)
    
    if source == 'USNS':
        return f"usns_{filename.split('_')[0]}" 
        
    elif source == 'DDTI':
        try: return f"ddti_{filename.split('_')[0]}"
        except: return f"ddti_{filename}"
        
    elif source == 'CAMUS':
        return f"camus_{filename.split('_')[0]}"
        
    elif source == 'BUSI':
        return f"busi_{filename.split(' ')[0]}" 

    elif source == 'THYROID_EXT':
        return f"thyroid_{filename.replace('.png', '')}"

    return "unknown"

def resize_and_pad_offline(image, target_h, target_w, interpolation=cv2.INTER_LINEAR):
    h, w = image.shape[:2]
    
    scale = min(target_h / h, target_w / w)
    new_h, new_w = int(h * scale), int(w * scale)

    resized = cv2.resize(image, (new_w, new_h), interpolation=interpolation)

    if len(image.shape) == 3:
        canvas = np.zeros((target_h, target_w, 3), dtype=np.uint8)
    else:
        canvas = np.zeros((target_h, target_w), dtype=np.uint8)

    pad_top = (target_h - new_h) // 2
    pad_left = (target_w - new_w) // 2

    if len(image.shape) == 3:
        canvas[pad_top:pad_top+new_h, pad_left:pad_left+new_w, :] = resized
    else:
        canvas[pad_top:pad_top+new_h, pad_left:pad_left+new_w] = resized
        
    return canvas

def process_single_augment_task(args):
    img_path, mask_path, output_dir, num_aug, prefix = args
    
    try:
        image = cv2.imread(img_path)
        mask = cv2.imread(mask_path, 0)
        if image is None or mask is None: return

        image = resize_and_pad_offline(image, TARGET_HEIGHT, TARGET_WIDTH, cv2.INTER_CUBIC)
        mask = resize_and_pad_offline(mask, TARGET_HEIGHT, TARGET_WIDTH, cv2.INTER_NEAREST)

        base_name = os.path.splitext(os.path.basename(img_path))[0]
        if prefix: base_name = f"{prefix}_{base_name}"

        cv2.imwrite(os.path.join(output_dir, f"{base_name}.png"), image)
        cv2.imwrite(os.path.join(output_dir, f"{base_name}_mask.png"), mask)

        if num_aug > 0 and np.sum(mask) > 0:
            for i in range(num_aug):
                aug = offline_aug_fn(image, mask)
                cv2.imwrite(os.path.join(output_dir, f"{base_name}_aug_{i}.png"), aug['image'])
                cv2.imwrite(os.path.join(output_dir, f"{base_name}_aug_{i}_mask.png"), aug['mask'])
    except Exception as e:
        print(f"Error processing {img_path}: {e}")

def prep_usns(raw_dir, out_dir):
    mask_files = glob.glob(os.path.join(raw_dir, "*_mask.tif"))
    mask_files.sort() 
    
    tasks = []
    for m_path in mask_files:
        img_path = m_path.replace('_mask.tif', '.tif')
        if os.path.exists(img_path):
            tasks.append((img_path, m_path, out_dir, 0, ""))

    print(f"Processing USNS: {len(mask_files)} tasks...")
    return tasks

def prep_thyroid_merged(root_dir, ddti_img_dir, ddti_mask_dir, out_dir):
    tasks = []
    subsets = [
        (os.path.join(root_dir, "tg3k/thyroid-image"), os.path.join(root_dir, "tg3k/thyroid-mask"), "tg3k"),
        (os.path.join(root_dir, "tn3k/trainval-image"), os.path.join(root_dir, "tn3k/trainval-mask"), "tn3k_tr"),
        (os.path.join(root_dir, "tn3k/test-image"), os.path.join(root_dir, "tn3k/test-mask"), "tn3k_ts")
    ]
    
    for img_d, msk_d, prefix in subsets:
        if not os.path.exists(img_d): continue
        imgs = glob.glob(os.path.join(img_d, "*"))
        for i_path in imgs:
            fname = os.path.basename(i_path).split('.')[0]
            m_path = None
            for ext in ['.png', '.jpg', '_mask.png', '_mask.jpg']:
                potential = os.path.join(msk_d, fname + ext)
                if os.path.exists(potential):
                    m_path = potential
                    break
            
            if m_path:
                tasks.append((i_path, m_path, out_dir, 0, prefix))

    ddti_imgs = glob.glob(os.path.join(ddti_img_dir, "*.png"))
    for i_path in ddti_imgs:
        fname = os.path.basename(i_path).split('.')[0]
        m_path = os.path.join(ddti_mask_dir, fname + "_mask.png")
        if os.path.exists(m_path):
            tasks.append((i_path, m_path, out_dir, 2, "ddti"))

    print(f"Processing Thyroid Merged: {len(tasks)} tasks...")
    return tasks

def prep_others(busi_dir, camus_dir, out_dir):
    tasks = []
    NUM_AUG = 4 

    for sub in ['benign', 'malignant']:
        d = os.path.join(busi_dir, sub)
        if os.path.exists(d):
            imgs = [f for f in os.listdir(d) if 'mask' not in f and f.endswith('.png')]
            for f in imgs:
                i_path = os.path.join(d, f)
                m_path = os.path.join(d, f.replace('.png', '_mask.png'))
                if os.path.exists(m_path):
                    tasks.append((i_path, m_path, out_dir, NUM_AUG, "busi"))

    camus_f = os.path.join(camus_dir, 'frames')
    camus_m = os.path.join(camus_dir, 'masks')
    if os.path.exists(camus_f):
        frames = glob.glob(os.path.join(camus_f, "*.png"))
        for i_path in frames:
            fname = os.path.basename(i_path)
            m_path = os.path.join(camus_m, fname.replace('frame_', 'mask_'))
            if os.path.exists(m_path):
                tasks.append((i_path, m_path, out_dir, NUM_AUG, "camus"))
                
    print(f"Processing Others (BUSI/CAMUS): {len(tasks)} tasks...")
    return tasks

def main_processing():
    generate_ddti_masks(DDTI_DATA_DIR, GEN_MASKS_DIR)
    all_tasks = []

    all_tasks.extend(prep_usns(USNS_DATA_DIR, PROCESSED_USNS_DIR))
    all_tasks.extend(prep_thyroid_merged(THYROID_RAW_DIR, DDTI_DATA_DIR, GEN_MASKS_DIR, PROCESSED_THYROID_DIR))
    all_tasks.extend(prep_others(BUSI_DATA_DIR, CAMUS_DATA_DIR, PROCESSED_OTHERS_DIR))
    
    print(f"Total task: {len(all_tasks)}")

    with ProcessPoolExecutor(max_workers=4) as executor:
        list(tqdm(executor.map(process_single_augment_task, all_tasks), total=len(all_tasks)))

    data_entries = []
    
    for folder, src_name in [(PROCESSED_USNS_DIR, 'USNS'), 
                             (PROCESSED_THYROID_DIR, 'THYROID_EXT'), 
                             (PROCESSED_OTHERS_DIR, 'OTHERS')]:
        
        mask_files = glob.glob(os.path.join(folder, "*_mask.png"))
        for m_path in mask_files:
            i_path = m_path.replace('_mask.png', '.png')
            if not os.path.exists(i_path): continue

            real_source = src_name
            fname = os.path.basename(i_path)
            if 'busi' in fname: real_source = 'BUSI'
            elif 'camus' in fname: real_source = 'CAMUS'
            elif 'tn3k' in fname or 'tg3k' in fname: real_source = 'THYROID_EXT'
            
            pid = get_patient_id(i_path, real_source)
            is_aug = '_aug_' in fname
            
            data_entries.append({
                'image_path': i_path,
                'mask_path': m_path,
                'source': real_source,
                'patient_id': pid,
                'is_augmented': is_aug
            })
            
    df = pd.DataFrame(data_entries)

    df = process_metadata(df)
    df = remove_duplicates(df)

    df = split_data(df, n_splits=5)
    
    df.to_csv('unified_medical_data_5folds.csv', index=False)

main_processing()

In [None]:
# ==========================
# 3. NATIVE TF DATA PIPELINE
# ==========================
BATCH_SIZE_PER_REPLICA = 8
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 50
LEARNING_RATE = 1e-4 * (BATCH_SIZE_PER_REPLICA / 8) * strategy.num_replicas_in_sync
df = pd.read_csv('/kaggle/working/unified_medical_data_5folds.csv')

fold_idx = 0
df_train = df[df['fold'] != fold_idx]
df_val = df[(df['fold'] == fold_idx) & (df['is_augmented'] == False)]

train_imgs = df_train['image_path'].tolist()
train_masks = df_train['mask_path'].tolist()

val_imgs = df_val['image_path'].tolist()
val_masks = df_val['mask_path'].tolist()

def load_raw_data(img_path, mask_path):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_png(img, channels=3)

    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=1)
    
    if img.dtype != tf.uint8:
        img = tf.cast(tf.clip_by_value(img, 0, 255), tf.uint8)

    mask = tf.cast(mask > 0, tf.uint8)

    return img.numpy(), mask.numpy()

def load_dataset_into_ram(img_paths, mask_paths):
    X_data = np.zeros((len(img_paths), TARGET_HEIGHT, TARGET_WIDTH, 3), dtype=np.uint8)
    y_data = np.zeros((len(img_paths), TARGET_HEIGHT, TARGET_WIDTH, 1), dtype=np.uint8)
    def load_single(idx, p, m):
        try:
            img_raw, mask_raw = load_raw_data(p, m)
            if len(mask_raw.shape) == 2:
                mask_raw = np.expand_dims(mask_raw, axis=-1)
            if len(img_raw.shape) == 2:
                 img_raw = np.stack((img_raw,)*3, axis=-1)

            return idx, img_raw, mask_raw
        except Exception as e:
            print(f"Lỗi ảnh {p}: {e}")
            return None

    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
        futures = [executor.submit(load_single, i, p, m)
                   for i, (p, m) in enumerate(zip(img_paths, mask_paths))]

        for f in tqdm(concurrent.futures.as_completed(futures), total=len(img_paths)):
            result = f.result()
            if result is not None:
                idx, img, msk = result
                X_data[idx] = img
                y_data[idx] = msk

    print(f"RAM: {X_data.nbytes / 1e9:.2f} GB")
    return X_data, y_data

X_train, y_train = load_dataset_into_ram(train_imgs, train_masks)
X_val, y_val = load_dataset_into_ram(val_imgs, val_masks)

def get_unified_dataset(img, mask, batch_size=BATCH_SIZE, training=True):
    dataset = tf.data.Dataset.from_tensor_slices((img, mask))
    dataset = dataset.cache()
    if training:
        dataset = dataset.shuffle(buffer_size=len(img))
        dataset = dataset.repeat()
        dataset = dataset.batch(batch_size, drop_remainder=True)
    else:
        dataset = dataset.batch(max(1, batch_size // 2), drop_remainder=True)

    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA 
    dataset = dataset.with_options(options)
    return dataset

train_ds = get_unified_dataset(X_train, y_train, training=True)
val_ds = get_unified_dataset(X_val, y_val, training=False)

In [None]:
# ==============================================
# 3. MODEL ARGA MULTI-TASK VỚI BONDARY LEARNING
# ==============================================
@tf.keras.utils.register_keras_serializable()
class MedicalPreprocessingLayer(tf.keras.layers.Layer):
    def __init__(self, target_height=512, target_width=512, **kwargs):
        super().__init__(**kwargs)
        self.target_height = target_height
        self.target_width = target_width
        self.kernel_size = 3

    def call(self, inputs, training=True):
        if isinstance(inputs, (list, tuple)):
            image, mask = inputs
        else:
            image = inputs
            mask = None
            
        image = tf.cast(image, tf.float32) / 255.0
        if mask is not None:
            mask = tf.cast(mask, tf.float32)

        if training and mask is not None:
            combined = tf.concat([image, mask], axis=-1)
            PAD = 40
            combined = tf.image.pad_to_bounding_box(
                combined, PAD, PAD,
                self.target_height + 2*PAD, self.target_width + 2*PAD
            )
            
            shape_tensor = tf.shape(image)
            if len(image.shape) == 4:
                crop_size = [shape_tensor[0], self.target_height, self.target_width, 4]
            else:
                crop_size = [self.target_height, self.target_width, 4]
                
            combined = tf.image.random_crop(combined, crop_size)
            combined = tf.image.random_flip_left_right(combined)
            
            image = combined[..., :3]
            mask = combined[..., 3:]

            image = tf.image.random_brightness(image, max_delta=0.1)
            image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
            image = tf.clip_by_value(image, 0.0, 1.0)

        if mask is not None:
            mask_for_pool = mask
            if len(mask.shape) == 3:
                mask_for_pool = tf.expand_dims(mask, 0)

            dilated = tf.nn.max_pool2d(mask_for_pool, ksize=self.kernel_size, strides=1, padding='SAME')
            eroded = -tf.nn.max_pool2d(-mask_for_pool, ksize=self.kernel_size, strides=1, padding='SAME')
            edge = dilated - eroded

            if len(mask.shape) == 3:
                edge = tf.squeeze(edge, axis=0)

            if len(mask.shape) == 4:
                max_vals = tf.reduce_max(mask, axis=[1, 2, 3]) 
                has_object = max_vals > 0.0
                label = tf.cast(has_object, tf.float32)
                label = tf.reshape(label, [-1, 1])
            else:
                has_object = tf.reduce_max(mask) > 0.0
                label = tf.cast(has_object, tf.float32)
                label = tf.reshape(label, [1])

            return image, {'seg_out': mask, 'edge_out': edge, 'cls_out': label}

        return image

    def get_config(self):
        config = super().get_config()
        config.update({
            "target_height": self.target_height,
            "target_width": self.target_width
        })
        return config
        
class InstanceNormalization(layers.Layer):
    """
    Instance Normalization Layer (Custom Implementation)
    """
    def __init__(self, epsilon=1e-7, **kwargs):
        super(InstanceNormalization, self).__init__(**kwargs)
        self.epsilon = epsilon

    def build(self, input_shape):
        dim = input_shape[-1]
        self.gamma = self.add_weight(name='gamma', 
                                     shape=(dim,), 
                                     initializer='ones', 
                                     trainable=True)
        self.beta = self.add_weight(name='beta', 
                                    shape=(dim,), 
                                    initializer='zeros', 
                                    trainable=True)
        super(InstanceNormalization, self).build(input_shape)

    def call(self, inputs):
        mean, variance = tf.nn.moments(inputs, axes=[1, 2], keepdims=True)
        return self.gamma * (inputs - mean) * tf.math.rsqrt(variance + self.epsilon) + self.beta
    
    def get_config(self):
        config = super(InstanceNormalization, self).get_config()
        config.update({'epsilon': self.epsilon})
        return config

@tf.custom_gradient
def grad_scaler(x, scale):
    def grad(dy):
        return dy * tf.cast(scale, dy.dtype), None 
    return x, grad

class GradientScaler(layers.Layer):
    """
    Layer này cho phép tín hiệu đi qua bình thường (Forward),
    nhưng làm yếu Gradient đi khi lan truyền ngược (Backward).
    """
    def __init__(self, scale=0.1, **kwargs):
        super(GradientScaler, self).__init__(**kwargs)
        self.scale = tf.cast(scale, dtype=tf.float32)

    def call(self, x):
        return grad_scaler(x, self.scale)

    def get_config(self):
        config = super(GradientScaler, self).get_config()
        config.update({'scale': float(self.scale)})
        return config

class ChannelMean(layers.Layer):
    def call(self, x):
        return tf.reduce_mean(x, axis=-1, keepdims=True)

class ChannelMax(layers.Layer):
    def call(self, x):
        return tf.reduce_max(x, axis=-1, keepdims=True)

def CBAM(inputs, ratio=16, name="cbam"):
    channel = inputs.shape[-1]

    avg_pool = layers.GlobalAveragePooling2D(name=f"{name}_gap")(inputs)
    max_pool = layers.GlobalMaxPooling2D(name=f"{name}_gmp")(inputs)
    
    mlp = models.Sequential([
        layers.Dense(channel // ratio, activation='swish', use_bias=False, name=f"{name}_mlp_1"),
        layers.Dense(channel, use_bias=False, name=f"{name}_mlp_2")
    ], name=f"{name}_mlp")
    
    channel_att = layers.Add(name=f"{name}_channel_add")([mlp(avg_pool), mlp(max_pool)])
    channel_att = layers.Activation('sigmoid', name=f"{name}_channel_sigmoid")(channel_att)
    channel_att = layers.Reshape((1, 1, channel), name=f"{name}_channel_reshape")(channel_att)
    
    x = layers.Multiply(name=f"{name}_channel_mult")([inputs, channel_att])

    avg_pool_s = ChannelMean()(x)
    max_pool_s = ChannelMax()(x)

    concat = layers.Concatenate(axis=-1, name=f"{name}_spatial_concat")([avg_pool_s, max_pool_s])
    
    spatial_att = layers.Conv2D(1, 7, padding='same', activation='sigmoid', use_bias=False, name=f"{name}_spatial_conv")(concat)
    x = layers.Multiply(name=f"{name}_spatial_mult")([x, spatial_att])
    return x

def RGCM(inputs, filters, groups=8, reduction=2, name="rgcm"):
    residual = inputs
    mid_filters = filters // reduction
    if mid_filters % groups != 0: mid_filters = ((mid_filters // groups) + 1) * groups
    
    # Block 1
    x = layers.Conv2D(mid_filters, 1, use_bias=False, name=f"{name}_conv1")(inputs)
    x = InstanceNormalization(name=f"{name}_in1")(x)
    x = layers.Activation('swish', name=f"{name}_relu1")(x)
    
    # Block 2
    x = layers.Conv2D(mid_filters, 3, padding='same', groups=groups, use_bias=False, name=f"{name}_conv2")(x)
    x = InstanceNormalization(name=f"{name}_in2")(x)
    x = layers.Activation('swish', name=f"{name}_relu2")(x)
    
    # Block 3
    x = layers.Conv2D(filters, 1, use_bias=False, name=f"{name}_conv3")(x)
    x = InstanceNormalization(name=f"{name}_in3")(x)
    
    # Skip Connection
    if inputs.shape[-1] != filters:
        residual = layers.Conv2D(filters, 1, use_bias=False, name=f"{name}_shortcut_conv")(inputs)
        residual = InstanceNormalization(name=f"{name}_shortcut_in")(residual)
        
    x = layers.Add(name=f"{name}_add")([x, residual])
    x = layers.Activation('swish', name=f"{name}_out_relu")(x)
    return x

def GABM(xl, xg, filters, name="gabm"):
    theta_x = layers.Conv2D(filters, 1, use_bias=False, name=f"{name}_theta_conv")(xl)
    theta_x = InstanceNormalization(name=f"{name}_theta_in")(theta_x)
    
    phi_g = layers.Conv2D(filters, 1, use_bias=False, name=f"{name}_phi_conv")(xg)
    phi_g = InstanceNormalization(name=f"{name}_phi_in")(phi_g)
    
    f = layers.Add(name=f"{name}_add")([theta_x, phi_g])
    f = layers.Activation('swish', name=f"{name}_relu")(f)
    
    psi = layers.Conv2D(1, 1, use_bias=False, name=f"{name}_psi_conv")(f)
    psi = InstanceNormalization(name=f"{name}_psi_in")(psi)
    psi = layers.Activation('sigmoid', name=f"{name}_psi_sigmoid")(psi)
    
    return layers.Multiply(name=f"{name}_out_mult")([xl, psi])

def ASPP(inputs, out_filters=320, name="aspp"):
    
    def conv_block(x, kernel_size, dilation_rate=1, block_name="block"):
        x = layers.SeparableConv2D(out_filters, kernel_size, padding='same', 
                          dilation_rate=dilation_rate, use_bias=False, name=f"{name}_{block_name}_conv")(x)
        x = InstanceNormalization(name=f"{name}_{block_name}_in")(x)
        x = layers.Activation('swish', name=f"{name}_{block_name}_relu")(x)
        return x

    b1 = conv_block(inputs, 1, dilation_rate=1, block_name="b1")
    b2 = conv_block(inputs, 3, dilation_rate=6, block_name="b2")
    b3 = conv_block(inputs, 3, dilation_rate=12, block_name="b3")
    b4 = conv_block(inputs, 3, dilation_rate=18, block_name="b4")
    
    b5 = layers.GlobalAveragePooling2D(name=f"{name}_b5_gap")(inputs)
    b5 = layers.Reshape((1, 1, inputs.shape[-1]), name=f"{name}_b5_reshape")(b5)
    b5 = layers.Conv2D(out_filters, 1, padding='same', use_bias=False, name=f"{name}_b5_conv")(b5)
    b5 = InstanceNormalization(name=f"{name}_b5_in")(b5)
    b5 = layers.Activation('swish', name=f"{name}_b5_relu")(b5)
    b5 = layers.Lambda(
        lambda x: tf.cast(
            tf.image.resize(
                tf.cast(x[0], tf.float32),
                tf.shape(x[1])[1:3]
            ),
            dtype=x[1].dtype
        ),
        name=f"{name}_b5_resize_safe"
    )([b5, inputs])
    x = layers.Concatenate(name=f"{name}_concat")([b1, b2, b3, b4, b5])
    x = conv_block(x, 1, block_name="proj")
    
    return x

def build_multitask_arga_unet():
    inputs = layers.Input(shape=(TARGET_HEIGHT, TARGET_WIDTH, 3), name="input_1")
    
    # --- ENCODER  ---
    c1 = layers.Conv2D(32, 3, padding='same', name="enc_stem_conv")(inputs)
    c1 = InstanceNormalization(name="enc_stem_in")(c1)
    c1 = layers.Activation('swish', name="enc_stem_relu")(c1)
    
    # Stage 1
    r1 = RGCM(c1, 32, name="enc_stage1_rgcm")
    p1 = layers.MaxPooling2D(2, name="enc_stage1_pool")(r1)
    
    # Stage 2
    r2 = RGCM(p1, 64, name="enc_stage2_rgcm")
    p2 = layers.MaxPooling2D(2, name="enc_stage2_pool")(r2)
    
    # Stage 3
    r3 = RGCM(p2, 128, name="enc_stage3_rgcm")
    p3 = layers.MaxPooling2D(2, name="enc_stage3_pool")(r3)
    
    # Stage 4
    r4 = RGCM(p3, 256, name="enc_stage4_rgcm")
    p4 = layers.MaxPooling2D(2, name="enc_stage4_pool")(r4)
    
    # --- BOTTLENECK ---
    b_pre = RGCM(p4, 512, name="bot_rgcm")
    b = ASPP(b_pre, out_filters=320, name="bot_aspp")
    
    # --- CLASSIFIER HEAD ---
    cls_x = GradientScaler(scale=0.1, name="head_cls_scaler")(b)
    cls_x = layers.Conv2D(256, 3, padding='same', activation='swish', kernel_initializer='he_normal', name="head_cls_conv")(cls_x)
    cls_x = InstanceNormalization(name="head_cls_in")(cls_x)
    
    g_avg = layers.GlobalAveragePooling2D(name="head_cls_gap")(cls_x)
    g_max = layers.GlobalMaxPooling2D(name="head_cls_gmp")(cls_x)
    
    cls = layers.Concatenate(name="head_cls_concat")([g_avg, g_max])
    cls = layers.Dense(256, activation='swish', kernel_initializer='he_normal', name="head_cls_dense1")(cls)
    cls = layers.Dropout(0.5, name="head_cls_dropout")(cls)
    
    cls_out = layers.Dense(1, activation='sigmoid', kernel_initializer='glorot_uniform', name='head_cls_out', dtype='float32')(cls)
    
    # --- DECODER ---
    
    # Block 6
    u6 = layers.UpSampling2D(2, interpolation='nearest', name="dec_stage6_up")(b)
    g6 = GABM(xl=r4, xg=u6, filters=256, name="dec_stage6_gabm")
    c6 = layers.Concatenate(name="dec_stage6_concat")([u6, g6])
    
    c6 = layers.Conv2D(256, 3, padding='same', use_bias=False, name="dec_stage6_conv1")(c6)
    c6 = InstanceNormalization(name="dec_stage6_in1")(c6)
    c6 = layers.Activation('swish', name="dec_stage6_relu1")(c6)
    c6 = RGCM(c6, 256, name="dec_stage6_rgcm")
    c6 = CBAM(c6, name="dec_stage6_cbam")
    
    # Block 7
    u7 = layers.UpSampling2D(2, interpolation='nearest', name="dec_stage7_up")(c6)
    g7 = GABM(xl=r3, xg=u7, filters=128, name="dec_stage7_gabm")
    c7 = layers.Concatenate(name="dec_stage7_concat")([u7, g7])
    
    c7 = layers.Conv2D(128, 3, padding='same', use_bias=False, name="dec_stage7_conv1")(c7)
    c7 = InstanceNormalization(name="dec_stage7_in1")(c7)
    c7 = layers.Activation('swish', name="dec_stage7_relu1")(c7)
    c7 = RGCM(c7, 128, name="dec_stage7_rgcm")
    c7 = CBAM(c7, name="dec_stage7_cbam")
    
    # Block 8
    u8 = layers.UpSampling2D(2, interpolation='nearest', name="dec_stage8_up")(c7)
    g8 = GABM(xl=r2, xg=u8, filters=64, name="dec_stage8_gabm")
    c8 = layers.Concatenate(name="dec_stage8_concat")([u8, g8])
    
    c8 = layers.Conv2D(64, 3, padding='same', use_bias=False, name="dec_stage8_conv1")(c8)
    c8 = InstanceNormalization(name="dec_stage8_in1")(c8)
    c8 = layers.Activation('swish', name="dec_stage8_relu1")(c8)
    c8 = RGCM(c8, 64, name="dec_stage8_rgcm")
    c8 = CBAM(c8, name="dec_stage8_cbam")
    
    # Block 9
    u9 = layers.UpSampling2D(2, interpolation='nearest', name="dec_stage9_up")(c8)
    g9 = GABM(xl=r1, xg=u9, filters=32, name="dec_stage9_gabm")
    c9 = layers.Concatenate(name="dec_stage9_concat")([u9, g9])
    
    c9 = layers.Conv2D(32, 3, padding='same', use_bias=False, name="dec_stage9_conv1")(c9)
    c9 = InstanceNormalization(name="dec_stage9_in1")(c9)
    c9 = layers.Activation('swish', name="dec_stage9_relu1")(c9)
    c9 = RGCM(c9, 32, name="dec_stage9_rgcm")
    c9 = CBAM(c9, name="dec_stage9_cbam")

    edge_x = layers.Conv2D(32, 3, padding='same', activation='swish', name="head_edge_conv")(c9)
    edge_out = layers.Conv2D(1, 1, activation='sigmoid', name='edge_out', dtype='float32')(edge_x)

    c9 = layers.SpatialDropout2D(0.1)(c9)
    seg_out = layers.Conv2D(1, 1, activation='sigmoid', name='seg_out', dtype='float32')(c9)
    
    return models.Model(inputs=inputs, outputs=[seg_out, edge_out, cls_out], name="ARGA_Unet")

In [None]:
# ==========================================
# 4. LOSS & METRICS & SUPPORT FUNCTION
# ==========================================
def dice_coef(y_true, y_pred, smooth=1e-5):
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2]) + smooth)

def focal_tversky_loss(y_true, y_pred, alpha=0.4, beta=0.6, gamma=2.0, smooth=1.0):
    y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
    tp = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    fp = tf.reduce_sum((1 - y_true) * y_pred, axis=[1, 2])
    fn = tf.reduce_sum(y_true * (1 - y_pred), axis=[1, 2])
    tversky_index = (tp + smooth) / (tp + alpha * fn + (1 - alpha) * fp + smooth)
    loss = tf.pow((1 - tversky_index), gamma)
    
    return tf.reduce_mean(loss)

class WeightedEdgeLoss(tf.keras.losses.Loss):
    def __init__(self, pos_weight=10.0, **kwargs):
        
        super().__init__(**kwargs)
        self.pos_weight = pos_weight

    def call(self, y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, K.epsilon(), 1 - K.epsilon())
        loss = - (self.pos_weight * y_true * tf.math.log(y_pred) + 
                  (1 - y_true) * tf.math.log(1 - y_pred))
        
        return tf.reduce_mean(loss)

def log_cosh_dice_loss(y_true, y_pred, smooth=1e-5):
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2])
    union = tf.reduce_sum(y_true, axis=[1, 2]) + tf.reduce_sum(y_pred, axis=[1, 2])
    dice = (2. * intersection + smooth) / (union + smooth)
    dice_loss = 1.0 - dice
    return tf.math.log(tf.math.cosh(dice_loss))

def edge_dice_loss(y_true, y_pred, smooth=1e-5):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return 1.0 - score

def boundary_consistency_loss(seg_pred, edge_pred):
    
    sobel_x = tf.constant([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], tf.float32)
    sobel_x = tf.reshape(sobel_x, [3, 3, 1, 1])
    sobel_y = tf.constant([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], tf.float32)
    sobel_y = tf.reshape(sobel_y, [3, 3, 1, 1])

    grad_x = tf.nn.depthwise_conv2d(seg_pred, sobel_x, strides=[1, 1, 1, 1], padding='SAME')
    grad_y = tf.nn.depthwise_conv2d(seg_pred, sobel_y, strides=[1, 1, 1, 1], padding='SAME')
    
    seg_boundary = tf.sqrt(tf.square(grad_x) + tf.square(grad_y) + 1e-7)
    
    seg_boundary = seg_boundary / (tf.reduce_max(seg_boundary) + 1e-7)
    
    return tf.reduce_mean(tf.abs(seg_boundary - edge_pred))

def scale_consistency_loss(model, images, pred_orig, scale_factor=0.75):
    B, H, W, C = images.shape
    new_h = int(TARGET_HEIGHT * scale_factor)
    new_w = int(TARGET_WIDTH * scale_factor)
    
    images_scaled = tf.image.resize(images, [new_h, new_w])

    pad_h = TARGET_HEIGHT - new_h
    pad_w = TARGET_WIDTH - new_w
    images_scaled_padded = tf.image.pad_to_bounding_box(images_scaled, 0, 0, TARGET_HEIGHT, TARGET_WIDTH)

    preds_scaled_padded = model(images_scaled_padded, training=True)
    pred_seg_scaled = preds_scaled_padded[0]

    pred_orig_resized = tf.image.resize(pred_orig, [new_h, new_w])
    pred_orig_consistent = tf.image.pad_to_bounding_box(pred_orig_resized, 0, 0, TARGET_HEIGHT, TARGET_WIDTH)
    
    mask_valid = tf.ones((B, new_h, new_w, 1))
    mask_valid = tf.image.pad_to_bounding_box(mask_valid, 0, 0, TARGET_HEIGHT, TARGET_WIDTH)
    
    mse = tf.square(pred_seg_scaled - pred_orig_consistent)
    consistency_loss = tf.reduce_sum(mse * mask_valid) / (tf.reduce_sum(mask_valid) + 1e-5)
    
    return consistency_loss

def mixup_batch(inputs, alpha=0.2):
    images, labels = inputs
    batch_size = tf.shape(images)[0]

    shift = tf.random.uniform([], minval=1, maxval=batch_size, dtype=tf.int32)
    images_two = tf.roll(images, shift=shift, axis=0)
    labels_two = {k: tf.roll(v, shift=shift, axis=0) for k, v in labels.items()}
    l = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 0.3) 
    
    do_flip = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0) > 0.5
    l = tf.where(do_flip, 1.0 - l, l)

    l_cls = tf.reshape(l, [batch_size, 1])

    images_mix = l * images + (1 - l) * images_two
    
    new_labels = {}

    if 'seg_out' in labels:
        new_labels['seg_out'] = l * labels['seg_out'] + (1 - l) * labels_two['seg_out']

    if 'edge_out' in labels:
        edge_mask_decision = tf.cast(l > 0.5, tf.float32)
        new_labels['edge_out'] = edge_mask_decision * labels['edge_out'] + (1 - edge_mask_decision) * labels_two['edge_out']

    if 'cls_out' in labels:
        new_labels['cls_out'] = l_cls * labels['cls_out'] + (1 - l_cls) * labels_two['cls_out']

    return images_mix, new_labels

In [None]:
# ==========================================
# 5. PRETRAIN MODEL TRÊN TOÀN DATASET
# ==========================================
with strategy.scope():
    model = build_multitask_arga_unet()
    model.build((None, TARGET_HEIGHT, TARGET_WIDTH, 3))

model.summary()

In [None]:
# ====================================
# CUSTOM TRAINING LOOP (KERAS 3 SUCKS)
# ====================================

N_TRAIN = len(X_train)
STEPS_PER_EPOCH = N_TRAIN // BATCH_SIZE
TPU_LOOPS_COUNT = 1 
STEPS_PER_LOOP = STEPS_PER_EPOCH

class CustomScheduler:
    def __init__(self, model, optimizer, mode='min', patience_lr=5, patience_stop=15, 
                 factor=0.5, min_lr=1e-8, save_name="best_model.weights.h5"):
        self.model = model
        self.optimizer = optimizer
        self.mode = mode
        self.patience_lr = patience_lr
        self.patience_stop = patience_stop
        self.factor = factor
        self.min_lr = min_lr
        self.save_name = save_name
        
        self.lr_wait = 0
        self.stop_wait = 0

        if mode == 'min':
            self.best_value = float('inf')
            self.monitor_op = np.less
        else:
            self.best_value = -float('inf')
            self.monitor_op = np.greater
            
    def load_best_weights(self):
        print(f"   Restoring best weights from {self.save_name}...")
        try:
            self.model.load_weights(self.save_name)
            print("   Restore successful.")
        except Exception as e:
            print(f"   Restore failed: {e}")
            
    def step(self, current_value):
        if self.monitor_op(current_value, self.best_value):
            print(f"   Improved ({self.best_value:.4f} -> {current_value:.4f}). Saving weights...")
            self.best_value = current_value
            self.model.save_weights(self.save_name)
            
            self.lr_wait = 0
            self.stop_wait = 0
        else:
            self.lr_wait += 1
            self.stop_wait += 1
            print(f"   No improv. LR Wait: {self.lr_wait}/{self.patience_lr} | Stop Wait: {self.stop_wait}/{self.patience_stop}")

            if self.lr_wait >= self.patience_lr:
                old_lr = float(self.optimizer.learning_rate.numpy())
                new_lr = old_lr * self.factor
                
                if new_lr > self.min_lr:
                    self.optimizer.learning_rate.assign(new_lr)
                    print(f"   Reduced LR to {new_lr:.1e}")
                    self.lr_wait = 0
                else:
                    print(f"   LR reached min ({self.min_lr}). Cannot reduce further.")

            if self.stop_wait >= self.patience_stop:
                print("   Early Stopping triggered.")
                return True
        
        return False

def run_tpu_training(model, 
                     train_ds, 
                     val_ds, 
                     save_name='pretrain.weights.h5',
                     mode='min',
                     patience_lr=5,
                     patience_st=15,
                     value='dice',
                     epochs=50, 
                     lr=LEARNING_RATE, 
                     weight_loss=[1.0, 0.5, 0.1],
                     steps_per_loop=STEPS_PER_LOOP,
                     loops_count=TPU_LOOPS_COUNT):
    
    with strategy.scope():
        optimizer = tf.keras.optimizers.AdamW(learning_rate=lr)
        val_dice = tf.keras.metrics.Mean(name='val_dice')
        val_cls_loss = tf.keras.metrics.Mean(name='val_cls_loss')
        aug_layer = MedicalPreprocessingLayer(TARGET_HEIGHT, TARGET_WIDTH)
        
        @tf.function(jit_compile=True)
        def train_step(inputs):
            images, labels = aug_layer(inputs, training=True)
            images, targets = mixup_batch((images, labels), alpha=0.2)
            with tf.GradientTape() as tape:
                preds = model(images, training=True)
                
                pred_seg = preds[0]
                pred_edge = preds[1]
                pred_cls = preds[2]
                
                if weight_loss[0] > 0:
                    target_seg = tf.cast(targets['seg_out'], tf.float32)
                    l_tversky = focal_tversky_loss(target_seg, pred_seg)
                    l_logcosh = log_cosh_dice_loss(target_seg, pred_seg)
                    l_seg = (l_tversky +  l_logcosh) * weight_loss[0]
                else:
                    l_seg = 0.0

                if weight_loss[1] > 0:
                    target_edge = tf.cast(targets['edge_out'], tf.float32)
                    l_bce = WeightedEdgeLoss()(target_edge, pred_edge)
                    l_dice = edge_dice_loss(target_edge, pred_edge)
                    l_edge = (l_bce + l_dice) * weight_loss[1]
                else:
                    l_edge = 0.0

                if weight_loss[2] > 0:
                    target_cls = tf.cast(targets['cls_out'], tf.float32)
                    l_cls_raw = tf.keras.losses.binary_crossentropy(target_cls, pred_cls, from_logits=False)
                    l_cls = tf.reduce_mean(l_cls_raw) * weight_loss[2]
                else:
                    l_cls = 0.0
                
                l_consist = 0.0
                l_consist_scale = 0.0
                
                if (weight_loss[0] > 0) or (weight_loss[1] > 0):
                    
                    l_consist = boundary_consistency_loss(pred_seg, pred_edge) * 0.1
                    
                    rand_val = tf.random.uniform([], 0, 1)

                    def compute_scale():
                        return scale_consistency_loss(model, images, pred_seg)
                        
                    l_scale = tf.cond(
                        rand_val < 0.3, 
                        compute_scale, 
                        lambda: 0.0
                    )
                    l_consist_scale *= 0.1
                                
                loss = l_seg + l_edge + l_cls + l_consist + l_consist_scale

            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            return loss     

        @tf.function
        def train_loop_body(iterator, steps):
            total_loss = 0.0
            for _ in tf.range(steps):
                inputs = next(iterator)
                loss = strategy.run(train_step, args=(inputs,))
                total_loss += tf.reduce_sum(strategy.reduce(tf.distribute.ReduceOp.SUM, loss, axis=None))
                
            return total_loss / tf.cast(steps, tf.float32)

        @tf.function(jit_compile=True)
        def val_step_fn(inputs):
            images, targets = aug_layer(inputs, training=False)
            preds = model(images, training=False)
            cls_loss = tf.keras.losses.binary_crossentropy(targets['cls_out'], preds[2], from_logits=False)
            cls_loss = tf.reduce_mean(cls_loss)

            d_score = dice_coef(targets['seg_out'], preds[0])
    
            return cls_loss, d_score
    

    train_dataset = strategy.experimental_distribute_dataset(train_ds)
    val_dataset = strategy.experimental_distribute_dataset(val_ds)

    scheduler = CustomScheduler(
        model=model, 
        optimizer=optimizer, 
        mode=mode,
        patience_lr=patience_lr,
        patience_stop=patience_st, 
        save_name=save_name
    )

    train_iter = iter(train_dataset)
    best_dice = -1.0

    for epoch in range(epochs):
        loss_sum = 0.0
        
        pbar = tqdm(range(loops_count), desc=f"Ep {epoch+1}", leave=False)
        for _ in pbar:
            loss = train_loop_body(train_iter, tf.constant(steps_per_loop))
            loss_sum += float(loss)
            pbar.set_postfix({'loss': f"{loss:.4f}"})

        val_dice.reset_state()
        val_cls_loss.reset_state()
        for batch in val_ds:
            per_replica_losses, per_replica_dice = strategy.run(val_step_fn, args=(batch,))
    
        total_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
        total_dice = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_dice, axis=None)
        
        val_cls_loss.update_state(total_loss)
        val_dice.update_state(total_dice)
            
        v_dice = float(val_dice.result())
        v_cls = float(val_cls_loss.result())
        curr_lr = float(optimizer.learning_rate.numpy())
        print(f"Epoch {epoch+1} | Loss: {loss_sum/loops_count:.4f} | Dice: {v_dice:.4f} | Cls loss: {v_cls:.4f} | LR: {curr_lr:.1e}")
        
        if value=='dice':
            if scheduler.step(v_dice):
                break
        else:
            if scheduler.step(v_cls):
                break

    scheduler.load_best_weights()
    return model

print("Pretrain Multi-Task...")
model = run_tpu_training(
    model=model,
    train_ds=train_ds,
    val_ds=val_ds,
    epochs=EPOCHS,
    value='dice',
    mode='max'
)

In [None]:
for var in ['X_train', 'y_train', 'X_val', 'y_val', 'train_ds', 'val_ds', 'train_dataset', 'val_dataset']:
    if var in globals():
        del globals()[var]
gc.collect()

In [None]:
# ===============================
# 6. TẠO DATASET ĐÍCH ĐỂ FINETUNE
# ===============================
def check_label(mask_path):
    msk = cv2.imread(mask_path, 0)
    label = 1 if np.max(msk) > 0 else 0
    return mask_path, label

def prepare_finetune_dataframe(data_dir):
    all_mask_paths = glob.glob(os.path.join(data_dir, "*_mask.png"))
    img_paths = []
    valid_mask_paths = []
    labels = []
    
    print("Đang quét dữ liệu và tạo nhãn...")
    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
        future_to_path = {}
        for m_path in all_mask_paths:
            i_path = m_path.replace('_mask.png', '.png')
            if os.path.exists(i_path):
                img_paths.append(i_path)
                valid_mask_paths.append(m_path)
                future_to_path[executor.submit(check_label, m_path)] = m_path
        
        path_to_label = {}
        for future in tqdm(concurrent.futures.as_completed(future_to_path), total=len(valid_mask_paths)):
            _, label = future.result()
            path_to_label[future_to_path[future]] = label
            
    final_labels = [path_to_label[p] for p in valid_mask_paths]

    df = pd.DataFrame({
        'image_path': img_paths,
        'mask_path': valid_mask_paths,
        'has_nerve': final_labels
    })
    
    df['patient_id'] = df['image_path'].apply(lambda x: os.path.basename(x).split('_')[0])
    
    return df
    
def merge_and_remove_duplicates(df):
    """
    Gộp mask của các ảnh trùng nhau bằng phép tính MAX (Union).
    """
    print(f"Original size: {len(df)}")
    hashes = {}
    hash_to_masks = {}
    hash_to_img_path = {}
    hash_to_first_idx = {}
    print("Grouping duplicates...")
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        try:
            image = Image.open(row['image_path'])
            h = str(imagehash.phash(image))
            
            if h not in hash_to_masks:
                hash_to_masks[h] = []
                hash_to_img_path[h] = row['image_path']
                hash_to_first_idx[h] = idx
                
            hash_to_masks[h].append(row['mask_path'])
            
        except Exception as e:
            pass
    final_data = []
    print("Merging masks...")
    os.makedirs("merged_masks", exist_ok=True)
    
    for h, mask_paths in tqdm(hash_to_masks.items()):
        if len(mask_paths) == 1:
            final_data.append({
                'image_path': hash_to_img_path[h],
                'mask_path': mask_paths[0],
                'has_nerve': 1 if cv2.imread(mask_paths[0], 0).max() > 0 else 0
            })
        else:
            merged_mask = None
            for mp in mask_paths:
                m = cv2.imread(mp, 0)
                if merged_mask is None:
                    merged_mask = m
                else:
                    merged_mask = np.maximum(merged_mask, m)
            base_name = os.path.basename(hash_to_img_path[h]).replace('.png', '')
            new_mask_path = f"merged_masks/{base_name}_merged_mask.png"
            cv2.imwrite(new_mask_path, merged_mask)
            
            final_data.append({
                'image_path': hash_to_img_path[h],
                'mask_path': new_mask_path,
                'has_nerve': 1 if np.max(merged_mask) > 0 else 0
            })

    new_df = pd.DataFrame(final_data)
    new_df['patient_id'] = new_df['image_path'].apply(lambda x: os.path.basename(x).split('_')[0])
    
    print(f"Final size after merging: {len(new_df)}")
    return new_df
    
def split_finetune_data(df, n_splits=5):
    
    sgkf = StratifiedGroupKFold(n_splits=n_splits)
    df['fold'] = -1
    
    X = df['image_path']
    y = df['has_nerve']
    groups = df['patient_id']
    
    for fold_idx, (train_idx, val_idx) in enumerate(sgkf.split(X, y, groups)):
        df.loc[val_idx, 'fold'] = fold_idx
        
    return df

df_finetune = prepare_finetune_dataframe(PROCESSED_USNS_DIR)
df_finetune = merge_and_remove_duplicates(df_finetune)
df_finetune = split_finetune_data(df_finetune, n_splits=5)

fold_idx = 0

df_train = df_finetune[df_finetune['fold'] != fold_idx]

df_val = df_finetune[df_finetune['fold'] == fold_idx]

print(f"Final Train Size: {len(df_train)}")
print(f"Final Val Size  : {len(df_val)}")

train_imgs = df_train['image_path'].tolist()
train_masks = df_train['mask_path'].tolist()

val_imgs = df_val['image_path'].tolist()
val_masks = df_val['mask_path'].tolist()
df_train_pos = df_train[df_train['has_nerve'] == 1].copy()
df_val_pos = df_val[df_val['has_nerve'] == 1].copy()

X_train_pos, y_train_pos = load_dataset_into_ram(df_train_pos['image_path'].tolist(), df_train_pos['mask_path'].tolist())
X_val_pos, y_val_pos = load_dataset_into_ram(df_val_pos['image_path'].tolist(), df_val_pos['mask_path'].tolist())
X_train, y_train = load_dataset_into_ram(train_imgs, train_masks)
X_val, y_val = load_dataset_into_ram(val_imgs, val_masks)

train_ds_pos = get_unified_dataset(X_train_pos, y_train_pos, training=True)
val_ds_pos = get_unified_dataset(X_val_pos, y_val_pos, training=False)

train_ds = get_unified_dataset(X_train, y_train, training=True)
val_ds = get_unified_dataset(X_val, y_val, training=False)

STEPS_PER_EPOCH = len(X_train) // (BATCH_SIZE // 2)

SETPS_PER_EPOCH_POS = len(X_train_pos) // (BATCH_SIZE // 2)

In [None]:
# ==============================
# 7. FINETUNE TRÊN DATASET ĐÍCH
# ==============================
print("Training Seg Branch ...")
with strategy.scope():
    model = build_multitask_arga_unet()

    model.load_weights('/kaggle/working/pretrain.weights.h5', skip_mismatch=True)

    for layer in model.layers:
        if layer.name.startswith('dec_'):
            layer.trainable = True
        else:
            layer.trainable = False

model = run_tpu_training(
    model=model,
    train_ds=train_ds_pos,
    val_ds=val_ds_pos,
    mode='max',
    patience_lr=5,
    patience_st=12,
    value='dice',
    epochs=8,
    weight_loss=[1.0, 0.5, 0.0],
    lr=1e-3,
    save_name="stage1_best.weights.h5",
    steps_per_loop=SETPS_PER_EPOCH_POS
)
with strategy.scope():
    for layer in model.layers:
        if layer.name.startswith('head_cls_'):
            layer.trainable = False
        else:
            layer.trainable = True

model = run_tpu_training(
    model=model,
    train_ds=train_ds_pos,
    val_ds=val_ds_pos,
    mode='max',
    patience_lr=5,
    patience_st=12,
    value='dice',
    epochs=20,
    weight_loss=[4.0, 2.0, 0.0],
    lr=3.2e-4,
    save_name="stage1_best.weights.h5",
    steps_per_loop=SETPS_PER_EPOCH_POS
)

print("Training CLS branch...")
with strategy.scope():
    for layer in model.layers:
        if layer.name.startswith('head_cls_'):
            layer.trainable = True
        else:
            layer.trainable = False

model = run_tpu_training(
    model=model,
    train_ds=train_ds,
    val_ds=val_ds,
    mode='min',
    patience_lr=3,
    patience_st=15,
    value='cls',
    epochs=20,
    weight_loss=[0.0, 0.0, 1.0],
    lr=8e-5,
    save_name="stage2_best.weights.h5",
    steps_per_loop=STEPS_PER_EPOCH
)

print("Finetune sync...")
with strategy.scope():
    for layer in model.layers:
        layer.trainable = True

model = run_tpu_training(
    model=model,
    train_ds=train_ds,
    val_ds=val_ds,
    mode='max',
    patience_lr=5,
    patience_st=15,
    value='dice',
    epochs=30,
    weight_loss=[0.5, 0.5, 1.5],
    lr=8e-5,
    save_name="finetune_best.weights.h5",
    steps_per_loop=STEPS_PER_EPOCH
)

In [None]:
# =============================================
# 8. THRESHOLD CHECKING TRÊN TẬP ĐÍCH MÔ PHỎNG
# =============================================
MC_SAMPLES = 16
with strategy.scope(): 
    model = build_multitask_arga_unet()
    model.load_weights('/kaggle/working/finetune_best.weights.h5', skip_mismatch=True)
    aug_layer = MedicalPreprocessingLayer(TARGET_HEIGHT, TARGET_WIDTH)

def val_adapter(img, mask):
    rank = tf.rank(mask)
    reduction_axes = tf.range(1, rank) 
    has_object = tf.reduce_max(tf.cast(mask, tf.float32), axis=reduction_axes) > 0.0
    label_cls = tf.cast(has_object, tf.float32)
    label_cls = tf.reshape(label_cls, [-1, 1]) 
    return img, {'seg_out': mask, 'cls_out': label_cls}

val_dataset_inference = val_ds.map(val_adapter, num_parallel_calls=tf.data.AUTOTUNE)

@tf.function
def predict_step(inputs):
    current_batch_size = tf.shape(inputs)[0] 
    preprocessed = aug_layer(inputs, training=False)

    if isinstance(preprocessed, (list, tuple)):
        img_normalized = preprocessed[0]
    else:
        img_normalized = preprocessed

    img_orig = img_normalized
    img_flip = tf.image.flip_left_right(img_normalized)
    combined_tta = tf.concat([img_orig, img_flip], axis=0)
    batch_ready = tf.repeat(combined_tta, MC_SAMPLES, axis=0)
    
    outputs = model(batch_ready, training=True)
    
    seg_preds_all = outputs[0]
    cls_preds_all = outputs[-1]

    if tf.shape(cls_preds_all)[-1] > 1:
        cls_preds_all = cls_preds_all[..., 0:1]

    seg_reshaped = tf.reshape(seg_preds_all, [2 * current_batch_size, MC_SAMPLES, TARGET_HEIGHT, TARGET_WIDTH, 1])
    cls_reshaped = tf.reshape(cls_preds_all, [2 * current_batch_size, MC_SAMPLES, 1])

    seg_mean_mc = tf.reduce_mean(seg_reshaped, axis=1)
    seg_var_mc  = tf.math.reduce_variance(seg_reshaped, axis=1)
    cls_mean_mc = tf.reduce_mean(cls_reshaped, axis=1)

    seg_mean_orig, seg_mean_flip = tf.split(seg_mean_mc, num_or_size_splits=2, axis=0)
    seg_var_orig,  seg_var_flip  = tf.split(seg_var_mc,  num_or_size_splits=2, axis=0) 
    cls_mean_orig, cls_mean_flip = tf.split(cls_mean_mc, num_or_size_splits=2, axis=0)

    seg_mean_flip_back = tf.image.flip_left_right(seg_mean_flip)
    seg_var_flip_back  = tf.image.flip_left_right(seg_var_flip) 

    final_seg = (seg_mean_orig + seg_mean_flip_back) / 2.0
    final_unc = (seg_var_orig + seg_var_flip_back) / 2.0
    final_cls = (cls_mean_orig + cls_mean_flip) / 2.0
    
    return final_seg, final_cls, final_unc
    
def tune_all_parameters(strategy, val_ds, predict_step_fn, target_pos_ratio=0.2):
    print("--- Hyper-parameter tunning processing ---")
    
    cache_pred_cls = []
    cache_pred_seg = []
    cache_pred_unc = [] 
    cache_true_cls = []
    cache_true_seg = []
    
    dist_val = strategy.experimental_distribute_dataset(val_ds)
    
    print("1. Đang chạy Inference trên toàn bộ tập Validation...")
    
    for batch_inputs, batch_labels in tqdm(dist_val):
        if isinstance(batch_labels, dict):
            lbl_cls = strategy.gather(batch_labels['cls_out'], axis=0).numpy()
            lbl_seg = strategy.gather(batch_labels['seg_out'], axis=0).numpy()
        else:
            print("Error: Dataset format not dict! Kiểm tra lại pipeline.")
            break

        preds = strategy.run(predict_step_fn, args=(batch_inputs,))
        
        pred_seg = strategy.gather(preds[0], axis=0).numpy()
        pred_cls = strategy.gather(preds[1], axis=0).numpy()
            
        pred_unc = strategy.gather(preds[2], axis=0).numpy()
        
        cache_pred_cls.append(pred_cls)
        cache_pred_seg.append(pred_seg)
        cache_pred_unc.append(pred_unc)
        cache_true_cls.append(lbl_cls)
        cache_true_seg.append(lbl_seg)

    all_pred_cls = np.concatenate(cache_pred_cls).flatten()
    all_true_cls = np.concatenate(cache_true_cls).flatten()
    all_pred_seg = np.concatenate(cache_pred_seg)
    all_pred_unc = np.concatenate(cache_pred_unc)
    all_true_seg = np.concatenate(cache_true_seg)

    if np.max(all_true_seg) > 1:
        all_true_seg = all_true_seg / 255.0
    print(f"-> Total: {len(all_pred_cls)} samples.")

    precision, recall, thresholds = precision_recall_curve(all_true_cls, all_pred_cls)
    f1_scores = 2 * recall * precision / (recall + precision + 1e-7)
    
    best_idx = np.argmax(f1_scores)
    best_cls_thresh = thresholds[best_idx]
    best_f1 = f1_scores[best_idx]
    
    print(f"-> Best CLS Threshold: {best_cls_thresh:.4f} (Max F1: {best_f1:.4f})")
    
    print("-> Applying Soft Gating & Uncertainty Penalty...")
    
    pos_idx = np.where(all_true_cls == 1)[0]
    
    if len(pos_idx) == 0:
        best_final_thresh = 0.5
    else:
        sub_pred_seg = all_pred_seg[pos_idx]
        sub_pred_cls = all_pred_cls[pos_idx].reshape(-1, 1, 1, 1)
        sub_pred_unc = all_pred_unc[pos_idx]
        sub_true_seg = all_true_seg[pos_idx]
        
        penalty_weight = 0.6
        soft_mask = (sub_pred_seg - penalty_weight * sub_pred_unc) * sub_pred_cls
        
        soft_mask = np.clip(soft_mask, 0.0, 1.0)
    
        best_dice = 0.0
        best_final_thresh = 0.5
        
        for t in np.arange(0.1, 0.9, 0.05):
            pred_bin = (soft_mask > t).astype(np.float32)
            
            intersection = np.sum(sub_true_seg * pred_bin, axis=(1, 2, 3))
            union = np.sum(sub_true_seg, axis=(1, 2, 3)) + np.sum(pred_bin, axis=(1, 2, 3))
            dices = (2. * intersection + 1e-5) / (union + 1e-5)
            mean_dice = np.mean(dices)
            
            if mean_dice > best_dice:
                best_dice = mean_dice
                best_final_thresh = t
                
        print(f"-> Best COMBINED Threshold: {best_final_thresh:.4f} (Dice: {best_dice:.4f})")

    print("Tune UNCERTAINTY Threshold...")
    tp_uncs = []
    fp_uncs = []
    
    print("   Analyzing connected components...")
    for i in tqdm(range(len(all_pred_seg))):
        pred_mask_raw = all_pred_seg[i, ..., 0]
        unc_map_raw = all_pred_unc[i, ..., 0]
        true_mask_raw = all_true_seg[i, ..., 0]
        
        pred_mask_bin = (pred_mask_raw > best_final_thresh).astype(np.uint8)
        true_mask_bin = true_mask_raw.astype(np.uint8)
        
        num, labels_im = cv2.connectedComponents(pred_mask_bin)
        
        for region_idx in range(1, num):
            region_mask = (labels_im == region_idx).astype(np.uint8)
            mean_u = np.mean(unc_map_raw[region_mask == 1])
            
            intersection = np.sum(region_mask * true_mask_bin)
            
            if intersection > 0:
                tp_uncs.append(mean_u)
            else:
                fp_uncs.append(mean_u)
                
    print(f"   Collected TP regions: {len(tp_uncs)}")
    print(f"   Collected FP regions: {len(fp_uncs)}")
    
    median_tp = np.median(tp_uncs) if len(tp_uncs) > 0 else 0
    median_fp = np.median(fp_uncs) if len(fp_uncs) > 0 else 1.0
    
    suggested_unc_thresh = (median_tp + median_fp) / 2
    print(f"-> Best UNCERTAINTY Threshold: {suggested_unc_thresh:.4f}")
    
    return best_cls_thresh, best_final_thresh, suggested_unc_thresh

BEST_CLS, BEST_SEG, BEST_UNC = tune_all_parameters(
    strategy, 
    val_dataset_inference,
    predict_step, 
    target_pos_ratio=0.5
)

In [None]:
# ====================
# 9.THỰC HIỆN PREDICT
# ====================
TEST_DIR = '/kaggle/input/ultrasound-nerve-segmentation/test'
TEST_PNG_DIR = '/kaggle/working/test_usns_png'
BATCH_SIZE_PREDICT = 4 * strategy.num_replicas_in_sync

if not os.path.exists(TEST_PNG_DIR):
    os.makedirs(TEST_PNG_DIR)
test_files_tif = glob.glob(os.path.join(TEST_DIR, "*.tif"))
print(f"Converting {len(test_files_tif)} test images to PNG...")
    
for t_path in tqdm(test_files_tif):
    base_name = os.path.basename(t_path).replace('.tif', '.png')
    img = cv2.imread(t_path)
    if img is not None:
        cv2.imwrite(os.path.join(TEST_PNG_DIR, base_name), img)

print("Test conversion complete.")

test_png_files = glob.glob(os.path.join(TEST_PNG_DIR, "*.png"))
test_png_files.sort() 

N_ORIGINAL = len(test_png_files)
remainder = N_ORIGINAL % BATCH_SIZE_PREDICT
if remainder != 0:
    pad_len = BATCH_SIZE_PREDICT - remainder
    test_png_files += test_png_files[:pad_len]
    print(f"Padding dataset: {N_ORIGINAL} -> {len(test_png_files)} images (Added {pad_len})")
else:
    print("No padding needed.")

def read_test_image(img_path):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.cast(img, tf.float32) / 255.0
    img = tf.image.resize(img, [TARGET_HEIGHT, TARGET_WIDTH])
    img.set_shape([TARGET_HEIGHT, TARGET_WIDTH, 3])
    return img

def get_test_dataset(image_paths):
    dataset = tf.data.Dataset.from_tensor_slices(image_paths)
    dataset = dataset.map(read_test_image, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE_PREDICT, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

test_dataset = get_test_dataset(test_png_files)
dist_test_dataset = strategy.experimental_distribute_dataset(test_dataset)

def filter_by_uncertainty(pred_mask, unc_map, threshold=0.5):
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(pred_mask.astype(np.uint8))
    
    final_mask = np.zeros_like(pred_mask)
    
    for i in range(1, num_labels):
        component_mask = (labels == i)
        mean_unc = np.mean(unc_map[component_mask])
        if mean_unc < threshold:
            final_mask[component_mask] = 1
            
    return final_mask

print("Start predicting...")
def rle_encode(mask):
    pixels = mask.flatten(order='F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)
    
def remove_small_objects(mask, min_size=100):
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
    cleaned_mask = np.zeros_like(mask)
    for i in range(1, num_labels):
        area = stats[i, cv2.CC_STAT_AREA]
        if area >= min_size:
            cleaned_mask[labels == i] = 1
    return cleaned_mask
    
def keep_largest_component(mask):
    mask = mask.astype(np.uint8)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
    if num_labels < 2:
        return mask
    largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
    cleaned_mask = np.zeros_like(mask)
    cleaned_mask[labels == largest_label] = 1
    return cleaned_mask

def fill_holes(mask):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    filled_mask = np.zeros_like(mask)
    cv2.drawContours(filled_mask, contours, -1, 1, thickness=cv2.FILLED)
    return filled_mask

data = []
file_idx = 0
total_batches = len(test_png_files) // BATCH_SIZE_PREDICT + 1

all_masks = []
all_labels = []
all_uncs = []

for batch_images in tqdm(dist_test_dataset, total=total_batches):
    batch_outputs = strategy.run(predict_step, args=(batch_images,))
    
    masks_tensor = strategy.gather(batch_outputs[0], axis=0)
    labels_tensor = strategy.gather(batch_outputs[1], axis=0)
    uncs_tensor   = strategy.gather(batch_outputs[2], axis=0)

    all_masks.append(masks_tensor.numpy())
    all_labels.append(labels_tensor.numpy())
    all_uncs.append(uncs_tensor.numpy())
    
final_masks = np.concatenate(all_masks, axis=0)
final_labels = np.concatenate(all_labels, axis=0)
final_uncs = np.concatenate(all_uncs, axis=0)

final_masks = final_masks[:N_ORIGINAL]
final_labels = final_labels[:N_ORIGINAL]
final_uncs = final_uncs[:N_ORIGINAL]

original_files = test_png_files[:N_ORIGINAL]

for i, file_path in enumerate(tqdm(original_files)):
    img_id = int(os.path.splitext(os.path.basename(file_path))[0])

    mask_prob = final_masks[i, :, :, 0]
    nerve_prob = final_labels[i, 0]
    unc_map_prob = final_uncs[i, :, :, 0]

    mask_prob_resized = cv2.resize(mask_prob, (580, 420), interpolation=cv2.INTER_LINEAR)
    unc_map_resized = cv2.resize(unc_map_prob, (580, 420), interpolation=cv2.INTER_LINEAR)

    mask_prob_penalized = mask_prob_resized - 0.5 * unc_map_resized
    final_prob_map = mask_prob_penalized * nerve_prob
        
    mask_binary = (final_prob_map > BEST_SEG).astype(np.uint8)

    mask_filtered = filter_by_uncertainty(mask_binary, unc_map_resized, threshold=BEST_UNC)

    mask_clean = remove_small_objects(mask_filtered, min_size=100)
    mask_clean = fill_holes(mask_clean)
    mask_clean = keep_largest_component(mask_clean)
        
    mask_area = np.sum(mask_clean)
    if mask_area < 100:
        mask_clean = np.zeros_like(mask_clean)
            
    if np.sum(mask_clean) == 0:
        encoded = ""
    else:
        encoded = rle_encode(mask_clean)
            
    data.append({
        "img": img_id,
        "pixels": encoded
    })
# Lưu Submission
data = sorted(data, key=lambda x: x['img'])
df = pd.DataFrame(data)
df.to_csv('submission.csv', index=False)