<a href="https://colab.research.google.com/github/CanKeles5/ColorizeFacesAutoencoder/blob/master/ColorizeFacesAutoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as DS
import torch.optim as optim
import torch.utils.data.sampler
import matplotlib.pyplot as plt

from torch.utils.data import *
from torch.autograd import Variable
from torch.utils.data import DataLoader as DL

In [0]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

In [0]:
tfms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

In [0]:
train_set = DS.ImageFolder(root='../input/celeba-dataset/img_align_celeba', transform=tfms)
train_loader = DL(train_set, batch_size=4, shuffle=True)

print(len(train_set))

In [0]:
subset_indices = range(500)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=4, sampler=SubsetRandomSampler(subset_indices))

print(len(subset_indices))

In [0]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )

class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.dconv_1 = double_conv(1, 16)
        self.dconv_2 = double_conv(16, 32)
        self.dconv_3 = double_conv(32, 64)
        self.dconv_4 = double_conv(64, 128)
        self.dconv_5 = double_conv(128, 256)
        
        self.maxpool = nn.MaxPool2d(2)
        
        self.upconv_4 = double_conv(256, 128)
        self.upconv_3 = double_conv(128, 64)
        self.upconv_2 = double_conv(64, 32)
        self.upconv_1 = double_conv(20, 3)
        
        self.TConv4 = nn.ConvTranspose2d(256, 128, 2, stride=2, padding=0)
        self.TConv3 = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
        self.TConv2 = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
        self.TConv1 = nn.ConvTranspose2d(32, 4, 2, stride=2, padding=0)
        
    def forward(self, x):
        conv1 = self.dconv_1(x)
        x = self.maxpool(conv1)
        
        conv2 = self.dconv_2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_3(x)
        x = self.maxpool(conv3)
        
        conv4 = self.dconv_4(x)
        x = self.maxpool(conv4)
        
        x = self.dconv_5(x)
                
        x = self.TConv4(x)
        x = torch.cat([x, conv4], dim=1)
        x = self.upconv_4(x)
                
        x = self.TConv3(x)
        x = torch.cat([x, conv3], dim=1)
        x = self.upconv_3(x)
                
        x = self.TConv2(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.upconv_2(x)
                
        x = self.TConv1(x)
        x = torch.cat([x, conv1], dim=1)
        x = self.upconv_1(x)
        
        return x

In [0]:
model = AutoEncoder()

model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

In [0]:
sum([p.numel() for p in model.parameters()])

In [0]:
model.train()

n_epochs = 100

for epoch in range(n_epochs):
    running_loss = 0.0
    for i, (X, _) in enumerate(train_loader):
        y = X
        X = (0.2989*X[:,0,:,:] + 0.5870*X[:,1,:,:] + 0.1140*X[:,2,:,:]) #RGB to grayscale
        
        X = X.unsqueeze(1)
        
        X = Variable(X.to(device))
        y = Variable(y.to(device))
        
        X *= 255
        
        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output, y)
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
    print("loss for epoch " + str(epoch) + ": " + str(running_loss/500))

In [0]:
model.eval()

for i in range(5):
    im, _ = train_set[i]
    orig = im
    
    im = im.unsqueeze(0)
    im = (0.2989*im[:,0,:,:] + 0.5870*im[:,1,:,:] + 0.1140*im[:,2,:,:])
    im = im.unsqueeze(0)
    
    im = im.to(device)
    im *= 255
    
    output = model(im)
    output = output.squeeze(0)
    
    output = output.clamp(0.0, 1.0)
    
    PIL_img = transforms.ToPILImage()(output.detach().cpu())
    PIL_img = PIL_img.save(str(i) + ".jpg")