In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
import os

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
os.makedirs('Gray-to-RGB-Outputs', exist_ok=True)
os.makedirs('DataSet/CIFAR10', exist_ok=True)

In [4]:
transform = transforms.Compose([
	transforms.Resize(64),
	transforms.CenterCrop(64),
	transforms.ToTensor(),
	transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [5]:
data = datasets.CIFAR10(root='./DataSet/CIFAR10', train=True, download=True, transform=transform)
data_loader = DataLoader(data, batch_size=64, shuffle=True)

In [6]:
for img, _ in data_loader:
	print(img.shape)
	break

torch.Size([64, 3, 64, 64])


In [7]:
def rgb2gray(x):
	# return 0.2989 * x[:,0:1,:,:] + 0.5870 * x[:,1:2,:,:] + 0.1140 * x[:,2:3,:,:]
	return transforms.functional.rgb_to_grayscale(x)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder (downsampling)
        self.encoder = nn.Sequential(
            # Input: 1 x 64 x 64 (Gray Img)
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),  # 64 x 32 x 32
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 128 x 16 x 16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 256 x 8 x 8
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),  # 512 x 4 x 4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Decoder (upsampling)
        self.decoder = nn.Sequential(
            # 512 x 4 x 4
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # 256 x 8 x 8
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 128 x 16 x 16
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 64 x 32 x 32
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),     # 3 x 64 x 64 
            nn.Tanh()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [9]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Input: 4 channels (1 grayscale + 3 RGB) x 64 x 64
        self.model = nn.Sequential(
            nn.Conv2d(4, 64, kernel_size=4, stride=2, padding=1),   # 64 x 32 x 32
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 128 x 16 x 16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 256 x 8 x 8
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), # 512 x 4 x 4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),   # 1 x 1 x 1
            nn.Sigmoid()
        )
        
    def forward(self, gray, color):
        # Concatenate grayscale and color images along channel dimension
        x = torch.cat([gray, color], dim=1)  # Shape: [batch, 4, 64, 64]
        output = self.model(x)
        return output.view(output.size(0), -1)  # Flatten for BCELoss


In [10]:
G = Generator().to(device)
D = Discriminator().to(device)

criterion_GAN = nn.BCELoss()
criterion_L1 = nn.L1Loss()

optimizer_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [12]:
EPOCHS = 10
lambda_L1 = 100  # Weight for L1 loss

print("Starting training...")

for epoch in range(EPOCHS):
    for i, (imgs, _) in enumerate(data_loader):
        batch_size = imgs.size(0)
        
        # Move to device
        real_imgs = imgs.to(device)
        gray_imgs = rgb2gray(real_imgs).to(device)
        
        # Labels for real and fake
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        #  Train Discriminator
        optimizer_D.zero_grad()
        
        # Generate fake colored images
        fake_imgs = G(gray_imgs)
        
        # Real images
        real_output = D(gray_imgs, real_imgs)
        real_loss = criterion_GAN(real_output, real_labels)
        
        # Fake images (detached to avoid training generator)
        fake_output = D(gray_imgs, fake_imgs.detach())
        fake_loss = criterion_GAN(fake_output, fake_labels)
        
        # Total discriminator loss
        d_loss = (real_loss + fake_loss) / 2
        
        d_loss.backward()
        optimizer_D.step()
        
        #  Train Generator
        optimizer_G.zero_grad()
        
        # Generator tries to fool discriminator
        fake_output = D(gray_imgs, fake_imgs)
        gan_loss = criterion_GAN(fake_output, real_labels)
        
        # L1 loss for pixel-wise similarity
        l1_loss = criterion_L1(fake_imgs, real_imgs)
        
        # Total generator loss
        g_loss = gan_loss + lambda_L1 * l1_loss
        
        g_loss.backward()
        optimizer_G.step()
        
        # Print statistics
        # if i % 100 == 0:
    print(f"Epoch [{epoch+1}/{EPOCHS}], Step [{i}/{len(data_loader)}], "
          f"D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}, "
          f"GAN_loss: {gan_loss.item():.4f}, L1_loss: {l1_loss.item():.4f}")
    
    # Save sample images
    with torch.no_grad():
        # Take first 25 images from the last batch
        sample_gray = gray_imgs[:25]
        sample_fake = G(sample_gray)
        sample_real = real_imgs[:25]
        
        # Denormalize for saving
        sample_fake = (sample_fake + 1) / 2  # Convert from [-1,1] to [0,1]
        sample_real = (sample_real + 1) / 2
        sample_gray = (sample_gray + 1) / 2
        
        # Save images
        save_image(sample_gray, f"Gray-to-RGB-Outputs/epoch_{epoch+1}_gray.png", nrow=5, normalize=False)
        save_image(sample_fake, f"Gray-to-RGB-Outputs/epoch_{epoch+1}_fake.png", nrow=5, normalize=False)
        save_image(sample_real, f"Gray-to-RGB-Outputs/epoch_{epoch+1}_real.png", nrow=5, normalize=False)
    
    print(f"Epoch {epoch+1}/{EPOCHS} completed!")

print("Training completed!")


Starting training...
Epoch [1/10], Step [781/782], D_loss: 0.2090, G_loss: 18.0711, GAN_loss: 2.3404, L1_loss: 0.1573
Epoch 1/10 completed!
Epoch [2/10], Step [781/782], D_loss: 0.0002, G_loss: 28.0465, GAN_loss: 10.5507, L1_loss: 0.1750
Epoch 2/10 completed!
Epoch [3/10], Step [781/782], D_loss: 0.0009, G_loss: 24.2854, GAN_loss: 7.0287, L1_loss: 0.1726
Epoch 3/10 completed!
Epoch [4/10], Step [781/782], D_loss: 0.3760, G_loss: 15.3767, GAN_loss: 3.0372, L1_loss: 0.1234
Epoch 4/10 completed!
Epoch [5/10], Step [781/782], D_loss: 0.0001, G_loss: 28.2118, GAN_loss: 10.9626, L1_loss: 0.1725
Epoch 5/10 completed!
Epoch [6/10], Step [781/782], D_loss: 0.0000, G_loss: 30.8276, GAN_loss: 11.6354, L1_loss: 0.1919
Epoch 6/10 completed!
Epoch [7/10], Step [781/782], D_loss: 0.0002, G_loss: 30.4374, GAN_loss: 13.4866, L1_loss: 0.1695
Epoch 7/10 completed!
Epoch [8/10], Step [781/782], D_loss: 0.0000, G_loss: 29.3135, GAN_loss: 11.6819, L1_loss: 0.1763
Epoch 8/10 completed!
Epoch [9/10], Step [78

In [None]:
# saveing model
# torch.save(G.state_dict(), 'generator.pth')
# torch.save(D.state_dict(), 'discriminator.pth')