In [2]:
import pydicom
import numpy as np
import cv2

def dicom_to_png(dicom_path, save_path):
    dicom = pydicom.dcmread(dicom_path)
    image = dicom.pixel_array
    image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX)
    image = image.astype(np.uint8)
    cv2.imwrite(save_path, image)




In [3]:
import torchvision.transforms as transforms
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((128, 128)),  
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


In [5]:
import nibabel as nib
def load_nifti(nifti_path):
    img = nib.load(nifti_path)
    img_data = img.get_fdata()
    return img_data


In [6]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Output in range [-1, 1]
        )

    def forward(self, z):
        return self.model(z)



In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img).view(-1, 1)


In [11]:
from torch.utils.data import DataLoader
import torchvision


# Train GAN with Medical Dataset (Example using Chest X-ray)
dataloader = DataLoader(
    torchvision.datasets.ImageFolder(root="./medical_dataset", transform=transform),
    batch_size=64, shuffle=True
)

# Use the same GAN training loop as before


FileNotFoundError: [Errno 2] No such file or directory: './medical_dataset'

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Load trained generator
generator = Generator().to(device)
generator.load_state_dict(torch.load("generator_medical.pth"))
generator.eval()

# Load and preprocess input medical image
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    image = Image.open(image_path).convert("L")
    image = transform(image).unsqueeze(0).to(device)
    return image

# Anomaly detection function
def detect_anomaly(image_path, threshold=0.3):
    original_image = preprocess_image(image_path)

    # Generate reconstructed normal image
    z = torch.randn(1, 100, 1, 1).to(device)
    reconstructed_image = generator(z).detach()

    # Compute error (Mean Absolute Error)
    error = torch.mean(torch.abs(original_image - reconstructed_image)).item()

    # Show images
    original_image_np = original_image.cpu().numpy().squeeze()
    reconstructed_image_np = reconstructed_image.cpu().numpy().squeeze()

    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    axes[0].imshow(original_image_np, cmap="gray")
    axes[0].set_title("Original Image")
    axes[1].imshow(reconstructed_image_np, cmap="gray")
    axes[1].set_title("Reconstructed Image")
    plt.show()

    # Decide if it's an anomaly
    if error > threshold:
        print(f"⚠️ Anomaly Detected! Error: {error:.4f}")
    else:
        print(f"✅ Normal Image. Error: {error:.4f}")


def detect_tumor(image_path, threshold=0.3):
    original_image = preprocess_image(image_path)

    # Generate reconstructed normal image
    z = torch.randn(1, 100, 1, 1).to(device)
    reconstructed_image = generator(z).detach()

    # Compute reconstruction error
    error = torch.mean(torch.abs(original_image - reconstructed_image)).item()

    if error > threshold:
        print(f"⚠️ Tumor Detected! Error: {error:.4f}")
    else:
        print(f"✅ No Tumor Detected. Error: {error:.4f}")


# Run anomaly detection
image_path = "test_xray.png"
detect_anomaly(image_path)
