In [2]:
#from google.colab import drive
#drive.mount('/content/drive')
#!unzip /content/drive/MyDrive/data.zip -d /content/drive/MyDrive/

!nvidia-smi

Thu May 23 01:50:33 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 546.29                 Driver Version: 546.29       CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce GTX 1650 Ti   WDDM  | 00000000:01:00.0 Off |                  N/A |
| N/A   46C    P8               4W /  50W |    183MiB /  4096MiB |     16%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
%pip install pytorch-msssim

Note: you may need to restart the kernel to use updated packages.


In [4]:
%pip install torchmetrics

Note: you may need to restart the kernel to use updated packages.


In [5]:
import torch
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import torchmetrics
from PIL import Image
import torch.optim as optim
from torchvision import transforms
from pytorch_msssim import MS_SSIM

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
class ColorblindDataset(torch.utils.data.Dataset):
    """Dataset class for loading pairs of images (original, colorblind-simulated)."""

    def __init__(self, data_dir, colorblind_types, transform=None):
        """
        Args:
            data_dir (str): Path to the directory containing image pairs.
            colorblind_types (list): List of colorblindness types to include (e.g., ["protanopia", "deuteranopia", "tritanopia"]).
            transform (callable, optional): Optional transform to be applied on images.
        """
        self.data_dir = data_dir
        self.colorblind_types = colorblind_types
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
        self.image_paths = self.load_data()

    def load_data(self):
        """Loads image pairs (original, colorblind-simulated)."""
        image_paths = []
        for colorblind_type in self.colorblind_types:
            type_dir = os.path.join(self.data_dir, colorblind_type)
            original_dir = os.path.join(self.data_dir, "original")
            for filename in os.listdir(type_dir):
                if filename.endswith((".jpg", ".png", ".jpeg")):
                    original_path = os.path.join(original_dir, filename)
                    colorblind_path = os.path.join(type_dir, filename)
                    image_paths.append((original_path, colorblind_path))
        return image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        """Loads and returns an original image and its colorblind-simulated version."""
        original_path, colorblind_path = self.image_paths[index]

        # print(f"Image loaded: {original_path}")
        # print(f"Colorblind_image Type: {colorblind_path}")

        original_image = Image.open(original_path).convert("RGB")
        colorblind_image = Image.open(colorblind_path).convert("RGB")

        # print("Original Image Shape (after loading):", original_image.size)
        # print("Colorblind Image Shape (after loading):", colorblind_image.size)

        if self.transform:
            original_image = self.transform(original_image)
            colorblind_image = self.transform(colorblind_image)
            # print("Original Image Shape (after preprocessing):", original_image.shape)
            # print("Colorblind Image Shape (after preprocessing):", colorblind_image.shape)

        return original_image, colorblind_image

In [7]:
device = "mps" if torch.backends.mps.is_available(
) else "cuda" if torch.cuda.is_available() else "cpu"


print(device)

cuda


In [8]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2,
                      padding=1),  # Output: 32x128x128
            nn.BatchNorm2d(32),
            nn.ReLU(True)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=2,
                      padding=1),  # Output: 64x64x64
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2,
                      padding=1),  # Output: 128x32x32
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )
        self.enc4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2,
                      padding=1),  # Output: 256x16x16
            nn.BatchNorm2d(256),
            nn.ReLU(True)
        )
        self.enc5 = nn.Sequential(
            # Output: 512x8x8 (Latent Representation)
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )

        # Decoder (with skip connections)
        self.dec5 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4,
                               stride=2, padding=1),  # Output: 256x16x16
            nn.BatchNorm2d(256),
            nn.ReLU(True)
        )
        self.dec4 = nn.Sequential(
            # Output: 128x32x32 (Concatenated: 512 -> 256)
            nn.ConvTranspose2d(256 * 2, 128, kernel_size=4,
                               stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )
        self.dec3 = nn.Sequential(
            # Output: 64x64x64 (Concatenated: 256 -> 128)
            nn.ConvTranspose2d(128 * 2, 64, kernel_size=4,
                               stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        self.dec2 = nn.Sequential(
            # Output: 32x128x128 (Concatenated: 128 -> 64)
            nn.ConvTranspose2d(64 * 2, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True)
        )
        self.dec1 = nn.Sequential(
            # Output: 3x256x256 (Concatenated: 64 -> 32)
            nn.ConvTranspose2d(32 * 2, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.enc5(enc4)

        dec5 = self.dec5(enc5)
        dec4 = self.dec4(torch.cat([dec5, enc4], 1))  # Skip connection
        dec3 = self.dec3(torch.cat([dec4, enc3], 1))  # Skip connection
        dec2 = self.dec2(torch.cat([dec3, enc2], 1))  # Skip connection
        dec1 = self.dec1(torch.cat([dec2, enc1], 1))  # Skip connection
        return dec1

In [9]:
def preprocess_image(image):
    # Check if additional preprocessing is needed
    if image.dtype != torch.float32:
        image = image.to(torch.float32)

    image = (image + 1) * 0.5

    return image


ssim = torchmetrics.StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
criterion = nn.MSELoss()
ms_ssim_value = MS_SSIM(data_range=1, size_average=True, channel=3)


def calculate_loss(model_output, colorblind_image):
    return 1-ms_ssim_value(model_output, colorblind_image)



In [10]:
def visualize_images(input_image, real_image, output_image, target_colorblind_type, iteration):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

    # Convert tensors to NumPy (assuming they are on the appropriate device)
    input_image_np = input_image[0].cpu().permute(1, 2, 0).detach().numpy()
    real_image_np = real_image[0].cpu().permute(1, 2, 0).detach().numpy()
    output_image_np = output_image[0].cpu().permute(1, 2, 0).detach().numpy()

    # Ensure image data is in the range [0, 1]
    input_image_np = np.clip(input_image_np, 0, 1)
    real_image_np = np.clip(real_image_np, 0, 1)
    output_image_np = np.clip(output_image_np, 0, 1)

    # input_image_np = cv2.resize(input_image_np, dsize=(512, 512), interpolation=cv2.INTER_CUBIC)
    # real_image_np = cv2.resize(real_image_np, dsize=(512, 512), interpolation=cv2.INTER_CUBIC)
    # output_image_np = cv2.resize(output_image_np, dsize=(512, 512), interpolation=cv2.INTER_CUBIC)
    # print("Original Image Shape:", input_image_np.shape)
    # print("Target Image Shape:", real_image_np.shape)
    # print("Output Image Shape:", output_image_np.shape)

    ax1.imshow(input_image_np)
    ax1.axis('off')
    ax1.set_title('Original Image')

    ax2.imshow(real_image_np)
    ax2.axis('off')
    ax2.set_title(f'Target ({target_colorblind_type})')

    ax3.imshow(output_image_np)
    ax3.axis('off')
    ax3.set_title(f'Generated Image ({target_colorblind_type})')
    # plt.savefig(f"generated_image_{target_colorblind_type}_{iteration}.jpg")  # Save the figure

    plt.tight_layout()
    plt.show()

    # plt.close(fig)

In [None]:
# Hyperparameters
num_epochs = 50
batch_size = 64
learning_rate = 0.0001

# Create the Autoencoder model
model = Autoencoder().to(device)
# print(model)
optimizer = optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


colorblind_types = ["protanopia", "tritanopia", "deuteranopia"]

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# Training Loop
total_iterations = 0
train_losses_epoch = []  # Average loss per epoch (training)
train_ssim_epoch = []    # Average SSIM values per epoch (training)
val_losses_epoch = []    # Average loss per epoch (validation)
val_ssim_epoch = []      # Average SSIM values per epoch (validation)

for epoch in range(num_epochs):

    data_dir = "data/train"
    dataset = ColorblindDataset(data_dir, colorblind_types, transform)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True)

    model.train()
    train_loss = 0
    train_ssim = []

    for i, (image, colorblind_image) in enumerate(dataloader):
        original_images = preprocess_image(image).to(device)
        colorblind_images = preprocess_image(colorblind_image).to(device)
        target_colorblind_type = colorblind_types[i % len(colorblind_types)]
        # Zero the parameter gradients
        optimizer.zero_grad()

        # print("Original Image Range:", original_images.min(), original_images.max())
        # print("Target Image Range:", colorblind_images.min(), colorblind_images.max())
        # Forward pass (reconstruction)
        output_images = model(original_images)

        # print("Output Image Shape:", output_images.shape)
        # print("Output Image Range:", output_images.min(), output_images.max())

        # Compute loss
        loss = calculate_loss(output_images, colorblind_images)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        total_iterations += 1

        train_ssim_value = ssim(output_images, colorblind_images)

        # Print loss and SSIM for each iteration (not stored)
        print(f"Epoch: {epoch + 1}/{num_epochs}, Type: {target_colorblind_type}, Iteration: {
              i+1}, Training Loss: {loss.item():.4f}, SSIM: {train_ssim_value.item():.4f}")

        train_loss += loss.item()
        train_ssim.append(train_ssim_value.item())

    # Calculate average loss and SSIM for the epoch (training) and store for plotting
    avg_train_loss = train_loss / len(dataloader)
    avg_train_ssim = sum(train_ssim) / len(train_ssim)
    train_losses_epoch.append(avg_train_loss)
    train_ssim_epoch.append(avg_train_ssim)

    visualize_images(original_images, colorblind_images,
                     output_images, target_colorblind_type, total_iterations-1)

    val_data_dir = "data/val"
    val_dataset = ColorblindDataset(val_data_dir, colorblind_types, transform)
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=1, shuffle=False)

    model.eval()  # Set model to evaluation mode
    val_loss = 0
    val_ssim = []
    with torch.no_grad():  # No need to calculate gradients for validation
        for image, colorblind_image in val_dataloader:
            original_images = preprocess_image(image).to(device)
            colorblind_images = preprocess_image(colorblind_image).to(device)
            output_images = model(original_images)
            loss = calculate_loss(output_images, colorblind_images)

            val_ssim_value = ssim(output_images, colorblind_images)

            val_loss += loss.item()
            val_ssim.append(val_ssim_value.item())

        target_colorblind_type = colorblind_types[i % len(colorblind_types)]
        visualize_images(original_images, colorblind_images,
                         output_images, target_colorblind_type, 0)

    avg_val_loss = val_loss / len(val_dataloader)
    avg_val_ssim = sum(val_ssim) / len(val_ssim)
    val_ssim_epoch.append(avg_val_ssim)
    val_losses_epoch.append(avg_val_loss)
    print(f"\nEpoch: {epoch + 1}/{num_epochs}, Training Loss: {avg_train_loss:.4f}, Training SSIM: {
          avg_train_ssim:.4f},\nValidation Loss: {avg_val_loss:.4f},Validation SSIM: {avg_val_ssim:.4f}\n")

print("Training Completed Successfully!")

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(train_losses_epoch, label='Training Loss', linestyle='--')
plt.plot(val_losses_epoch, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss [1-MSSIM] Over Epochs')
plt.legend()
plt.savefig('MSSIM_Loss_Graph.jpg')
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(train_ssim_epoch, label='Training SSIM',
         color='blue', linestyle='--')
plt.plot(val_ssim_epoch, label='Validation SSIM',
         color='green')
plt.xlabel('Epoch')
plt.ylabel('SSIM')
plt.title('Training and Validation SSIM Over Epochs')
plt.legend()
plt.savefig('SSIM_Accuracy_Graph.jpg')
plt.show()

In [17]:
model = Autoencoder().to(device)
model_save_path = "/AutoencoderDaltonized/trainedAutoencoder256_SSIM_Model_latest.pth"
torch.save(model.state_dict(), model_save_path)

print(f"Model saved to {model_save_path}")

Model saved to /AutoencoderDaltonized/trainedAutoencoder256_SSIM_Model_latest.pth


In [None]:
import time

# Load your trained model
model_path = "/AutoencoderDaltonized/trainedAutoencoder256_SSIM_Model_latest.pth"
model = Autoencoder().to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

# Test Data
colorblind_types = ["protanopia", "tritanopia", "deuteranopia"]

test_transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
                                     ])

test_data_dir = "data/test"
test_dataset = ColorblindDataset(
    test_data_dir, colorblind_types, test_transform)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False)

test_losses = []
test_ssims = []
total_forward_time = 0
num_iterations = 0

# Testing Loop
with torch.no_grad():
    for i, (image, colorblind_image) in enumerate(test_dataloader):
        original_images = preprocess_image(image).to(device)
        colorblind_images = preprocess_image(colorblind_image).to(device)

        # Start timing
        start_time = time.time()

        output_images = model(original_images)

        # End timing
        end_time = time.time()
        forward_time = end_time - start_time

        # Accumulate forward pass time and count iterations
        total_forward_time += forward_time
        num_iterations += 1

        # Loss and SSIM Calculation
        loss = calculate_loss(output_images, colorblind_images)
        ssim_value = ssim(output_images, colorblind_images)

        test_losses.append(loss.item())
        test_ssims.append(ssim_value.item())

        target_colorblind_type = colorblind_types[i % len(colorblind_types)]

    visualize_images(original_images, colorblind_images,
                     output_images, target_colorblind_type, 0)

batch_size = 1

# Aggregate and Report Results
avg_test_loss = sum(test_losses) / len(test_losses)
avg_test_ssim = sum(test_ssims) / len(test_ssims)

# Calculate average forward pass time
avg_forward_time_per_batch = total_forward_time / num_iterations
avg_forward_time_per_image = avg_forward_time_per_batch / batch_size

print(f"\nAverage Test Loss: {avg_test_loss:.4f}")
print(f"Average Test SSIM: {avg_test_ssim:.4f}")
print(f"Average Forward Pass Time per Batch: {
      avg_forward_time_per_batch:.4f} seconds")
print(f"Average Forward Pass Time per Image: {
      avg_forward_time_per_image:.6f} seconds\n")