In [2]:
!pip install gradio

Collecting gradio
  Downloading gradio-5.27.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.9.0 (from gradio)
  Downloading gradio_client-1.9.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6 (

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import os
import matplotlib.pyplot as plt
import numpy as np
import datetime
import gradio as gr
import pickle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
  print(torch.cuda.current_device())
  print(torch.cuda.device(0))
  print(torch.cuda.device_count())
  print(torch.cuda.get_device_name(0))
else:
  print("No NVIDIA driver found. Using CPU")

# Define dataset for historical images
class HistoricalDataset(Dataset):
    def __init__(self, image_dir, transform=None, period_labels=None):
        self.image_dir = image_dir
        self.image_list = os.listdir(image_dir)
        self.transform = transform
        # Dictionary mapping image filenames to period labels (e.g., "1900s", "1950s")
        self.period_labels = period_labels if period_labels else {}

        # Default period mapping (if not provided explicitly)
        self.default_periods = {
            "1850-1900": 0,
            "1900-1940": 1,
            "1940-1970": 2,
            "1970-2000": 3,
            "2000-present": 4
        }

        # Extract year from filename if possible
        if not period_labels:
            for img_name in self.image_list:
                # Try to extract year from filename (assuming format like "1942_battle.jpg")
                try:
                    year_str = ''.join([c for c in img_name.split('_')[0] if c.isdigit()])
                    if year_str and len(year_str) == 4:
                        year = int(year_str)
                        if 1850 <= year < 1900:
                            self.period_labels[img_name] = 0
                        elif 1900 <= year < 1940:
                            self.period_labels[img_name] = 1
                        elif 1940 <= year < 1970:
                            self.period_labels[img_name] = 2
                        elif 1970 <= year < 2000:
                            self.period_labels[img_name] = 3
                        elif 2000 <= year:
                            self.period_labels[img_name] = 4
                except:
                    pass

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        img_path = os.path.join(self.image_dir, img_name)
        color_img = Image.open(img_path).convert("RGB")

        if self.transform:
            color_img = self.transform(color_img)

        # Get period label if available, default to 0 (1850-1900) if not found
        period = self.period_labels.get(img_name, 0)

        return color_img, period

# CIFAR-10 based feature extractor (encoder)
class CIFAR10Encoder(nn.Module):
    def __init__(self):
        super(CIFAR10Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.relu(self.conv3(x))
        return x

# Time Period Classifier Network
class TimePeriodClassifier(nn.Module):
    def __init__(self, num_classes=5):
        super(TimePeriodClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 4 * 4, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, 64 * 4 * 4)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Enhanced Colorization Network with CIFAR-10 pretrained encoder and Era-Specific Pathways
class EnhancedColorizationNet(nn.Module):
    def __init__(self, num_eras=5, pretrained_encoder=None):
        super(EnhancedColorizationNet, self).__init__()

        # The encoder base from CIFAR-10 model
        if pretrained_encoder:
            self.encoder = pretrained_encoder
        else:
            self.encoder = CIFAR10Encoder()

        # Era-specific pathways (decoders)
        self.era_pathways = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
                nn.Sigmoid()
            ) for _ in range(num_eras)
        ])

        # Era classifier branch
        self.era_classifier = TimePeriodClassifier(num_eras)

    def forward(self, x, era_idx=None):
        # Shared feature extraction through the encoder
        x_features = self.encoder(x)

        if era_idx is None:
            # If era_idx is not provided, use the classifier to predict it
            # Create a downsized version for the classifier
            x_small = nn.functional.interpolate(x, size=(32, 32), mode='bilinear', align_corners=True)
            era_logits = self.era_classifier(x_small)
            era_probs = torch.softmax(era_logits, dim=1)
            _, era_idx = torch.max(era_probs, dim=1)

            # Initialize output tensor
            outputs = []
            for i in range(len(self.era_pathways)):
                # Apply each era's colorization pathway
                era_output = self.era_pathways[i](x_features)
                outputs.append(era_output)

            # Stack outputs
            stacked_outputs = torch.stack(outputs, dim=1)

            # Create batch indices
            batch_size = x.size(0)
            batch_indices = torch.arange(batch_size, device=x.device)

            # Get era-specific outputs
            final_output = stacked_outputs[batch_indices, era_idx]

            return final_output, era_idx
        else:
            # If era_idx is provided as an integer, use that specific pathway
            if isinstance(era_idx, int):
                return self.era_pathways[era_idx](x_features), torch.tensor([era_idx], device=x.device)
            # If era_idx is a tensor, use the provided indices
            else:
                outputs = []
                for i in range(len(self.era_pathways)):
                    era_output = self.era_pathways[i](x_features)
                    outputs.append(era_output)

                stacked_outputs = torch.stack(outputs, dim=1)
                batch_size = x.size(0)
                batch_indices = torch.arange(batch_size, device=x.device)

                final_output = stacked_outputs[batch_indices, era_idx]
                return final_output, era_idx

# Helper functions
def rgb_to_gray(img):
    """Convert RGB image to grayscale."""
    return img.mean(dim=1, keepdim=True)

def torch_rgb_to_hsv(rgb):
    """Convert RGB image tensor to HSV format."""
    r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]
    max_val, _ = torch.max(rgb, dim=1)
    min_val, _ = torch.min(rgb, dim=1)
    diff = max_val - min_val

    # Compute H
    h = torch.zeros_like(r)
    mask = (max_val == r) & (g >= b)
    h[mask] = (g[mask] - b[mask]) / diff[mask]
    mask = (max_val == r) & (g < b)
    h[mask] = (g[mask] - b[mask]) / diff[mask] + 6.0
    mask = max_val == g
    h[mask] = (b[mask] - r[mask]) / diff[mask] + 2.0
    mask = max_val == b
    h[mask] = (r[mask] - g[mask]) / diff[mask] + 4.0
    h = h / 6.0
    h[diff == 0.0] = 0.0

    # Compute S
    s = torch.zeros_like(r)
    s[diff != 0.0] = diff[diff != 0.0] / max_val[diff != 0.0]

    # V is just max_val
    v = max_val

    return torch.stack([h, s, v], dim=1)

def torch_hsv_to_rgb(hsv):
    """Convert HSV image tensor to RGB format."""
    h, s, v = hsv[:, 0, :, :], hsv[:, 1, :, :], hsv[:, 2, :, :]
    i = (h * 6.0).floor()
    f = h * 6.0 - i
    p = v * (1.0 - s)
    q = v * (1.0 - s * f)
    t = v * (1.0 - s * (1.0 - f))

    i_mod = i % 6
    r = torch.zeros_like(h)
    g = torch.zeros_like(h)
    b = torch.zeros_like(h)

    r[i_mod == 0.0] = v[i_mod == 0.0]
    g[i_mod == 0.0] = t[i_mod == 0.0]
    b[i_mod == 0.0] = p[i_mod == 0.0]

    r[i_mod == 1.0] = q[i_mod == 1.0]
    g[i_mod == 1.0] = v[i_mod == 1.0]
    b[i_mod == 1.0] = p[i_mod == 1.0]

    r[i_mod == 2.0] = p[i_mod == 2.0]
    g[i_mod == 2.0] = v[i_mod == 2.0]
    b[i_mod == 2.0] = t[i_mod == 2.0]

    r[i_mod == 3.0] = p[i_mod == 3.0]
    g[i_mod == 3.0] = q[i_mod == 3.0]
    b[i_mod == 3.0] = v[i_mod == 3.0]

    r[i_mod == 4.0] = t[i_mod == 4.0]
    g[i_mod == 4.0] = p[i_mod == 4.0]
    b[i_mod == 4.0] = v[i_mod == 4.0]

    r[i_mod == 5.0] = v[i_mod == 5.0]
    g[i_mod == 5.0] = p[i_mod == 5.0]
    b[i_mod == 5.0] = q[i_mod == 5.0]

    return torch.stack([r, g, b], dim=1)

def apply_era_specific_enhancements(images, era_idx):
    """
    Apply era-specific color enhancements

    Parameters:
    - images: tensor of shape (batch_size, 3, height, width)
    - era_idx: tensor of shape (batch_size) with values 0-4 representing the era

    Returns:
    - enhanced_images: tensor with era-specific color adjustments
    """
    # Convert to [0,1] range
    images = torch.clamp(images, 0, 1)

    # Convert to HSV
    images_hsv = torch_rgb_to_hsv(images)
    batch_size = images.size(0)

    # Era-specific enhancements
    # 1850-1900: Sepia-like, brownish tones
    # 1900-1940: Lower saturation, slightly bluish
    # 1940-1970: Kodachrome look - vibrant but distinct palette
    # 1970-2000: Slightly oversaturated with warm tones
    # 2000-present: Modern digital look - balanced, accurate colors

    for i in range(batch_size):
        if era_idx[i] == 0:  # 1850-1900
            # Sepia effect
            images_hsv[i, 0, :, :] = 0.08  # Hue shift toward yellow-brown
            images_hsv[i, 1, :, :] *= 0.7  # Lower saturation
            images_hsv[i, 2, :, :] = torch.clamp(images_hsv[i, 2, :, :] * 0.9, 0, 1)  # Slightly darker

        elif era_idx[i] == 1:  # 1900-1940
            # Early film look
            images_hsv[i, 0, :, :] = torch.clamp(images_hsv[i, 0, :, :] + 0.05, 0, 1)  # Slight hue shift
            images_hsv[i, 1, :, :] *= 0.8  # Lower saturation
            # Add slight blue tint to shadows
            shadow_mask = images_hsv[i, 2, :, :] < 0.4
            images_hsv[i, 0, shadow_mask] = 0.6  # Blue-ish hue

        elif era_idx[i] == 2:  # 1940-1970
            # Kodachrome look
            images_hsv[i, 1, :, :] = torch.clamp(images_hsv[i, 1, :, :] * 1.2, 0, 1)  # More saturation
            # Enhance reds and yellows
            red_mask = (images_hsv[i, 0, :, :] < 0.05) | (images_hsv[i, 0, :, :] > 0.95)
            images_hsv[i, 1, red_mask] = torch.clamp(images_hsv[i, 1, red_mask] * 1.3, 0, 1)

        elif era_idx[i] == 3:  # 1970-2000
            # Film of the late 20th century
            images_hsv[i, 1, :, :] = torch.clamp(images_hsv[i, 1, :, :] * 1.1, 0, 1)  # Slightly increased saturation
            images_hsv[i, 0, :, :] = torch.clamp(images_hsv[i, 0, :, :] - 0.02, 0, 1)  # Slight warm shift

        # For era_idx 4 (2000-present), we keep colors as they are - digital look

    # Convert back to RGB
    enhanced_images = torch_hsv_to_rgb(images_hsv)
    return enhanced_images

# Function to load CIFAR-10 pretrained model
def load_cifar10_pretrained_model(pth_path):
    try:
        # Load the PyTorch model state dict
        cifar_state_dict = torch.load(cifar10_model.pth)

        # Create a model instance to load the state dict into
        cifar_model = CIFAR10Encoder()
        cifar_model.load_state_dict(cifar_state_dict)

        print(f"Successfully loaded CIFAR-10 model from {pth_path}")
        return cifar_model
    except Exception as e:
        print(f"Error loading CIFAR-10 model: {str(e)}")
        return None

# Function to extract encoder from CIFAR-10 model and adapt it
def extract_encoder_from_cifar10(cifar_model):
    try:
        # Create a new encoder model
        encoder = CIFAR10Encoder()

        # Try to map weights from CIFAR-10 model to our encoder
        # This is a placeholder - actual implementation depends on the structure of your pickle file

        # Option 1: If the pickle contains a state_dict
        if hasattr(cifar_model, 'state_dict'):
            # Map relevant layers
            encoder_dict = encoder.state_dict()
            cifar_dict = cifar_model.state_dict()

            # Mapping will depend on the actual structure of your CIFAR-10 model
            # This is a simplified example:
            for name, param in cifar_dict.items():
                if 'conv1' in name and name.replace('conv1', '') in encoder_dict:
                    encoder_dict[name] = param

            encoder.load_state_dict(encoder_dict)
            print("Loaded weights from CIFAR-10 model state_dict")

        # Option 2: If the pickle is the model itself or has a different structure
        elif hasattr(cifar_model, 'conv1'):
            # Transfer weights directly if the architecture is compatible
            encoder.conv1.weight.data = cifar_model.conv1.weight.data
            encoder.conv1.bias.data = cifar_model.conv1.bias.data
            # Similarly for other layers...
            print("Transferred weights directly from CIFAR-10 model attributes")

        else:
            print("CIFAR-10 model structure not recognized, using random initialization")

        return encoder
    except Exception as e:
        print(f"Error adapting CIFAR-10 model: {str(e)}")
        print("Returning an untrained encoder")
        return CIFAR10Encoder()

# Training function for CIFAR-10 dataset (in case we need to train from scratch)
def train_on_cifar10(encoder, num_epochs=40):
    print("Starting CIFAR-10 training...")

    # Data loading for CIFAR-10
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    # Load CIFAR-10 training set
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                            shuffle=True, num_workers=2)

    # Setup a simple classifier on top of the encoder for training
    class CIFAR10Classifier(nn.Module):
        def __init__(self, encoder):
            super(CIFAR10Classifier, self).__init__()
            self.encoder = encoder
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(256, 10)

        def forward(self, x):
            # CIFAR is RGB, our encoder expects grayscale - convert
            x_gray = rgb_to_gray(x)
            features = self.encoder(x_gray)
            x = self.avgpool(features)
            x = torch.flatten(x, 1)
            x = self.fc(x)
            return x

    # Create classifier with the encoder
    classifier = CIFAR10Classifier(encoder).to(device)

    # Set up loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=0.001)

    # Training loop
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = classifier(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 200 == 199:
                print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 200:.3f}')
                running_loss = 0.0

    print('Finished CIFAR-10 Training')

    # Save CIFAR-10 trained model if needed
    torch.save(encoder.state_dict(), 'cifar10_encoder.pth')

    return encoder

# Fine-tuning function
def fine_tune_model(model, train_loader, optimizer, criterion, num_epochs=40, freeze_encoder=True):
    model.train()

    # Optionally freeze the encoder layers
    if freeze_encoder:
        for param in model.encoder.parameters():
            param.requires_grad = False

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (images, period_labels) in enumerate(train_loader):
            grayscale_images = rgb_to_gray(images).to(device)
            images = images.to(device)
            period_labels = period_labels.to(device)

            optimizer.zero_grad()

            # Forward pass with era labels
            outputs, predicted_era = model(grayscale_images, period_labels)
            loss = criterion(outputs, images)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if i % 20 == 19:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/20:.4f}")
                running_loss = 0.0

    print("Finished Fine-tuning")
    return model

# Era labels
era_labels = [
    "1850-1900 (Victorian/Early Photography)",
    "1900-1940 (Early 20th Century)",
    "1940-1970 (Mid-Century)",
    "1970-2000 (Late 20th Century)",
    "2000-present (Digital Era)"
]

# Define the colorization function for Gradio
def colorize_image(input_img, era_selection):
    global model
    model.eval()

    # Map era selection string to index
    era_idx = era_labels.index(era_selection)

    # Convert input image to grayscale
    if input_img is None:
        return None, None, "No image provided"

    # Create grayscale version
    gray_img = Image.fromarray(input_img).convert("L")

    # Transform for model input
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    # Convert to tensor
    gray_tensor = transform(gray_img).unsqueeze(0).to(device)

    # Process with the model
    with torch.no_grad():
        colorized_tensor, predicted_era = model(gray_tensor, era_idx)

        # Apply era-specific enhancements
        enhanced_tensor = apply_era_specific_enhancements(colorized_tensor, predicted_era)

    # Convert tensor back to PIL image and then to numpy array
    colorized_img = transforms.ToPILImage()(enhanced_tensor.squeeze(0).cpu())

    # Resize back to original dimensions
    original_size = (input_img.shape[1], input_img.shape[0])  # Width, Height
    colorized_img = colorized_img.resize(original_size)

    # Convert to numpy array for Gradio
    colorized_array = np.array(colorized_img)

    # Create grayscale version for display
    gray_array = np.array(Image.fromarray(input_img).convert("L").convert("RGB"))

    # Get detected era for display
    if isinstance(predicted_era, torch.Tensor) and len(predicted_era) > 0:
        detected_era_idx = predicted_era[0].item()
        detected_era = f"Detected Era: {era_labels[detected_era_idx]}"
    else:
        detected_era = "Era detection failed"

    return gray_array, colorized_array, detected_era

# Function to initialize dataset and model
def prepare_model_and_dataset(cifar10_model_path=None):
    global model

    # Step 1: Attempt to load the CIFAR-10 pretrained model
    cifar10_model = None
    encoder = None

    if cifar10_model_path:
        cifar10_model = load_cifar10_pretrained_model(cifar10_model_path)
        if cifar10_model:
            # Extract and adapt the encoder from CIFAR-10 model
            encoder = extract_encoder_from_cifar10(cifar10_model)
        else:
            print("Failed to load CIFAR-10 model from pickle, will attempt to train on CIFAR-10")

    # If we couldn't load or adapt the CIFAR-10 model, train one from scratch
    if encoder is None:
        encoder = CIFAR10Encoder()
        try:
            # Train on CIFAR-10
            encoder = train_on_cifar10(encoder)
        except Exception as e:
            print(f"CIFAR-10 training failed: {str(e)}")
            print("Using untrained encoder")

    # Step 2: Create the full model with the encoder
    model = EnhancedColorizationNet(num_eras=5, pretrained_encoder=encoder).to(device)

    # Data transformations for historical dataset
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.ToTensor()
    ])

    # Step 3: Try to load the full model weights if they exist
    model_loaded = False
    try:
        model.load_state_dict(torch.load("era_sensitive_colorization_model.pth"))
        print("Loaded pre-trained full model weights")
        model_loaded = True
    except:
        print("Pre-trained full model not found, will attempt to fine-tune with historical dataset")

    # Step 4: If we couldn't load the full model, fine-tune with historical dataset
    if not model_loaded:
        try:
            # Create dataset (adjust the path as needed)
            train_dataset = HistoricalDataset("images", transform=transform)
            train_size = int(0.8 * len(train_dataset))
            test_size = len(train_dataset) - train_size
            train_data, test_data = random_split(train_dataset, [train_size, test_size])

            train_loader = DataLoader(train_data, batch_size=4, shuffle=True)
            test_loader = DataLoader(test_data, batch_size=4, shuffle=False)

            print(f"Historical dataset loaded with {len(train_dataset)} images")
            print(f"Fine-tuning with {train_size} images, testing with {test_size} images")

            # Fine-tune the model
            criterion = nn.MSELoss()
            # Only train decoder parts initially
            optimizer = optim.Adam(
                list(model.era_pathways.parameters()) + list(model.era_classifier.parameters()),
                lr=0.001
            )
            model = fine_tune_model(model, train_loader, optimizer, criterion, num_epochs=40, freeze_encoder=True)

            # Optional: Unfreeze the encoder for a few more epochs with lower learning rate
            for param in model.encoder.parameters():
                param.requires_grad = True

            optimizer = optim.Adam(model.parameters(), lr=0.0001)
            model = fine_tune_model(model, train_loader, optimizer, criterion, num_epochs=40, freeze_encoder=False)

            # Save the trained model
            torch.save(model.state_dict(), "era_sensitive_colorization_model.pth")
            print("Model fine-tuned and saved to era_sensitive_colorization_model.pth")

        except Exception as e:
            print(f"Could not fine-tune model: {str(e)}")
            print("Using untrained model - results may not be optimal")

    return "Model preparation complete"

# Main execution function
def main():
    # Initialize global model variable
    global model

    # Prepare the model and dataset
    cifar10_model_path = "/content/cifar10_model.pth"  # Changed from .pkl to .pth

    # Prepare the model and dataset with both file paths
    status = prepare_model_and_dataset(
        cifar10_model_path=cifar10_model_path
    )
    print(status)
    # Create Gradio interface
    with gr.Blocks(title="Historical Image Colorization") as demo:
        gr.Markdown("# Era-Sensitive Historical Image Colorization")
        gr.Markdown("Upload a black and white historical photo and select an era for colorization.")

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Upload B&W Image", type="numpy")
                era_dropdown = gr.Dropdown(
                    choices=era_labels,
                    value=era_labels[2],  # Default to Mid-Century
                    label="Select Era for Colorization"
                )
                colorize_btn = gr.Button("Colorize Image")

            with gr.Column():
                grayscale_output = gr.Image(label="Grayscale Image")
                colorized_output = gr.Image(label="Colorized Result")
                era_info = gr.Textbox(label="Era Information")

        # Set up the click event
        colorize_btn.click(
            fn=colorize_image,
            inputs=[input_image, era_dropdown],
            outputs=[grayscale_output, colorized_output, era_info]
        )

        gr.Markdown("""
        ## About This Tool

        This tool uses deep learning to colorize black and white historical photos in a way that's sensitive to different historical time periods:

        - **1850-1900**: Victorian-era sepia tones
        - **1900-1940**: Early 20th century film look
        - **1940-1970**: Kodachrome-inspired mid-century colors
        - **1970-2000**: Late 20th century film photography
        - **2000-present**: Modern digital photography colors

        The model will attempt to detect the appropriate era automatically, but you can also select a specific era for stylistic colorization.
        """)

    # Launch the Gradio interface
    print("Starting Gradio server...")
    demo.launch(share=True)

if __name__ == "__main__":
    main()

0
<torch.cuda.device object at 0x7efaf65bac10>
1
Tesla T4
Error loading CIFAR-10 model: name 'cifar10_model' is not defined
Failed to load CIFAR-10 model from pickle, will attempt to train on CIFAR-10
Starting CIFAR-10 training...
[1, 200] loss: 2.154
[1, 400] loss: 1.859
[1, 600] loss: 1.745
[2, 200] loss: 1.616
[2, 400] loss: 1.574
[2, 600] loss: 1.507
[3, 200] loss: 1.440
[3, 400] loss: 1.399
[3, 600] loss: 1.378
[4, 200] loss: 1.339
[4, 400] loss: 1.311
[4, 600] loss: 1.289
[5, 200] loss: 1.246
[5, 400] loss: 1.235
[5, 600] loss: 1.208
[6, 200] loss: 1.206
[6, 400] loss: 1.171
[6, 600] loss: 1.163
[7, 200] loss: 1.138
[7, 400] loss: 1.134
[7, 600] loss: 1.114
[8, 200] loss: 1.097
[8, 400] loss: 1.063
[8, 600] loss: 1.100
[9, 200] loss: 1.046
[9, 400] loss: 1.045
[9, 600] loss: 1.058
[10, 200] loss: 1.027
[10, 400] loss: 1.022
[10, 600] loss: 0.997
[11, 200] loss: 0.983
[11, 400] loss: 0.968
[11, 600] loss: 0.975
[12, 200] loss: 0.939
[12, 400] loss: 0.946
[12, 600] loss: 0.967
[13,