In [None]:
#Things to do:
#Locally:
# - create new notebook for end-to-end predictions
# - use best of n predictions to generate final
# - look into applying poisson blending
# - Use argparse to take inputs (input filepath, output filepath, best of n predictions, etc)
# - run code on updated models and put it into a formatted python file
# - write documentation and put model diagrams on github



#On server:
# - model neds to be trained on cropped celebA images (use 128 img size then crop to 64)
# - Fine tune model on faces from the dataset

In [77]:
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import os
import time
import logging
from PIL import Image, ImageDraw
from facenet_pytorch import MTCNN
import imageio

logging.root.setLevel(logging.NOTSET)

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--repo-dir', default="c:/Users/James/git/de-identification", help='Path to github repository')
parser.add_argument('--image-path', default="c:/Users/James/git/de-identification/downloaded-data/num-faces/train/image_data/16070.jpg", help='Path to github repository')
parser.add_argument('--save-folder', default="c:/Users/James/git/de-identification/dev-notebooks/16070", help='Folder to save results in')
parser.add_argument('--border-factor', default=0.2, help='Width of border used for context in infilling GAN generation')
parser.add_argument('--progress-images', default=True, help='Save progress images & gif in folder')
parser.add_argument('--model-version', default='cropped-10-epochs', help='Version of model to use')
parser.add_argument('--lr', default=0.002, help='learning rate to use in training')
parser.add_argument('--lam', default=0.1, help='perceptual loss factor')
parser.add_argument('--iterations', default=500, help='Number of iterations to train for')
parser.add_argument('--eval-interval', default=25, help='Number of iterations between evaluation')
parser.add_argument('--best-of-n', default=5, help='Number of best predictions to use for final prediction')

args = parser.parse_args("")
repo_dir = args.repo_dir
image_path = args.image_path
save_folder = args.save_folder
border_factor = args.border_factor
progress_images = args.progress_images
model_version = args.model_version
lr = args.lr
lam = args.lam
iterations = args.iterations
eval_interval = args.eval_interval
best_of_n = args.best_of_n

if not os.path.exists(save_folder):
    os.makedirs(save_folder)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device.type == "cpu":
    logging.warning(" No GPU detected, using CPU instead.")

mtcnn = MTCNN(keep_all=True, device=device) #face detection model

def generate_boxes(img, threshold=0.7):
    all_boxes, probs, landmarks = mtcnn.detect(img.copy(), landmarks=True)
    if all_boxes is None: return []
    all_boxes = [[int(x) for x in box] for box in all_boxes] 
    #gives box values outside of image, e.g. [-6, 135, 69, 229],
    all_boxes = [[max(0, box[0]), max(0, box[1]), box[2], box[3]] for box in all_boxes]
    boxes = []
    for box, prob in zip(all_boxes, probs):
        # width = box[2] - box[0]
        # height = box[3] - box[1]
        #minimum requirement??
        if prob >= threshold:
            boxes.append(box)
        else:
            print(box, prob)
    return boxes

def draw_boxes(img, boxes, masks=None):
    """Draws boxes (& masks if wanted) on image)"""
    frame_draw = img.copy()
    draw = ImageDraw.Draw(frame_draw)
    for box in boxes:
        colour = (255, 0, 0) if len(box) == 4 else box[4]
        draw.rectangle(box[:4], outline=colour, width=3) # box = (x1, y1, x2, y2)

    if masks is not None:
        for mask in masks:
            draw.rectangle(mask[:4], fill=(255,255,255))
    return frame_draw

def crop_face(face, x, y):
    """Turns image into a square by cropping"""
    height, width, _ = face.shape
    if height > width:
        diff = height - width
        top_crop = diff // 2
        bottom_crop = diff - top_crop
        face = face[top_crop:-bottom_crop, :]
        y+=top_crop
    elif width > height:
        diff = width - height
        left_crop = diff // 2
        right_crop = diff - left_crop
        face = face[:, left_crop:-right_crop]
        x+=left_crop

    assert face.shape[0] == face.shape[1], "Face is not square"
    return face, x, y

img = Image.open(image_path)
boxes = generate_boxes(img)
np_img = np.array(img)

faces = [] # [[square_face, x, y], ...]

squares = []
masks = []
for box in boxes:
    x1, y1, x2, y2 = box
    face = np_img[y1:y2, x1:x2]
    square_face, x, y = crop_face(face, x1, y1)
    faces.append([square_face, x, y])
    #4 lines below are only for masking visualisation
    square_size = square_face.shape[0]
    squares.append([x, y, x+square_size, y+square_size, (0, 255, 0)])
    square_border = int(square_size * border_factor)
    masks.append([x+square_border, y+square_border, x+square_size-square_border, y+square_size-square_border])

if progress_images:
    img.save(os.path.join(save_folder, "1-original.jpg"))
    draw_boxes(img, boxes).save(os.path.join(save_folder, "2-boxes.jpg"))
    draw_boxes(img, boxes + squares).save(os.path.join(save_folder, "3-boxes_squares.jpg"))
    draw_boxes(img, boxes + squares, masks).save(os.path.join(save_folder, "4-boxes_squares_masks.jpg"))
    logging.info(f"Saved progress images in {save_folder}")

######### MODELS #########

### Hyperparameters
workers = 2
image_size = 64 # use 128 but only generates central 64x64
border = int(image_size * border_factor)
ngpu = 1 # Number of GPUs available. 
nc = 3 # Number of channels in the training images. 
nz = 100 # Size of z latent vector (i.e. size of generator input)
ngf = 64 # Size of feature maps in generator
ndf = 64 # Size of feature maps in discriminator

# Generator Code
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is ``(nc) x 64 x 64``
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf) x 32 x 32``
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*2) x 16 x 16``
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*4) x 8 x 8``
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*8) x 4 x 4``
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)
    
#load models
g_path = os.path.join(repo_dir, f"pretrained-models/DCGAN-{model_version}-netG.pth")
d_path = os.path.join(repo_dir, f"pretrained-models/DCGAN-{model_version}-netD.pth")

netG_loaded = Generator(ngpu).to(device)
netG_loaded.load_state_dict(torch.load(g_path, map_location=device))
netG_loaded.eval()

netD_loaded = Discriminator(ngpu).to(device)
netD_loaded.load_state_dict(torch.load(d_path, map_location=device))
netD_loaded.eval()

logging.info(" Model loaded")

criterion = nn.BCELoss()

img_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

def display_img(transformed_img):
    img_grid = vutils.make_grid(transformed_img, padding=2, normalize=True).permute(1, 2, 0)
    return Image.fromarray((img_grid.numpy() * 255).astype(np.uint8))

transformed_images = []
for face in faces:
    transformed_images.append(img_transforms(Image.fromarray(face[0])))

images = torch.stack(transformed_images, dim = 0)
# display_img(transformed_images)

zhats = torch.randn(images.shape[0], nz, 1, 1, device=device).requires_grad_()
mask = torch.ones((images.shape[0], 3, image_size, image_size)).to(device)
mask[:, :, border:-border, border:-border] = 0
results = [netG_loaded(zhats).clone()]


logging.info(" Model training")
optimizer = optim.Adam([zhats], lr=lr)
t_start = time.time()
for i in range(iterations):
    generated = netG_loaded(zhats)
    contextual_loss = nn.functional.l1_loss(mask*generated, mask*images) # keep outside obscured region the same

    real_label = torch.full((images.shape[0],), 1., dtype=torch.float, device=device)
    output = netD_loaded(generated.detach()).view(-1)
    perceptual_loss = criterion(output, real_label) #g_loss

    complete_loss = contextual_loss + lam*perceptual_loss
    optimizer.zero_grad()
    complete_loss.backward()
    optimizer.step()

    if i % eval_interval == eval_interval-1:
        logging.info(f" i: [{i}/{iterations}] Losses:: Complete:{complete_loss:.4f}, contextual:{contextual_loss:.4f}, perceptual:{lam*perceptual_loss:.4f} (after x0.1), time: {time.time()-t_start:.2f}s")
        results.append(generated.clone())

results = torch.stack(results, dim = 0)
if progress_images:
    training_images = images.expand([results.shape[0]] + list(images.shape))*mask + results*(1-mask)
    for face_idx in range(training_images.shape[1]):
        training_img_filepath = os.path.join(save_folder, f"5-training-progress-img-{face_idx}.jpg")
        display_img(training_images[:, face_idx]).save(training_img_filepath)
        logging.info(f" Training progress images saved: {training_img_filepath}")


def overlay_generations(generated_faces, img):
    np_img = np.array(img)
    for i, (face, x, y), in enumerate(faces):
        target_size = face.shape[0]
        generated_face = generated_faces[i]
        resize_transform = transforms.Resize(target_size, antialias = False)
        generated_face = resize_transform(generated_face)
        generated_face = display_img(generated_face)

        square_border = int(target_size * border_factor)
        cropped_img = np.array(generated_face)[square_border:-square_border, square_border:-square_border]
        np_img[y+square_border:y+target_size-square_border, x+square_border:x+target_size-square_border] = cropped_img
    
    return Image.fromarray(np_img)

#Generating gif
final_images =  []
final_images_annotated = []
for generated_faces in results:
    final_image = overlay_generations(generated_faces.cpu(), img)
    final_images.append(np.array(final_image))

    final_image_annotated = draw_boxes(overlay_generations(generated_faces.cpu(), img), boxes + squares) 
    final_images_annotated.append(np.array(final_image_annotated))

gif_path = os.path.join(save_folder, f"6-training_gif.gif")
imageio.mimsave(gif_path, final_images, duration=0.2)

gif_annotated_path = os.path.join(save_folder, f"7-training_gif_annotations.gif")
imageio.mimsave(gif_annotated_path, final_images_annotated, duration=0.5)

with torch.no_grad():
    generated_faces = netG_loaded(zhats).cpu()

final_annotated_path = os.path.join(save_folder, f"8-annotated_final.jpg")
final_save_path = os.path.join(save_folder, f"final_image.jpg")

draw_boxes(overlay_generations(generated_faces, img), boxes + squares).save(final_annotated_path)
logging.info(f"Saved final annotated image {final_annotated_path}")
overlay_generations(generated_faces, img).save(final_save_path)
logging.info(f"Saved final image {final_save_path}!")

INFO:root:Saved progress images in c:/Users/James/git/de-identification/dev-notebooks/16070
INFO:root: Model loaded
INFO:root: Model training
INFO:root: i: [24/500] Losses:: Complete:0.4263, contextual:0.2701, perceptual:0.1561 (after x0.1), time: 2.85s
INFO:root: i: [49/500] Losses:: Complete:0.3955, contextual:0.2454, perceptual:0.1501 (after x0.1), time: 5.32s
INFO:root: i: [74/500] Losses:: Complete:0.4036, contextual:0.2243, perceptual:0.1793 (after x0.1), time: 7.83s
INFO:root: i: [99/500] Losses:: Complete:0.4212, contextual:0.2032, perceptual:0.2180 (after x0.1), time: 10.29s
INFO:root: i: [124/500] Losses:: Complete:0.4183, contextual:0.1871, perceptual:0.2313 (after x0.1), time: 12.63s
INFO:root: i: [149/500] Losses:: Complete:0.4129, contextual:0.1763, perceptual:0.2366 (after x0.1), time: 15.12s
INFO:root: i: [174/500] Losses:: Complete:0.3989, contextual:0.1695, perceptual:0.2294 (after x0.1), time: 17.16s
INFO:root: i: [199/500] Losses:: Complete:0.3894, contextual:0.1647