In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# Device configuration (GPU if available, else CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
num_epochs = 10
batch_size = 32
learning_rate = 0.001

# Custom dataset class for black and white images
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, transform=None):
        self.data = datasets.ImageFolder(root='images', transform=transform)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, _ = self.data[idx]
        return img

# Transformations for the dataset (resize, convert to tensor, normalize)
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize images to 64x64
    transforms.Grayscale(num_output_channels=3),  # Convert images to RGB (3 channels)
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize pixel values to [-1, 1]
])

# Load the dataset and create data loaders
dataset = CustomDataset(transform=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define the colorization model
class ColorizationModel(nn.Module):
    def __init__(self):
        super(ColorizationModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Create the model, optimizer, and loss function
model = ColorizationModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# Training loop
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, images in enumerate(tqdm(data_loader, desc=f'Epoch {epoch+1}/{num_epochs}')):
        images = images.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, images)  # Compare colorized output with original grayscale images
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Display loss information
        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')
            
# Save the trained model
torch.save(model.state_dict(), 'colorization_model.pth')


Epoch 1/10: 100%|█████████████████████████████████| 1/1 [00:00<00:00,  8.83it/s]
Epoch 2/10: 100%|█████████████████████████████████| 1/1 [00:00<00:00, 36.92it/s]
Epoch 3/10: 100%|█████████████████████████████████| 1/1 [00:00<00:00, 36.73it/s]
Epoch 4/10: 100%|█████████████████████████████████| 1/1 [00:00<00:00, 35.81it/s]
Epoch 5/10: 100%|█████████████████████████████████| 1/1 [00:00<00:00, 36.31it/s]
Epoch 6/10: 100%|█████████████████████████████████| 1/1 [00:00<00:00, 37.17it/s]
Epoch 7/10: 100%|█████████████████████████████████| 1/1 [00:00<00:00, 25.59it/s]
Epoch 8/10: 100%|█████████████████████████████████| 1/1 [00:00<00:00, 26.21it/s]
Epoch 9/10: 100%|█████████████████████████████████| 1/1 [00:00<00:00, 38.16it/s]
Epoch 10/10: 100%|████████████████████████████████| 1/1 [00:00<00:00, 36.93it/s]


In [4]:
import torch
import torchvision.transforms as transforms
from PIL import Image

# Load the saved model
model = ColorizationModel()
model.load_state_dict(torch.load('colorization_model.pth'))
model.eval()  # Set the model to evaluation mode

# Define transformations for input images
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize images to 64x64
    transforms.Grayscale(num_output_channels=3),  # Convert images to RGB (3 channels)
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize pixel values to [-1, 1]
])

# Load and preprocess the black and white image
bw_image_path = 'images/class1/butterfly.png'
bw_image = Image.open(bw_image_path).convert('L')  # Convert to grayscale
input_image = transform(bw_image).unsqueeze(0)  # Add batch dimension

# Colorize the image using the model
with torch.no_grad():
    output_image = model(input_image)

# Post-process the colorized image (if needed)
output_image = output_image.squeeze(0)  # Remove batch dimension
output_image = output_image.permute(1, 2, 0)  # Rearrange dimensions for PIL compatibility
output_image = (output_image * 0.5 + 0.5) * 255  # Denormalize pixel values
output_image = output_image.byte().cpu().numpy()  # Convert tensor to NumPy array

# Save or display the colorized image
colorized_image = Image.fromarray(output_image)
colorized_image.save('colorized_image.jpg')
colorized_image.show()
