In [3]:
import os
import json
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
from datasets import load_dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ================== Dataset Preparation ==================
transform_grayscale = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])
transform_colorization_target = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])
transform_hr = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()])

class CombinedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_split):
        self.dataset_split = dataset_split

    def __len__(self):
        return len(self.dataset_split)

    def __getitem__(self, idx):
        image = self.dataset_split[idx]["image"].convert("RGB")
        grayscale_image = image.convert("L")  # Grayscale
        colorization_target = transform_colorization_target(image)  # Downsampled for colorization
        hr_image = transform_hr(image)  # High-Resolution
        return transform_grayscale(grayscale_image), colorization_target, hr_image

# ================== Model Definitions ==================
class UNet(nn.Module):
    def __init__(self, output_size=(64, 64)):
        super(UNet, self).__init__()
        self.output_size = output_size
        self.encoder1 = self.conv_block(1, 32)
        self.encoder2 = self.conv_block(32, 64)
        self.encoder3 = self.conv_block(64, 128)
        self.encoder4 = self.conv_block(128, 256)
        self.bottleneck = self.conv_block(256, 512)
        self.decoder4 = self.upconv_block(512, 256)
        self.decoder3 = self.upconv_block(256, 128)
        self.decoder2 = self.upconv_block(128, 64)
        self.decoder1 = self.upconv_block(64, 32)
        self.final_conv = nn.Conv2d(32, 3, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
            nn.MaxPool2d(2, 2, ceil_mode=True)
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(enc1)
        enc3 = self.encoder3(enc2)
        enc4 = self.encoder4(enc3)
        bottleneck = self.bottleneck(enc4)
        dec4 = self.decoder4(bottleneck) + enc4
        dec3 = self.decoder3(dec4) + enc3
        dec2 = self.decoder2(dec3) + enc2
        dec1 = self.decoder1(dec2) + enc1
        output = self.final_conv(dec1)
        return torch.sigmoid(F.interpolate(output, size=self.output_size, mode="bilinear", align_corners=False))

class ResidualUNet(UNet):
    def __init__(self, input_channels=3, output_size=(128, 128)):
        super().__init__(output_size=output_size)
        self.encoder1 = self.conv_block(input_channels, 32)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(enc1)
        enc3 = self.encoder3(enc2)
        enc4 = self.encoder4(enc3)
        bottleneck = self.bottleneck(enc4)
        dec4 = self.decoder4(bottleneck) + enc4
        dec3 = self.decoder3(dec4) + enc3
        dec2 = self.decoder2(dec3) + enc2
        dec1 = self.decoder1(dec2) + enc1
        output = self.final_conv(dec1)
        output = F.interpolate(output, size=self.output_size, mode="bilinear", align_corners=False)
        x_resized = F.interpolate(x, size=output.shape[2:], mode="bilinear", align_corners=False)
        return torch.sigmoid(output) + x_resized

# ================== Helper Functions ==================
def calculate_ssim(pred, target, win_size=3):
    pred = pred.permute(0, 2, 3, 1).cpu().numpy()
    target = target.permute(0, 2, 3, 1).cpu().numpy()
    ssim_scores = [ssim(target[i], pred[i], multichannel=True, data_range=1.0, win_size=win_size) for i in range(pred.shape[0])]
    return np.mean(ssim_scores)

def visualize_results(input_image, predicted_output, target_output, hr_image, save_path='', is_colorization=True):
    """
    Visualize and save results of the training/testing phase.
    Args:
        input_image: Tensor of the input image (grayscale or low-resolution).
        predicted_output: Tensor of the model's predicted output.
        target_output: Tensor of the target output.
        hr_image: Tensor of the high-resolution ground truth.
        save_path: Path to save the visualization.
        is_colorization: Boolean indicating if the task is colorization or super-resolution.
    """
    # Squeeze batch dimensions for visualization
    input_image = input_image[0] if input_image.dim() == 4 else input_image
    predicted_output = predicted_output[0] if predicted_output.dim() == 4 else predicted_output
    target_output = target_output[0] if target_output.dim() == 4 else target_output
    hr_image = hr_image[0] if hr_image.dim() == 4 else hr_image

    # Convert tensors to NumPy arrays
    input_np = input_image.squeeze(0).cpu().numpy() if is_colorization else input_image.permute(1, 2, 0).cpu().numpy()
    predicted_np = predicted_output.permute(1, 2, 0).cpu().numpy()
    target_np = target_output.permute(1, 2, 0).cpu().numpy()
    hr_np = hr_image.permute(1, 2, 0).cpu().numpy()

    # Plot and save visualizations
    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    if is_colorization:
        axs[0].imshow(input_np, cmap='gray')
        axs[0].set_title("Grayscale Input")
    else:
        axs[0].imshow(input_np)
        axs[0].set_title("Low-Resolution Input")

    axs[1].imshow(predicted_np)
    axs[1].set_title("Predicted Output")
    axs[2].imshow(target_np)
    axs[2].set_title("Target Output")
    axs[3].imshow(hr_np)
    axs[3].set_title("Ground Truth High-Res")
    plt.savefig(save_path)
    plt.close()



# ================== Training ==================
def train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs, model_name, is_colorization=True):
    os.makedirs(f"results/visualizations/{model_name}", exist_ok=True)
    os.makedirs(f"results/json", exist_ok=True)

    results = {"epochs": []}

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        model.train()
        running_loss = 0.0
        batch_count = 0  # Counter for batches

        for batch in tqdm(train_loader, desc="Training"):
            batch_count += 1
            inputs = batch[0].to(device) if is_colorization else batch[1].to(device)
            targets = batch[1].to(device) if is_colorization else batch[2].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        train_loss = running_loss / batch_count  # Use batch_count instead of len(train_loader)

        # Validation Phase
        model.eval()
        val_loss, val_ssim = 0.0, 0.0
        with torch.no_grad():
            for grayscale, colorization_target, hr_image in tqdm(val_loader, desc="Validating"):
                inputs = grayscale.to(device) if is_colorization else colorization_target.to(device)
                targets = colorization_target.to(device) if is_colorization else hr_image.to(device)
                outputs = model(inputs)
                val_loss += criterion(outputs, targets).item()
                val_ssim += calculate_ssim(outputs, targets)

        val_loss /= len(val_loader)
        val_ssim /= len(val_loader)

        print(f"Epoch {epoch + 1}/{num_epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val SSIM: {val_ssim:.4f}")

        results["epochs"].append({"epoch": epoch + 1, "train_loss": train_loss, "val_loss": val_loss, "val_ssim": val_ssim})

        # Save visualizations for one batch
        visualize_results(
            inputs[0].cpu(),
            outputs[0].cpu(),
            targets[0].cpu(),
            hr_image[0].cpu(),
            save_path=f"results/visualizations/{model_name}/epoch_{epoch + 1}.png",
            is_colorization=is_colorization,
        )

    # Save results to JSON
    with open(f"results/json/{model_name}_metrics.json", "w") as f:
        json.dump(results, f, indent=4)

# ================== Testing ==================
def test_model(model, test_loader, criterion, model_name, is_colorization=True):
    model.eval()
    test_loss, test_ssim = 0.0, 0.0
    os.makedirs(f"results/test_visualizations/{model_name}", exist_ok=True)

    with torch.no_grad():
        for idx, (grayscale, colorization_target, hr_image) in enumerate(tqdm(test_loader, desc="Testing")):
            inputs = grayscale.to(device) if is_colorization else colorization_target.to(device)
            targets = colorization_target.to(device) if is_colorization else hr_image.to(device)
            outputs = model(inputs)
            test_loss += criterion(outputs, targets).item()
            test_ssim += calculate_ssim(outputs, targets)

            if idx < 5:  # Save visualizations for the first 5 batches
                save_path = f"results/test_visualizations/{model_name}/batch_{idx + 1}.png"
                visualize_results(inputs[0].cpu(), outputs[0].cpu(), targets[0].cpu(), hr_image[0].cpu(), save_path, is_colorization=is_colorization)

    test_loss /= len(test_loader)
    test_ssim /= len(test_loader)
    print(f"Test Loss: {test_loss:.4f}, Test SSIM: {test_ssim:.4f}")

    # Save test metrics
    test_metrics = {"test_loss": test_loss, "test_ssim": test_ssim}
    with open(f"results/json/{model_name}_test_results.json", "w") as f:
        json.dump(test_metrics, f, indent=4)

# ================== Main ==================
if __name__ == "__main__":
    dataset = load_dataset("zh-plus/tiny-imagenet")["train"]
    train_size, val_size = int(0.8 * len(dataset)), int(0.1 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size])

    train_loader = DataLoader(CombinedDataset(train_data), batch_size=32, shuffle=True)
    val_loader = DataLoader(CombinedDataset(val_data), batch_size=32, shuffle=False)
    test_loader = DataLoader(CombinedDataset(test_data), batch_size=32, shuffle=False)

    # Train Colorization Model
    colorization_model = UNet(output_size=(64, 64)).to(device)
    optimizer = torch.optim.AdamW(colorization_model.parameters(), lr=1e-4, weight_decay=1e-5)
    criterion = nn.MSELoss()
    train_model(colorization_model, train_loader, val_loader, optimizer, criterion, num_epochs=10, model_name="colorization", is_colorization=True)

    # Super-Resolution Training Data Generator
    def super_resolution_train_loader():
      for grayscale, _, hr_image in train_loader:

        with torch.no_grad():
            colorized_output = colorization_model(grayscale.to(device)).cpu()
        yield colorized_output, colorized_output, hr_image  # Provide a consistent tuple


    # Train Super-Resolution Model
    sr_model = ResidualUNet(input_channels=3, output_size=(128, 128)).to(device)
    optimizer = torch.optim.AdamW(sr_model.parameters(), lr=1e-4, weight_decay=1e-5)
    train_model(sr_model, super_resolution_train_loader(), val_loader, optimizer, criterion, num_epochs=10, model_name="super_resolution", is_colorization=False)

    # Testing
    print("Testing Colorization Model...")
    test_model(colorization_model, test_loader, criterion, model_name="colorization", is_colorization=True)
    print("Testing Super-Resolution Model...")
    test_model(sr_model, test_loader, criterion, model_name="super_resolution", is_colorization=False)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/3.90k [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/3.52k [00:00<?, ?B/s]

(…)-00000-of-00001-1359597a978bc4fa.parquet:   0%|          | 0.00/146M [00:00<?, ?B/s]

(…)-00000-of-00001-70d52db3c749a935.parquet:   0%|          | 0.00/14.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/100000 [00:00<?, ? examples/s]

Generating valid split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Epoch 1/10


Training: 100%|██████████| 2500/2500 [02:38<00:00, 15.74it/s]
Validating: 100%|██████████| 313/313 [00:28<00:00, 10.81it/s]


Epoch 1/10 - Train Loss: 0.0184, Val Loss: 0.0152, Val SSIM: 0.4594
Epoch 2/10


Training: 100%|██████████| 2500/2500 [02:36<00:00, 16.01it/s]
Validating: 100%|██████████| 313/313 [00:29<00:00, 10.59it/s]


Epoch 2/10 - Train Loss: 0.0154, Val Loss: 0.0147, Val SSIM: 0.4816
Epoch 3/10


Training: 100%|██████████| 2500/2500 [02:36<00:00, 15.96it/s]
Validating: 100%|██████████| 313/313 [00:28<00:00, 10.85it/s]


Epoch 3/10 - Train Loss: 0.0149, Val Loss: 0.0143, Val SSIM: 0.4971
Epoch 4/10


Training: 100%|██████████| 2500/2500 [02:37<00:00, 15.87it/s]
Validating: 100%|██████████| 313/313 [00:29<00:00, 10.75it/s]


Epoch 4/10 - Train Loss: 0.0144, Val Loss: 0.0141, Val SSIM: 0.4984
Epoch 5/10


Training: 100%|██████████| 2500/2500 [02:37<00:00, 15.89it/s]
Validating: 100%|██████████| 313/313 [00:29<00:00, 10.70it/s]


Epoch 5/10 - Train Loss: 0.0141, Val Loss: 0.0142, Val SSIM: 0.4986
Epoch 6/10


Training: 100%|██████████| 2500/2500 [02:36<00:00, 15.98it/s]
Validating: 100%|██████████| 313/313 [00:29<00:00, 10.47it/s]


Epoch 6/10 - Train Loss: 0.0137, Val Loss: 0.0140, Val SSIM: 0.5054
Epoch 7/10


Training: 100%|██████████| 2500/2500 [02:35<00:00, 16.09it/s]
Validating: 100%|██████████| 313/313 [00:28<00:00, 10.87it/s]


Epoch 7/10 - Train Loss: 0.0132, Val Loss: 0.0141, Val SSIM: 0.5070
Epoch 8/10


Training: 100%|██████████| 2500/2500 [02:36<00:00, 16.00it/s]
Validating: 100%|██████████| 313/313 [00:28<00:00, 10.92it/s]


Epoch 8/10 - Train Loss: 0.0126, Val Loss: 0.0143, Val SSIM: 0.5006
Epoch 9/10


Training: 100%|██████████| 2500/2500 [02:35<00:00, 16.06it/s]
Validating: 100%|██████████| 313/313 [00:29<00:00, 10.58it/s]


Epoch 9/10 - Train Loss: 0.0122, Val Loss: 0.0143, Val SSIM: 0.5017
Epoch 10/10


Training: 100%|██████████| 2500/2500 [02:35<00:00, 16.12it/s]
Validating: 100%|██████████| 313/313 [00:29<00:00, 10.77it/s]


Epoch 10/10 - Train Loss: 0.0117, Val Loss: 0.0144, Val SSIM: 0.4928
Epoch 1/10


Training: 2500it [02:50, 14.64it/s]
Validating: 100%|██████████| 313/313 [01:09<00:00,  4.53it/s]


Epoch 1/10 - Train Loss: 0.0488, Val Loss: 0.0022, Val SSIM: 0.9471
Epoch 2/10


Training: 0it [00:00, ?it/s]


ZeroDivisionError: float division by zero