In [None]:
!pip install evaluate

In [None]:
import os
import torch
import multiprocessing
import subprocess
import time
import numpy as np
import evaluate
from datasets import Dataset, Image as DSImage
from transformers import (
    SegformerImageProcessor,
    SegformerForSemanticSegmentation,
    TrainingArguments,
    Trainer,
    logging
)
from torch import nn

logging.set_verbosity_info()

# --- CONFIGURATION ---
MODEL_NAME = "nvidia/mit-b3"
OUTPUT_DIR = "/kaggle/working/checkpoints_b3_1024_DICE"
FINAL_MODEL_DIR = "/kaggle/working/final_rat_model_b3_1024_DICE"

IMAGE_DIR = "/kaggle/input/datasets/gonoszgonosz/rodent-data-2/processed/images"
MASK_DIR = "/kaggle/input/datasets/gonoszgonosz/rodent-data-2/processed/masks"

EPOCHS = 30
LEARNING_RATE = 6e-5
BATCH_SIZE = 1              
GRAD_ACCUMULATION = 16      

# --- 1. GPU MONITOR ---
def monitor_gpu(interval=60):
    while True:
        try:
            result = subprocess.check_output(
                ["nvidia-smi", "--query-gpu=utilization.gpu,memory.used,memory.total", "--format=csv,noheader,nounits"]
            ).decode().strip().split('\n')
            stats = [f"GPU {i}: {line.split(',')[0]}% Util | {line.split(',')[1]}/{line.split(',')[2]} MB" for i, line in enumerate(result)]
            print(f"\n[GPU MONITOR] " + " | ".join(stats) + "\n")
        except Exception: pass
        time.sleep(interval)

# --- 2. DATA LOAD ---
def load_dataset():
    all_images = [f for f in os.listdir(IMAGE_DIR) if f.endswith(('.jpg', '.png'))]
    all_masks = [f for f in os.listdir(MASK_DIR) if f.endswith(('.jpg', '.png'))]
    img_map = {os.path.splitext(f)[0]: f for f in all_images}
    mask_map = {os.path.splitext(f)[0]: f for f in all_masks}
    common_ids = sorted(list(set(img_map.keys()) & set(mask_map.keys())))
    
    final_image_paths = [os.path.join(IMAGE_DIR, img_map[i]) for i in common_ids]
    final_mask_paths = [os.path.join(MASK_DIR, mask_map[i]) for i in common_ids]

    ds = Dataset.from_dict({"image": final_image_paths, "label": final_mask_paths})
    ds = ds.cast_column("image", DSImage())
    ds = ds.cast_column("label", DSImage())
    ds = ds.train_test_split(test_size=0.10, seed=42)
    return ds

processor = SegformerImageProcessor.from_pretrained(
    MODEL_NAME, do_resize=True, size={"height": 1024, "width": 1024} 
)

def train_transforms(example_batch):
    images = [x.convert("RGB") for x in example_batch["image"]]
    labels = []
    for x in example_batch["label"]:
        mask_np = np.array(x.convert("L"))
        mask_np = np.where(mask_np > 0, 1, 0).astype(np.uint8)
        labels.append(mask_np)
    return processor(images, labels, return_tensors="pt")

# --- 4. DICE LOSS IMPLEMENTATION ---
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        # Flatten predictions and targets
        probs = torch.softmax(logits, dim=1)
        probs = probs[:, 1, :, :].contiguous().view(-1)
        targets = targets.contiguous().view(-1).float()

        intersection = (probs * targets).sum()
        dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
        return 1 - dice

class DiceTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        upsampled_logits = nn.functional.interpolate(
            logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
        )
        
        loss_fct = DiceLoss()
        loss = loss_fct(upsampled_logits, labels)
        return (loss, outputs) if return_outputs else loss

# --- 5. METRICS ---
metric = evaluate.load("mean_iou")
def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = nn.functional.interpolate(
            logits_tensor, size=labels.shape[-2:], mode="bilinear", align_corners=False
        ).argmax(dim=1)

        metrics = metric.compute(
            predictions=logits_tensor.numpy(),
            references=labels,
            num_labels=2,
            ignore_index=255,
            reduce_labels=False,
        )
        return {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in metrics.items()}

# --- 6. MAIN ---
def main():
    os.environ["WANDB_DISABLED"] = "true"
    os.environ["REPORT_TO"] = "none"
    
    ds = load_dataset()
    ds["train"].set_transform(train_transforms)
    ds["test"].set_transform(train_transforms)

    model = SegformerForSemanticSegmentation.from_pretrained(
        MODEL_NAME, num_labels=2, id2label={0: "background", 1: "rat"}, label2id={"background": 0, "rat": 1}, ignore_mismatched_sizes=True
    )

    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        learning_rate=6e-5,
        num_train_epochs=30,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=16,
        fp16=True,
        eval_strategy="steps",
        eval_steps=50,
        save_strategy="steps",
        save_steps=50,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="mean_iou",
        remove_unused_columns=False
    )

    # Note: Using the DiceTrainer here
    trainer = DiceTrainer(
        model=model,
        args=training_args,
        train_dataset=ds["train"],
        eval_dataset=ds["test"],
        compute_metrics=compute_metrics,
    )

    print(f"--- TRAINING START: DICE LOSS BASELINE ---")
    trainer.train()

    print(f"--- SAVING TO {FINAL_MODEL_DIR} ---")
    trainer.save_model(FINAL_MODEL_DIR)
    processor.save_pretrained(FINAL_MODEL_DIR)

if __name__ == "__main__":
    p = multiprocessing.Process(target=monitor_gpu, daemon=True)
    p.start()
    main()