In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🧠 03 - GAN-based Artifact Removal\n",
    "\n",
    "This notebook focuses on applying a trained GAN model to enhance CT images by removing motion and beam hardening artifacts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 📦 Imports\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import transforms\n",
    "from torchvision.utils import make_grid\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🔄 Load Data (Artifact + Clean Pair)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_image(path):\n",
    "    img = Image.open(path).convert('L')\n",
    "    return img\n",
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((256, 256)),\n",
    "    transforms.ToTensor()\n",
    "])\n",
    "\n",
    "img_artifact = transform(load_image(\"../data/artifact/sample_ct_artifact_slice.png\")).unsqueeze(0).to(device)\n",
    "img_clean = transform(load_image(\"../data/raw/sample_ct_slice.png\")).unsqueeze(0).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🧬 Load Pretrained GAN Generator (e.g., U-Net / pix2pix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class UNetGenerator(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(1, 64, 4, stride=2, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(64, 128, 4, stride=2, padding=1),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1),\n",
    "            nn.Tanh()\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        x = self.decoder(x)\n",
    "        return x\n",
    "\n",
    "# Load model\n",
    "model = UNetGenerator().to(device)\n",
    "model.load_state_dict(torch.load(\"../src/gan_artifact_removal/unet_generator.pth\", map_location=device))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ✨ Denoise via GAN Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    output = model(img_artifact)\n",
    "\n",
    "# Convert tensors to images\n",
    "to_img = transforms.ToPILImage()\n",
    "artifact_img = to_img(img_artifact.squeeze().cpu())\n",
    "clean_img = to_img(img_clean.squeeze().cpu())\n",
    "output_img = to_img(output.squeeze().cpu())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🖼️ Visual Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 3, figsize=(12, 4))\n",
    "axs[0].imshow(artifact_img, cmap='gray'); axs[0].set_title('Input w/ Artifact')\n",
    "axs[1].imshow(output_img, cmap='gray'); axs[1].set_title('GAN Output')\n",
    "axs[2].imshow(clean_img, cmap='gray'); axs[2].set_title('Ground Truth')\n",
    "for ax in axs: ax.axis('off')\n",
    "plt.tight_layout(); plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 📊 Quantitative Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_np = np.array(output_img)\n",
    "clean_np = np.array(clean_img)\n",
    "\n",
    "gan_psnr = psnr(clean_np, output_np)\n",
    "gan_ssim = ssim(clean_np, output_np)\n",
    "\n",
    "print(\"GAN Reconstruction - PSNR: {:.2f}, SSIM: {:.4f}\".format(gan_psnr, gan_ssim))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ✅ Summary\n",
    "\n",
    "- GAN significantly improves visual quality over artifact-affected CT.\n",
    "- Quantitative metrics confirm enhancement.\n",
    "- Output can feed into registration/VR pipelines.\n",
    "\n",
    "**Next**: Move to `04_registration_pipeline.ipynb` to align time-series or 4D CT frames."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": ""
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
