## Install and Import Libraries

In [1]:
! pip install sentencepiece
! pip install transformers
! pip install datasets
! pip install sacremoses
! pip install accelerate
! pip install sacrebleu
! pip install bert_score
! pip install sentence_transformers

Collecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl.metadata (8.3 kB)
Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: sacremoses
Successfully installed sacremoses-0.1.1
Collecting sacrebleu
  Downloading sacrebleu-2.5.1-py3-none-any.whl.metadata (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting portalocker (from sacrebleu)
  Downloading portalocker-3.2.0-py3-none-any.whl.metadata (8.7 kB)
Downloading sacrebleu-2.5.1-py3-none-any.whl (104 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading portalocker-3.2.0-py3-none-any.whl (22 kB)
Installing collected packages: portalocker, sacrebleu
Successfully installed 

In [2]:
import os
import glob
import json
import random
import numpy as np
import pandas as pd
import cv2
import nibabel as nib
import matplotlib.pyplot as plt
import tensorflow as tf
import torch
import sacrebleu

from tqdm import tqdm
from pathlib import Path
from sklearn.model_selection import train_test_split
from skimage.measure import label, regionprops
from transformers import (
    VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor,
    AutoModelForSeq2SeqLM, AutoTokenizer as AutoTokenizerLM,
    DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
)
from datasets import Dataset
from sentence_transformers import SentenceTransformer, util
from bert_score import score as bert_score

## Load Dataset

Sementara hanya menggunakan dataset GLI 

In [None]:
# Load data split radgenome
def load_json_files():
    with open('/kaggle/input/radgenome-caption/train.json', 'r') as f:
        train_data = json.load(f)
    with open('/kaggle/input/radgenome-caption/val.json', 'r') as f:
        val_data = json.load(f)
    with open('/kaggle/input/radgenome-caption/test.json', 'r') as f:
        test_data = json.load(f)
    return train_data, val_data, test_data

# Load caption dari dataset radgenome
def load_captions():
    caption_files = {
        # 'BraTS_MET': '/kaggle/input/radgenome-caption/BraTS_MET/modal_wise_finding.json',
        # 'BraTS_MEN': '/kaggle/input/radgenome-caption/BraTS_MEN/modal_wise_finding.json',
        'BraTS_GLI': '/kaggle/input/radgenome-caption/BraTS_GLI/modal_wise_finding.json'
    }
    captions = {}
    for dataset, file_path in caption_files.items():
        with open(file_path, 'r') as f:
            captions[dataset] = json.load(f)
    return captions

# Fungsi untuk dapatkan path image
def get_image_path(filename):
    if filename.startswith('BraTS-GLI'):
        prefix = filename.rsplit('-', 1)[0]
        return f"/kaggle/input/brats2023-part-1/{prefix}/{filename}.nii"
    elif filename.startswith('BraTS-MEN'):
        prefix = filename.rsplit('-', 1)[0]
        path1 = f"/kaggle/input/brats-men/BraTS-MEN-Train/{prefix}/{filename}.nii"
        path2 = f"/kaggle/input/meningits-part2/BraTS-MEN-Train2/{prefix}/{filename}.nii"
        return path1 if os.path.exists(path1) else path2
    elif filename.startswith('BraTS-MET'):
        prefix = filename.rsplit('-', 1)[0]
        train_path = f"/kaggle/input/brats2023/brats2023/brats2023-training/{prefix}/{filename}.nii"
        val_path = f"/kaggle/input/brats2023/brats2023/brats2023-validation/{prefix}/{filename}.nii"
        return train_path if os.path.exists(train_path) else val_path
    return None

# Ambil caption dan kategori (ini buat captioning nanti, segmentasi blm perlu)
def get_caption(filename, captions):
    if filename.startswith('BraTS-GLI'):
        return captions['BraTS_GLI'].get(filename, '')
    elif filename.startswith('BraTS-MEN'):
        return captions['BraTS_MEN'].get(filename, '')
    elif filename.startswith('BraTS-MET'):
        return captions['BraTS_MET'].get(filename, '')
    return ''

def get_category(filename):
    if filename.startswith('BraTS-GLI'):
        return 'GLI'
    elif filename.startswith('BraTS-MEN'):
        return 'MEN'
    elif filename.startswith('BraTS-MET'):
        return 'MET'
    return 'Unknown'

# Dapatkan path segmentasi dari image path
def get_segmentation_path(row):
    image_path = row['image_path']
    category = row['category']
    if image_path is None:
        return None
    if category in ['GLI', 'MEN', 'MET']:
        seg_filename = row['filename'].rsplit('-', 1)[0] + '-seg.nii'
        return os.path.join(os.path.dirname(image_path), seg_filename)
    return None

def process_dataset(filenames, captions):
    data = []
    for filename in filenames:
        # sementara hanya proses GLI
        if not filename.startswith('BraTS-GLI'):
            continue
            
        if filename.startswith('sub-strokecase'):
            continue
            
        if filename.startswith(('GE3T', 'Singapore', 'Utrecht')):
            continue  # skip WMH
            
        image_path = get_image_path(filename)
        if image_path is None or not os.path.exists(image_path):
            continue
        caption = get_caption(filename, captions)
        category = get_category(filename)
        data.append({
            'filename': filename,
            'image_path': image_path,
            'caption': caption,
            'category': category
        })
    df = pd.DataFrame(data)
    df['segmentation_path'] = df.apply(get_segmentation_path, axis=1)
    df = df[df['segmentation_path'].notna()]
    df = df[df['segmentation_path'].apply(lambda x: isinstance(x, str) and os.path.exists(x))].reset_index(drop=True)
    return df

# # Eksekusi all pipeline
# train_data, val_data, test_data = load_json_files()
# captions = load_captions()

# # Jika ingin menggunakan semua dataset (MEN dan MET)
# # all_filenames = train_data + val_data + test_data

# # Sementara gunakan dataset GLI terlebih dahulu
# all_filenames = [fn for fn in (train_data + val_data + test_data) if fn.startswith('BraTS-GLI')]
# df = process_dataset(all_filenames, captions)

In [None]:
# split train val test sesuai split dari radgenome di awal tadi
train_data, val_data, test_data = load_json_files()
captions = load_captions()

# process_dataset dipanggil untuk tiap split & gabungkan
train_df_src = process_dataset(train_data, captions)
val_df_src   = process_dataset(val_data, captions)
test_df_src  = process_dataset(test_data, captions)
combined = pd.concat([train_df_src, val_df_src, test_df_src], ignore_index=True)

# Base Extraction
combined['base'] = combined['filename'].apply(lambda x: str(x).rsplit("-",1)[0])
bases = combined.groupby('base').first().reset_index()

# stratified split di level base
train_frac = 0.60
temp_frac = 1.0 - train_frac
val_frac_of_temp = 0.5
RANDOM_STATE = 42

train_bases, temp_bases = train_test_split(
    bases,
    stratify=bases['category'],
    test_size=(1.0 - train_frac),
    random_state=RANDOM_STATE
)

val_bases, test_bases = train_test_split(
    temp_bases,
    stratify=temp_bases['category'],
    test_size=val_frac_of_temp,
    random_state=RANDOM_STATE
)

train_df = combined[combined['base'].isin(train_bases['base'])].drop(columns=['base']).reset_index(drop=True)
validation_df = combined[combined['base'].isin(val_bases['base'])].drop(columns=['base']).reset_index(drop=True)
test_df = combined[combined['base'].isin(test_bases['base'])].drop(columns=['base']).reset_index(drop=True)

# simpan hasil rebalanced
train_df.to_csv("train_dataset_rebalanced.csv", index=False)
validation_df.to_csv("val_dataset_rebalanced.csv", index=False)
test_df.to_csv("test_dataset_rebalanced.csv", index=False)

print("Train dataset:")
train_counts = train_df['category'].value_counts()
for category, count in train_counts.items():
    print(f"{category}: {count}")
print(f"Total: {len(train_df)}")

print("\nValidation dataset:")
val_counts = validation_df['category'].value_counts()
for category, count in val_counts.items():
    print(f"{category}: {count}")
print(f"Total: {len(validation_df)}")

print("\nTest dataset:")
test_counts = test_df['category'].value_counts()
for category, count in test_counts.items():
    print(f"{category}: {count}")
print(f"Total: {len(test_df)}")

In [None]:
# # perlu rebalance split data (karena split yang dibikin dari si radgenome ga fair, ada kategori yang samplenya kurang di bbrp split)

# np.random.seed(42)

# train_indices = train_df.index.tolist()
# sample_size = min(400, len(train_indices))
# random_indices = np.random.choice(train_indices, size=sample_size, replace=False)

# train_to_val = train_df.loc[random_indices]
# validation_df = pd.concat([validation_df, train_to_val], ignore_index=True)

# train_df = train_df.drop(random_indices)

# met_train_data = train_df[train_df['category'] == 'MET']
# met_sample_size = min(90, len(met_train_data))
# met_to_test_indices = np.random.choice(met_train_data.index, size=met_sample_size, replace=False)
# met_to_test = train_df.loc[met_to_test_indices]

# test_df = pd.concat([test_df, met_to_test], ignore_index=True)
# train_df = train_df.drop(met_to_test_indices)

# train_df = train_df.reset_index(drop=True)
# validation_df = validation_df.reset_index(drop=True)
# test_df = test_df.reset_index(drop=True)

# train_df.to_csv('train_dataset_rebalanced.csv', index=False)
# validation_df.to_csv('val_dataset_rebalanced.csv', index=False)
# test_df.to_csv('test_dataset_rebalanced.csv', index=False)

# print("Rebalanced Train dataset:")
# train_counts = train_df['category'].value_counts()
# for category, count in train_counts.items():
#     print(f"{category}: {count}")
# print(f"Total: {len(train_df)}")

# print("\nRebalanced Validation dataset:")
# val_counts = validation_df['category'].value_counts()
# for category, count in val_counts.items():
#     print(f"{category}: {count}")
# print(f"Total: {len(validation_df)}")

# print("\nRebalanced Test dataset:")
# test_counts = test_df['category'].value_counts()
# for category, count in test_counts.items():
#     print(f"{category}: {count}")
# print(f"Total: {len(test_df)}")

In [None]:
print(train_df.head())

## Load Model Segmentasi & Generate Predicted Segmentation

In [None]:
# Load model segmentasi
SEG_MODEL_PATH = "/kaggle/input/3d-mri-brain-segmentation/keras/default/1/3D_MRI_Brain_tumor_segmentation.h5"
PRED_OUT_DIR = Path("/kaggle/working/predicted_masks")
PRED_OUT_DIR.mkdir(parents=True, exist_ok=True)
IMG_SIZE = 128  
TEST_RUN = False  # set False to run full dataset (but test True first)
TEST_N = 12

In [None]:
# dice loss
def dice_coef(y_true, y_pred, smooth=1.0):
    class_num = 4
    for i in range(class_num):
        y_true_f = K.flatten(y_true[:,:,:,i])
        y_pred_f = K.flatten(y_pred[:,:,:,i])
        intersection = K.sum(y_true_f * y_pred_f)
        loss = ((2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))
        
        if i == 0:
            total_loss = loss
        else:
            total_loss = total_loss + loss
    total_loss = total_loss / class_num

    return total_loss

# define per class evaluation of dice coef
# inspired by https://github.com/keras-team/keras/issues/9395
def dice_coef_necrotic(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,1] * y_pred[:,:,:,1]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,1])) + K.sum(K.square(y_pred[:,:,:,1])) + epsilon)

def dice_coef_edema(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,2] * y_pred[:,:,:,2]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,2])) + K.sum(K.square(y_pred[:,:,:,2])) + epsilon)

def dice_coef_enhancing(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,3] * y_pred[:,:,:,3]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,3])) + K.sum(K.square(y_pred[:,:,:,3])) + epsilon)

# Computing Precision 
def precision(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision

    
# Computing Sensitivity      
def sensitivity(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + K.epsilon())


# Computing Specificity
def specificity(y_true, y_pred):
    true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())

In [None]:
seg_model = tf.keras.models.load_model(
    SEG_MODEL_PATH,
    custom_objects={
        'dice_coef': dice_coef,
        'precision': precision,
        'sensitivity': sensitivity,
        'specificity': specificity,
        'dice_coef_necrotic': dice_coef_necrotic,
        'dice_coef_edema': dice_coef_edema,
        'dice_coef_enhancing': dice_coef_enhancing
    },
    compile=False
)
print("Loaded segmentation model. GPU available:", bool(tf.config.list_physical_devices("GPU")))
print("Model input shape:", getattr(seg_model, "input_shape", None))

In [None]:
def preprocess_slice_from_vol(vol, idx, img_size=IMG_SIZE):
    sl = vol[:, :, idx].astype(np.float32)
    slr = cv2.resize(sl, (img_size, img_size), interpolation=cv2.INTER_LINEAR)
    maxv = slr.max()
    if maxv > 0:
        slr = slr / maxv
    return slr

def predict_full_volume_singlefile(image_path, seg_out_path, mode='duplicate'):
    vol_obj = nib.load(image_path)
    vol = vol_obj.get_fdata()
    affine = vol_obj.affine
    H, W, Z = vol.shape
    seg_resized = np.zeros((IMG_SIZE, IMG_SIZE, Z), dtype=np.uint8)

    for k in range(Z):
        s1 = preprocess_slice_from_vol(vol, k)
        if mode == 'duplicate':
            s2 = s1
        elif mode == 'zeros':
            s2 = np.zeros_like(s1)
        else:
            s2 = s1

        X = np.zeros((1, IMG_SIZE, IMG_SIZE, 2), dtype=np.float32)
        X[0, :, :, 0] = s1
        X[0, :, :, 1] = s2

        pred = seg_model.predict(X, verbose=0)
        lbl = np.argmax(pred[0], axis=-1).astype(np.uint8)
        seg_resized[:, :, k] = lbl

    # resize back to original H,W using nearest neighbor (labels)
    seg_out = np.zeros((H, W, Z), dtype=np.uint8)
    for k in range(Z):
        seg_out[:, :, k] = cv2.resize(seg_resized[:, :, k], (W, H), interpolation=cv2.INTER_NEAREST)

    nib.save(nib.Nifti1Image(seg_out, affine=affine), str(seg_out_path))
    return True

def predict_full_volume_pair(image_path_A, image_path_B, seg_out_path):
    """
    image_path_A: FLAIR-like (t2f) path
    image_path_B: T1-CE path (must exist)
    Produces seg_out_path (canonical per-base).
    """
    # load volumes
    objA = nib.load(image_path_A); volA = objA.get_fdata()
    objB = nib.load(image_path_B); volB = objB.get_fdata()
    # ensure shapes match (they should in BraTS)
    if volA.shape != volB.shape:
        raise ValueError(f"Shape mismatch: {image_path_A} {volA.shape} vs {image_path_B} {volB.shape}")

    affine = objA.affine
    H, W, Z = volA.shape
    seg_resized = np.zeros((IMG_SIZE, IMG_SIZE, Z), dtype=np.uint8)

    for k in range(Z):
        a = preprocess_slice_from_vol(volA, k)
        b = preprocess_slice_from_vol(volB, k)
        X = np.zeros((1, IMG_SIZE, IMG_SIZE, 2), dtype=np.float32)
        X[0, :, :, 0] = a
        X[0, :, :, 1] = b

        pred = seg_model.predict(X, verbose=0)
        lbl = np.argmax(pred[0], axis=-1).astype(np.uint8)
        seg_resized[:, :, k] = lbl

    seg_out = np.zeros((H, W, Z), dtype=np.uint8)
    for k in range(Z):
        seg_out[:, :, k] = cv2.resize(seg_resized[:, :, k], (W, H), interpolation=cv2.INTER_NEAREST)

    nib.save(nib.Nifti1Image(seg_out, affine=affine), str(seg_out_path))
    return True

def generate_predseg_per_row_and_update_df(df, out_dir=PRED_OUT_DIR):
    """
    For each base in df: if BOTH t2f (FLAIR-like) and t1c (T1-CE) exist,
    generate one canonical <base>-predseg.nii using predict_full_volume_pair(),
    then propagate that path to all rows (all modalities) of that base.

    If t1c missing for a base -> leave pred_seg_path empty for that base.
    """
    out_dir.mkdir(parents=True, exist_ok=True)
    pred_map = {}  # base -> pred_path (only for bases where pair exists)

    # choose iteration set for pass1: debug small set if TEST_RUN True
    rows_for_pass1 = list(df.head(TEST_N).itertuples(index=False)) if TEST_RUN else list(df.itertuples(index=False))

    # PASS 1: generate pair preds for bases that have both t2f and t1c
    for r in tqdm(rows_for_pass1, desc="Pass1: generate pair preds"):
        fname = str(getattr(r, 'filename'))
        base = fname.rsplit('-', 1)[0]
        if base in pred_map:
            continue

        # gather rows for this base from full df
        rows_base = df[df['filename'].str.startswith(base)]
        p_t2 = None
        p_t1c = None
        for _, rr in rows_base.iterrows():
            fn = str(rr['filename']).lower()
            p = rr['image_path']
            if 't2f' in fn or 'flair' in fn:
                p_t2 = p
            if 't1c' in fn or 't1ce' in fn:
                p_t1c = p

        # require both paths to exist on disk
        if not p_t2 or not p_t1c:
            # either no t2f or no t1c → skip this base
            continue
        if not Path(p_t2).exists() or not Path(p_t1c).exists():
            # missing file on disk → skip
            continue

        # canonical output name per base (no modality suffix)
        out_name = f"{base}-predseg.nii"
        out_path = out_dir / out_name

        if not out_path.exists():
            try:
                predict_full_volume_pair(p_t2, p_t1c, out_path)
            except Exception as e:
                print("pair predict failed for", base, p_t2, p_t1c, "error:", e)
                continue

        pred_map[base] = str(out_path)

    # PASS 2: propagate pred_map to all rows of df
    pred_paths = []
    for _, row in df.iterrows():
        base = str(row['filename']).rsplit('-', 1)[0]
        pred_paths.append(pred_map.get(base, ""))

    df['pred_seg_path'] = pred_paths
    return df

In [None]:
train_df = generate_predseg_per_row_and_update_df(train_df)
train_df.to_csv('train_dataset_rebalanced_with_predseg.csv', index=False)

val_df = generate_predseg_per_row_and_update_df(validation_df)
val_df.to_csv('val_dataset_rebalanced_with_predseg.csv', index=False)

test_df = generate_predseg_per_row_and_update_df(test_df)
test_df.to_csv('test_dataset_rebalanced_with_predseg.csv', index=False)

In [None]:
# Fungsi untuk mengganti path /kaggle/working ke path dataset yang sudah dibuat
PRED_INPUT_DIR = Path("/kaggle/input/hasil-segmentasi-2/results (2)/predicted_masks")

def update_predseg_paths_from_input(df, pred_input_dir=PRED_INPUT_DIR):
    """
    Untuk tiap base di df, cari file <base>-predseg.nii di pred_input_dir.
    Jika ada, set df['pred_seg_path'] = str(path). Jika tidak ada, set ke "" (kosong).
    Fungsi tidak membuat file baru, hanya update/overwrite kolom pred_seg_path.
    """
    pred_input_dir = Path(pred_input_dir)
    pred_map = {}

    # buat set base unik untuk efisiensi
    bases = df['filename'].apply(lambda x: str(x).rsplit('-', 1)[0]).unique()

    for base in tqdm(bases, desc="Mapping existing predicted masks"):
        candidate = pred_input_dir / f"{base}-predseg.nii"
        if candidate.exists():
            pred_map[base] = str(candidate)
        else:
            # kalau tidak ada, kosongkan / atau bisa biarkan nilai lama
            pred_map[base] = ""

    # terapkan ke semua baris (tidak menambah kolom baru, langsung menimpa/isi kolom `pred_seg_path`)
    df['pred_seg_path'] = df['filename'].apply(lambda fn: pred_map.get(str(fn).rsplit('-', 1)[0], ""))

    return df

train_df = update_predseg_paths_from_input(train_df)
val_df   = update_predseg_paths_from_input(validation_df)
test_df  = update_predseg_paths_from_input(test_df)

# simpan kembali ke CSV — ini hanya menulis CSV, bukan membuat predseg baru
train_df.to_csv('train_dataset_rebalanced_with_predseg.csv', index=False)
val_df.to_csv('val_dataset_rebalanced_with_predseg.csv', index=False)
test_df.to_csv('test_dataset_rebalanced_with_predseg.csv', index=False)

In [None]:
# Load csv
csv_path = "/kaggle/input/hasil-segmentasi-2/results (2)/train_dataset_rebalanced_with_predseg.csv"
train_df = pd.read_csv(csv_path)

# Fungsi untuk visualisasi
def pick_and_show_sample(df, slice_idx=None, force_new=False):
    global _SELECTED_SAMPLE
    if force_new or '_SELECTED_SAMPLE' not in globals():
        _SELECTED_SAMPLE = df.sample(1).iloc[0]
        print("New random sample selected:", _SELECTED_SAMPLE['filename'])
    else:
        print("Using cached sample:", _SELECTED_SAMPLE['filename'])
    row = _SELECTED_SAMPLE
    img = np.nan_to_num(nib.load(row['image_path']).get_fdata())
    mask = np.nan_to_num(nib.load(row['segmentation_path']).get_fdata())
    num_slices = img.shape[2]
    if slice_idx is None:
        slice_idx = num_slices // 2
    img_slice = img[:, :, slice_idx]
    mask_slice = mask[:, :, slice_idx]
    if img_slice.max() > img_slice.min():
        norm_img = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min() + 1e-8)
    else:
        norm_img = np.zeros_like(img_slice)
    img_rgb = np.stack([norm_img]*3, axis=-1)
    mask_bin = (mask_slice > 0).astype(float)
    overlay = img_rgb.copy()
    overlay[:, :, 0] += mask_bin * 0.5
    overlay = np.clip(overlay, 0, 1)
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(img_rgb)
    axes[0].set_title("Original Image")
    axes[1].imshow(mask_bin, cmap='gray')
    axes[1].set_title("Mask")
    axes[2].imshow(overlay)
    axes[2].set_title("Overlay")
    for ax in axes:
        ax.axis('off')
    plt.tight_layout()
    plt.show()

pick_and_show_sample(train_df, force_new=True)

In [None]:
pick_and_show_sample(train_df, slice_idx=50, force_new=True) 

In [None]:
pick_and_show_sample(train_df, slice_idx=65) 

In [None]:
pick_and_show_sample(train_df, slice_idx=77) #(slice tengah biasanya di slice 77)

## Predicted Segmentation Policy

In [3]:
# Fungsi menghitung dice bin karena ada beberapa predicted segmentation yang jelek, jika jelek pakai Ground Truth segmentation
def dice_bin(a, b, eps=1e-6):
    a = (a > 0).astype(np.uint8)
    b = (b > 0).astype(np.uint8)
    inter = np.sum(a * b)
    return (2 * inter) / (np.sum(a) + np.sum(b) + eps)

def compute_dice_map(df):
    if 'base' not in df.columns:
        df['base'] = df['filename'].apply(lambda x: str(x).rsplit('-', 1)[0])
    bases = sorted(df['base'].unique())
    dice_map = {}
    for base in tqdm(bases, desc="Compute dice per base"):
        rows = df[df['base']==base]
        pred = next((p for p in rows['pred_seg_path'].fillna("").tolist() if p), "")
        gt = next((g for g in rows['segmentation_path'].fillna("").tolist() if g), "")
        if pred and gt and Path(pred).exists() and Path(gt).exists():
            try:
                a = nib.load(pred).get_fdata().astype(np.uint8)
                b = nib.load(gt).get_fdata().astype(np.uint8)
                dice_map[base] = float(dice_bin(a,b)) if a.shape==b.shape else np.nan
            except Exception:
                dice_map[base] = np.nan
        else:
            dice_map[base] = np.nan
    return dice_map

def apply_pred_vs_gt_policy(df, threshold=0.6, show_summary=True):
    """
    Update df['pred_seg_path'] in-place according to policy:
      - compute dice per base when both pred & gt exist and shapes match
      - if dice < threshold -> use GT (segmentation_path) as pred_seg_path (if GT exists)
      - else keep pred (if exists)
      - if neither exists -> empty string

    Returns updated df (same object) and dice_map (dict base->dice or np.nan)
    """
    # ensure base column exists or compute it
    if 'base' not in df.columns:
        df['base'] = df['filename'].apply(lambda x: str(x).rsplit('-', 1)[0])

    bases = sorted(df['base'].unique())
    dice_map = {}
    replaced_bases = []
    kept_bases = []
    missing_bases = []

    for base in tqdm(bases, desc="Computing dice per base"):
        rows_base = df[df['base'] == base]
        # pick first non-empty pred and gt for this base (if any)
        pred = ""
        gt = ""
        # prefer non-empty strings and paths that exist
        for p in rows_base['pred_seg_path'].fillna("").unique():
            if p:
                pred = str(p)
                break
        for g in rows_base['segmentation_path'].fillna("").unique():
            if g:
                gt = str(g)
                break

        # normalize to Path for existence check
        pred_path = Path(pred) if pred else None
        gt_path = Path(gt) if gt else None

        if pred_path and pred_path.exists() and gt_path and gt_path.exists():
            try:
                a = nib.load(str(pred_path)).get_fdata()
                b = nib.load(str(gt_path)).get_fdata()
            except Exception as e:
                # unreadable file -> mark as nan
                dice_map[base] = np.nan
                missing_bases.append(base)
                continue

            if a.shape == b.shape:
                d = float(dice_bin(a, b))
                dice_map[base] = d
            else:
                dice_map[base] = np.nan
            # apply policy
            if not np.isnan(dice_map[base]) and dice_map[base] < threshold:
                # prefer GT if available
                df.loc[df['base'] == base, 'pred_seg_path'] = str(gt_path)
                replaced_bases.append(base)
            else:
                # dice >= threshold -> keep pred (no-op)
                kept_bases.append(base)
        else:
            # missing either pred or gt
            dice_map[base] = np.nan
            # If GT exists but pred missing -> set pred_seg_path to GT
            if gt_path and gt_path.exists():
                df.loc[df['base'] == base, 'pred_seg_path'] = str(gt_path)
                replaced_bases.append(base)
            elif pred_path and pred_path.exists():
                # only pred exists -> keep it
                kept_bases.append(base)
            else:
                # neither exists -> set empty
                df.loc[df['base'] == base, 'pred_seg_path'] = ""
                missing_bases.append(base)

    if show_summary:
        total = len(bases)
        print(f"Total bases evaluated: {total}")
        print(f"Kept pred (dice >= {threshold} or only pred exists): {len(kept_bases)}")
        print(f"Replaced by GT (dice < {threshold} or pred missing but GT exists): {len(replaced_bases)}")
        print(f"No mask available (neither pred nor GT): {len(missing_bases)}")
        # optional: some dice stats (exclude nan)
        dice_vals = [v for v in dice_map.values() if not np.isnan(v)]
        if len(dice_vals) > 0:
            print(f"Dice stats on computed bases: mean {np.mean(dice_vals):.3f}, median {np.median(dice_vals):.3f}, count {len(dice_vals)}")
        else:
            print("No valid dice values computed (no pairs with both files and matching shapes).")

    return df, dice_map
    
# Fungsi Policy untuk Menentukan Segmentation yang akan Digunakan
def apply_policy_overwrite_pred_with_gt(df, dice_map, threshold=0.6, verbose=True):
    if 'base' not in df.columns:
        df['base'] = df['filename'].apply(lambda x: str(x).rsplit('-',1)[0])
    for base, d in dice_map.items():
        rows_base = (df['base'] == base)
        if not np.isnan(d):
            if d < threshold:
                gt = df.loc[rows_base, 'segmentation_path'].fillna("").iloc[0] if any(df.loc[rows_base,'segmentation_path'].fillna("")!="") else ""
                if gt and Path(gt).exists():
                    df.loc[rows_base, 'pred_seg_path'] = gt
        else:
            # fallback: if pred missing but GT exists -> set pred to GT
            pred_exists = any(df.loc[rows_base,'pred_seg_path'].fillna("")!="")
            gt_exists = any(df.loc[rows_base,'segmentation_path'].fillna("")!="")
            if (not pred_exists) and gt_exists:
                gt = df.loc[rows_base, 'segmentation_path'].fillna("").iloc[0]
                if gt and Path(gt).exists():
                    df.loc[rows_base, 'pred_seg_path'] = gt
    if verbose:
        vals = [v for v in dice_map.values() if not np.isnan(v)]
        if vals:
            print(f"Dice computed for {len(vals)} bases (mean={np.mean(vals):.3f}, median={np.median(vals):.3f})")
    return df

In [4]:
# # Penggunaan
# train_df, train_dice_map = apply_pred_vs_gt_policy(train_df, threshold=0.6)
# train_df.to_csv('train_dataset_rebalanced_with_predseg.csv', index=False)

# # validation/test
# validation_df, val_dice_map = apply_pred_vs_gt_policy(validation_df, threshold=0.6)
# validation_df.to_csv('val_dataset_rebalanced_with_predseg.csv', index=False)

# test_df, test_dice_map = apply_pred_vs_gt_policy(test_df, threshold=0.6)
# test_df.to_csv('test_dataset_rebalanced_with_predseg.csv', index=False)

In [5]:
import pandas as pd
import os

# Tentukan base path ke folder hasil Anda
base_path = "/kaggle/input/hasil-segmentasi-2/"

# Definisikan nama file
train_file = "train_dataset_rebalanced_with_predseg (12).csv"
val_file = "val_dataset_rebalanced_with_predseg (2).csv"
test_file = "test_dataset_rebalanced_with_predseg (3).csv"

# Load dataset dari file CSV

train_df = pd.read_csv(os.path.join(base_path, train_file))
validation_df = pd.read_csv(os.path.join(base_path, val_file))
test_df = pd.read_csv(os.path.join(base_path, test_file))
    
print("Berhasil me-load train_df, validation_df, dan test_df dari file CSV.")
print(f"Jumlah data train: {len(train_df)}")
print(f"Jumlah data validasi: {len(validation_df)}")
print(f"Jumlah data test: {len(test_df)}")

Berhasil me-load train_df, validation_df, dan test_df dari file CSV.
Jumlah data train: 280
Jumlah data validasi: 92
Jumlah data test: 96


In [6]:
train_df.head()

Unnamed: 0,filename,image_path,caption,category,segmentation_path,pred_seg_path,base
0,BraTS-GLI-00106-000-t1c,/kaggle/input/brats2023-part-1/BraTS-GLI-00106...,Post-contrast T1-weighted images reveal promin...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00106...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00106-000
1,BraTS-GLI-00106-000-t1n,/kaggle/input/brats2023-part-1/BraTS-GLI-00106...,There is an irregular lesion in the right fron...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00106...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00106-000
2,BraTS-GLI-00106-000-t2f,/kaggle/input/brats2023-part-1/BraTS-GLI-00106...,"On FLAIR sequence, the lesion shows high signa...",GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00106...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00106-000
3,BraTS-GLI-00106-000-t2w,/kaggle/input/brats2023-part-1/BraTS-GLI-00106...,The lesion in the right frontal lobe displays ...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00106...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00106-000
4,BraTS-GLI-00024-000-t1c,/kaggle/input/brats2023-part-1/BraTS-GLI-00024...,The lesion demonstrates ring-enhancement post-...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00024...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00024-000


In [7]:
validation_df.head()

Unnamed: 0,filename,image_path,caption,category,segmentation_path,pred_seg_path,base
0,BraTS-GLI-00686-000-t1c,/kaggle/input/brats2023-part-1/BraTS-GLI-00686...,Post-contrast T1-weighted images reveal signif...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00686...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00686-000
1,BraTS-GLI-00686-000-t1n,/kaggle/input/brats2023-part-1/BraTS-GLI-00686...,"In the right frontal, temporal, and insular lo...",GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00686...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00686-000
2,BraTS-GLI-00686-000-t2f,/kaggle/input/brats2023-part-1/BraTS-GLI-00686...,"On FLAIR images, the lesion in the right front...",GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00686...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00686-000
3,BraTS-GLI-00686-000-t2w,/kaggle/input/brats2023-part-1/BraTS-GLI-00686...,"The lesion in the right frontal, temporal, and...",GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00686...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00686-000
4,BraTS-GLI-00488-000-t1c,/kaggle/input/brats2023-part-1/BraTS-GLI-00488...,"After contrast administration, the lesion in t...",GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00488...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00488-000


In [8]:
test_df.head()

Unnamed: 0,filename,image_path,caption,category,segmentation_path,pred_seg_path,base
0,BraTS-GLI-00291-000-t1c,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,The lesions exhibit uneven enhancement post co...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00291-000
1,BraTS-GLI-00291-000-t1n,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,Two lesions in the right parietal lobe show is...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00291-000
2,BraTS-GLI-00291-000-t2f,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,Two lesions in the right parietal lobe show mi...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00291-000
3,BraTS-GLI-00291-000-t2w,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,Two lesions in the right parietal lobe show hi...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00291-000
4,BraTS-GLI-00706-000-t1c,/kaggle/input/brats2023-part-1/BraTS-GLI-00706...,"After contrast administration, the lesion in t...",GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00706...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00706-000


## Find Keywords

### Extract Keywords from Segmentation Mask

In [10]:
# KEYWORD PAKAI LOKASI

def extract_mask_keywords(mask_nii_path):
    # Load NIfTI
    nii = nib.load(mask_nii_path)
    seg_vol = nii.get_fdata().astype(int)
    vox_dims = nii.header.get_zooms()
    voxel_vol = np.prod(vox_dims)

    keywords = {}
    solidity = {}
    eccentricity = {}

    # --- 1. Analisis Per-Kelas (Volume & Bentuk) ---
    for c in [1, 2, 3]:
        coords = np.argwhere(seg_vol == c)
        if len(coords) == 0:
            continue
            
        # Volume
        cnt = coords.shape[0]
        vol_cm3 = cnt * voxel_vol / 1000
        
        if vol_cm3 < 20: size = "small"
        elif vol_cm3 <= 50: size = "moderate"
        else: size = "large"

        keywords[f"class_{c}_vol"] = round(vol_cm3, 1)
        keywords[f"class_{c}_size"] = size

        # Shape (Solidity & Eccentricity per slice)
        sols, eccs = [], []
        for k in np.unique(coords[:, 2]):
            mask2d = (seg_vol[:, :, k] == c).astype(int)
            if mask2d.sum() == 0: continue
            rp = regionprops(label(mask2d))[0]
            sols.append(rp.solidity)
            eccs.append(rp.eccentricity)
        
        if sols:
            solidity[c] = np.mean(sols)
            keywords[f"class_{c}_shape"] = "compact" if solidity[c] >= 0.8 else "irregular"
        
        if eccs:
            eccentricity[c] = np.mean(eccs)
            keywords[f"class_{c}_form"] = "rounded" if eccentricity[c] < 0.7 else "elongated"

    # --- 2. Analisis Lokasi (Global Tumor) ---
    # Menggunakan gabungan seluruh tumor (seg > 0) untuk lokasi
    tumor_mask = seg_vol > 0
    
    if tumor_mask.sum() == 0:
        keywords['location'] = "unknown"
    else:
        # Tentukan titik tengah sumbu X
        x_mid = seg_vol.shape[0] // 2
        
        # Hitung voxel di sisi kiri dan kanan IMAGE array
        # (Asumsi NIfTI standar)
        image_left_voxels = tumor_mask[:x_mid, :, :].sum()
        image_right_voxels = tumor_mask[x_mid:, :, :].sum()
        
        # Logika Koreksi: Radiological View Flip
        # Sisi Kanan Image = Sisi Kiri Pasien (Anatomical LEFT)
        # Sisi Kiri Image  = Sisi Kanan Pasien (Anatomical RIGHT)
        
        if image_right_voxels > image_left_voxels * 1.2:
            loc_str = "left"    # Dominan di image kanan -> Pasien Kiri
        elif image_left_voxels > image_right_voxels * 1.2:
            loc_str = "right"   # Dominan di image kiri -> Pasien Kanan
        else:
            loc_str = "bilateral" # Tersebar di kedua sisi
            
        keywords['location'] = loc_str

    return keywords

In [13]:
import pprint
from pathlib import Path

# Tes fungsi 
df = train_df  

# contoh 10 base unik (kamu bisa ubah jumlahnya)
sample_bases = df['base'].unique().tolist()[:10]
print("Contoh bases (N={}):".format(len(sample_bases)))
print(sample_bases)

# pilih index (0..len(sample_bases)-1)
i = 8
base = sample_bases[i]
print("Memeriksa base:", base)

# ambil pred path (first non-empty)
rows = df[df['base']==base]
pred = next((p for p in rows['pred_seg_path'].fillna("").tolist() if p), "")
gt   = next((g for g in rows['segmentation_path'].fillna("").tolist() if g), "")

print("pred_seg_path:", pred)
print("segmentation_path (GT):", gt)

# cek keberadaan file
if pred and Path(pred).exists():
    print("Pred file exists — running extract_mask_keywords(pred)...")
    kws_pred = extract_mask_keywords(pred)
    pprint.pprint(kws_pred)
else:
    print("Pred file tidak ada atau kosong — skip extract pred.")

# opsional: juga cek GT
if gt and Path(gt).exists():
    print("\nGT file exists — running extract_mask_keywords(gt)...")
    kws_gt = extract_mask_keywords(gt)
    pprint.pprint(kws_gt)
else:
    print("GT file tidak ada atau kosong — skip extract GT.")


Contoh bases (N=10):
['BraTS-GLI-00106-000', 'BraTS-GLI-00024-000', 'BraTS-GLI-00604-000', 'BraTS-GLI-00734-001', 'BraTS-GLI-00322-000', 'BraTS-GLI-00598-000', 'BraTS-GLI-00397-000', 'BraTS-GLI-00443-000', 'BraTS-GLI-00652-000', 'BraTS-GLI-00478-000']
Memeriksa base: BraTS-GLI-00652-000
pred_seg_path: /kaggle/input/hasil-segmentasi-2/results (2)/predicted_masks/BraTS-GLI-00652-000-predseg.nii
segmentation_path (GT): /kaggle/input/brats2023-part-1/BraTS-GLI-00652-000/BraTS-GLI-00652-000-seg.nii
Pred file exists — running extract_mask_keywords(pred)...
{'class_1_form': 'elongated',
 'class_1_shape': 'irregular',
 'class_1_size': 'moderate',
 'class_1_vol': 32.4,
 'class_2_form': 'elongated',
 'class_2_shape': 'irregular',
 'class_2_size': 'large',
 'class_2_vol': 109.0,
 'class_3_form': 'elongated',
 'class_3_shape': 'irregular',
 'class_3_size': 'large',
 'class_3_vol': 85.6,
 'location': 'right'}

GT file exists — running extract_mask_keywords(gt)...
{'class_1_form': 'elongated',
 'cla

In [14]:
# Fungsi untuk membuat keyword map per base
def build_keywords_map_from_df(df, extractor_fn):
    bases = sorted(df['base'].unique())
    kw_map = {}
    for base in tqdm(bases, desc="Extract keywords per base"):
        rows = df[df['base']==base]
        pred = next((p for p in rows['pred_seg_path'].fillna("").tolist() if p), "")
        if pred and Path(pred).exists():
            try:
                kws = extractor_fn(pred)
                kw_map[base] = kws
            except Exception:
                kw_map[base] = {}
        else:
            kw_map[base] = {}
    return kw_map

In [15]:
subset_df = df[df['base'].isin(sample_bases)].copy()
kw_map_small = build_keywords_map_from_df(subset_df, extract_mask_keywords)

# tampilkan ringkasan
print("Total bases processed:", len(kw_map_small))
# contoh 5 base pertama dengan isi
import itertools
for base, kws in itertools.islice(kw_map_small.items(), 10):
    print("\nBase:", base)
    if kws:
        print("  keys:", list(kws.keys()))
        # tampilkan beberapa nilai penting kalau ada
        for k in ['class_1_vol','class_1_size','location','class_1_shape','class_1_form']:
            if k in kws:
                print(f"   {k}: {kws[k]}")
    else:
        print("  (no keywords found)")

Extract keywords per base: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]

Total bases processed: 10

Base: BraTS-GLI-00024-000
  keys: ['class_1_vol', 'class_1_size', 'class_1_shape', 'class_1_form', 'class_2_vol', 'class_2_size', 'class_2_shape', 'class_2_form', 'class_3_vol', 'class_3_size', 'class_3_shape', 'class_3_form', 'location']
   class_1_vol: 20.9
   class_1_size: moderate
   location: left
   class_1_shape: irregular
   class_1_form: elongated

Base: BraTS-GLI-00106-000
  keys: ['class_1_vol', 'class_1_size', 'class_1_shape', 'class_1_form', 'class_2_vol', 'class_2_size', 'class_2_shape', 'class_2_form', 'class_3_vol', 'class_3_size', 'class_3_shape', 'class_3_form', 'location']
   class_1_vol: 13.5
   class_1_size: small
   location: right
   class_1_shape: irregular
   class_1_form: elongated

Base: BraTS-GLI-00322-000
  keys: ['class_1_vol', 'class_1_size', 'class_1_shape', 'class_1_form', 'class_2_vol', 'class_2_size', 'class_2_shape', 'class_2_form', 'class_3_vol', 'class_3_size', 'class_3_shape', 'class_3_form', 'location']
   class_1_vol: 




In [16]:
# # Fungsi untuk konversi keyword menjadi frasa natural singkat
def keywords_to_phrase(kws):
    if not kws:
        return ""
    parts = []
    if 'class_1_size' in kws and 'class_1_vol' in kws:
        parts.append(f"a {kws['class_1_size']} lesion (~{kws['class_1_vol']} cm³)")
    if 'location' in kws:
        parts.append(f"located in the {kws['location']}")
    if 'class_1_shape' in kws:
        parts.append(f"with {kws['class_1_shape']} morphology")
    if 'class_1_form' in kws:
        parts.append(f"and {kws['class_1_form']} form")
    phrase = ", ".join(parts)
    return phrase

In [18]:
# coba beberapa contoh phrase
for base in sample_bases[:10]:
    kws = kw_map_small.get(base, {}) or {}
    phrase = keywords_to_phrase(kws)
    print(f"\nBase: {base}")
    print(" kws:", kws)
    print(" phrase:", repr(phrase))


Base: BraTS-GLI-00106-000
 kws: {'class_1_vol': 13.5, 'class_1_size': 'small', 'class_1_shape': 'irregular', 'class_1_form': 'elongated', 'class_2_vol': 75.9, 'class_2_size': 'large', 'class_2_shape': 'irregular', 'class_2_form': 'elongated', 'class_3_vol': 146.0, 'class_3_size': 'large', 'class_3_shape': 'irregular', 'class_3_form': 'elongated', 'location': 'right'}
 phrase: 'a small lesion (~13.5 cm³), located in the right, with irregular morphology, and elongated form'

Base: BraTS-GLI-00024-000
 kws: {'class_1_vol': 20.9, 'class_1_size': 'moderate', 'class_1_shape': 'irregular', 'class_1_form': 'elongated', 'class_2_vol': 76.9, 'class_2_size': 'large', 'class_2_shape': 'irregular', 'class_2_form': 'elongated', 'class_3_vol': 9.6, 'class_3_size': 'small', 'class_3_shape': 'compact', 'class_3_form': 'elongated', 'location': 'left'}
 phrase: 'a moderate lesion (~20.9 cm³), located in the left, with irregular morphology, and elongated form'

Base: BraTS-GLI-00604-000
 kws: {'class_1_v

In [20]:
def inject_keywords_into_caption(original_caption, kw_phrase, method='append'):
    if not kw_phrase:
        return original_caption
    if method == 'integrate':
        import re
        sents = re.split(r'(?<=[.!?])\s+', original_caption.strip())
        if len(sents) > 1:
            sents[0] = sents[0].rstrip('.!?') + f", which is {kw_phrase}."
            return " ".join(sents)
        else:
            return original_caption.rstrip('.') + f". The segmentation suggests {kw_phrase}."
    else:
        return original_caption.rstrip('.') + f". The segmentation suggests {kw_phrase}."

In [21]:
base = sample_bases[0]
rows = df[df['base']==base].iloc[0]  # ambil satu row untuk caption dan preview
orig_caption = rows['caption']
kws = kw_map_small.get(base, {}) or {}
phrase = keywords_to_phrase(kws)
injected_preview = inject_keywords_into_caption(orig_caption, phrase, method='append')

print("Original caption:\n", orig_caption)
print("\nPhrase to inject:\n", phrase)
print("\nInjected preview:\n", injected_preview)


Original caption:
 Post-contrast T1-weighted images reveal prominent garland-like enhancement of the lesion, with unclear margins.

Phrase to inject:
 a small lesion (~13.5 cm³), located in the right, with irregular morphology, and elongated form

Injected preview:
 Post-contrast T1-weighted images reveal prominent garland-like enhancement of the lesion, with unclear margins. The segmentation suggests a small lesion (~13.5 cm³), located in the right, with irregular morphology, and elongated form.


In [22]:
def prepare_injected_caption_df(df, extractor_fn, threshold=0.6):
    dice_map = compute_dice_map(df)
    df = apply_policy_overwrite_pred_with_gt(df, dice_map, threshold=threshold, verbose=True)
    kw_map = build_keywords_map_from_df(df, extractor_fn)
    if 'base' not in df.columns:
        df['base'] = df['filename'].apply(lambda x: str(x).rsplit('-',1)[0])
    injected = []
    for _, row in df.iterrows():
        base = row['base']
        kws = kw_map.get(base, {}) or {}
        phrase = keywords_to_phrase(kws)
        newcap = inject_keywords_into_caption(str(row['caption']), phrase, method='append')
        injected.append(newcap)
    df['caption_injected'] = injected
    return df, dice_map, kw_map

## Create injected captions

In [23]:
# prepare
threshold = 0.6
df_prepared, dice_map, kw_map = prepare_injected_caption_df(train_df.copy(), extract_mask_keywords, threshold=threshold)
# agar mudah inspeksi, simpan juga df_prepared sebagai train_prepared
train_prepared = df_prepared
print("Done. Rows:", len(train_prepared))

Compute dice per base: 100%|██████████| 70/70 [00:20<00:00,  3.35it/s]


Dice computed for 70 bases (mean=0.827, median=0.853)


Extract keywords per base: 100%|██████████| 70/70 [00:40<00:00,  1.72it/s]

Done. Rows: 280





In [24]:
import numpy as np
from collections import Counter

# dice summary
dice_vals = [v for v in dice_map.values() if not np.isnan(v)]
print("Bases total:", len(dice_map))
print("Dice computed (non-nan):", len(dice_vals))
if dice_vals:
    print("  mean: {:.3f}, median: {:.3f}, min: {:.3f}, max: {:.3f}".format(np.mean(dice_vals), np.median(dice_vals), np.min(dice_vals), np.max(dice_vals)))

# bases with dice < threshold
low_dice_bases = [b for b,v in dice_map.items() if (not np.isnan(v)) and (v < threshold)]
print("Bases with dice < {:.2f}: {}".format(threshold, len(low_dice_bases)))

# how many bases have no dice computed (nan)
nan_bases = [b for b,v in dice_map.items() if np.isnan(v)]
print("Bases with no dice (nan):", len(nan_bases))


Bases total: 70
Dice computed (non-nan): 70
  mean: 0.827, median: 0.853, min: 0.614, max: 1.000
Bases with dice < 0.60: 0
Bases with no dice (nan): 0


## Train Preparation & Finetuning

In [26]:
# load model + tokenizer + feature_extractor

import os
from pathlib import Path
import random
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor, Trainer, TrainingArguments

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
os.environ["WANDB_DISABLED"] = "true" 
MODEL_DIR = "/kaggle/input/3dvit-biomedbert/tfjs/default/1/image caption model/model"
TOKENIZER_DIR = "/kaggle/input/3dvit-biomedbert/tfjs/default/1/image caption model/tokenizer"
FEATURE_EXTRACTOR_DIR = "/kaggle/input/3dvit-biomedbert/tfjs/default/1/image caption model/model"
T5_OUT_DIR = "./t5_edit_model"   
RESULTS_DIR = Path.cwd()
RESULTS_DIR.mkdir(exist_ok=True)

print("Loading model/tokenizer/feature_extractor")
vision_model = VisionEncoderDecoderModel.from_pretrained(MODEL_DIR).to(DEVICE)
vision_tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
feature_extractor = ViTFeatureExtractor.from_pretrained(FEATURE_EXTRACTOR_DIR)

# config / generation defaults
vision_model.config.decoder_start_token_id = vision_tokenizer.cls_token_id if vision_tokenizer.cls_token_id is not None else vision_tokenizer.bos_token_id
vision_model.config.pad_token_id = vision_tokenizer.pad_token_id
vision_model.config.eos_token_id = vision_tokenizer.eos_token_id
vision_model.config.max_length = 64
vision_model.config.num_beams = 4

print("Model loaded.")

Device: cuda
Loading model/tokenizer/feature_extractor


The following encoder weights were not tied to the decoder ['vision_encoder_decoder/layernorm', 'vision_encoder_decoder/embeddings', 'vision_encoder_decoder/encoder', 'vision_encoder_decoder/pooler']
The following encoder weights were not tied to the decoder ['vision_encoder_decoder/layernorm', 'vision_encoder_decoder/embeddings', 'vision_encoder_decoder/encoder', 'vision_encoder_decoder/pooler']
The following encoder weights were not tied to the decoder ['vision_encoder_decoder/layernorm', 'vision_encoder_decoder/embeddings', 'vision_encoder_decoder/encoder', 'vision_encoder_decoder/pooler']


Model loaded.




In [27]:
from PIL import Image
import numpy as np
import nibabel as nib

def nifti_to_pil_slice(nifti_path, slice_idx=None, normalize=True):
    """Load NIfTI, pick slice (center jika None), return PIL.Image (uint8 grayscale)."""
    img = nib.load(str(nifti_path)).get_fdata()
    # ensure 3D: if 4D, pick first volume
    if img.ndim == 4:
        img = img[..., 0]
    if img.ndim != 3:
        raise ValueError(f"Unsupported nii shape: {img.shape}")
    z = img.shape[2]
    if slice_idx is None:
        slice_idx = z // 2
    sl = np.array(img[:, :, slice_idx], dtype=float)
    if normalize:
        mn, mx = sl.min(), sl.max()
        if mx > mn:
            sl = (sl - mn) / (mx - mn)
        else:
            sl = sl * 0.0
    # convert to 0-255 uint8
    arr = (sl * 255).astype(np.uint8)
    pil = Image.fromarray(arr)
    # convert to RGB if feature_extractor expects 3-channel
    if pil.mode != "RGB":
        pil = pil.convert("RGB")
    return pil


In [28]:
def generate_captions_from_df_safe(df, vision_model, feature_extractor, vision_tokenizer,
                                   image_col='image_path', out_col='caption_generated',
                                   batch_size=8, slice_idx=None, device=None):
    if device is None:
        device = DEVICE
    vision_model = vision_model.to(device)
    vision_model.eval()
    generated = []
    paths = df[image_col].tolist()
    for i in tqdm(range(0, len(paths), batch_size), desc="Generating captions"):
        batch_paths = paths[i:i+batch_size]
        imgs = []
        for p in batch_paths:
            try:
                pil = nifti_to_pil_slice(p, slice_idx=slice_idx)
                imgs.append(pil)
            except Exception:
                imgs.append(None)
        imgs_input = [im for im in imgs if im is not None]
        if len(imgs_input) == 0:
            generated.extend([""]*len(imgs)); continue
        enc = feature_extractor(images=imgs_input, return_tensors="pt")
        pixel_values = enc['pixel_values'].to(device)
        with torch.no_grad():
            out = vision_model.generate(pixel_values=pixel_values,
                                        max_length=getattr(vision_model.config,'max_length',64),
                                        num_beams=getattr(vision_model.config,'num_beams',4))
        decoded = vision_tokenizer.batch_decode(out, skip_special_tokens=True)
        it = iter(decoded)
        for im in imgs:
            generated.append("" if im is None else next(it))
    df[out_col] = generated
    return df

In [29]:
train_df = generate_captions_from_df_safe(train_df, vision_model, feature_extractor, vision_tokenizer, 
                                          image_col='image_path', out_col='caption_generated', batch_size=4)

train_df[['filename','caption_generated']].head()

Generating captions: 100%|██████████| 70/70 [02:44<00:00,  2.35s/it]


Unnamed: 0,filename,caption_generated
0,BraTS-GLI-00106-000-t1c,"after contrast administration, the lesion in t..."
1,BraTS-GLI-00106-000-t1n,in the right basal ganglia and insular - tempo...
2,BraTS-GLI-00106-000-t2f,"on flair imaging, the lesion in the right fron..."
3,BraTS-GLI-00106-000-t2w,the lesion in the right frontal lobe shows hig...
4,BraTS-GLI-00024-000-t1c,"after contrast administration, the lesion in t..."


In [30]:
validation_df = generate_captions_from_df_safe(validation_df, vision_model, feature_extractor, vision_tokenizer, 
                                          image_col='image_path', out_col='caption_generated', batch_size=4)

# quick peek
validation_df[['filename','caption_generated']].head()

Generating captions: 100%|██████████| 23/23 [00:54<00:00,  2.38s/it]


Unnamed: 0,filename,caption_generated
0,BraTS-GLI-00686-000-t1c,"after contrast administration, the lesion in t..."
1,BraTS-GLI-00686-000-t1n,in the right basal ganglia and insular - tempo...
2,BraTS-GLI-00686-000-t2f,"on flair sequence, the lesion in the right bas..."
3,BraTS-GLI-00686-000-t2w,the lesion in the right frontal lobe shows a m...
4,BraTS-GLI-00488-000-t1c,"after contrast administration, the lesion in t..."


In [210]:
train_kw_map = build_keywords_map_from_df(train_prepared, extract_mask_keywords)
print(train_kw_map)

Extract keywords per base: 100%|██████████| 70/70 [00:41<00:00,  1.69it/s]

{'BraTS-GLI-00006-000': {'class_1_vol': 52.1, 'class_1_size': 'large', 'class_1_shape': 'irregular', 'class_1_form': 'elongated', 'class_2_vol': 71.8, 'class_2_size': 'large', 'class_2_shape': 'irregular', 'class_2_form': 'elongated', 'class_3_vol': 30.3, 'class_3_size': 'moderate', 'class_3_shape': 'irregular', 'class_3_form': 'elongated', 'location': 'left'}, 'BraTS-GLI-00017-000': {'class_1_vol': 20.4, 'class_1_size': 'moderate', 'class_1_shape': 'irregular', 'class_1_form': 'elongated', 'class_2_vol': 62.6, 'class_2_size': 'large', 'class_2_shape': 'irregular', 'class_2_form': 'elongated', 'class_3_vol': 7.1, 'class_3_size': 'small', 'class_3_shape': 'compact', 'class_3_form': 'elongated', 'location': 'left'}, 'BraTS-GLI-00021-000': {'class_1_vol': 0.9, 'class_1_size': 'small', 'class_1_shape': 'compact', 'class_1_form': 'elongated', 'class_2_vol': 81.6, 'class_2_size': 'large', 'class_2_shape': 'irregular', 'class_2_form': 'elongated', 'class_3_vol': 6.8, 'class_3_size': 'small', 




In [278]:
def remove_location_words(text):
    return re.sub(r'\b(left|right|bilateral)\b', '[LOC]', text, flags=re.IGNORECASE)

# Prepare dataset for T5 fine-tuning (edit task)
def build_train_examples_from_df(df, kw_map):
    inputs, targets = [], []
    for _, row in df.iterrows():
        base = row['base']
        kws = kw_map.get(base, {}) or {}
        phrase = keywords_to_phrase(kws)
        clean_vit_caption = remove_location_words(row['caption_generated'])
        
        # Prompt structure
        src = f"refine: {clean_vit_caption.strip()} Keywords: {phrase}"
        # src = f"Constraints: {phrase}. Refine this caption to match constraints: {clean_vit_caption.strip()}"
        tgt = row['caption_injected'].strip()
        
        inputs.append(src)
        targets.append(tgt)
    return inputs, targets

In [279]:
# Fine-tune T5 (small) with HuggingFace Trainer
def fine_tune_t5(train_inputs, train_targets, val_inputs=None, val_targets=None,
                 model_name="t5-small", out_dir="./t5_edit_model", 
                 num_train_epochs=5,  
                 per_device_train_batch_size=8, 
                 per_device_eval_batch_size=8,
                 learning_rate=3e-4,  
                 warmup_steps=500,    
                 weight_decay=0.01,   
                 load_best_model_at_end=True,
                 metric_for_best_model="eval_loss",
                 save_total_limit=3): 

    tokenizer = AutoTokenizer.from_pretrained(model_name)  
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))
    
    # tokenization
    def preprocess(examples):
        model_inputs = tokenizer(examples["input"], max_length=256, truncation=True, padding="max_length")
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(examples["target"], max_length=256, truncation=True, padding="max_length")
        
        labels_ids = np.array(labels["input_ids"])
        labels_ids[labels_ids == tokenizer.pad_token_id] = -100
    
        model_inputs["labels"] = labels_ids.tolist()
        return model_inputs

    from datasets import Dataset, DatasetDict
    train_ds = Dataset.from_dict({"input": train_inputs, "target": train_targets})
    val_ds = Dataset.from_dict({"input": val_inputs, "target": val_targets}) if (val_inputs and val_targets) else None

    tokenized_train = train_ds.map(preprocess, batched=True, remove_columns=["input", "target"])
    tokenized_val = val_ds.map(preprocess, batched=True, remove_columns=["input", "target"]) if val_ds else None

    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

    training_args = Seq2SeqTrainingArguments(
        output_dir=out_dir,
        num_train_epochs=num_train_epochs,
        learning_rate=learning_rate,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        warmup_steps=warmup_steps,
        weight_decay=weight_decay,
        evaluation_strategy="epoch" if tokenized_val else "no", # Logika lama dipertahankan
        save_strategy="epoch",
        load_best_model_at_end=load_best_model_at_end,
        metric_for_best_model=metric_for_best_model,
        predict_with_generate=True,
        logging_strategy="steps",
        logging_steps=50,
        fp16=torch.cuda.is_available(),
        save_total_limit=save_total_limit
    )
    # ----------------------------------------------------

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val if tokenized_val else None,
        tokenizer=tokenizer,
        data_collator=data_collator
    )

    trainer.train()
    trainer.save_model(out_dir)
    return out_dir, tokenizer, model

In [270]:
# Prepare injected captions (policy + extraction)
train_df, train_dice_map, train_kw_map = prepare_injected_caption_df(train_df, extract_mask_keywords, threshold=0.6)
validation_df, val_dice_map, val_kw_map = prepare_injected_caption_df(validation_df, extract_mask_keywords, threshold=0.6)

Compute dice per base:   0%|          | 0/70 [00:00<?, ?it/s]

Dice computed for 70 bases (mean=0.827, median=0.853)


Extract keywords per base:   0%|          | 0/70 [00:00<?, ?it/s]

Compute dice per base:   0%|          | 0/23 [00:00<?, ?it/s]

Dice computed for 23 bases (mean=0.818, median=0.814)


Extract keywords per base:   0%|          | 0/23 [00:00<?, ?it/s]

In [280]:
# Build train examples for T5
train_inputs, train_targets = build_train_examples_from_df(train_df, train_kw_map)
val_inputs, val_targets = build_train_examples_from_df(validation_df, val_kw_map)

for i in range(10):
    print("SRC:", train_inputs[i])
    print("TGT:", train_targets[i])
    print("---")
    
# quick statistics
empty_src = sum(1 for s in train_inputs if not s.strip())
empty_tgt = sum(1 for t in train_targets if not t.strip())
print("Train pairs:", len(train_inputs), "empty src:", empty_src, "empty tgt:", empty_tgt)

SRC: refine: after contrast administration, the lesion in the [LOC] frontal lobe shows marked ring - like enhancement, with unclear, and measures approximately 12 * 12 * 73. is accompanied by a breakdown in size of the blood - like area of the [LOC] lateral ventricle, and is associated with no mention of the midline structures. Keywords: a small lesion (~13.5 cm³), located in the right, with irregular morphology, and elongated form
TGT: Post-contrast T1-weighted images reveal prominent garland-like enhancement of the lesion, with unclear margins. The segmentation suggests a small lesion (~13.5 cm³), located in the right, with irregular morphology, and elongated form.
---
SRC: refine: in the [LOC] basal ganglia and insular - temporal lobe, there is an irregular lesion with isointense to low signal, indistinct boundaries, on t1 - weighted images, and surrounding brain parenchyma edema is present. the [LOC] lateral ventricle exhibiting isointensity, with no displacement of the midline str

In [281]:
# Verify format
print("=== VERIFY NEW FORMAT ===")
print("Input example:")
print(train_inputs[0])
print("\nTarget example:")
print(train_targets[0])
print()

=== VERIFY NEW FORMAT ===
Input example:
refine: after contrast administration, the lesion in the [LOC] frontal lobe shows marked ring - like enhancement, with unclear, and measures approximately 12 * 12 * 73. is accompanied by a breakdown in size of the blood - like area of the [LOC] lateral ventricle, and is associated with no mention of the midline structures. Keywords: a small lesion (~13.5 cm³), located in the right, with irregular morphology, and elongated form

Target example:
Post-contrast T1-weighted images reveal prominent garland-like enhancement of the lesion, with unclear margins. The segmentation suggests a small lesion (~13.5 cm³), located in the right, with irregular morphology, and elongated form.



In [282]:
# BARU
out_dir, t5_tokenizer, t5_model = fine_tune_t5(
    train_inputs, train_targets, 
    val_inputs, val_targets,
    model_name="t5-small", 
    out_dir="./t5_edit_model"
)

Map:   0%|          | 0/280 [00:00<?, ? examples/s]



Map:   0%|          | 0/92 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):


Epoch,Training Loss,Validation Loss
1,No log,2.86369
2,No log,2.561599
3,2.870500,2.210345
4,2.870500,1.940326
5,2.870500,1.734925


  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


In [274]:
# === KODE DEBUGGING UNTUK VERIFIKASI DATA PELATIHAN T5 ===
print("--- Memeriksa Data Pelatihan (TRAIN) ---")
# Pastikan train_inputs dan train_targets ada dan tidak kosong
if 'train_inputs' in locals() and len(train_inputs) > 0:
    print(f"Total data pelatihan: {len(train_inputs)}")
    # Ambil 3 contoh acak
    for i in sorted(random.sample(range(len(train_inputs)), 3)):
        print(f"\n[Contoh Train {i}]")
        print(f"  SRC (Input T5): \n    {repr(train_inputs[i])}")
        print(f"  TGT (Target T5): \n    {repr(train_targets[i])}")
else:
    print("ERROR: 'train_inputs' tidak ditemukan atau kosong.")

print("\n--- Memeriksa Data Validasi (VALIDATION) ---")
# Periksa juga data validasi
if 'val_inputs' in locals() and len(val_inputs) > 0:
    print(f"Total data validasi: {len(val_inputs)}")
    # Ambil 1 contoh acak
    i = random.randint(0, len(val_inputs) - 1)
    print(f"\n[Contoh Val {i}]")
    print(f"  SRC (Input T5): \n    {repr(val_inputs[i])}")
    print(f"  TGT (Target T5): \n    {repr(val_targets[i])}")
else:
    print("ERROR: 'val_inputs' tidak ditemukan atau kosong.")
# ==========================================================

--- Memeriksa Data Pelatihan (TRAIN) ---
Total data pelatihan: 757

[Contoh Train 25]
  SRC (Input T5): 
    'refine: on flair sequence, the lesion in the [LOC] frontal lobe and basal ganglia region shows mixed high and low signal intensities. there is extensive surrounding brain to the midline structures show a large areas of high signal intensities, and a large area of edema in the report does not described in the sulci, and low Keywords: a moderate lesion (~28.2 cm³), located in the right, with irregular morphology, and elongated form'
  TGT (Target T5): 
    'On the T1W sequence, there are irregular foci with mixed high and low signal intensities in the right frontal-parietal-temporal-occipital lobes and bilateral periventricular areas. The lesion, measuring approximately 81*96*82mm, crosses the midline, associated with significant surrounding edema, compression of the right lateral ventricle, and deviation of midline structures to the left. The segmentation suggests a moderate les

## Inference & Evaluation

In [283]:
test_kw_map = build_keywords_map_from_df(test_df, extract_mask_keywords)
print(test_kw_map)

Extract keywords per base:   0%|          | 0/24 [00:00<?, ?it/s]

{'BraTS-GLI-00012-000': {'class_1_vol': 125.1, 'class_1_size': 'large', 'class_1_shape': 'compact', 'class_1_form': 'elongated', 'class_2_vol': 115.3, 'class_2_size': 'large', 'class_2_shape': 'irregular', 'class_2_form': 'elongated', 'class_3_vol': 22.4, 'class_3_size': 'moderate', 'class_3_shape': 'compact', 'class_3_form': 'elongated', 'location': 'left'}, 'BraTS-GLI-00032-001': {'class_1_vol': 7.6, 'class_1_size': 'small', 'class_1_shape': 'irregular', 'class_1_form': 'rounded', 'class_2_vol': 82.2, 'class_2_size': 'large', 'class_2_shape': 'irregular', 'class_2_form': 'elongated', 'class_3_vol': 49.4, 'class_3_size': 'moderate', 'class_3_shape': 'compact', 'class_3_form': 'elongated', 'location': 'left'}, 'BraTS-GLI-00121-000': {'class_1_vol': 0.7, 'class_1_size': 'small', 'class_1_shape': 'compact', 'class_1_form': 'rounded', 'class_2_vol': 48.1, 'class_2_size': 'moderate', 'class_2_shape': 'irregular', 'class_2_form': 'elongated', 'class_3_vol': 17.9, 'class_3_size': 'small', 'c

In [105]:
# 2) Generate baseline captions on test using explicit variables:
test_df = generate_captions_from_df_safe(test_df, vision_model, feature_extractor, vision_tokenizer,
                                             image_col='image_path', out_col='caption_generated', batch_size=4)

# quick peek
test_df[['filename','caption_generated']].head()

In [284]:
import re

# 1. Definisi Fungsi Masking/Removal
def remove_location_words(text):
    if not isinstance(text, str): return ""
    # Ganti kata lokasi dengan token netral [LOC] atau hapus saja
    return re.sub(r'\b(left|right|bilateral)\b', '[LOC]', text, flags=re.IGNORECASE)

# 2. Terapkan pada pembuatan T5 source strings
test_df['t5_src'] = test_df.apply(
    lambda r: f"refine: {remove_location_words(r['caption_generated'])} Keywords: {keywords_to_phrase(test_kw_map.get(r['base'], {}))}",
    axis=1
)

# # Terapkan format BARU: "Constraints: {keyword}. Refine...: {caption}"
# test_df['t5_src'] = test_df.apply(
#     lambda r: f"Constraints: {keywords_to_phrase(test_kw_map.get(r['base'], {}))}. Refine this caption to match constraints: {remove_location_words(r['caption_generated']).strip()}",
#     axis=1
# )

t5_tokenizer = AutoTokenizer.from_pretrained(T5_OUT_DIR) # Pastikan class tokenizer sesuai saat training
t5_model = AutoModelForSeq2SeqLM.from_pretrained(T5_OUT_DIR).to(DEVICE)
t5_model.eval()

T5ForConditionalGeneration(
  (shared): Embedding(32100, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32100, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [285]:
# run T5 inference
generated = []
batch_size = 8

for i in tqdm(range(0, len(test_df), batch_size), desc="T5 Inferencing"):
    batch = test_df['t5_src'].tolist()[i:i+batch_size]
    
    inputs = t5_tokenizer(
        batch, 
        return_tensors="pt", 
        padding=True, 
        truncation=True, 
        max_length=256
    ).to(DEVICE)
    
    with torch.no_grad():
        input_tokens = t5_tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
        input_len = input_tokens.input_ids.shape[1]
        
        outs = t5_model.generate(
            **inputs, 
            max_length=int(input_len * 0.7),
            min_length=20,
            num_beams=4,
            length_penalty=0.8,
            no_repeat_ngram_size=3,
            early_stopping=True,
            decoder_start_token_id=t5_tokenizer.pad_token_id,
            pad_token_id=t5_tokenizer.pad_token_id,
            eos_token_id=t5_tokenizer.eos_token_id,
        )
    
    dec = t5_tokenizer.batch_decode(outs, skip_special_tokens=True)
    generated.extend(dec)

test_df['caption_generated_injected'] = generated

T5 Inferencing:   0%|          | 0/12 [00:00<?, ?it/s]

In [249]:
def evaluate_corpus(hyps, refs):
    bleu = sacrebleu.corpus_bleu(hyps, [refs])
    P, R, F1 = bert_score(hyps, refs, lang='en', rescale_with_baseline=True)
    sbert = SentenceTransformer('all-mpnet-base-v2')
    emb_h = sbert.encode(hyps, convert_to_tensor=True)
    emb_r = sbert.encode(refs, convert_to_tensor=True)
    cosines = util.cos_sim(emb_h, emb_r).diag().cpu().numpy()
    return {"bleu": float(bleu.score), "bert_f1_mean": float(F1.mean().item()), "sbert_cosine_mean": float(cosines.mean())}

In [250]:
# 4) Evaluate:
refs = test_df['caption'].fillna("").tolist()

print("\n--- 1. HASIL EVALUASI: Baseline (Tahap 1: vision_model) ---")
hyps_baseline = test_df['caption_generated'].tolist()
# Pastikan refs dan hyps memiliki panjang yang sama
if len(hyps_baseline) == len(refs):
    metrics_baseline = evaluate_corpus(hyps_baseline, refs)
    print(metrics_baseline)
else:
    print(f"Error: Panjang hipotesis ({len(hyps_baseline)}) tidak sama dengan referensi ({len(refs)})")


--- 1. HASIL EVALUASI: Baseline (Tahap 1: vision_model) ---


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

{'bleu': 12.083650208569985, 'bert_f1_mean': 0.2993832528591156, 'sbert_cosine_mean': 0.7617081999778748}


In [251]:
print("\n--- 2. HASIL EVALUASI: Model Akhir (Tahap 1 + Tahap 2: T5 Editor) ---")
hyps_final = test_df['caption_generated_injected'].tolist()
# Pastikan refs dan hyps memiliki panjang yang sama
if len(hyps_final) == len(refs):
    metrics_final = evaluate_corpus(hyps_final, refs)
    print(metrics_final)
else:
    print(f"Error: Panjang hipotesis ({len(hyps_final)}) tidak sama dengan referensi ({len(refs)})")


--- 2. HASIL EVALUASI: Model Akhir (Tahap 1 + Tahap 2: T5 Editor) ---


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

{'bleu': 14.180616482283071, 'bert_f1_mean': 0.31815317273139954, 'sbert_cosine_mean': 0.7493132948875427}


In [252]:
print("\nContoh hasil akhir (head):")
print(test_df[['filename', 'caption', 'caption_generated', 't5_src', 'caption_generated_injected']].head())


Contoh hasil akhir (head):
                  filename                                            caption  \
0  BraTS-GLI-00291-000-t1c  The lesions exhibit uneven enhancement post co...   
1  BraTS-GLI-00291-000-t1n  Two lesions in the right parietal lobe show is...   
2  BraTS-GLI-00291-000-t2f  Two lesions in the right parietal lobe show mi...   
3  BraTS-GLI-00291-000-t2w  Two lesions in the right parietal lobe show hi...   
4  BraTS-GLI-00706-000-t1c  After contrast administration, the lesion in t...   

                                   caption_generated  \
0  post - contrast t1 - weighted imaging reveals ...   
1  in the right parietal lobe, there is a mass - ...   
2  the lesion in the right parietal lobe shows mi...   
3  the lesion in the right parietal lobe appears ...   
4  after contrast administration, the lesion in t...   

                                              t5_src  \
0  refine: post - contrast t1 - weighted imaging ...   
1  refine: in the [LOC] parietal lob

In [286]:
# 5) Save results:
test_df.to_csv(RESULTS_DIR / "test_with_generated_injected.csv", index=False)

In [255]:
test_df.head()

Unnamed: 0,filename,image_path,caption,category,segmentation_path,pred_seg_path,base,caption_generated,t5_src,caption_generated_injected
0,BraTS-GLI-00291-000-t1c,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,The lesions exhibit uneven enhancement post co...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00291-000,post - contrast t1 - weighted imaging reveals ...,refine: post - contrast t1 - weighted imaging ...,post-contrast t1 - weighted imaging reveals ri...
1,BraTS-GLI-00291-000-t1n,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,Two lesions in the right parietal lobe show is...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00291-000,"in the right parietal lobe, there is a mass - ...","refine: in the [LOC] parietal lobe, there is a...","In the right parietal lobe, there is a mass - ..."
2,BraTS-GLI-00291-000-t2f,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,Two lesions in the right parietal lobe show mi...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00291-000,the lesion in the right parietal lobe shows mi...,refine: the lesion in the [LOC] parietal lobe ...,The lesion in the right parietal lobe shows mi...
3,BraTS-GLI-00291-000-t2w,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,Two lesions in the right parietal lobe show hi...,GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00291...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00291-000,the lesion in the right parietal lobe appears ...,refine: the lesion in the [LOC] parietal lobe ...,The lesion in the right parietal lobe appears ...
4,BraTS-GLI-00706-000-t1c,/kaggle/input/brats2023-part-1/BraTS-GLI-00706...,"After contrast administration, the lesion in t...",GLI,/kaggle/input/brats2023-part-1/BraTS-GLI-00706...,/kaggle/input/hasil-segmentasi-2/results (2)/p...,BraTS-GLI-00706-000,"after contrast administration, the lesion in t...","refine: after contrast administration, the les...","After contrast administration, the lesion in t..."


In [256]:
import sacrebleu
from sacrebleu.metrics import BLEU 
from bert_score import score as bert_score
from sentence_transformers import SentenceTransformer, util
import torch
import numpy as np

sbert_model = SentenceTransformer('all-mpnet-base-v2').to(DEVICE) 

def evaluate_corpus_detailed(hyps, refs, device=DEVICE):
    bleu1_obj = BLEU(max_ngram_order=1)
    bleu1 = bleu1_obj.corpus_score(hyps, [refs]).score
    
    bleu2_obj = BLEU(max_ngram_order=2)
    bleu2 = bleu2_obj.corpus_score(hyps, [refs]).score
    
    bleu3_obj = BLEU(max_ngram_order=3)
    bleu3 = bleu3_obj.corpus_score(hyps, [refs]).score
    
    bleu4_obj = BLEU(max_ngram_order=4)
    bleu4 = bleu4_obj.corpus_score(hyps, [refs]).score
    # ---------------------------------
    
    P, R, F1 = bert_score(hyps, refs, lang='en', rescale_with_baseline=True, device=device)
    
    emb_h = sbert_model.encode(hyps, convert_to_tensor=True, show_progress_bar=False)
    emb_r = sbert_model.encode(refs, convert_to_tensor=True, show_progress_bar=False)
    cosines = util.cos_sim(emb_h, emb_r).diag().cpu().numpy()

    return {
        "BLEU_corpus": float(bleu4), 
        
        "BLEU1": float(bleu1),
        "BLEU2": float(bleu2),
        "BLEU3": float(bleu3),
        "BLEU4": float(bleu4),
        
        "BERTScore_F1": float(F1.mean().item()),
        "SBERT_Cosine": float(cosines.mean())
    }



In [257]:
import pprint

# 4) Evaluate:
refs = test_df['caption'].fillna("").tolist()

print("\n--- 1. HASIL EVALUASI: Baseline (Tahap 1: vision_model) ---")
hyps_baseline = test_df['caption_generated'].tolist()

if len(hyps_baseline) == len(refs):
    metrics_baseline = evaluate_corpus_detailed(hyps_baseline, refs) 
    pprint.pprint(metrics_baseline) 
else:
    print(f"Error: Panjang hipotesis ({len(hyps_baseline)}) tidak sama dengan referensi ({len(refs)})")

print("\n--- 2. HASIL EVALUASI: Model Akhir (Tahap 1 + Tahap 2: T5 Editor) ---")
hyps_final = test_df['caption_generated_injected'].tolist()

if len(hyps_final) == len(refs):
    metrics_final = evaluate_corpus_detailed(hyps_final, refs) 
    pprint.pprint(metrics_final) 
else:
    print(f"Error: Panjang hipotesis ({len(hyps_final)}) tidak sama dengan referensi ({len(refs)})")


--- 1. HASIL EVALUASI: Baseline (Tahap 1: vision_model) ---


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{'BERTScore_F1': 0.2993832528591156,
 'BLEU1': 35.97527722232321,
 'BLEU2': 23.54651975407515,
 'BLEU3': 16.4210980605062,
 'BLEU4': 12.083650208569985,
 'BLEU_corpus': 12.083650208569985,
 'SBERT_Cosine': 0.7617081999778748}

--- 2. HASIL EVALUASI: Model Akhir (Tahap 1 + Tahap 2: T5 Editor) ---


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{'BERTScore_F1': 0.31815317273139954,
 'BLEU1': 37.53454947484798,
 'BLEU2': 25.464847374693814,
 'BLEU3': 18.618058795277936,
 'BLEU4': 14.180616482283071,
 'BLEU_corpus': 14.180616482283071,
 'SBERT_Cosine': 0.7493132948875427}


In [259]:
# Print generated text length distribution
import numpy as np

stage1_lens = [len(t.split()) for t in test_df['caption_generated']]
stage2_lens = [len(t.split()) for t in test_df['caption_generated_injected']]
gt_lens = [len(t.split()) for t in test_df['caption']]

print(f"Ground Truth avg length: {np.mean(gt_lens):.1f} words")
print(f"Stage 1 avg length:      {np.mean(stage1_lens):.1f} words")
print(f"Stage 2 avg length:      {np.mean(stage2_lens):.1f} words")

Ground Truth avg length: 35.3 words
Stage 1 avg length:      51.6 words
Stage 2 avg length:      48.2 words


In [260]:
# Check if T5 just copies input
from difflib import SequenceMatcher

similarities = []
for _, row in test_df.iterrows():
    sim = SequenceMatcher(None, 
                          row['caption_generated'], 
                          row['caption_generated_injected']).ratio()
    similarities.append(sim)

avg_sim = np.mean(similarities)
print(f"Avg similarity between Stage 1 and Stage 2: {avg_sim:.2%}")

if avg_sim > 0.9:
    print("⚠️ WARNING: T5 is mostly copying input without editing!")

Avg similarity between Stage 1 and Stage 2: 74.87%


In [263]:
# Analisis kasus dengan BLEU terendah
import pandas as pd
from nltk.translate.bleu_score import sentence_bleu

results = []
for idx, row in test_df.iterrows():
    ref = [row['caption'].split()]
    hyp = row['caption_generated_injected'].split()
    bleu = sentence_bleu(ref, hyp)
    results.append({
        'idx': idx,
        'bleu': bleu,
        'caption_gen': row['caption_generated'],
        'caption_injected': row['caption_generated_injected'],
        'caption_gt': row['caption']
    })

results_df = pd.DataFrame(results).sort_values('bleu')

print("=== 10 WORST CASES ===")
print(results_df.head(10))

print("\n=== 10 BEST CASES ===")
print(results_df.tail(10))

=== 10 WORST CASES ===
    idx      bleu                                        caption_gen  \
62   62  0.069977  on flair sequence, the lesion in the right fro...   
85   85  0.071520  the lesion located in the anterior interhemisp...   
22   22  0.078574  on flair sequences, the lesion in the left fro...   
16   16  0.079883  after contrast administration, the lesion demo...   
32   32  0.080706  after contrast administration, the lesion in t...   
82   82  0.084669  on flair sequence, the lesion in the right fro...   
73   73  0.089176  in the right frontal lobe, there is an irregul...   
78   78  0.089666  on flair sequences, the lesion in the right fr...   
47   47  0.092100  the lesion in the right frontal lobe shows a m...   
28   28  0.093506  after contrast administration, the lesion in t...   

                                     caption_injected  \
62  On flair sequence, the lesion in the left fron...   
85  The lesion located in the left anterior interh...   
22  On flair 

Corpus/Sentence contains 0 counts of 4-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 3-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().


In [264]:
# Check training data distribution
from collections import Counter

# Extract locations from captions
def extract_location(caption):
    # Simple regex untuk detect hemisphere
    import re
    if re.search(r'\bleft\b', caption.lower()):
        return 'left'
    elif re.search(r'\bright\b', caption.lower()):
        return 'right'
    elif re.search(r'\bbilateral\b', caption.lower()):
        return 'bilateral'
    return 'unknown'

train_locations = [extract_location(cap) for cap in train_df['caption']]
print("Training data distribution:")
print(Counter(train_locations))

test_locations = [extract_location(cap) for cap in test_df['caption']]
print("\nTest data distribution:")
print(Counter(test_locations))

Training data distribution:
Counter({'left': 161, 'right': 89, 'unknown': 23, 'bilateral': 7})

Test data distribution:
Counter({'left': 58, 'right': 23, 'unknown': 15})


In [265]:
import re
from collections import Counter

print("=== LOCATION ACCURACY EVALUATION ===\n")

# Helper function
def extract_location_from_text(text):
    """Extract location (left/right/bilateral) from text."""
    text_lower = text.lower()
    if 'bilateral' in text_lower or ('left' in text_lower and 'right' in text_lower):
        return 'bilateral'
    elif 'left' in text_lower:
        return 'left'
    elif 'right' in text_lower:
        return 'right'
    else:
        return 'unknown'

# 1. Ground Truth Location Distribution
gt_locations = [extract_location_from_text(cap) for cap in test_df['caption']]
print("1. GROUND TRUTH LOCATIONS:")
print(Counter(gt_locations))

# 2. Keyword Location Distribution
kw_locations = [test_kw_map.get(row['base'], {}).get('location', 'unknown') 
                for _, row in test_df.iterrows()]
print("\n2. KEYWORD LOCATIONS (from segmentation):")
print(Counter(kw_locations))

# 3. ViT Prediction Location Distribution
vit_locations = [extract_location_from_text(cap) for cap in test_df['caption_generated']]
print("\n3. ViT PREDICTION LOCATIONS:")
print(Counter(vit_locations))

# 4. T5 Output Location Distribution
t5_locations = [extract_location_from_text(cap) for cap in test_df['caption_generated_injected']]
print("\n4. T5 OUTPUT LOCATIONS (after adding keyword):")
print(Counter(t5_locations))

# 5. Accuracy Metrics
print("\n" + "="*60)
print("ACCURACY METRICS:")
print("="*60)

# Keyword vs GT accuracy
kw_correct = sum(1 for kw, gt in zip(kw_locations, gt_locations) 
                 if kw != 'unknown' and gt != 'unknown' and kw == gt)
kw_total = sum(1 for kw, gt in zip(kw_locations, gt_locations) 
               if kw != 'unknown' and gt != 'unknown')
kw_accuracy = kw_correct / kw_total if kw_total > 0 else 0

print(f"\nKeyword Location Accuracy: {kw_correct}/{kw_total} = {kw_accuracy*100:.1f}%")

# ViT vs GT accuracy
vit_correct = sum(1 for vit, gt in zip(vit_locations, gt_locations) 
                  if vit != 'unknown' and gt != 'unknown' and vit == gt)
vit_total = sum(1 for vit, gt in zip(vit_locations, gt_locations) 
                if vit != 'unknown' and gt != 'unknown')
vit_accuracy = vit_correct / vit_total if vit_total > 0 else 0

print(f"ViT Location Accuracy:     {vit_correct}/{vit_total} = {vit_accuracy*100:.1f}%")

# T5 vs GT accuracy
t5_correct = sum(1 for t5, gt in zip(t5_locations, gt_locations) 
                 if t5 != 'unknown' and gt != 'unknown' and t5 == gt)
t5_total = sum(1 for t5, gt in zip(t5_locations, gt_locations) 
               if t5 != 'unknown' and gt != 'unknown')
t5_accuracy = t5_correct / t5_total if t5_total > 0 else 0

print(f"T5 Location Accuracy:      {t5_correct}/{t5_total} = {t5_accuracy*100:.1f}%")

# 6. Detailed Mismatch Analysis
print("\n" + "="*60)
print("MISMATCH ANALYSIS:")
print("="*60)

mismatches = []
for idx, row in test_df.iterrows():
    gt_loc = extract_location_from_text(row['caption'])
    kw_loc = test_kw_map.get(row['base'], {}).get('location', 'unknown')
    vit_loc = extract_location_from_text(row['caption_generated'])
    t5_loc = extract_location_from_text(row['caption_generated_injected'])
    
    if gt_loc != 'unknown' and t5_loc != gt_loc:
        mismatches.append({
            'idx': idx,
            'gt': gt_loc,
            'keyword': kw_loc,
            'vit': vit_loc,
            't5': t5_loc,
            'kw_correct': kw_loc == gt_loc,
            't5_follows_kw': t5_loc == kw_loc,
            't5_follows_vit': t5_loc == vit_loc,
        })

print(f"\nTotal mismatches: {len(mismatches)}/{len(test_df)}")

if mismatches:
    print("\nFirst 10 mismatches:")
    print(f"{'Idx':<5} {'GT':<10} {'Keyword':<10} {'ViT':<10} {'T5':<10} {'KW✓':<6} {'T5→KW':<8} {'T5→ViT':<8}")
    print("-" * 80)
    for m in mismatches[:10]:
        print(f"{m['idx']:<5} {m['gt']:<10} {m['keyword']:<10} {m['vit']:<10} {m['t5']:<10} "
              f"{str(m['kw_correct']):<6} {str(m['t5_follows_kw']):<8} {str(m['t5_follows_vit']):<8}")

# 7. Key Question: Does T5 use keywords?
print("\n" + "="*60)
print("KEY QUESTION: Does T5 Use Keyword Location?")
print("="*60)

t5_follows_kw = sum(1 for m in mismatches if m['t5_follows_kw'])
t5_follows_vit = sum(1 for m in mismatches if m['t5_follows_vit'])
t5_follows_neither = len(mismatches) - t5_follows_kw - t5_follows_vit

print(f"\nIn {len(mismatches)} mismatched cases:")
print(f"  T5 follows Keyword: {t5_follows_kw} ({t5_follows_kw/len(mismatches)*100:.1f}%)")
print(f"  T5 follows ViT:     {t5_follows_vit} ({t5_follows_vit/len(mismatches)*100:.1f}%)")
print(f"  T5 follows Neither: {t5_follows_neither} ({t5_follows_neither/len(mismatches)*100:.1f}%)")

if t5_follows_vit > t5_follows_kw:
    print("\n⚠️  WARNING: T5 is following ViT MORE than Keywords!")
    print("    This means T5 hasn't learned to trust keyword location.")

=== LOCATION ACCURACY EVALUATION ===

1. GROUND TRUTH LOCATIONS:
Counter({'left': 36, 'bilateral': 24, 'right': 21, 'unknown': 15})

2. KEYWORD LOCATIONS (from segmentation):
Counter({'left': 64, 'right': 32})

3. ViT PREDICTION LOCATIONS:
Counter({'right': 71, 'left': 11, 'bilateral': 7, 'unknown': 7})

4. T5 OUTPUT LOCATIONS (after adding keyword):
Counter({'left': 52, 'right': 26, 'bilateral': 14, 'unknown': 4})

ACCURACY METRICS:

Keyword Location Accuracy: 56/81 = 69.1%
ViT Location Accuracy:     27/78 = 34.6%
T5 Location Accuracy:      49/79 = 62.0%

MISMATCH ANALYSIS:

Total mismatches: 32/96

First 10 mismatches:
Idx   GT         Keyword    ViT        T5         KW✓    T5→KW    T5→ViT  
--------------------------------------------------------------------------------
4     right      right      right      bilateral  True   False    False   
8     right      right      right      unknown    True   False    False   
14    left       left       left       bilateral  True   False   