In [1]:
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from matplotlib import pyplot as plt
from PIL import Image
import random

In [2]:
data_path = "data/raw/cctv_footage/surveillance_cameras_all"
low_res_img_paths = os.listdir(data_path)

idx = 3
img = Image.open(data_path + f"/{low_res_img_paths[idx]}")

transformation = transforms.Compose([
    transforms.Resize((286, 286)),
    transforms.RandomRotation((-15, 15)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomPerspective(0.5),
    transforms.RandomCrop((256, 256)),
    transforms.ToTensor(),
])

transformation(img).shape

torch.Size([3, 256, 256])

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 32
EPOCHS = 3
LR = 0.0003
CHANNELS = 3
IMAGE_SIZE = 256

In [4]:
lowResImagesPath = "data/raw/cctv_footage/surveillance_cameras_all"
lowResImages = os.listdir(lowResImagesPath) # 2860 

highResImagesPath = "data/raw/high_quality_images/mugshot_frontal_original_all"
highResImages = os.listdir(highResImagesPath)  # 130

transformation = transforms.Compose([
    transforms.Resize((286, 286)),
    transforms.RandomRotation((-15, 15)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomPerspective(0.5),
    transforms.RandomCrop((256, 256)),
    transforms.ToTensor(),
])

target_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),
])

class CreateDataset(Dataset):
    def __init__(self, lowResImagesPath, highResImagesPath, transform=None, target_transform=None):
        self.lowResImagesPath = lowResImagesPath
        self.highResImagesPath = highResImagesPath
        self.lowResImages = os.listdir(lowResImagesPath)
        self.highResImages = os.listdir(highResImagesPath)
        
        self.transform = transform
        self.target_transform = target_transform
        
    def __len__(self):
        return len(self.lowResImages)
    
    def __getitem__(self, idx):
        _low_image_name = self.lowResImages[idx]
        _low_res_image_path = os.path.join(self.lowResImagesPath, _low_image_name)
        _subject = _low_image_name.split("_")[0]
        _high_image_name = _subject + "_frontal.jpg"
        _high_image_path = os.path.join(self.highResImagesPath, _high_image_name)
        
        _low_res_img = Image.open(_low_res_image_path)
        _high_res_img = Image.open(_high_image_path)
        if self.transform:
            _low_res_img = self.transform(_low_res_img)
        if self.target_transform:
            _high_res_img = self.target_transform(_high_res_img)
        
        return _low_res_img, _high_res_img

In [5]:
train_dataset = CreateDataset(lowResImagesPath="data/processed/train/low_res",
                              highResImagesPath="data/processed/train/high_res",
                              transform=transformation,
                              target_transform=target_transform)

test_dataset = CreateDataset(lowResImagesPath="data/processed/test/low_res",
                             highResImagesPath="data/processed/test/high_res",
                             transform=transformation,
                             target_transform=target_transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [6]:
next(iter(train_loader))[0].shape

torch.Size([32, 3, 256, 256])

In [7]:
def cnn_block(in_channels,out_channels,kernel_size,stride=1,padding=0, first_layer = False):

   if first_layer:
       return nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding)
   else:
       return nn.Sequential(
           nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding),
           nn.BatchNorm2d(out_channels,momentum=0.1,eps=1e-5),
           )

def tcnn_block(in_channels,out_channels,kernel_size,stride=1,padding=0,output_padding=0, first_layer = False):
   if first_layer:
       return nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,output_padding=output_padding)

   else:
       return nn.Sequential(
           nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,output_padding=output_padding),
           nn.BatchNorm2d(out_channels,momentum=0.1,eps=1e-5),
           )

In [8]:
class Generator(nn.Module):
 def __init__(self,instance_norm=False):#input : 256x256
   super(Generator,self).__init__()
   self.e1 = cnn_block(c_dim,gf_dim,4,2,1, first_layer = True)
   self.e2 = cnn_block(gf_dim,gf_dim*2,4,2,1,)
   self.e3 = cnn_block(gf_dim*2,gf_dim*4,4,2,1,)
   self.e4 = cnn_block(gf_dim*4,gf_dim*8,4,2,1,)
   self.e5 = cnn_block(gf_dim*8,gf_dim*8,4,2,1,)
   self.e6 = cnn_block(gf_dim*8,gf_dim*8,4,2,1,)
   self.e7 = cnn_block(gf_dim*8,gf_dim*8,4,2,1,)
   self.e8 = cnn_block(gf_dim*8,gf_dim*8,4,2,1, first_layer=True)

   self.d1 = tcnn_block(gf_dim*8,gf_dim*8,4,2,1)
   self.d2 = tcnn_block(gf_dim*8*2,gf_dim*8,4,2,1)
   self.d3 = tcnn_block(gf_dim*8*2,gf_dim*8,4,2,1)
   self.d4 = tcnn_block(gf_dim*8*2,gf_dim*8,4,2,1)
   self.d5 = tcnn_block(gf_dim*8*2,gf_dim*4,4,2,1)
   self.d6 = tcnn_block(gf_dim*4*2,gf_dim*2,4,2,1)
   self.d7 = tcnn_block(gf_dim*2*2,gf_dim*1,4,2,1)
   self.d8 = tcnn_block(gf_dim*1*2,c_dim,4,2,1, first_layer = True)#256x256
   self.tanh = nn.Tanh()

 def forward(self,x):
   e1 = self.e1(x)
   e2 = self.e2(nn.LeakyReLU(0.2)(e1))
   e3 = self.e3(nn.LeakyReLU(0.2)(e2))
   e4 = self.e4(nn.LeakyReLU(0.2)(e3))
   e5 = self.e5(nn.LeakyReLU(0.2)(e4))
   e6 = self.e6(nn.LeakyReLU(0.2)(e5))
   e7 = self.e7(nn.LeakyReLU(0.2)(e6))
   e8 = self.e8(nn.LeakyReLU(0.2)(e7))
   d1 = torch.cat([nn.Dropout(0.5)(self.d1(nn.ReLU()(e8))),e7],1)
   d2 = torch.cat([nn.Dropout(0.5)(self.d2(nn.ReLU()(d1))),e6],1)
   d3 = torch.cat([nn.Dropout(0.5)(self.d3(nn.ReLU()(d2))),e5],1)
   d4 = torch.cat([self.d4(nn.ReLU()(d3)),e4],1)
   d5 = torch.cat([self.d5(nn.ReLU()(d4)),e3],1)
   d6 = torch.cat([self.d6(nn.ReLU()(d5)),e2],1)
   d7 = torch.cat([self.d7(nn.ReLU()(d6)),e1],1)
   d8 = self.d8(nn.ReLU()(d7))

   return self.tanh(d8)

In [9]:
class Discriminator(nn.Module):
 def __init__(self,instance_norm=False):#input : 256x256
   super(Discriminator,self).__init__()
   self.conv1 = cnn_block(c_dim*2,df_dim,4,2,1, first_layer=True) # 128x128
   self.conv2 = cnn_block(df_dim,df_dim*2,4,2,1)# 64x64
   self.conv3 = cnn_block(df_dim*2,df_dim*4,4,2,1)# 32 x 32
   self.conv4 = cnn_block(df_dim*4,df_dim*8,4,1,1)# 31 x 31
   self.conv5 = cnn_block(df_dim*8,1,4,1,1, first_layer=True)# 30 x 30

   self.sigmoid = nn.Sigmoid()
 def forward(self, x, y):
   O = torch.cat([x,y],dim=1)
   O = nn.LeakyReLU(0.2)(self.conv1(O))
   O = nn.LeakyReLU(0.2)(self.conv2(O))
   O = nn.LeakyReLU(0.2)(self.conv3(O))
   O = nn.LeakyReLU(0.2)(self.conv4(O))
   O = self.conv5(O)

   return self.sigmoid(O)

In [10]:
# Define parameters
batch_size = 4
workers = 2

epochs = 30

gf_dim = 64
df_dim = 64

L1_lambda = 100.0

in_w = in_h = 256
c_dim = 3

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
low, high = next(iter(train_loader))
low.shape, high.shape

(torch.Size([32, 3, 256, 256]), torch.Size([32, 3, 256, 256]))

In [12]:
g_model = Generator()
op = g_model(low)


In [13]:
op_n = op[0].squeeze().detach().permute(1, 2, 0).type(torch.float).numpy() * 256

op_n.shape

(256, 256, 3)

In [14]:
plt.imshow(op_n)
plt.show()

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


: 

In [None]:
model_g = Generator().to(DEVICE)
model_d = Discriminator().to(DEVICE)

optimizer_g = torch.optim.Adam(model_g.parameters(), lr=2e-4,betas=(0.5,0.999))
optimizer_d = torch.optim.Adam(model_d.parameters(), lr=2e-5,betas=(0.5,0.999))

bce_criterion = nn.BCELoss()
L1_criterion = nn.L1Loss()

In [None]:
for epoch in range(EPOCHS):
    for id, (low_res, high_res) in enumerate(train_loader):
        real_images = high_res.to(DEVICE)

        # train the discriminator
        b_size = low_res.size(0)

        real_labels = torch.ones((b_size, 1, 30, 30)).to(DEVICE)
        fake_labels = torch.zeros((b_size, 1, 30, 30)).to(DEVICE)

        fake_images = model_g(low_res.to(DEVICE))
        real_patch = model_d(low_res.to(DEVICE), real_images)

        fake_patch = model_d(low_res.to(DEVICE), fake_images.detach())

        d_loss_real = bce_criterion(real_patch, real_labels)
        d_loss_fake = bce_criterion(fake_patch, fake_labels)
        d_loss = d_loss_real + d_loss_fake

        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()

        # train the generator
        fake_images = model_g(low_res.to(DEVICE))  # Generate new fake images
        fake_patch = model_d(low_res.to(DEVICE), fake_images)
        
        fake_gan_loss = bce_criterion(fake_patch, real_labels)
        L1_loss = L1_criterion(fake_images, real_images)
        g_loss = fake_gan_loss + L1_lambda * L1_loss

        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()

        print(f"Epoch [{epoch+1}/{EPOCHS}], Batch [{id+1}/{len(train_loader)}], G_loss: {g_loss.item():.4f}, D_loss: {d_loss.item():.4f}")