In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
!pip install piq pytorch-msssim

Collecting piq
  Downloading piq-0.8.0-py3-none-any.whl.metadata (17 kB)
Collecting pytorch-msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->pytorch-msssim)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->pytorch-msssim)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->pytorch-msssim)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->pytorch-msssim)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->pytorch-msssim)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

In [13]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out += residual
        return out


class FeatureEncoderModule(nn.Module):
    def __init__(self, in_channels=1, out_channels=64, num_res_blocks=4):
        super(FeatureEncoderModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.res_blocks = nn.Sequential(*[ResidualBlock(out_channels) for _ in range(num_res_blocks)])
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.relu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.res_blocks(x)
        x = self.conv2(x)
        return x


class FeatureDecoderModule(nn.Module):
    def __init__(self, in_channels=64, out_channels=1, num_res_blocks=4):
        super(FeatureDecoderModule, self).__init__()
        self.conv1 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.res_blocks = nn.Sequential(*[ResidualBlock(in_channels) for _ in range(num_res_blocks)])
        self.conv2 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.output_activation = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.res_blocks(x)
        x = self.conv2(x)
        x = self.output_activation(x)
        return x

class FeatureFusionEncryptionModule(nn.Module):
    def __init__(self, channels):
        super(FeatureFusionEncryptionModule, self).__init__()
        self.F_block = ResidualBlock(channels)  # Function F
        self.G_block = ResidualBlock(channels)  # Function G

    def forward(self, Ki, alpha_i):
        EKi = Ki + self.F_block(alpha_i)  # Encrypted Key
        EPi = alpha_i + self.G_block(EKi)  # Encrypted Feature
        return EKi, EPi


class FeatureFusionDecryptionModule(nn.Module):
    def __init__(self, channels):
        super(FeatureFusionDecryptionModule, self).__init__()
        self.F_block = ResidualBlock(channels)  # Function F
        self.G_block = ResidualBlock(channels)  # Function G

    def forward(self, EKi, EPi):
        alpha_i = EPi - self.G_block(EKi)  # Recover alpha
        Ki = EKi - self.F_block(alpha_i)  # Recover feature key
        return Ki, alpha_i



In [14]:
class MedicalImageEncryptionModel(nn.Module):
    def __init__(self, channels=64):
        super(MedicalImageEncryptionModel, self).__init__()
        self.encoder = FeatureEncoderModule()
        self.encryption = FeatureFusionEncryptionModule(channels)
        self.decryption = FeatureFusionDecryptionModule(channels)
        self.decoder = FeatureDecoderModule()

    def forward(self, x):
        # Step 1: Encode Image Features
        Ki = self.encoder(x)

        # Step 2: Generate a random fusion parameter alpha_i (same shape as Ki)
        alpha_i = torch.rand_like(Ki, device=Ki.device)

        # Step 3: Encrypt Features
        EKi, EPi = self.encryption(Ki, alpha_i)

        # Step 4: Decrypt Features
        recovered_Ki, recovered_alpha_i = self.decryption(EKi, EPi)

        # Step 5: Decode the Recovered Features
        reconstructed_x = self.decoder(recovered_Ki)

        return reconstructed_x, Ki, EKi, EPi, recovered_Ki

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
      img_path = self.image_files[idx]
      image = Image.open(img_path).convert("L")
      filename = os.path.basename(img_path)

      if self.transform:
          image = self.transform(image)

      return image, filename

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

In [None]:
train_dataset = CustomImageDataset(root_dir='/content/drive/MyDrive/temp_pnu_selection', transform=transform)
test_dataset = CustomImageDataset(root_dir='/content/drive/MyDrive/mtg_processed_final_path', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)


In [None]:
import torch.nn.functional as F
from pytorch_msssim import SSIM

# Initialize SSIM loss function
ssim_loss_fn = SSIM(data_range=1.0, size_average=True, channel=1)

def hybrid_loss(reconstructed, original, alpha=0.84):
    """
    Computes a weighted sum of (1 - SSIM) and MSE loss.
    Args:
        reconstructed: The reconstructed output from the model.
        original: The original input image.
        alpha: Weight for SSIM loss (should be between 0 and 1).

    Returns:
        Hybrid loss combining SSIM and MSE.
    """
    ssim_loss = 1 - ssim_loss_fn(reconstructed, original)  # Convert SSIM to loss
    mse_loss = F.mse_loss(reconstructed, original)  # MSE Loss

    return alpha * ssim_loss + (1 - alpha) * mse_loss  # Weighted sum


In [15]:
import torch
import torch.optim as optim
from tqdm import tqdm

# Initialize model and move to GPU
model = MedicalImageEncryptionModel().to("cuda")

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
num_epochs = 250
log_interval = 10  # Compute loss every 10 epochs

def train_model():
    for epoch in range(1, num_epochs + 1):
        model.train()
        train_loss = 0.0

        for images,_ in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
            images = images.to("cuda")
            optimizer.zero_grad()

            # Forward pass
            reconstructed, _, _, _, _ = model(images)

            # Compute loss
            loss = hybrid_loss(reconstructed, images)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        if epoch % log_interval == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Loss: {avg_train_loss:.6f}")

# Run training
train_model()


Epoch 1/250: 100%|██████████| 50/50 [00:09<00:00,  5.28it/s]
Epoch 2/250: 100%|██████████| 50/50 [00:08<00:00,  5.72it/s]
Epoch 3/250: 100%|██████████| 50/50 [00:08<00:00,  5.63it/s]
Epoch 4/250: 100%|██████████| 50/50 [00:08<00:00,  5.69it/s]
Epoch 5/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]
Epoch 6/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 7/250: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s]
Epoch 8/250: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s]
Epoch 9/250: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]
Epoch 10/250: 100%|██████████| 50/50 [00:08<00:00,  5.70it/s]


Epoch [10/250], Loss: 0.043005


Epoch 11/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 12/250: 100%|██████████| 50/50 [00:08<00:00,  5.67it/s]
Epoch 13/250: 100%|██████████| 50/50 [00:08<00:00,  5.69it/s]
Epoch 14/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 15/250: 100%|██████████| 50/50 [00:08<00:00,  5.70it/s]
Epoch 16/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 17/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]
Epoch 18/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 19/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 20/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]


Epoch [20/250], Loss: 0.024792


Epoch 21/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 22/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 23/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]
Epoch 24/250: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]
Epoch 25/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 26/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 27/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 28/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 29/250: 100%|██████████| 50/50 [00:08<00:00,  5.84it/s]
Epoch 30/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]


Epoch [30/250], Loss: 0.017503


Epoch 31/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 32/250: 100%|██████████| 50/50 [00:08<00:00,  5.86it/s]
Epoch 33/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 34/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 35/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]
Epoch 36/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 37/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 38/250: 100%|██████████| 50/50 [00:08<00:00,  5.84it/s]
Epoch 39/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 40/250: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]


Epoch [40/250], Loss: 0.012410


Epoch 41/250: 100%|██████████| 50/50 [00:08<00:00,  5.85it/s]
Epoch 42/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 43/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 44/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]
Epoch 45/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 46/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 47/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 48/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 49/250: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s]
Epoch 50/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]


Epoch [50/250], Loss: 0.010321


Epoch 51/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 52/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 53/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 54/250: 100%|██████████| 50/50 [00:08<00:00,  5.72it/s]
Epoch 55/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 56/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]
Epoch 57/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 58/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]
Epoch 59/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]
Epoch 60/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]


Epoch [60/250], Loss: 0.008783


Epoch 61/250: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s]
Epoch 62/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 63/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 64/250: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s]
Epoch 65/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]
Epoch 66/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 67/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 68/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]
Epoch 69/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 70/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]


Epoch [70/250], Loss: 0.008019


Epoch 71/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]
Epoch 72/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 73/250: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]
Epoch 74/250: 100%|██████████| 50/50 [00:08<00:00,  5.84it/s]
Epoch 75/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 76/250: 100%|██████████| 50/50 [00:08<00:00,  5.70it/s]
Epoch 77/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 78/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 79/250: 100%|██████████| 50/50 [00:08<00:00,  5.72it/s]
Epoch 80/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]


Epoch [80/250], Loss: 0.006792


Epoch 81/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 82/250: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]
Epoch 83/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]
Epoch 84/250: 100%|██████████| 50/50 [00:08<00:00,  5.65it/s]
Epoch 85/250: 100%|██████████| 50/50 [00:08<00:00,  5.68it/s]
Epoch 86/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 87/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 88/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 89/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]
Epoch 90/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]


Epoch [90/250], Loss: 0.005875


Epoch 91/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 92/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 93/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 94/250: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]
Epoch 95/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]
Epoch 96/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 97/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 98/250: 100%|██████████| 50/50 [00:08<00:00,  5.84it/s]
Epoch 99/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 100/250: 100%|██████████| 50/50 [00:08<00:00,  5.72it/s]


Epoch [100/250], Loss: 0.004587


Epoch 101/250: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s]
Epoch 102/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 103/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 104/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]
Epoch 105/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 106/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 107/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 108/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 109/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 110/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]


Epoch [110/250], Loss: 0.003745


Epoch 111/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 112/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 113/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 114/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 115/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 116/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]
Epoch 117/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 118/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 119/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 120/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]


Epoch [120/250], Loss: 0.003241


Epoch 121/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 122/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 123/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 124/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 125/250: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s]
Epoch 126/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 127/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 128/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 129/250: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s]
Epoch 130/250: 100%|██████████| 50/50 [00:08<00:00,  5.72it/s]


Epoch [130/250], Loss: 0.002651


Epoch 131/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 132/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 133/250: 100%|██████████| 50/50 [00:08<00:00,  5.72it/s]
Epoch 134/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 135/250: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s]
Epoch 136/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 137/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 138/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 139/250: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]
Epoch 140/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]


Epoch [140/250], Loss: 0.002364


Epoch 141/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 142/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 143/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 144/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]
Epoch 145/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 146/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 147/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]
Epoch 148/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 149/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 150/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]


Epoch [150/250], Loss: 0.001974


Epoch 151/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 152/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 153/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]
Epoch 154/250: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]
Epoch 155/250: 100%|██████████| 50/50 [00:08<00:00,  5.72it/s]
Epoch 156/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]
Epoch 157/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 158/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 159/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]
Epoch 160/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]


Epoch [160/250], Loss: 0.001724


Epoch 161/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 162/250: 100%|██████████| 50/50 [00:08<00:00,  5.84it/s]
Epoch 163/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 164/250: 100%|██████████| 50/50 [00:08<00:00,  5.69it/s]
Epoch 165/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]
Epoch 166/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 167/250: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]
Epoch 168/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 169/250: 100%|██████████| 50/50 [00:08<00:00,  5.70it/s]
Epoch 170/250: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]


Epoch [170/250], Loss: 0.001757


Epoch 171/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 172/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 173/250: 100%|██████████| 50/50 [00:08<00:00,  5.69it/s]
Epoch 174/250: 100%|██████████| 50/50 [00:08<00:00,  5.85it/s]
Epoch 175/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 176/250: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]
Epoch 177/250: 100%|██████████| 50/50 [00:08<00:00,  5.84it/s]
Epoch 178/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 179/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 180/250: 100%|██████████| 50/50 [00:08<00:00,  5.85it/s]


Epoch [180/250], Loss: 0.001320


Epoch 181/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 182/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 183/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]
Epoch 184/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 185/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 186/250: 100%|██████████| 50/50 [00:08<00:00,  5.85it/s]
Epoch 187/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 188/250: 100%|██████████| 50/50 [00:08<00:00,  5.68it/s]
Epoch 189/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]
Epoch 190/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]


Epoch [190/250], Loss: 0.001204


Epoch 191/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 192/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 193/250: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s]
Epoch 194/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 195/250: 100%|██████████| 50/50 [00:08<00:00,  5.84it/s]
Epoch 196/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 197/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 198/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]
Epoch 199/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 200/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]


Epoch [200/250], Loss: 0.001744


Epoch 201/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]
Epoch 202/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 203/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 204/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]
Epoch 205/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]
Epoch 206/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 207/250: 100%|██████████| 50/50 [00:08<00:00,  5.84it/s]
Epoch 208/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 209/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 210/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]


Epoch [210/250], Loss: 0.001006


Epoch 211/250: 100%|██████████| 50/50 [00:08<00:00,  5.79it/s]
Epoch 212/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 213/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]
Epoch 214/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]
Epoch 215/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 216/250: 100%|██████████| 50/50 [00:08<00:00,  5.82it/s]
Epoch 217/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]
Epoch 218/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]
Epoch 219/250: 100%|██████████| 50/50 [00:08<00:00,  5.84it/s]
Epoch 220/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]


Epoch [220/250], Loss: 0.000895


Epoch 221/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 222/250: 100%|██████████| 50/50 [00:08<00:00,  5.71it/s]
Epoch 223/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 224/250: 100%|██████████| 50/50 [00:08<00:00,  5.73it/s]
Epoch 225/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 226/250: 100%|██████████| 50/50 [00:08<00:00,  5.72it/s]
Epoch 227/250: 100%|██████████| 50/50 [00:08<00:00,  5.68it/s]
Epoch 228/250: 100%|██████████| 50/50 [00:08<00:00,  5.70it/s]
Epoch 229/250: 100%|██████████| 50/50 [00:08<00:00,  5.70it/s]
Epoch 230/250: 100%|██████████| 50/50 [00:08<00:00,  5.70it/s]


Epoch [230/250], Loss: 0.000850


Epoch 231/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 232/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 233/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 234/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 235/250: 100%|██████████| 50/50 [00:08<00:00,  5.81it/s]
Epoch 236/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 237/250: 100%|██████████| 50/50 [00:08<00:00,  5.72it/s]
Epoch 238/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]
Epoch 239/250: 100%|██████████| 50/50 [00:08<00:00,  5.74it/s]
Epoch 240/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]


Epoch [240/250], Loss: 0.000854


Epoch 241/250: 100%|██████████| 50/50 [00:08<00:00,  5.76it/s]
Epoch 242/250: 100%|██████████| 50/50 [00:08<00:00,  5.70it/s]
Epoch 243/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 244/250: 100%|██████████| 50/50 [00:08<00:00,  5.77it/s]
Epoch 245/250: 100%|██████████| 50/50 [00:08<00:00,  5.72it/s]
Epoch 246/250: 100%|██████████| 50/50 [00:08<00:00,  5.80it/s]
Epoch 247/250: 100%|██████████| 50/50 [00:08<00:00,  5.84it/s]
Epoch 248/250: 100%|██████████| 50/50 [00:08<00:00,  5.75it/s]
Epoch 249/250: 100%|██████████| 50/50 [00:08<00:00,  5.78it/s]
Epoch 250/250: 100%|██████████| 50/50 [00:08<00:00,  5.83it/s]

Epoch [250/250], Loss: 0.000853





In [20]:
import torch
import numpy as np
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

def evaluate_model():
    model.eval()
    total_ssim = 0.0
    total_psnr = 0.0
    num_samples = 0

    with torch.no_grad():  # Disable gradient tracking during evaluation
        for images,_ in tqdm(test_loader, desc="Evaluating Model"):
            images = images.to("cuda")

            # Forward pass through the model
            reconstructed_x, Ki, EKi, EPi, recovered_Ki = model(images)

            # Convert tensors to NumPy arrays for SSIM and PSNR computation
            original_np = images.squeeze().cpu().numpy()  # (N, H, W)
            recon_np = reconstructed_x.squeeze().cpu().numpy()  # (N, H, W)

            # Ensure 2D format per image (H, W)
            if original_np.ndim == 3:  # If batch of images (N, H, W)
                for i in range(original_np.shape[0]):
                    min_dim = min(original_np.shape[1], original_np.shape[2])
                    win_size = min(7, min_dim) if min_dim >= 7 else min_dim  # Adjust window size

                    batch_ssim = ssim(original_np[i], recon_np[i], data_range=1.0, win_size=win_size)
                    batch_psnr = psnr(original_np[i], recon_np[i], data_range=1.0)

                    total_ssim += batch_ssim
                    total_psnr += batch_psnr
                    num_samples += 1
            else:  # Single image case
                min_dim = min(original_np.shape[0], original_np.shape[1])
                win_size = min(7, min_dim) if min_dim >= 7 else min_dim  # Adjust window size

                batch_ssim = ssim(original_np, recon_np, data_range=1.0, win_size=win_size)
                batch_psnr = psnr(original_np, recon_np, data_range=1.0)

                total_ssim += batch_ssim
                total_psnr += batch_psnr
                num_samples += 1

    # Compute average SSIM and PSNR
    avg_ssim = total_ssim / num_samples
    avg_psnr = total_psnr / num_samples

    print(f"Average SSIM: {avg_ssim:.4f}")
    print(f"Average PSNR: {avg_psnr:.4f} dB")
    print(num_samples)

# Run Evaluation
evaluate_model()


Evaluating Model: 100%|██████████| 18/18 [00:02<00:00,  8.07it/s]

Average SSIM: 0.9950
Average PSNR: 36.6538 dB
138



