# RT-DETR-X Training

Training RT-DETR-X with ResNet-101 backbone for WBC Classification on Raabin-WBC dataset.

## Model Details
- **Backbone**: ResNet-101
- **Training**: Pretrained weights (fine-tuning)
- **Dataset**: Raabin-WBC with 5 cell types

## 1. Setup and Imports

In [None]:
# %pip install -U ultralytics torch torchvision pillow tqdm scikit-learn seaborn timm

In [None]:
%matplotlib inline

import os
import json
import yaml
from datetime import datetime

import numpy as np
import torch

from sklearn.metrics import classification_report

# Import common training utilities
from training_utils import (
    create_sampled_dataset,
    create_full_dataset_config,
    train_model,
    evaluate_model,
    save_results,
    print_training_summary,
)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Configuration

In [None]:
# =============================================================================
# MODEL CONFIGURATION
# =============================================================================
MODEL_NAME = "RT-DETR-X"
BACKBONE = "ResNet-101"
IS_PRETRAINED = True  # Using pretrained weights

# Pretrained model file
MODEL_FILE = "rtdetr-x.pt"

# =============================================================================
# BASE DIRECTORY
# =============================================================================
NOTEBOOK_DIR = os.getcwd()
BASE_DIR = os.path.join(NOTEBOOK_DIR, "output")

# Dataset path (contains separate Train and val folders)
DATA_ROOT = r"C:\D drive\mydata\MSML\DataSets\Raabin_datsets_withlabels"

print(f"Notebook directory: {NOTEBOOK_DIR}")
print(f"Base directory: {BASE_DIR}")
print(f"Data root: {DATA_ROOT}")

# =============================================================================
# SAMPLING CONFIGURATION
# =============================================================================
USE_FULL_DATASET = True  # Set to True to use ALL images, False for sampling

# Sample sizes per class (only used when USE_FULL_DATASET=False)
TRAIN_SAMPLE_SIZE = 100   # Number of training samples per class
VAL_SAMPLE_SIZE = 20      # Number of validation samples per class

# =============================================================================
# CHECKPOINT CONFIGURATION (for resume training on full dataset)
# =============================================================================
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints", MODEL_NAME)
CHECKPOINT_MODEL_PATH = os.path.join(CHECKPOINT_DIR, "last.pt")
CHECKPOINT_META_PATH = os.path.join(CHECKPOINT_DIR, "training_meta.json")

# Create checkpoint directory
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Data paths (separate train and validation directories)
TRAIN_IMAGES_DIR = os.path.join(DATA_ROOT, "Train", "images")
TRAIN_LABELS_DIR = os.path.join(DATA_ROOT, "Train", "labels")
VAL_IMAGES_DIR = os.path.join(DATA_ROOT, "val", "images")
VAL_LABELS_DIR = os.path.join(DATA_ROOT, "val", "labels")

# For evaluation (uses training images by default)
IMAGES_DIR = TRAIN_IMAGES_DIR

# Output directories
os.makedirs(BASE_DIR, exist_ok=True)
MODEL_DIR = os.path.join(BASE_DIR, "models")
RESULTS_DIR = os.path.join(BASE_DIR, "results")
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Class definitions
CLASSES = {
    "Basophil": 0,
    "Eosinophil": 1,
    "Lymphocyte": 2,
    "Monocyte": 3,
    "Neutrophil": 4
}
ID2LABEL = {v: k for k, v in CLASSES.items()}
NUM_CLASSES = len(CLASSES)

print(f"\nUsing device: {DEVICE}")
if USE_FULL_DATASET:
    print(f"Dataset mode: FULL DATASET")
    print(f"Checkpoint directory: {CHECKPOINT_DIR}")
    # Check for existing checkpoint
    if os.path.exists(CHECKPOINT_MODEL_PATH) and os.path.exists(CHECKPOINT_META_PATH):
        with open(CHECKPOINT_META_PATH, 'r') as f:
            meta = json.load(f)
        print(f"  -> Found existing checkpoint: {meta['total_epochs']} epochs completed")
        print(f"  -> Training will RESUME from epoch {meta['total_epochs'] + 1}")
    else:
        print(f"  -> No checkpoint found. Training will start from scratch.")
else:
    print(f"Dataset mode: SAMPLED (Train: {TRAIN_SAMPLE_SIZE}/class, Val: {VAL_SAMPLE_SIZE}/class)")
    print(f"  -> Sampled mode: Always starts fresh (no resume)")
print(f"\nTraining data: {TRAIN_IMAGES_DIR}")
print(f"Validation data: {VAL_IMAGES_DIR}")
print(f"\nModel: {MODEL_NAME} ({BACKBONE})")
print(f"Training mode: {'Pretrained (fine-tuning)' if IS_PRETRAINED else 'From scratch'}")

## 3. Training Hyperparameters

In [3]:
# =============================================================================
# TRAINING HYPERPARAMETERS (PRETRAINED CONFIG)
# =============================================================================
# Fewer epochs needed since backbone is already trained

TRAINING_CONFIG = {
    "epochs": 1,            
    "imgsz": 640,
    "batch": 4,
    "lr0": 0.01,            
    "lrf": 0.0001,
    "momentum": 0.937,
    "weight_decay": 0.0005,
    "workers": 8,
    "patience": 15,
    "cos_lr": True,
    "warmup_epochs": 1,
    "warmup_momentum": 0.8,
    "warmup_bias_lr": 0.1,
}

print("Training Configuration (Pretrained Fine-tuning):")
print("="*60)
for k, v in TRAINING_CONFIG.items():
    print(f"  {k}: {v}")

Training Configuration (Pretrained Fine-tuning):
  epochs: 1
  imgsz: 640
  batch: 4
  lr0: 0.01
  lrf: 0.0001
  momentum: 0.937
  weight_decay: 0.0005
  workers: 8
  patience: 15
  cos_lr: True
  warmup_epochs: 1
  warmup_momentum: 0.8
  warmup_bias_lr: 0.1


## 4. Data Preparation

In [None]:
# create_sampled_dataset is imported from training_utils.py
# See training_utils.py for the implementation

In [None]:
# Create data configuration
if USE_FULL_DATASET:
    print("Using FULL DATASET\n")
    print(f"Training: {TRAIN_IMAGES_DIR}")
    print(f"Validation: {VAL_IMAGES_DIR}")
    
    DATA_YAML = create_full_dataset_config(DATA_ROOT, BASE_DIR, NUM_CLASSES, ID2LABEL)
    print(f"\nData config: {DATA_YAML}")
else:
    print(f"Creating SAMPLED dataset...")
    print(f"  Train samples: {TRAIN_SAMPLE_SIZE} per class")
    print(f"  Val samples: {VAL_SAMPLE_SIZE} per class\n")
    
    DATA_YAML = create_sampled_dataset(
        DATA_ROOT, 
        BASE_DIR, 
        CLASSES, 
        train_samples_per_class=TRAIN_SAMPLE_SIZE,
        val_samples_per_class=VAL_SAMPLE_SIZE,
        random_seed=42
    )

## 5. Training

In [None]:
# train_model is imported from training_utils.py
# See training_utils.py for the implementation

In [None]:
# Train the model
training_result = train_model(
    model_source=MODEL_FILE,
    model_name=MODEL_NAME,
    data_yaml=DATA_YAML,
    training_config=TRAINING_CONFIG,
    base_dir=BASE_DIR,
    use_full_dataset=USE_FULL_DATASET,
    checkpoint_dir=CHECKPOINT_DIR if USE_FULL_DATASET else None,
    default_warmup_epochs=1  # Pretrained model needs less warmup
)

print(f"\nTraining completed in {training_result['training_time']:.1f}s")
print(f"Best model saved to: {training_result['best_model_path']}")

if training_result['resumed']:
    print(f"\nResumed from epoch {training_result['previous_epochs'] + 1}")
print(f"Total epochs trained: {training_result['total_epochs']}")

## 6. Evaluation

In [None]:
# evaluate_model is imported from training_utils.py
# See training_utils.py for the implementation

In [9]:
# Evaluate the model
CONF_THRESH = 0.1
EVAL_PER_CLASS = 100

print(f"Evaluating: {MODEL_NAME}")
evaluation_result = evaluate_model(
    model_path=training_result["best_model_path"],
    images_dir=IMAGES_DIR,
    classes=CLASSES,
    id2label=ID2LABEL,
    conf_thresh=CONF_THRESH,
    eval_per_class=EVAL_PER_CLASS,
)

print(f"\nResults:")
print(f"  Accuracy: {evaluation_result['accuracy']:.4f}")
print(f"  Avg inference time: {evaluation_result['avg_inference_time']*1000:.2f}ms")
print(f"  No predictions: {evaluation_result['no_prediction_count']}/{evaluation_result['total_samples']}")

Evaluating: RT-DETR-X


                                                                                                                       


Results:
  Accuracy: 0.8160
  Avg inference time: 53.54ms
  No predictions: 0/500




In [10]:
# Print classification report
if evaluation_result["classification_report"] is not None:
    y_true = np.array(evaluation_result["y_true"])
    y_pred = np.array(evaluation_result["y_pred"])
    valid = y_pred != -1
    
    print(f"\n--- {MODEL_NAME} Classification Report ---")
    print(classification_report(
        y_true[valid],
        y_pred[valid],
        target_names=list(CLASSES.keys()),
        labels=list(range(NUM_CLASSES)),
        zero_division=0
    ))


--- RT-DETR-X Classification Report ---
              precision    recall  f1-score   support

    Basophil       0.65      1.00      0.79       100
  Eosinophil       0.94      0.81      0.87       100
  Lymphocyte       0.98      0.63      0.77       100
    Monocyte       0.76      0.91      0.83       100
  Neutrophil       0.96      0.73      0.83       100

    accuracy                           0.82       500
   macro avg       0.86      0.82      0.82       500
weighted avg       0.86      0.82      0.82       500



## 7. Save Results to Disk

In [None]:
# Save results to JSON
results_file = save_results(
    results_dir=RESULTS_DIR,
    model_name=MODEL_NAME,
    backbone=BACKBONE,
    is_pretrained=IS_PRETRAINED,
    training_result=training_result,
    evaluation_result=evaluation_result,
    training_config=TRAINING_CONFIG,
    classes=CLASSES
)

print(f"Results saved to: {results_file}")

In [None]:
# Print training summary
print_training_summary(
    model_name=MODEL_NAME,
    backbone=BACKBONE,
    training_result=training_result,
    evaluation_result=evaluation_result,
    training_config=TRAINING_CONFIG,
    checkpoint_model_path=CHECKPOINT_MODEL_PATH if USE_FULL_DATASET else None,
    results_file=results_file
)