In [None]:
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# üî¨ Mask2Former Iris Segmentation Training\n",
        "\n",
        "## üìã Setup Checklist\n",
        "- [ ] GPU enabled (P100 or T4)\n",
        "- [ ] Internet enabled\n",
        "- [ ] Dataset added: `iris-segmentation-ubiris-v2`\n",
        "- [ ] Code added: `iris-segmentation-code`\n",
        "\n",
        "**Expected training time: 8-10 hours**\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================\n",
        "# CELL 1: Check GPU and Environment\n",
        "# ============================================\n",
        "print(\"üîç Checking environment...\\n\")\n",
        "\n",
        "# Check GPU\n",
        "import subprocess\n",
        "result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)\n",
        "print(result.stdout)\n",
        "\n",
        "# Check PyTorch\n",
        "import torch\n",
        "print(f\"\\n‚úÖ PyTorch: {torch.__version__}\")\n",
        "print(f\"‚úÖ CUDA available: {torch.cuda.is_available()}\")\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    print(f\"‚úÖ GPU: {torch.cuda.get_device_name(0)}\")\n",
        "    memory = torch.cuda.get_device_properties(0).total_memory / 1e9\n",
        "    print(f\"‚úÖ Memory: {memory:.1f} GB\")\n",
        "    \n",
        "    if memory >= 15:\n",
        "        print(f\"‚úÖ Can use batch_size = 8\")\n",
        "    else:\n",
        "        print(f\"‚ö†Ô∏è  Reduce batch_size to 4\")\n",
        "else:\n",
        "    print(\"‚ùå GPU NOT AVAILABLE - Check Settings ‚Üí Accelerator ‚Üí GPU!\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================\n",
        "# CELL 2: Install Required Packages\n",
        "# ============================================\n",
        "print(\"üì¶ Installing dependencies...\\n\")\n",
        "\n",
        "# Install packages (quiet mode)\n",
        "!pip install -q transformers==4.57.3\n",
        "!pip install -q albumentations==2.0.8\n",
        "!pip install -q timm einops\n",
        "\n",
        "print(\"\\n‚úÖ Checking installations...\")\n",
        "\n",
        "# Verify installations\n",
        "import transformers\n",
        "import albumentations\n",
        "import timm\n",
        "\n",
        "print(f\"‚úÖ Transformers: {transformers.__version__}\")\n",
        "print(f\"‚úÖ Albumentations: {albumentations.__version__}\")\n",
        "print(f\"‚úÖ Timm: {timm.__version__}\")\n",
        "print(\"\\n‚úÖ All dependencies installed!\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================\n",
        "# CELL 3: Verify Datasets\n",
        "# ============================================\n",
        "from pathlib import Path\n",
        "\n",
        "print(\"üìÇ Checking datasets...\\n\")\n",
        "\n",
        "# List all inputs\n",
        "print(\"üìã Available datasets:\")\n",
        "!ls -la /kaggle/input/\n",
        "\n",
        "# ‚ö†Ô∏è IMPORTANT: Update these names to match YOUR dataset names!\n",
        "DATASET_NAME = 'iris-segmentation-ubiris-v2'  # ‚Üê CHANGE THIS\n",
        "CODE_NAME = 'iris-segmentation-code'          # ‚Üê CHANGE THIS\n",
        "\n",
        "print(f\"\\n\" + \"=\"*70)\n",
        "print(f\"Using dataset: {DATASET_NAME}\")\n",
        "print(f\"Using code: {CODE_NAME}\")\n",
        "print(\"=\"*70)\n",
        "\n",
        "# Check dataset\n",
        "dataset_path = Path(f'/kaggle/input/{DATASET_NAME}/dataset')\n",
        "\n",
        "print(f\"\\nüìä Dataset: {DATASET_NAME}\")\n",
        "print(f\"   Path: {dataset_path}\")\n",
        "print(f\"   Exists: {dataset_path.exists()}\")\n",
        "\n",
        "if dataset_path.exists():\n",
        "    images_dir = dataset_path / 'images'\n",
        "    masks_dir = dataset_path / 'masks'\n",
        "    \n",
        "    if images_dir.exists():\n",
        "        images = list(images_dir.glob('*'))\n",
        "        print(f\"   ‚úÖ Images: {len(images)} files\")\n",
        "    else:\n",
        "        print(f\"   ‚ùå Images dir not found at {images_dir}\")\n",
        "    \n",
        "    if masks_dir.exists():\n",
        "        masks = list(masks_dir.glob('*'))\n",
        "        print(f\"   ‚úÖ Masks: {len(masks)} files\")\n",
        "    else:\n",
        "        print(f\"   ‚ùå Masks dir not found at {masks_dir}\")\n",
        "else:\n",
        "    print(f\"   ‚ùå Dataset not found!\")\n",
        "    print(f\"\\nüí° Update DATASET_NAME in this cell to match your dataset name\")\n",
        "    print(f\"   Available: {[d.name for d in Path('/kaggle/input/').iterdir()]}\")\n",
        "\n",
        "# Check code\n",
        "code_path = Path(f'/kaggle/input/{CODE_NAME}')\n",
        "\n",
        "print(f\"\\nüì¶ Code: {CODE_NAME}\")\n",
        "print(f\"   Path: {code_path}\")\n",
        "print(f\"   Exists: {code_path.exists()}\")\n",
        "\n",
        "if code_path.exists():\n",
        "    print(f\"   ‚úÖ Code files found\")\n",
        "    !ls -la /kaggle/input/{CODE_NAME}/\n",
        "else:\n",
        "    print(f\"   ‚ùå Code not found!\")\n",
        "    print(f\"\\nüí° Update CODE_NAME in this cell to match your code dataset name\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================\n",
        "# CELL 4: Extract Code to Working Directory\n",
        "# ============================================\n",
        "import shutil\n",
        "from pathlib import Path\n",
        "import os\n",
        "\n",
        "print(\"üì• Extracting code...\\n\")\n",
        "\n",
        "# Use same CODE_NAME from Cell 3\n",
        "CODE_NAME = 'iris-segmentation-code'  # ‚Üê Make sure this matches Cell 3\n",
        "\n",
        "code_path = Path(f'/kaggle/input/{CODE_NAME}')\n",
        "dest_path = Path('/kaggle/working/code')\n",
        "\n",
        "# Remove if exists\n",
        "if dest_path.exists():\n",
        "    print(\"   Removing old code...\")\n",
        "    shutil.rmtree(dest_path)\n",
        "\n",
        "# Check structure and copy\n",
        "if (code_path / 'kaggle_code').exists():\n",
        "    # Code uploaded in kaggle_code/ subfolder\n",
        "    print(\"   Copying from kaggle_code/ subfolder...\")\n",
        "    shutil.copytree(code_path / 'kaggle_code', dest_path)\n",
        "elif (code_path / 'src').exists():\n",
        "    # Code uploaded directly\n",
        "    print(\"   Copying from root...\")\n",
        "    shutil.copytree(code_path, dest_path)\n",
        "else:\n",
        "    print(\"   ‚ùå Unexpected structure. Files in code dataset:\")\n",
        "    !ls -la /kaggle/input/{CODE_NAME}/\n",
        "    raise Exception(\"Cannot find code structure. Check your code dataset.\")\n",
        "\n",
        "print(f\"‚úÖ Code extracted to: {dest_path}\")\n",
        "\n",
        "# Change to code directory\n",
        "os.chdir(dest_path)\n",
        "print(f\"‚úÖ Changed directory to: {os.getcwd()}\")\n",
        "\n",
        "# List files\n",
        "print(\"\\nüìã Code structure:\")\n",
        "!ls -la\n",
        "print(\"\\nüìÇ Source files:\")\n",
        "!ls -la src/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================\n",
        "# CELL 5: Update Config with Correct Paths\n",
        "# ============================================\n",
        "import json\n",
        "from pathlib import Path\n",
        "\n",
        "print(\"üîß Updating config...\\n\")\n",
        "\n",
        "# Load config\n",
        "config_path = Path('configs/mask2former_config_kaggle.json')\n",
        "\n",
        "if not config_path.exists():\n",
        "    print(f\"‚ùå Config not found: {config_path}\")\n",
        "    print(\"Available files in configs/:\")\n",
        "    !ls -la configs/\n",
        "    raise Exception(\"Config file not found\")\n",
        "\n",
        "with open(config_path, 'r') as f:\n",
        "    config = json.load(f)\n",
        "\n",
        "# ‚ö†Ô∏è IMPORTANT: Update dataset name to match YOUR dataset\n",
        "DATASET_NAME = 'iris-segmentation-ubiris-v2'  # ‚Üê CHANGE THIS to match Cell 3\n",
        "\n",
        "# Update dataset paths\n",
        "config['data']['dataset_dir'] = f'/kaggle/input/{DATASET_NAME}/dataset'\n",
        "config['data']['dataset_root'] = f'/kaggle/input/{DATASET_NAME}/dataset'\n",
        "config['data']['images_dir'] = f'/kaggle/input/{DATASET_NAME}/dataset/images'\n",
        "config['data']['masks_dir'] = f'/kaggle/input/{DATASET_NAME}/dataset/masks'\n",
        "\n",
        "# Update output paths\n",
        "config['output_dir'] = '/kaggle/working/outputs/mask2former_iris'\n",
        "config['checkpointing']['save_dir'] = '/kaggle/working/outputs/mask2former_iris'\n",
        "config['visualization']['save_dir'] = '/kaggle/working/outputs/mask2former_iris/visualizations'\n",
        "\n",
        "# Disable WandB unless you have API key\n",
        "config['logging']['use_wandb'] = False\n",
        "\n",
        "# Save updated config\n",
        "with open(config_path, 'w') as f:\n",
        "    json.dump(config, f, indent=2)\n",
        "\n",
        "print(\"‚úÖ Config updated!\")\n",
        "print(f\"\\nüìã Key settings:\")\n",
        "print(f\"   Dataset: {config['data']['dataset_root']}\")\n",
        "print(f\"   Batch size: {config['data']['batch_size']}\")\n",
        "print(f\"   Epochs: {config['training']['num_epochs']}\")\n",
        "print(f\"   Output: {config['output_dir']}\")\n",
        "print(f\"   Mixed precision: {config['training']['mixed_precision']}\")\n",
        "\n",
        "# Verify dataset exists\n",
        "dataset_path = Path(config['data']['dataset_root'])\n",
        "if dataset_path.exists():\n",
        "    images = list((dataset_path / 'images').glob('*'))\n",
        "    masks = list((dataset_path / 'masks').glob('*'))\n",
        "    print(f\"\\n‚úÖ Dataset verified: {len(images)} images, {len(masks)} masks\")\n",
        "    print(f\"\\nüéØ Ready to train!\")\n",
        "else:\n",
        "    print(f\"\\n‚ùå Dataset not found at: {dataset_path}\")\n",
        "    print(f\"\\nüí° Update DATASET_NAME in this cell to match your dataset!\")\n",
        "    print(f\"   Available datasets:\")\n",
        "    for d in Path('/kaggle/input/').iterdir():\n",
        "        print(f\"   ‚Ä¢ {d.name}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================\n",
        "# CELL 6: Calculate Class Weights (Optional)\n",
        "# ============================================\n",
        "print(\"‚öñÔ∏è Calculating class weights...\")\n",
        "print(\"This may take 5-10 minutes...\\n\")\n",
        "print(\"üí° You can skip this if class_weights are already in config\\n\")\n",
        "\n",
        "# Run class weights utility\n",
        "!python class_weights_util.py\n",
        "\n",
        "print(\"\\n‚úÖ Class weights calculated!\")\n",
        "print(\"Weights saved to: class_weights.pt\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================\n",
        "# CELL 7: START TRAINING üöÄ\n",
        "# ============================================\n",
        "print(\"=\" * 70)\n",
        "print(\"üöÄ STARTING MASK2FORMER TRAINING\")\n",
        "print(\"=\" * 70)\n",
        "print(\"\\nüìä Expected timeline:\")\n",
        "print(\"   ‚Ä¢ Initialization: 5-10 minutes\")\n",
        "print(\"   ‚Ä¢ Per epoch: ~6-8 minutes (P100/T4)\")\n",
        "print(\"   ‚Ä¢ Total: 8-10 hours for 100 epochs\")\n",
        "print(\"   ‚Ä¢ Checkpoints saved every 20 epochs\")\n",
        "print(\"\\nüéØ Target metrics:\")\n",
        "print(\"   ‚Ä¢ Val mIoU: ‚â• 0.90\")\n",
        "print(\"   ‚Ä¢ Val Dice: ‚â• 0.93\")\n",
        "print(\"\\n‚ö†Ô∏è  Keep this tab open! Training will stop if you close it.\")\n",
        "print(\"\\n\" + \"=\" * 70 + \"\\n\")\n",
        "\n",
        "# Start training\n",
        "!python train_mask2former.py --config configs/mask2former_config_kaggle.json\n",
        "\n",
        "print(\"\\n\" + \"=\" * 70)\n",
        "print(\"‚úÖ TRAINING COMPLETED!\")\n",
        "print(\"=\" * 70)\n",
        "print(\"\\nüìä Check Cell 8 for results\")\n",
        "print(\"üíæ Run Cell 9 to download checkpoint\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================\n",
        "# CELL 8: Check Training Results\n",
        "# ============================================\n",
        "import torch\n",
        "from pathlib import Path\n",
        "\n",
        "print(\"üìä Checking training results...\\n\")\n",
        "\n",
        "checkpoint_dir = Path('/kaggle/working/outputs/mask2former_iris/checkpoints')\n",
        "\n",
        "if not checkpoint_dir.exists():\n",
        "    print(\"‚ùå No checkpoints found. Training may have failed.\")\n",
        "    print(\"   Check training logs in Cell 7\")\n",
        "else:\n",
        "    checkpoints = list(checkpoint_dir.glob('*.pt'))\n",
        "    print(f\"‚úÖ Found {len(checkpoints)} checkpoints\\n\")\n",
        "    \n",
        "    # List all checkpoints\n",
        "    print(\"üìã Available checkpoints:\")\n",
        "    for ckpt in sorted(checkpoints):\n",
        "        size_mb = ckpt.stat().st_size / (1024 * 1024)\n",
        "        print(f\"   ‚Ä¢ {ckpt.name} ({size_mb:.1f} MB)\")\n",
        "    \n",
        "    # Check best checkpoint\n",
        "    best_ckpt = checkpoint_dir / 'best.pt'\n",
        "    \n",
        "    if best_ckpt.exists():\n",
        "        print(f\"\\n\" + \"=\"*70)\n",
        "        print(f\"üèÜ BEST CHECKPOINT FOUND\")\n",
        "        print(\"=\"*70)\n",
        "        \n",
        "        # Load and show metrics\n",
        "        ckpt = torch.load(best_ckpt, map_location='cpu', weights_only=False)\n",
        "        \n",
        "        print(f\"\\nüìà Final Results:\")\n",
        "        print(f\"   Epoch: {ckpt.get('epoch', 'N/A')}\")\n",
        "        \n",
        "        metrics = ckpt.get('metrics', {})\n",
        "        mean_iou = metrics.get('mean_iou', 0)\n",
        "        mean_dice = metrics.get('mean_dice', 0)\n",
        "        iris_iou = metrics.get('class_1_iou', 0)\n",
        "        iris_dice = metrics.get('class_1_dice', 0)\n",
        "        \n",
        "        print(f\"\\nüéØ Key Metrics:\")\n",
        "        print(f\"   Val mIoU:    {mean_iou:.4f} {'‚úÖ' if mean_iou >= 0.90 else '‚ö†Ô∏è'}  (target ‚â• 0.90)\")\n",
        "        print(f\"   Val Dice:    {mean_dice:.4f} {'‚úÖ' if mean_dice >= 0.93 else '‚ö†Ô∏è'}  (target ‚â• 0.93)\")\n",
        "        print(f\"   Iris IoU:    {iris_iou:.4f}\")\n",
        "        print(f\"   Iris Dice:   {iris_dice:.4f}\")\n",
        "        \n",
        "        if 'boundary_f1' in metrics:\n",
        "            print(f\"   Boundary F1: {metrics['boundary_f1']:.4f}\")\n",
        "        \n",
        "        # Overall assessment\n",
        "        print(f\"\\nüìä Overall Assessment:\")\n",
        "        if mean_iou >= 0.90 and mean_dice >= 0.93:\n",
        "            print(\"   ‚úÖ EXCELLENT - Both targets achieved!\")\n",
        "        elif mean_iou >= 0.85 and mean_dice >= 0.90:\n",
        "            print(\"   üü° GOOD - Close to targets\")\n",
        "        else:\n",
        "            print(\"   ‚ö†Ô∏è  NEEDS IMPROVEMENT - Consider training longer\")\n",
        "        \n",
        "        # File info\n",
        "        size_mb = best_ckpt.stat().st_size / (1024 * 1024)\n",
        "        print(f\"\\nüíæ Checkpoint:\")\n",
        "        print(f\"   File: {best_ckpt.name}\")\n",
        "        print(f\"   Size: {size_mb:.1f} MB\")\n",
        "        print(f\"   Path: {best_ckpt}\")\n",
        "        \n",
        "        print(f\"\\n\" + \"=\"*70)\n",
        "        print(\"‚úÖ Ready to download! Run Cell 9\")\n",
        "        print(\"=\"*70)\n",
        "    else:\n",
        "        print(\"\\n‚ö†Ô∏è  best.pt not found\")\n",
        "        print(\"   Training may not have completed\")\n",
        "        print(\"   Or check for latest checkpoint above\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ============================================\n",
        "# CELL 9: Download Checkpoint ‚¨áÔ∏è\n",
        "# ============================================\n",
        "from IPython.display import FileLink\n",
        "import shutil\n",
        "from pathlib import Path\n",
        "\n",
        "print(\"üì¶ Preparing downloads...\\n\")\n",
        "\n",
        "# Best checkpoint\n",
        "best_ckpt = Path('/kaggle/working/outputs/mask2former_iris/checkpoints/best.pt')\n",
        "\n",
        "if best_ckpt.exists():\n",
        "    size_mb = best_ckpt.stat().st_size / (1024 * 1024)\n",
        "    print(f\"‚úÖ Best checkpoint ready!\")\n",
        "    print(f\"   Size: {size_mb:.1f} MB\")\n",
        "    print(f\"   File: {best_ckpt.name}\")\n",
        "    print(\"\\n‚¨áÔ∏è  Click link below to download best.pt:\")\n",
        "    display(FileLink(str(best_ckpt)))\n",
        "    \n",
        "    # Also prepare compressed version of all outputs\n",
        "    print(\"\\nüì¶ Compressing all results (checkpoints + visualizations)...\")\n",
        "    output_zip = '/kaggle/working/mask2former_results'\n",
        "    \n",
        "    try:\n",
        "        shutil.make_archive(output_zip, 'zip', '/kaggle/working/outputs')\n",
        "        \n",
        "        zip_path = Path(output_zip + '.zip')\n",
        "        if zip_path.exists():\n",
        "            zip_size = zip_path.stat().st_size / (1024 * 1024)\n",
        "            print(f\"‚úÖ All results compressed: {zip_size:.1f} MB\")\n",
        "            print(\"\\n‚¨áÔ∏è  Click link below to download full results:\")\n",
        "            display(FileLink(str(zip_path)))\n",
        "        \n",
        "    except Exception as e:\n",
        "        print(f\"‚ö†Ô∏è  Could not create zip: {e}\")\n",
        "        print(\"   Download best.pt above instead\")\n",
        "    \n",
        "    print(\"\\n\" + \"=\"*70)\n",
        "    print(\"üí° IMPORTANT: Download files before closing notebook!\")\n",
        "    print(\"=\"*70)\n",
        "    print(\"\\nüìã What to download:\")\n",
        "    print(\"   1. best.pt - Main checkpoint (for inference)\")\n",
        "    print(\"   2. mask2former_results.zip - Full results (optional)\")\n",
        "    \n",
        "    print(\"\\nüéØ Next steps:\")\n",
        "    print(\"   1. Download best.pt\")\n",
        "    print(\"   2. Copy to local project: outputs/mask2former_iris/checkpoints/\")\n",
        "    print(\"   3. Run inference: python infer_mask2former.py --checkpoint best.pt\")\n",
        "    \n",
        "else:\n",
        "    print(\"‚ùå Checkpoint not found at:\", best_ckpt)\n",
        "    print(\"\\nüí° Check if training completed successfully in Cell 7\")\n",
        "    print(\"   Or look for other checkpoints in Cell 8\")\n",
        "    \n",
        "    # Try to find any checkpoint\n",
        "    checkpoint_dir = Path('/kaggle/working/outputs/mask2former_iris/checkpoints')\n",
        "    if checkpoint_dir.exists():\n",
        "        checkpoints = list(checkpoint_dir.glob('*.pt'))\n",
        "        if checkpoints:\n",
        "            print(\"\\nüìã Found these checkpoints instead:\")\n",
        "            for ckpt in checkpoints:\n",
        "                print(f\"   ‚Ä¢ {ckpt.name}\")\n",
        "                display(FileLink(str(ckpt)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "---\n",
        "\n",
        "## üéâ Training Complete!\n",
        "\n",
        "### What to do next:\n",
        "\n",
        "1. ‚úÖ Download `best.pt` from Cell 9\n",
        "2. ‚úÖ Copy to local: `outputs/mask2former_iris/checkpoints/best.pt`\n",
        "3. ‚úÖ Run inference locally:\n",
        "   ```bash\n",
        "   python infer_mask2former.py \\\n",
        "       --checkpoint outputs/mask2former_iris/checkpoints/best.pt \\\n",
        "       --image dataset/images/C100_S1_I1.tiff \\\n",
        "       --output results/\n",
        "   ```\n",
        "\n",
        "### Troubleshooting:\n",
        "\n",
        "- **No checkpoint?** ‚Üí Check Cell 7 logs for errors\n",
        "- **Low metrics?** ‚Üí Train longer or adjust hyperparameters\n",
        "- **Out of memory?** ‚Üí Reduce batch_size in Cell 5\n",
        "\n",
        "### Resources:\n",
        "\n",
        "- [Mask2Former Paper](https://arxiv.org/abs/2112.01527)\n",
        "- [UBIRIS Dataset](http://iris.di.ubi.pt/ubiris2.html)\n",
        "- [Transformers Docs](https://huggingface.co/docs/transformers)\n",
        "\n",
        "---\n",
        "\n",
        "**Good luck with your iris segmentation project! üöÄ**"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}