<a href="https://colab.research.google.com/github/ApurvaMayank-iitb/24D0894_IE643_NIFTI/blob/main/Interface_AnoGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import files
uploaded = files.upload()  # This will prompt you to upload both .pth files


Saving Discriminator_trained.pth to Discriminator_trained.pth
Saving Generator_trained.pth to Generator_trained.pth


In [3]:
import torch
import torch.nn as nn

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Generator model definition
class Generator(nn.Module):
    def __init__(self, input_dim=25, output_channels=1):
        super(Generator, self).__init__()

        self.net = nn.Sequential(
            nn.ConvTranspose2d(input_dim, 128, kernel_size=4, stride=1, padding=0),  # 1x1 -> 4x4
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 4x4 -> 8x8
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # 8x8 -> 16x16
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),  # 16x16 -> 32x32
            nn.BatchNorm2d(16),
            nn.ReLU(True),

            nn.ConvTranspose2d(16, output_channels, kernel_size=8, stride=8, padding=0),  # 32x32 -> 256x256
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(x.size(0), x.size(1), 1, 1)  # Reshape to start with a 1x1 input
        return self.net(x)

# Discriminator model definition
class Discriminator(nn.Module):
    def __init__(self, input_channels=1):
        super(Discriminator, self).__init__()

        self.net = nn.Sequential(
            nn.Conv2d(input_channels, 16, kernel_size=4, stride=2, padding=1),  # 256x256 -> 128x128
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1),  # 128x128 -> 64x64
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 64x64 -> 32x32
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 32x32 -> 16x16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 1, kernel_size=4, stride=1, padding=0),  # 16x16 -> 13x13
            nn.Sigmoid()  # Probability of real or fake
        )

    def forward(self, x):
        return self.net(x).view(-1, 1).squeeze(1)  # Flatten output

# Initialize models
latent_dim = 25
G = Generator(input_dim=latent_dim, output_channels=1).to(device)
D = Discriminator(input_channels=1).to(device)


In [4]:
# Load the trained model weights
G.load_state_dict(torch.load('Generator_trained.pth', map_location=device))
D.load_state_dict(torch.load('Discriminator_trained.pth', map_location=device))

# Set the models to evaluation mode
G.eval()
D.eval()


  G.load_state_dict(torch.load('Generator_trained.pth', map_location=device))
  D.load_state_dict(torch.load('Discriminator_trained.pth', map_location=device))


Discriminator(
  (net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(128, 1, kernel_size=(4, 4), stride=(1, 1))
    (12): Sigmoid()
  )
)

In [5]:
import gradio as gr
import nibabel as nib
import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Preprocessing and Averaging Function
def load_and_average_modalities(modalities, slice_axis=2, target_size=(256, 256)):
    processed_slices = []
    for file_path in modalities:
        # Load .nii file
        img = nib.load(file_path)
        data = img.get_fdata()

        # Extract the slice with the most non-zero pixels
        slices = [data.take(i, axis=slice_axis) for i in range(data.shape[slice_axis])]
        max_index = np.argmax([np.count_nonzero(slice_) for slice_ in slices])
        best_slice = slices[max_index]

        # Resize and normalize
        best_slice_resized = cv2.resize(best_slice, target_size)
        best_slice_normalized = best_slice_resized / np.max(best_slice_resized) if np.max(best_slice_resized) > 0 else best_slice_resized
        processed_slices.append(best_slice_normalized)

    # Average across modalities
    averaged_slice = np.mean(processed_slices, axis=0)
    return averaged_slice

# Main Function to Run the Anomaly Detection
def detect_anomaly(*modalities):
    # Step 1: Load and preprocess modalities
    averaged_slice = load_and_average_modalities(modalities)
    validation_tensor = torch.tensor(averaged_slice, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

    # Step 2: Generate fake image and calculate reconstruction error
    noise = torch.randn((1, latent_dim), device=device)
    with torch.no_grad():
        generated_img = G(noise).squeeze().cpu().numpy()

    real_img_np = validation_tensor.squeeze().cpu().numpy()
    reconstruction_error = np.abs(real_img_np - generated_img)

    # Step 3: Thresholding for anomaly mask
    mean_error = np.mean(reconstruction_error)
    std_error = np.std(reconstruction_error)
    threshold = mean_error + 1.5 * std_error
    anomaly_mask = reconstruction_error > threshold
    anomaly_ratio = np.sum(anomaly_mask) / anomaly_mask.size
    result_text = "Tumor Found!" if anomaly_ratio > 0.04 else "No Tumor Found."

    # Step 4: Plot results
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    axes[0].imshow(real_img_np, cmap='gray')
    axes[0].set_title("Original Averaged Image")
    axes[1].imshow(generated_img, cmap='gray')
    axes[1].set_title("Generated Image")
    axes[2].imshow(reconstruction_error, cmap='hot')
    axes[2].set_title("Reconstruction Error")
    axes[3].imshow(anomaly_mask, cmap='gray')
    axes[3].set_title("Anomaly Mask")

    plt.tight_layout()
    plt.close(fig)
    return result_text, fig

# Define Gradio Interface
iface = gr.Interface(
    fn=detect_anomaly,
    inputs=[
        gr.inputs.File(label="Upload T1CE Modality (.nii)", type="file"),
        gr.inputs.File(label="Upload T2 Modality (.nii)", type="file"),
        # Add more modalities if needed
    ],
    outputs=[
        gr.outputs.Textbox(label="Detection Result"),
        gr.outputs.Image(type="plot", label="Anomaly Detection Visualization")
    ],
    title="MRI Anomaly Detection",
    description="Upload MRI scans for different modalities to detect anomalies."
)

# Launch the interface
iface.launch(debug=True)


ModuleNotFoundError: No module named 'gradio'

In [6]:
!pip install gradio nibabel


Collecting gradio
  Downloading gradio-5.4.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.4-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.4.0-py3-none-any.whl.metadata (2.9 kB)
Collecting gradio-client==1.4.2 (from gradio)
  Downloading gradio_client-1.4.2-py3-none-any.whl.metadata (7.1 kB)
Collecting huggingface-hub>=0.25.1 (from gradio)
  Downloading huggingface_hub-0.26.2-py3-none-any.whl.metadata (13 kB)
Collecting markupsafe~=2.0 (from gradio)
  Downloading MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart==0.0.12 (from gradio)
  Downloading python_multipart-0.0.12-py3-none-any.whl.metadata (1.9 kB)
Col

In [None]:
import gradio as gr
import nibabel as nib
import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Preprocessing and Averaging Function
def load_and_average_modalities(modalities, slice_axis=2, target_size=(256, 256)):
    processed_slices = []
    for file_path in modalities:
        # Load .nii file
        img = nib.load(file_path)
        data = img.get_fdata()

        # Extract the slice with the most non-zero pixels
        slices = [data.take(i, axis=slice_axis) for i in range(data.shape[slice_axis])]
        max_index = np.argmax([np.count_nonzero(slice_) for slice_ in slices])
        best_slice = slices[max_index]

        # Resize and normalize
        best_slice_resized = cv2.resize(best_slice, target_size)
        best_slice_normalized = best_slice_resized / np.max(best_slice_resized) if np.max(best_slice_resized) > 0 else best_slice_resized
        processed_slices.append(best_slice_normalized)

    # Average across modalities
    averaged_slice = np.mean(processed_slices, axis=0)
    return averaged_slice

# Main Function to Run the Anomaly Detection
def detect_anomaly(*modalities):
    # Step 1: Load and preprocess modalities
    averaged_slice = load_and_average_modalities(modalities)
    validation_tensor = torch.tensor(averaged_slice, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

    # Step 2: Generate fake image and calculate reconstruction error
    noise = torch.randn((1, latent_dim), device=device)
    with torch.no_grad():
        generated_img = G(noise).squeeze().cpu().numpy()

    real_img_np = validation_tensor.squeeze().cpu().numpy()
    reconstruction_error = np.abs(real_img_np - generated_img)

    # Step 3: Thresholding for anomaly mask
    mean_error = np.mean(reconstruction_error)
    std_error = np.std(reconstruction_error)
    threshold = mean_error + 1.5 * std_error
    anomaly_mask = reconstruction_error > threshold
    anomaly_ratio = np.sum(anomaly_mask) / anomaly_mask.size
    result_text = "Tumor Found!" if anomaly_ratio > 0.04 else "No Tumor Found."

    # Step 4: Plot results
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    axes[0].imshow(real_img_np, cmap='gray')
    axes[0].set_title("Original Averaged Image")
    axes[1].imshow(generated_img, cmap='gray')
    axes[1].set_title("Generated Image")
    axes[2].imshow(reconstruction_error, cmap='hot')
    axes[2].set_title("Reconstruction Error")
    axes[3].imshow(anomaly_mask, cmap='gray')
    axes[3].set_title("Anomaly Mask")

    plt.tight_layout()
    plt.close(fig)
    return result_text, fig

# Define Gradio Interface
iface = gr.Interface(
    fn=detect_anomaly,
    inputs=[
        gr.File(label="Upload T1CE Modality (.nii)"),
        gr.File(label="Upload T2 Modality (.nii)"),
        # Add more modalities if needed
    ],
    outputs=[
        gr.Textbox(label="Detection Result"),
        gr.Plot(label="Anomaly Detection Visualization")
    ],
    title="MRI Anomaly Detection",
    description="Upload MRI scans for different modalities to detect anomalies."
)

# Launch the interface
iface.launch(debug=True)


Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://ae643de7f04398a2e5.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
