In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🔗 04 - CT Registration Analysis\n",
    "\n",
    "This notebook implements image registration on CT slices to align sequential or distorted frames.\n",
    "\n",
    "**Techniques**:\n",
    "- Rigid registration (translation + rotation)\n",
    "- Non-rigid B-spline deformation\n",
    "- Visual and quantitative evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 📦 Imports\n",
    "import SimpleITK as sitk\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from skimage.metrics import structural_similarity as ssim, mean_squared_error as mse\n",
    "from PIL import Image\n",
    "\n",
    "def load_image(path):\n",
    "    return sitk.GetImageFromArray(np.array(Image.open(path).convert('L')))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 📂 Load Moving and Fixed Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fixed = load_image(\"../data/raw/sample_ct_slice.png\")\n",
    "moving = load_image(\"../data/artifact/sample_ct_artifact_slice.png\")\n",
    "\n",
    "# Normalize and view\n",
    "sitk.Show(sitk.Tile([fixed, moving], (2,1)), \"Fixed | Moving\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ⚙️ Rigid Registration (Translation + Rotation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_transform = sitk.CenteredTransformInitializer(\n",
    "    fixed, moving, sitk.Euler2DTransform(),\n",
    "    sitk.CenteredTransformInitializerFilter.GEOMETRY)\n",
    "\n",
    "registration_method = sitk.ImageRegistrationMethod()\n",
    "registration_method.SetMetricAsMeanSquares()\n",
    "registration_method.SetInterpolator(sitk.sitkLinear)\n",
    "registration_method.SetOptimizerAsRegularStepGradientDescent(learningRate=1.0, minStep=0.01, numberOfIterations=200)\n",
    "registration_method.SetInitialTransform(initial_transform, inPlace=False)\n",
    "\n",
    "rigid_transform = registration_method.Execute(fixed, moving)\n",
    "moved_rigid = sitk.Resample(moving, fixed, rigid_transform, sitk.sitkLinear, 0.0, moving.GetPixelID())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🧠 Non-rigid (B-spline) Registration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize transform\n",
    "transform_domain_mesh_size = [8]*fixed.GetDimension()\n",
    "bspline_transform = sitk.BSplineTransformInitializer(fixed, transform_domain_mesh_size)\n",
    "\n",
    "registration_method = sitk.ImageRegistrationMethod()\n",
    "registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)\n",
    "registration_method.SetOptimizerAsLBFGSB(gradientConvergenceTolerance=1e-5,\n",
    "                                        numberOfIterations=100,\n",
    "                                        maximumNumberOfCorrections=5,\n",
    "                                        maximumNumberOfFunctionEvaluations=1000)\n",
    "registration_method.SetInterpolator(sitk.sitkLinear)\n",
    "registration_method.SetInitialTransform(bspline_transform, inPlace=False)\n",
    "\n",
    "bspline_out = registration_method.Execute(fixed, moved_rigid)\n",
    "moved_bspline = sitk.Resample(moved_rigid, fixed, bspline_out, sitk.sitkLinear, 0.0, fixed.GetPixelID())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🖼️ Compare Registration Outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sitk_to_np(img): return sitk.GetArrayFromImage(img)\n",
    "\n",
    "imgs = [fixed, moving, moved_rigid, moved_bspline]\n",
    "titles = [\"Fixed\", \"Original\", \"Rigid\", \"Non-Rigid\"]\n",
    "fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n",
    "for ax, img, title in zip(axs, imgs, titles):\n",
    "    ax.imshow(sitk_to_np(img), cmap='gray')\n",
    "    ax.set_title(title)\n",
    "    ax.axis('off')\n",
    "plt.tight_layout(); plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 📊 Evaluate Metrics (MSE, SSIM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fixed_np = sitk_to_np(fixed)\n",
    "rigid_np = sitk_to_np(moved_rigid)\n",
    "nonrigid_np = sitk_to_np(moved_bspline)\n",
    "\n",
    "print(\"Rigid:   MSE = {:.2f}, SSIM = {:.4f}\".format(mse(fixed_np, rigid_np), ssim(fixed_np, rigid_np)))\n",
    "print(\"Non-Rigid: MSE = {:.2f}, SSIM = {:.4f}\".format(mse(fixed_np, nonrigid_np), ssim(fixed_np, nonrigid_np)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ✅ Summary\n",
    "\n",
    "- Rigid and B-spline registration both improve alignment.\n",
    "- Non-rigid registration further enhances SSIM.\n",
    "- Registration is a prerequisite for motion modeling in 4D CT or dose optimization.\n",
    "\n",
    "**Next**: Proceed to `05_vr_visualization.ipynb` to visualize registered volumes in VR space."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": ""
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
