In [None]:
import os
import shutil
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F

In [None]:
def delete_all_in_folder(folder_path):
    if os.path.exists(folder_path):
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print(f'Failed to delete {file_path}. Reason: {e}')
    else:
        print(f'The folder {folder_path} does not exist.')
delete_all_in_folder('/kaggle/working/')

In [None]:
def set_seed(seed_value: int):
    """
    Set the seed for reproducibility in Python, NumPy, and PyTorch.
    Args:
    - seed_value (int): The seed value to use for reproducibility.
    """
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_value = 42
set_seed(seed_value)

In [None]:
class CIFAR100ColorizationDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.folder = folder
        self.transform = transform
        self.samples = []
        for subdir in sorted(os.listdir(folder)):
            subdir_path = os.path.join(folder, subdir)
            if os.path.isdir(subdir_path):
                for img_name in os.listdir(subdir_path):
                    img_path = os.path.join(subdir_path, img_name)
                    if img_path.lower().endswith(('.jpg', '.jpeg', '.png')):
                        self.samples.append(img_path)
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        gray_image = transforms.functional.rgb_to_grayscale(image)
        gray_image = gray_image.repeat(3, 1, 1)
        return gray_image, image
class ColorizationResNet(nn.Module):
    def __init__(self):
        super(ColorizationResNet, self).__init__()
        resnet = models.resnet50(weights=False)
        self.encoder_layers = list(resnet.children())[:-2]
        self.encoder1 = nn.Sequential(*self.encoder_layers[:4])
        self.encoder2 = self.encoder_layers[4]
        self.encoder3 = self.encoder_layers[5]
        self.encoder4 = self.encoder_layers[6]
        self.encoder5 = self.encoder_layers[7]
        self.decoder5 = nn.Sequential(
            nn.ConvTranspose2d(2048, 1024, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )
        self.decoder4 = nn.Sequential(
            nn.ConvTranspose2d(1024 + 1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.decoder3 = nn.Sequential(
            nn.ConvTranspose2d(512 + 512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(256 + 256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(128 + 64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.final_layer = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(enc1)
        enc3 = self.encoder3(enc2)
        enc4 = self.encoder4(enc3)
        enc5 = self.encoder5(enc4)
        dec5 = self.decoder5(enc5)
        dec5 = nn.functional.interpolate(dec5, size=enc4.shape[2:], mode='bilinear', align_corners=False)
        dec4 = self.decoder4(torch.cat([dec5, enc4], dim=1))
        dec4 = nn.functional.interpolate(dec4, size=enc3.shape[2:], mode='bilinear', align_corners=False)
        dec3 = self.decoder3(torch.cat([dec4, enc3], dim=1))
        dec3 = nn.functional.interpolate(dec3, size=enc2.shape[2:], mode='bilinear', align_corners=False)
        dec2 = self.decoder2(torch.cat([dec3, enc2], dim=1))
        dec2 = nn.functional.interpolate(dec2, size=enc1.shape[2:], mode='bilinear', align_corners=False)
        dec1 = self.decoder1(torch.cat([dec2, enc1], dim=1))
        output = self.final_layer(dec1)
        return output

In [None]:
def denormalize(image, mean=None, std=None):
    return torch.clamp(image, 0, 1)
def plot_examples(dataset, num_examples=10):
    fig, axes = plt.subplots(num_examples, 2, figsize=(10, 5 * num_examples))
    fig.suptitle('Original Color Image and Grayscale Image')
    for i in range(num_examples):
        idx = random.randint(0, len(dataset) - 1)
        gray_image, color_image = dataset[idx]
        color_image_np = color_image.permute(1, 2, 0).numpy()
        gray_image_np = gray_image.permute(1, 2, 0).numpy()
        axes[i, 0].imshow(color_image_np)
        axes[i, 0].set_title('Original Color Image')
        axes[i, 0].axis('off')
        axes[i, 1].imshow(gray_image_np[..., 0], cmap='gray')
        axes[i, 1].set_title('Grayscale Image')
        axes[i, 1].axis('off')
    plt.tight_layout()
    plt.show()
def plot_predictions(model, dataset, device, num_examples=5):
    model.eval()
    fig, axes = plt.subplots(num_examples, 3, figsize=(15, 5 * num_examples))
    fig.suptitle('Original, Grayscale, and Predicted Images')
    with torch.no_grad():
        for i in range(num_examples):
            idx = random.randint(0, len(dataset) - 1)
            gray_image, color_image = dataset[idx]
            gray_image = gray_image.unsqueeze(0).to(device)
            predicted_image = model(gray_image)
            predicted_image = predicted_image.squeeze(0).cpu()
            color_image_np = color_image.permute(1, 2, 0).numpy()
            predicted_image_np = predicted_image.permute(1, 2, 0).numpy()
            gray_image_np = gray_image.squeeze(0).permute(1, 2, 0).cpu().numpy()
            axes[i, 0].imshow(color_image_np)
            axes[i, 0].set_title('Original Color Image')
            axes[i, 0].axis('off')
            axes[i, 1].imshow(gray_image_np[..., 0], cmap='gray')
            axes[i, 1].set_title('Grayscale Image')
            axes[i, 1].axis('off')
            axes[i, 2].imshow(predicted_image_np)
            axes[i, 2].set_title('Predicted Color Image')
            axes[i, 2].axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
def save_predictions(model, dataset, device, epoch, num_examples=5, save_dir='/kaggle/working'):
    model.eval()
    epoch_dir = os.path.join(save_dir, f'epoch{epoch}')
    os.makedirs(epoch_dir, exist_ok=True)
    with torch.no_grad():
        for i in range(num_examples):
            idx = random.randint(0, len(dataset) - 1)
            gray_image, color_image = dataset[idx]
            gray_image = gray_image.unsqueeze(0).to(device)
            predicted_image = model(gray_image)
            predicted_image = predicted_image.squeeze(0).cpu()
            gray_image_np = gray_image.squeeze(0).permute(1, 2, 0).cpu().numpy()
            predicted_image_np = predicted_image.permute(1, 2, 0).numpy()
            color_image_np = color_image.permute(1, 2, 0).numpy()
            gray_image_path = os.path.join(epoch_dir, f'example_{i+1}_gray.png')
            predicted_image_path = os.path.join(epoch_dir, f'example_{i+1}_predicted.png')
            original_image_path = os.path.join(epoch_dir, f'example_{i+1}_original.png')
            plt.imsave(gray_image_path, gray_image_np[..., 0], cmap='gray')
            plt.imsave(predicted_image_path, predicted_image_np)
            plt.imsave(original_image_path, color_image_np)
    print(f"Saved predictions for epoch {epoch} in {epoch_dir}")

In [None]:
def train_and_validate(model, train_loader, val_loader, criterion, optimizer, num_epochs=25):
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]", leave=False)
        for gray_images, color_images in progress_bar:
            gray_images, color_images = gray_images.to(device), color_images.to(device)
            optimizer.zero_grad()
            outputs = model(gray_images)
            loss = criterion(outputs, color_images)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * gray_images.size(0)
            progress_bar.set_postfix(loss=loss.item())
        epoch_loss = running_loss / len(train_loader.dataset)
        tqdm.write(f'Training Loss: {epoch_loss:.4f}')
        model.eval()
        running_val_loss = 0.0
        progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]", leave=False)
        with torch.no_grad():
            for gray_images, color_images in progress_bar:
                gray_images, color_images = gray_images.to(device), color_images.to(device)
                outputs = model(gray_images)
                loss = criterion(outputs, color_images)
                running_val_loss += loss.item() * gray_images.size(0)
                progress_bar.set_postfix(loss=loss.item())
        epoch_val_loss = running_val_loss / len(val_loader.dataset)
        tqdm.write(f'Validation Loss: {epoch_val_loss:.5f}')
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            torch.save(model.state_dict(), 'bestmodel.pth')
            tqdm.write(f"Best model saved with validation loss: {best_val_loss:.5f}")
        plot_predictions(model, val_set, device, num_examples=5)
        save_predictions(model, val_set, device, epoch, num_examples=5, save_dir='/kaggle/working')
def test(model, test_loader, criterion):
    model.load_state_dict(torch.load('bestmodel.pth'))
    model.eval()
    running_test_loss = 0.0
    progress_bar = tqdm(test_loader, desc="Testing")
    with torch.no_grad():
        for gray_images, color_images in progress_bar:
            gray_images, color_images = gray_images.to(device), color_images.to(device)
            outputs = model(gray_images)
            loss = criterion(outputs, color_images)
            running_test_loss += loss.item() * gray_images.size(0)
            progress_bar.set_postfix(loss=loss.item())
    test_loss = running_test_loss / len(test_loader.dataset)
    print(f'Test Loss: {test_loss:.5f}')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Normalize(mean,std)
    transforms.ToTensor(),
])
train_dataset = CIFAR100ColorizationDataset(folder='/kaggle/input/cifar100/cifar100/train', transform=transform)
test_dataset = CIFAR100ColorizationDataset(folder='/kaggle/input/cifar100/cifar100/test', transform=transform)
train_idx, val_idx = train_test_split(
    list(range(len(train_dataset))), test_size=0.20, stratify=[train_dataset.samples[i].split('/')[-2] for i in range(len(train_dataset))]
)
train_set = Subset(train_dataset, train_idx)
val_set = Subset(train_dataset, val_idx)
print(f'Training set size: {len(train_set)}')
print(f'Validation set size: {len(val_set)}')
print(f'Test set size: {len(test_dataset)}')
batch_size = 64
num_workers = 8
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
model = ColorizationResNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
plot_examples(val_set, num_examples=10)

In [None]:
train_and_validate(model, train_loader, val_loader, criterion, optimizer, num_epochs=80)

In [None]:
test(model, test_loader, criterion)