In [None]:
import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from pathlib import Path
import pydicom
from glob import glob


class NumpyDataset(Dataset):
    def __init__(self, ldct_dir, ndct_dir, ext='npy', transform=None):
        self.ext = ext.lower()
        self.ldct_paths = sorted(glob(str(Path(ldct_dir) / f'*.{self.ext}')))
        self.ndct_paths = sorted(glob(str(Path(ndct_dir) / f'*.{self.ext}')))
        assert len(self.ldct_paths) == len(self.ndct_paths), "LDCT and NDCT length mismatch"
        self.transform = transform

    def read_dicom(self, path):
        ds = pydicom.dcmread(path)
        img = ds.pixel_array.astype(np.float32)
        if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'):
            img = img * ds.RescaleSlope + ds.RescaleIntercept
        return img

    def __getitem__(self, index):
        if self.ext == 'dcm':
            ldct = self.read_dicom(self.ldct_paths[index])
            ndct = self.read_dicom(self.ndct_paths[index])
        elif self.ext == 'npy':
            ldct = np.load(self.ldct_paths[index]).astype(np.float32)
            ndct = np.load(self.ndct_paths[index]).astype(np.float32)
        else:
            raise ValueError("Unsupported file format")

        ldct = np.expand_dims(ldct, axis=0)
        ndct = np.expand_dims(ndct, axis=0)

        if self.transform:
            ldct = self.transform(ldct)
            ndct = self.transform(ndct)

        return ldct, ndct

    def __len__(self):
        return len(self.ldct_paths)

# Transformations
def get_transforms():
    return transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.float() / 1000.0),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.4),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.1, contrast=0.1)], p=0.3)
    ])


def build_model(architecture, encoder):
    return smp.create_model(
        arch=architecture.lower(),
        encoder_name=encoder.lower(),
        in_channels=1,
        classes=1,
        encoder_weights='imagenet'
    )


def train_model(model, train_loader, val_loader, num_epochs=50, device='cuda'):
    model = model.to(device)
    criterion = nn.MSELoss()
    optimizer = Adam(model.parameters(), lr=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            preds = model(x)
            loss = criterion(preds, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * x.size(0)
        train_loss = running_loss / len(train_loader.dataset)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                preds = model(x)
                loss = criterion(preds, y)
                val_loss += loss.item() * x.size(0)
        val_loss /= len(val_loader.dataset)

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_model.pth")

# Evaluation
def test_model(model, test_loader, device='cuda'):
    model.load_state_dict(torch.load("best_model.pth"))
    model = model.to(device)
    model.eval()
    psnr_list, ssim_list = [], []

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).cpu().numpy()
            targets = y.cpu().numpy()
            for i in range(preds.shape[0]):
                p = np.clip(preds[i, 0], 0, 1)
                t = targets[i, 0]
                psnr = peak_signal_noise_ratio(t, p, data_range=1.0)
                ssim = structural_similarity(t, p, data_range=1.0)
                psnr_list.append(psnr)
                ssim_list.append(ssim)

    print(f"Test PSNR (avg): {np.mean(psnr_list):.2f}")
    print(f"Test SSIM (avg): {np.mean(ssim_list):.4f}")

# Predicted image
def prediction_image(model, loader, device='cuda'):
    model.eval()
    model.to(device)
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            pred = model(x).cpu().numpy()
            x = x.cpu().numpy()
            y = y.cpu().numpy()
            break

    idx = 0
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.title("LDCT Input")
    plt.imshow(x[idx, 0], cmap='gray')
    plt.axis("off")
    plt.subplot(1, 3, 2)
    plt.title("NDCT Ground Truth")
    plt.imshow(y[idx, 0], cmap='gray')
    plt.axis("off")
    plt.subplot(1, 3, 3)
    plt.title("Denoised Output")
    plt.imshow(pred[idx, 0], cmap='gray')
    plt.axis("off")
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    ldct_dir = ""  # folder path
    ndct_dir = ""  # folder path
    extension = "npy"  # or 'dcm'
    transform = get_transforms()

    dataset = NumpyDataset(ldct_dir, ndct_dir, ext=extension, transform=transform)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    val_size = int(0.2 * len(train_dataset))
    train_dataset, val_dataset = random_split(train_dataset, [len(train_dataset)-val_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4)
    test_loader = DataLoader(test_dataset, batch_size=4)

    architecture = "Unet"
    encoder_name = "inceptionv4"
    model = build_model(architecture, encoder_name)

    train_model(model, train_loader, val_loader, num_epochs=50, device='cuda')
    test_model(model, test_loader, device='cuda')
    prediction_image(model, test_loader, device='cuda')