In [1]:
import torch
import torch.nn as nn
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image


In [2]:
class ImageEncoder(nn.Module):
    """
    Modified image encoder that maps a 256x256 input image to a latent representation,
    and also returns a skip connection feature map.
    """
    def __init__(self, emb_dim=256):
        super(ImageEncoder, self).__init__()
        self.emb_dim = emb_dim
        
        # Define the encoder body up to the skip connection extraction:
        self.encoder_body = nn.Sequential(
            # First block: 256x256 -> 128x128
            nn.Conv2d(3, 64, kernel_size=3, padding=1),    # [B, 64, 256, 256]
            nn.ReLU(),
            nn.MaxPool2d(2),                                # [B, 64, 128, 128]
            
            # Second block: 128x128 -> 64x64
            nn.Conv2d(64, 128, kernel_size=3, padding=1),   # [B, 128, 128, 128]
            nn.ReLU(),
            nn.MaxPool2d(2),                                # [B, 128, 64, 64]
            
            # Third block: 64x64 -> 32x32
            nn.Conv2d(128, 256, kernel_size=3, padding=1),  # [B, 256, 64, 64]
            nn.ReLU(),
            nn.MaxPool2d(2),                                # [B, 256, 32, 32]
            
            # EXTRA block: deeper feature extraction at 32x32 resolution
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # [B, 256, 32, 32]
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # [B, 256, 32, 32]
            nn.ReLU(),
            
            # Projection to embedding dimension – produces skip feature map
            nn.Conv2d(256, emb_dim, kernel_size=3, padding=1),  # [B, emb_dim, 32, 32]
            nn.ReLU()
        )
        
        # Global pooling to generate latent vector from the skip feature map.
        self.avgpool = nn.AdaptiveAvgPool2d(1)  # [B, emb_dim, 1, 1]

    def forward(self, od_image, return_skip=False):
        # Run the encoder body to get the feature map (skip connection)
        x = self.encoder_body(od_image)  # shape: [B, emb_dim, 32, 32]
        # Pool to obtain the latent vector
        latent = self.avgpool(x).squeeze(-1).squeeze(-1)  # shape: [B, emb_dim]
        if return_skip:
            return latent, x  # Return both latent vector and skip connection feature map
        return latent

In [3]:
class ImageDecoder(nn.Module):
    """
    Decoder that reconstructs a 256x256 RGB image from a latent vector,
    incorporating a skip connection from the encoder.
    """
    def __init__(self, latent_dim=256, emb_dim=256):
        super(ImageDecoder, self).__init__()
        # Project latent vector to a feature map of shape [B, emb_dim, 32, 32]
        self.fc = nn.Linear(latent_dim, emb_dim * 32 * 32)
        
        self.deconv_layers = nn.Sequential(
            # Upsample from 32x32 -> 64x64
            nn.ConvTranspose2d(emb_dim, 128, kernel_size=4, stride=2, padding=1),  # [B, 128, 64, 64]
            nn.ReLU(),
            # Extra convolution block at resolution 64x64
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            
            # Upsample from 64x64 -> 128x128
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),       # [B, 64, 128, 128]
            nn.ReLU(),
            # Extra convolution block at resolution 128x128
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            
            # Upsample from 128x128 -> 256x256
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),        # [B, 32, 256, 256]
            nn.ReLU(),
            # Extra convolution block at resolution 256x256
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            
            # Final convolution: map to RGB channels
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()  # Use Tanh if your images are normalized to [-1, 1]
        )

    def forward(self, latent, skip_connection=None):
        # Project the latent vector back into a spatial feature map.
        x = self.fc(latent)
        x = x.view(x.size(0), -1, 32, 32)
        # Incorporate the skip connection via element-wise addition if provided.
        if skip_connection is not None:
            x = x + skip_connection
        x = self.deconv_layers(x)
        return x

In [4]:
class AutoEncoder(nn.Module):
    """
    Autoencoder that combines the ImageEncoder and ImageDecoder with skip connections.
    """
    def __init__(self, latent_dim=256, emb_dim=256):
        super(AutoEncoder, self).__init__()
        self.encoder = ImageEncoder(emb_dim=emb_dim)
        self.decoder = ImageDecoder(latent_dim=latent_dim, emb_dim=emb_dim)
    
    def forward(self, x):
        # Return both the latent vector and the intermediate skip feature map.
        latent, skip = self.encoder(x, return_skip=True)
        # Pass the skip connection into the decoder.
        recon = self.decoder(latent, skip_connection=skip)
        return recon

In [18]:
class RetinalImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        """
        Args:
            image_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.image_dir = image_dir
        # List all image files in the directory (you can adjust the extensions as needed)
        self.image_files = [os.path.join(image_dir, f)
                            for f in os.listdir(image_dir)
                            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Read the image using PIL
        image_path = self.image_files[idx]
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image

In [6]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms

# Define a VGG-based feature extractor.
class VGGFeatureExtractor(nn.Module):
    def __init__(self, layer_index=9):
        """
        Extract features from the VGG16 network up to a certain layer.
        :param layer_index: Index up to which to extract features (e.g., 9 corresponds to 'relu2_2').
        """
        super(VGGFeatureExtractor, self).__init__()
        vgg_pretrained = models.vgg16(pretrained=True).features
        # Build a sequential module from the first layer up to layer_index
        self.features = nn.Sequential(*[vgg_pretrained[x] for x in range(layer_index)])
        # Freeze parameters to avoid updates during training
        for param in self.features.parameters():
            param.requires_grad = False
        
        # VGG normalization values (if needed)
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
        self.std  = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
    
    def forward(self, x):
        # Normalize input as VGG expects values in [0, 1] normalized with these mean & std.
        x = (x - self.mean.to(x.device)) / self.std.to(x.device)
        return self.features(x)

# Perceptual loss function (using L1 loss between feature maps)
class PerceptualLoss(nn.Module):
    def __init__(self, feature_extractor):
        super(PerceptualLoss, self).__init__()
        self.feature_extractor = feature_extractor
        self.criterion = nn.L1Loss()  # alternatively, nn.MSELoss()
        
    def forward(self, recon, target):
        # Extract features of both reconstructed and target images.
        features_recon = self.feature_extractor(recon)
        features_target = self.feature_extractor(target)
        # Compute the loss on features
        loss = self.criterion(features_recon, features_target)
        return loss

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder = AutoEncoder(latent_dim=128, emb_dim=128).to(device)
optimizer = optim.Adam(autoencoder.parameters(), lr=1e-4)

# Create the VGG feature extractor (extracting up to layer 9)
vgg_extractor = VGGFeatureExtractor(layer_index=9).to('cuda')
perceptual_loss_fn = PerceptualLoss(vgg_extractor).to('cuda')
reconstruction_loss_fn = nn.L1Loss()  # or nn.MSELoss() for pixel space loss

# Define the transforms (e.g., resize to 256x256, convert to tensor, normalize)
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

# Directory with retinal images
image_dir = './dataset/train'
val_image_dir = './dataset/val'

dataset = RetinalImageDataset(image_dir, transform=transform)
val_dataset = RetinalImageDataset(val_image_dir, transform=transform)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)



In [None]:
num_epochs = 5001
best_val = float('inf')

# If you already have a checkpoint, you can load the saved model and optimizer states.
# checkpoint_path = "./validation_results/saved_models/autoencoder_checkpoint200.pth"
# checkpoint = torch.load(checkpoint_path, map_location=device)
# autoencoder.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

for epoch in range(num_epochs):
    autoencoder.train()
    running_train_loss = 0.0
    for i, batch in enumerate(dataloader):
        batch = batch.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        recon = autoencoder(batch)
        
        # Compute the reconstruction loss in pixel space (optional or weighted)
        rec_loss = reconstruction_loss_fn(recon, batch)
        # Compute the perceptual loss
        perc_loss = perceptual_loss_fn(recon, batch)
        
        # Backward pass and optimization
        # Combine losses (weights can be tuned)
        loss = rec_loss + 2 * perc_loss  
        loss.backward()
        optimizer.step()
        
        running_train_loss += loss.item() * batch.size(0)
    
    # Compute average training loss for this epoch.
    train_loss = running_train_loss / len(dataset)
    
    # --- Validation phase ---
    autoencoder.eval()
    running_val_loss = 0.0
    with torch.no_grad():
        for i, val_batch in enumerate(val_dataloader):
            val_batch = val_batch.to(device)
            val_recon = autoencoder(val_batch)

            val_rec_loss = reconstruction_loss_fn(val_recon, val_batch)
            val_perc_loss = perceptual_loss_fn(val_recon, val_batch)
            
            val_loss = val_rec_loss + 2 * val_perc_loss
            
            running_val_loss += val_loss.item() * val_batch.size(0)
    
    val_loss = running_val_loss / len(val_dataloader.dataset)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Save a checkpoint when validation loss improves.
    if epoch > 1000 and val_loss < best_val:
        best_val = val_loss
        os.makedirs("./validation_results/saved_models", exist_ok=True)
        # Save both model and optimizer states.
        torch.save({
            'epoch': epoch,
            'model_state_dict': autoencoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, f"./validation_results/saved_models/autoencoder_checkpoint_epoch{epoch}.pth")
        
        os.makedirs("validation_results", exist_ok=True)
        # Save one batch of validation reconstructions.
        with torch.no_grad():
            sample_batch = next(iter(val_dataloader)).to(device)
            recon_images = autoencoder(sample_batch)
            # Save the reconstructed images as a grid image.
            save_image(recon_images, f"validation_results/reconstructed_val_epoch{epoch}.png", nrow=4, normalize=True)
            # Also save the original images for comparison.
            save_image(sample_batch, f"validation_results/original_val_epoch{epoch}.png", nrow=4, normalize=True)
        
        print("Saved model and optimizer checkpoint and validation images.")
    # Also save every 200 epochs even if not improved.
    elif epoch % 200 == 0:
        os.makedirs("./validation_results/saved_models", exist_ok=True)
        torch.save({
            'epoch': epoch,
            'model_state_dict': autoencoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, f"./validation_results/saved_models/autoencoder_checkpoint_epoch{epoch}.pth")
        
        os.makedirs("validation_results", exist_ok=True)
        with torch.no_grad():
            sample_batch = next(iter(val_dataloader)).to(device)
            recon_images = autoencoder(sample_batch)
            save_image(recon_images, f"validation_results/reconstructed_val_epoch{epoch}.png", nrow=4, normalize=True)
            save_image(sample_batch, f"validation_results/original_val_epoch{epoch}.png", nrow=4, normalize=True)
        
        print("Saved periodic model and optimizer checkpoint and validation images.")


In [22]:
# If you already have a checkpoint, you can load the saved model and optimizer states.
checkpoint_path = "./validation_results/saved_models/autoencoder_checkpoint_epoch4400.pth"
checkpoint = torch.load(checkpoint_path, map_location=device)
autoencoder.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
autoencoder.to(device)
autoencoder.eval()

# Directory with retinal images
test_image_dir = './dataset/test'
test_dataset = RetinalImageDataset(test_image_dir, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)

# Run the autoencoder on the test set
with torch.no_grad():
    for batch_idx, test_data in enumerate(test_dataloader):
        # If your test dataset yields a tuple (input, label) adjust this accordingly:
        # inputs, _ = test_data
        inputs = test_data.to(device)  # Move inputs to the device
        
        # Forward pass: generate output from the autoencoder
        outputs = autoencoder(inputs)
        
        save_image(outputs, f"test_results/patient_{batch_idx}.png", nrow=1, normalize=True)

['./dataset/test/patient_00.jpg', './dataset/test/patient_01.jpg', './dataset/test/patient_02.jpg', './dataset/test/patient_03.jpg', './dataset/test/patient_04.jpg', './dataset/test/patient_05.jpg', './dataset/test/patient_06.jpg', './dataset/test/patient_07.jpg', './dataset/test/patient_08.jpg', './dataset/test/patient_09.jpg', './dataset/test/patient_10.jpg', './dataset/test/patient_11.jpg', './dataset/test/patient_12.jpg', './dataset/test/patient_13.jpg', './dataset/test/patient_14.jpg', './dataset/test/patient_15.jpg', './dataset/test/patient_16.jpg', './dataset/test/patient_17.jpg', './dataset/test/patient_18.jpg', './dataset/test/patient_19.jpg', './dataset/test/patient_20.jpg', './dataset/test/patient_21.jpg', './dataset/test/patient_22.jpg', './dataset/test/patient_23.jpg', './dataset/test/patient_24.jpg', './dataset/test/patient_25.jpg', './dataset/test/patient_26.jpg', './dataset/test/patient_27.jpg', './dataset/test/patient_28.jpg', './dataset/test/patient_29.jpg', './datase

In [34]:
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure

# Define directories for test and target images.
test_dir = "./test_results_3432/"
target_dir = "./dataset/test/"

# Define transformation: resizing, converting to tensor, and normalization.
# Adjust normalization or data_range as needed:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    # Normalization scales pixel values to approximately [-1, 1]
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

# Set data_range based on your transform:
# For normalized images (range ~[-1,1]), use data_range=2.0.
# If you remove Normalize, then use data_range=1.0.
ssim_metric = StructuralSimilarityIndexMeasure(data_range=2.0)

num_images = 64
total_ssim = 0.0

for i in range(num_images):
    # Construct file names. The format {:02d} ensures two digits (e.g., "00", "01", "63")
    test_file = os.path.join(test_dir, f"patient_{i:02d}.png")
    target_file = os.path.join(target_dir, f"patient_{i:02d}.jpg")
    
    # Open images and convert to RGB
    test_img = Image.open(test_file).convert("RGB")
    target_img = Image.open(target_file).convert("RGB")
    
    # Apply transformation and add the batch dimension (shape: (1, C, H, W))
    test_tensor = transform(test_img).unsqueeze(0)
    target_tensor = transform(target_img).unsqueeze(0)
    
    # Compute SSIM for the current image pair.
    ssim_score = ssim_metric(test_tensor, target_tensor)
    print(f"SSIM for patient_{i:02d}: {ssim_score.item()}")
    
    total_ssim += ssim_score.item()

average_ssim = total_ssim / num_images
print(f"Average SSIM over {num_images} images: {average_ssim}")


SSIM for patient_00: 0.9444077014923096
SSIM for patient_01: 0.9275009036064148
SSIM for patient_02: 0.9302933812141418
SSIM for patient_03: 0.7561014294624329
SSIM for patient_04: 0.9409477114677429
SSIM for patient_05: 0.6331303119659424
SSIM for patient_06: 0.7453079223632812
SSIM for patient_07: 0.7672991752624512
SSIM for patient_08: 0.9386633038520813
SSIM for patient_09: 0.9399006366729736
SSIM for patient_10: 0.41984423995018005
SSIM for patient_11: 0.906661331653595
SSIM for patient_12: 0.9183079600334167
SSIM for patient_13: 0.9610039591789246
SSIM for patient_14: 0.44831550121307373
SSIM for patient_15: 0.625623881816864
SSIM for patient_16: 0.9654307961463928
SSIM for patient_17: 0.9540240168571472
SSIM for patient_18: 0.9481490254402161
SSIM for patient_19: 0.9337709546089172
SSIM for patient_20: 0.9026065468788147
SSIM for patient_21: 0.9482941627502441
SSIM for patient_22: 0.958624541759491
SSIM for patient_23: 0.9517131447792053
SSIM for patient_24: 0.9480114579200745
S