### TODO:

- ~~Create git repo~~
- Neural Network Architecture to solve image colourization problem
- Data Augmentation
- Cross-validation
- Testing a few optimizers
- Testing various loss funtions

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

# Check if GPU is available and set PyTorch to use it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class ColorizationDataset(Dataset):
    def __init__(self, grayscale_dir, colorized_dir, transform=None):
        self.grayscale_dir = grayscale_dir
        self.colorized_dir = colorized_dir
        self.grayscale_images = os.listdir(grayscale_dir)
        self.colorized_images = os.listdir(colorized_dir)
        self.transform = transform

    def __len__(self):
        return len(self.grayscale_images)

    def __getitem__(self, idx):
        grayscale_path = os.path.join(self.grayscale_dir, self.grayscale_images[idx])
        colorized_path = os.path.join(self.colorized_dir, self.colorized_images[idx])
        
        grayscale_image = Image.open(grayscale_path).convert('L')
        colorized_image = Image.open(colorized_path).convert('RGB')
        
        if self.transform:
            grayscale_image = self.transform(grayscale_image)
            colorized_image = self.transform(colorized_image)
        
        return grayscale_image, colorized_image

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

train_dataset = ColorizationDataset('cv_p3_images_split/train/grayscale/', 'cv_p3_images_split/train/colored/', transform=transform)
val_dataset = ColorizationDataset('cv_p3_images_split/validation/grayscale/', 'cv_p3_images_split/validation/colored/', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

class ColorizationNet(nn.Module):
    def __init__(self):
        super(ColorizationNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = ColorizationNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for grayscale_images, colorized_images in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        grayscale_images = grayscale_images.to(device)
        colorized_images = colorized_images.to(device)
        
        optimizer.zero_grad()
        outputs = model(grayscale_images)
        loss = criterion(outputs, colorized_images)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * grayscale_images.size(0)
    
    train_loss /= len(train_loader.dataset)
    
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for grayscale_images, colorized_images in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            grayscale_images = grayscale_images.to(device)
            colorized_images = colorized_images.to(device)
            
            outputs = model(grayscale_images)
            loss = criterion(outputs, colorized_images)
            
            val_loss += loss.item() * grayscale_images.size(0)
    
    val_loss /= len(val_loader.dataset)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

torch.save(model.state_dict(), 'colorization_model.pth')


Using device: cuda


Epoch 1/10 - Training:   2%|▏         | 3/185 [03:49<3:50:27, 75.98s/it]