In [1]:
# ==============================================================================
# IMPORTS
# ==============================================================================
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import Compose, Resize, ToTensor
from torch.utils.data import DataLoader, Dataset
from skimage.color import lab2rgb, rgb2lab
from tqdm import tqdm
import matplotlib.pyplot as plt  # If you're using it for saving images


In [2]:
# ==============================================================================
# DATASET CLASS DEFINITION
# ==============================================================================
class LabColorizationDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_files = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image = Image.open(self.image_files[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        image_np = np.array(image)
        lab_image = rgb2lab(image_np.transpose((1, 2, 0)))
        L = lab_image[:, :, 0] / 100
        AB = lab_image[:, :, 1:] / 128
        L = ToTensor()(L)  # This should now create a 3D tensor [1, height, width]
        AB = ToTensor()(AB)
        return L, AB

In [3]:
# ==============================================================================
# MODEL DEFINITION
# ==============================================================================
class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ConvTranspose2d(256, 128, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ConvTranspose2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 2, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [4]:
# ==============================================================================
# CHECK GPU AVAILABILITY
# ==============================================================================
def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using GPU:", torch.cuda.get_device_name(torch.cuda.current_device()))
    else:
        device = torch.device("cpu")
        print("Using CPU")
    return device

In [5]:
# ==============================================================================
# TRAINING FUNCTION
# ==============================================================================
def train_model(model, train_loader, epochs=100):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.MSELoss()
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}')

In [6]:
# ==============================================================================
# MAIN EXECUTION BLOCK
# ==============================================================================
if __name__ == "__main__":
    # Get the device (GPU or CPU)
    device = get_device()

    # Define the data transformations
    transform = Compose([
        Resize((256, 256)),
        ToTensor()
    ])

    # Load datasets and create data loaders
    train_dataset = LabColorizationDataset('dataset_2/data/train_color', transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

    test_dataset_black = LabColorizationDataset('dataset_2/data/test_black', transform=transform)
    test_loader_black = DataLoader(test_dataset_black, batch_size=16, shuffle=False)

    test_dataset_color = LabColorizationDataset('dataset_2/data/test_color', transform=transform)
    test_loader_color = DataLoader(test_dataset_color, batch_size=16, shuffle=False)

    # Instantiate and move the model to the designated device (GPU or CPU)
    model = ConvAutoencoder().to(device)

    # Train the model
    train_model(model, train_loader, epochs=100)

    # Save the model weights
    torch.save(model.state_dict(), 'autoencoder_model.pth')


Using GPU: NVIDIA GeForce RTX 3050 Laptop GPU
Epoch 1/100, Loss: 0.012742999044005958
Epoch 2/100, Loss: 0.01240967680363895
Epoch 3/100, Loss: 0.012302816562806837
Epoch 4/100, Loss: 0.012319290345481123
Epoch 5/100, Loss: 0.012278032524826617
Epoch 6/100, Loss: 0.012115717480203118
Epoch 7/100, Loss: 0.012053664196461153
Epoch 8/100, Loss: 0.01197997181447217
Epoch 9/100, Loss: 0.011826080761064357
Epoch 10/100, Loss: 0.011790103945559778
Epoch 11/100, Loss: 0.011776264783697197
Epoch 12/100, Loss: 0.011687585011350747
Epoch 13/100, Loss: 0.011646198839270554
Epoch 14/100, Loss: 0.011642711237072945
Epoch 15/100, Loss: 0.011572456939485127
Epoch 16/100, Loss: 0.011543767778173137
Epoch 17/100, Loss: 0.011488292160958717
Epoch 18/100, Loss: 0.011522910330063713
Epoch 19/100, Loss: 0.011416517397442375
Epoch 20/100, Loss: 0.01138936878649143
Epoch 21/100, Loss: 0.011301609242674166
Epoch 22/100, Loss: 0.011269349091576217
Epoch 23/100, Loss: 0.011179256077391652
Epoch 24/100, Loss: 0.0

In [7]:
import torch
import torchvision
from torchviz import make_dot

# Create the ConvAutoencoder model
model = ConvAutoencoder()

# Define the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the device
model = model.to(device)

# Create a dummy input tensor (you can replace this with your actual input)
x = torch.randn(1, 1, 256, 256).to(device)

# Call the forward function to perform a forward pass
out = model(x)

# Use make_dot to visualize the computation graph
dot = make_dot(out, params=dict(model.named_parameters()))
dot.format = 'png'  # You can change the format to 'png', 'pdf', 'svg', etc.
dot.render("conv_autoencoder_graph")


'conv_autoencoder_graph.png'

In [8]:
# ==============================================================================
# PREDICTION FUNCTION
# ==============================================================================
def predict_and_save(model, dataset_loader, save_dir, device):
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)

    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():  # No need to track gradients for predictions
        for i, (L, _) in enumerate(dataset_loader):
            L = L.to(device)  # Move L to the device
            AB_pred = model(L).to('cpu')  # Predict the AB channels and immediately move to CPU

            # Post-process the outputs
            L = L.to('cpu').squeeze().numpy() * 100  # Move to CPU, then convert to NumPy
            AB_pred = AB_pred.squeeze().numpy() * 128  # Already on CPU, convert to NumPy

            # Combine L channel with predicted AB channels
            pred_lab_image = np.stack((L, AB_pred[0], AB_pred[1]), axis=2)
            pred_rgb_image = lab2rgb(pred_lab_image)

            # Get the original filename
            original_filename = dataset_loader.dataset.image_files[i]

            # Save the colorized image
            save_path = os.path.join(save_dir, os.path.basename(original_filename))
            plt.imsave(save_path, pred_rgb_image)


In [9]:
def plot_image_comparisons(base_dir, dirs, n_images=10):
    # Ensure the number of images is a positive integer
    n_images = max(1, n_images)
    
    # Find common filenames in the first directory listed in dirs
    first_dir_path = os.path.join(base_dir, next(iter(dirs.values())))
    filenames = [f for f in os.listdir(first_dir_path) if os.path.isfile(os.path.join(first_dir_path, f))]
    filenames = filenames[:n_images]  # Limit to n_images

    # Plotting setup
    fig, axes = plt.subplots(nrows=n_images, ncols=len(dirs), figsize=(15, 5 * n_images))
    axes = np.array(axes).reshape(n_images, -1)  # Ensure axes are indexed correctly

    # Load and plot each image
    for i, filename in enumerate(filenames):
        for j, (label, sub_dir) in enumerate(dirs.items()):
            img_path = os.path.join(base_dir, sub_dir, filename)
            ax = axes[i, j]
            if os.path.exists(img_path):
                img = Image.open(img_path)
                ax.imshow(img)
            else:
                # Display a placeholder if the image does not exist
                img = Image.new('RGB', (256, 256), color='white')
                ax.text(0.5, 0.5, 'Image not available', fontsize=12, ha='center', va='center', transform=ax.transAxes)
                ax.imshow(img)
            ax.set_title(f'{label} - {filename}')
            ax.axis('off')

    plt.tight_layout()
    plt.show()



In [None]:
if __name__ == "__main__":
    # ==============================================================================
    # Initialize the device
    # ==============================================================================
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # ==============================================================================
    # Load the model
    # ==============================================================================
    model = ConvAutoencoder()
    model.load_state_dict(torch.load('autoencoder_model.pth'))
    model.to(device)  # Move model to the designated device

    # ==============================================================================
    # Define a function to handle dataset predictions and saving
    # ==============================================================================
    def process_dataset(data_path, save_path):
        pred_transform = Compose([
            Resize((256, 256)),
            ToTensor()
        ])
        pred_dataset = LabColorizationDataset(data_path, transform=pred_transform)
        pred_loader = DataLoader(pred_dataset, batch_size=1, shuffle=False)
        predict_and_save(model, pred_loader, save_path, device)

    # ==============================================================================
    # Process each dataset
    # ==============================================================================
    datasets_info = {
        'pred_black': 'pred_pytorch',
        'pred_black_train': 'pred_pytorch_train',
        'pred_black_test': 'pred_pytorch_test'
    }

    for dataset, save_folder in datasets_info.items():
        data_path = f'dataset_2/data/{dataset}'
        save_path = f'dataset_2/data/{save_folder}'
        process_dataset(data_path, save_path)

    # ==============================================================================
    # Plotting comparison results
    # ==============================================================================
    base_dir = 'dataset_2/data'
    dirs_test1 = {
        'Original': 'pred_black',
        'Official Colored': 'pred_color',
        'Model Output': 'pred_pytorch'
    }
    dirs_test2 = {
        'Original': 'pred_black_train',
        'Official Colored': 'pred_color_train',
        'Model Output': 'pred_pytorch_train'
    }
    dirs_test3 = {
        'Original': 'pred_black_test',
        'Official Colored': 'pred_color_test',
        'Model Output': 'pred_pytorch_test'
    }
  
    print("Displaying Test Images Comparison:")
    plot_image_comparisons(base_dir, dirs_test1, n_images=9)
    print("Displaying Training Images Comparison:")
    plot_image_comparisons(base_dir, dirs_test2, n_images=10)
    print("Displaying Additional Test Images Comparison:")
    plot_image_comparisons(base_dir, dirs_test3, n_images=10)
