In [None]:
!pip install -q evaluate

In [None]:
import torch
import gc
torch.cuda.empty_cache()
gc.collect()
print("GPU Memory Flushed")

In [None]:
import os
import torch
import multiprocessing
import subprocess
import time
import numpy as np
import evaluate
from datasets import Dataset, Image as DSImage
from transformers import (
    SegformerImageProcessor,
    SegformerForSemanticSegmentation,
    TrainingArguments,
    Trainer,
    logging
)
from torch import nn

logging.set_verbosity_info()

MODEL_NAME = "nvidia/mit-b5" 
OUTPUT_DIR = "/kaggle/working/checkpoints_b5_512"
FINAL_MODEL_DIR = "/kaggle/working/final_rat_model_b5_512"

DATASET_PATHS = [
    ("/kaggle/input/rodent-b5-dataset/b5-dataset/b5-processed/images", "/kaggle/input/rodent-b5-dataset/b5-dataset/b5-processed/masks"),
]

EPOCHS = 30
LEARNING_RATE = 6e-5

BATCH_SIZE = 4            
GRAD_ACCUMULATION = 4 
GRAD_ACCUMULATION = 4     

# --- 1. GPU MONITOR ---
def monitor_gpu(interval=60):
    while True:
        try:
            result = subprocess.check_output(
                ["nvidia-smi", "--query-gpu=utilization.gpu,memory.used,memory.total", "--format=csv,noheader,nounits"]
            ).decode().strip().split('\n')
            stats = [f"GPU {i}: {line.split(',')[0]}% Util | {line.split(',')[1]}/{line.split(',')[2]} MB" for i, line in enumerate(result)]
            print(f"\n[GPU MONITOR] " + " | ".join(stats) + "\n")
        except Exception: pass
        time.sleep(interval)

# --- 2. DATA LOAD (COMBINED) ---
def load_dataset():
    print(f"--- LOADING COMBINED DATASETS ---")
    
    final_image_paths = []
    final_mask_paths = []

    # Loop through paths and collect files
    for img_dir, mask_dir in DATASET_PATHS:
        if not os.path.exists(img_dir):
            print(f" Skipping missing directory: {img_dir}")
            continue
            
        print(f"Scanning: {img_dir}...")
        all_images = [f for f in os.listdir(img_dir) if f.endswith(('.jpg', '.png'))]
        img_map = {os.path.splitext(f)[0]: f for f in all_images}
        
        # Find matching masks
        valid_count = 0
        for base_name, img_file in img_map.items():
            # Try png first, then jpg
            mask_name = f"{base_name}.png"
            if not os.path.exists(os.path.join(mask_dir, mask_name)):
                mask_name = f"{base_name}.jpg"
            
            if os.path.exists(os.path.join(mask_dir, mask_name)):
                final_image_paths.append(os.path.join(img_dir, img_file))
                final_mask_paths.append(os.path.join(mask_dir, mask_name))
                valid_count += 1
        print(f"   -> Found {valid_count} pairs.")

    print(f"--- TOTAL DATA: {len(final_image_paths)} pairs ---")
    
    if len(final_image_paths) == 0:
        raise ValueError("CRITICAL: No images found! Check your DATASET_PATHS.")

    ds = Dataset.from_dict({"image": final_image_paths, "label": final_mask_paths})
    ds = ds.cast_column("image", DSImage())
    ds = ds.cast_column("label", DSImage())
    ds = ds.train_test_split(test_size=0.10, seed=42)
    return ds

# --- 3. PROCESSOR & TRANSFORMS ---
processor = SegformerImageProcessor.from_pretrained(
    MODEL_NAME, 
    reduce_labels=False,
    do_resize=True,
    size={"height": 512, "width": 512}
)

def train_transforms(example_batch):
    images = [x.convert("RGB") for x in example_batch["image"]]
    labels = []
    for x in example_batch["label"]:
        mask_np = np.array(x.convert("L"))
        mask_np = np.where(mask_np > 0, 1, 0).astype(np.uint8)
        labels.append(mask_np)
        
    return processor(images, labels, return_tensors="pt")

# --- 4. SANITY CHECK ---
def sanity_check(ds):
    print("--- RUNNING SANITY CHECK ---")
    sample = ds["train"][0]
    output = train_transforms({"image": [sample["image"]], "label": [sample["label"]]})
    unique_vals = torch.unique(output["labels"]).tolist()
    print(f"Processed Mask Values: {unique_vals}")
    if any(v > 1 for v in unique_vals):
        raise ValueError(f" CRITICAL: Mask contains values {unique_vals}. Must be only [0, 1].")
    print(" DATA IS SAFE.")

# --- 5. METRICS & MODEL ---
metric = evaluate.load("mean_iou")
id2label = {0: "background", 1: "rat"}
label2id = {"background": 0, "rat": 1}

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        
        logits_tensor = nn.functional.interpolate(
            logits_tensor, 
            size=labels.shape[-2:], 
            mode="bilinear", 
            align_corners=False
        ).argmax(dim=1)

        metrics = metric.compute(
            predictions=logits_tensor.numpy(),
            references=labels,
            num_labels=2,
            ignore_index=255,
            reduce_labels=False,
        )
        return {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in metrics.items()}

# --- 6. TRAINER ---
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        upsampled_logits = nn.functional.interpolate(
            logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
        )
        
        # Heavy weight on Class 1 (Rat) to force the model to find it
        weights = torch.tensor([1.0, 5.0]).to(logits.device)
        loss_fct = nn.CrossEntropyLoss(weight=weights)
        
        loss = loss_fct(upsampled_logits, labels)
        return (loss, outputs) if return_outputs else loss

# --- 7. MAIN ---
def main():
    ds = load_dataset()
    sanity_check(ds)
    
    ds["train"].set_transform(train_transforms)
    ds["test"].set_transform(train_transforms)

    model = SegformerForSemanticSegmentation.from_pretrained(
        MODEL_NAME,
        num_labels=2,
        id2label=id2label,
        label2id=label2id,
        ignore_mismatched_sizes=True
    )

    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        learning_rate=LEARNING_RATE,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRAD_ACCUMULATION,
        fp16=True,
        eval_strategy="steps",
        eval_steps=50,
        save_strategy="steps",
        save_steps=50,
        save_total_limit=2,
        logging_steps=10,
        load_best_model_at_end=True,
        metric_for_best_model="mean_iou",
        report_to="none",
        remove_unused_columns=False
    )

    trainer = WeightedTrainer(
        model=model,
        args=training_args,
        train_dataset=ds["train"],
        eval_dataset=ds["test"],
        compute_metrics=compute_metrics,
    )

    last_checkpoint = None
    if os.path.isdir(OUTPUT_DIR):
        checkpoints = [d for d in os.listdir(OUTPUT_DIR) if d.startswith("checkpoint-")]
        if checkpoints:
            checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
            last_checkpoint = os.path.join(OUTPUT_DIR, checkpoints[-1])
            print(f"!!! RESUMING FROM: {last_checkpoint} !!!")

    print(f"--- TRAINING START: MIT-B5 @ 512x512 ---")
    trainer.train(resume_from_checkpoint=last_checkpoint)

    print(f"--- SAVING TO {FINAL_MODEL_DIR} ---")
    trainer.save_model(FINAL_MODEL_DIR)
    processor.save_pretrained(FINAL_MODEL_DIR)
    print("DONE.")

if __name__ == "__main__":
    p = multiprocessing.Process(target=monitor_gpu, daemon=True)
    p.start()
    main()