In [0]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import pickle
import os
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import random

In [0]:
class Generator(nn.Module):
    #network that makes 3 * 256 * 256 images from images of the same size
    def __init__(self):
        super(Generator, self).__init__()

        self.conv1 = nn.Conv2d(6, 30, 3)
        self.conv2 = nn.Conv2d(30, 30, 4)
        
        self.fc1 = nn.Linear(30 * 14 * 14, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 120)
        self.fc4 = nn.Linear(120, 30 * 14 * 14)

        self.deconv1 = nn.ConvTranspose2d(60, 30, 4)
        self.deconv2 = nn.ConvTranspose2d(60, 3, 3)

        self.do = nn.Dropout(0.5)

    def forward(self, masked_imgs, masks):
        avg_masks = F.avg_pool2d(masks, 3, 1, 1)
        x = torch.cat([masked_imgs, avg_masks], dim=1)

        conv_img_1 = F.relu(self.conv1(x))
        x, max_pool_1_idx = F.max_pool2d(conv_img_1, 2, return_indices=True)
        conv_img_2 = F.relu(self.conv2(x))
        x, max_pool_2_idx = F.max_pool2d(conv_img_2, 2, return_indices=True)
        
        x = x.view(-1, self.num_flat_features(x))

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.do(x)
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))

        x = self.do(x)

        x = x.view(-1, 30, 14, 14)

        x = F.max_unpool2d(x, max_pool_2_idx, 2)
        x = torch.cat([x, conv_img_2], dim=1)
        x = F.relu(self.deconv1(x))

        x = self.do(x)

        x = F.max_unpool2d(x, max_pool_1_idx, 2)
        x = torch.cat([x, conv_img_1], dim=1)
        x = torch.sigmoid(self.deconv2(x))

        x = masks * masked_imgs + (1 - masks) * x
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [0]:
class Classifier(nn.Module):
    #network that takes 3 * 256 * 256 images and gives a value between 0 and 1
    def __init__(self):
        super(Classifier, self).__init__()

        self.conv1 = nn.Conv2d(3, 20, 3)
        self.conv2 = nn.Conv2d(20, 20, 3)
        self.conv3 = nn.Conv2d(20, 20, 3)
        self.conv4 = nn.Conv2d(20, 20, 3)

        self.fc1 = nn.Linear(20 * 13 * 13, 200)
        self.fc2 = nn.Linear(200, 80)
        self.fc3 = nn.Linear(80, 1)

        self.do1 = nn.Dropout(0.5)
        self.do2 = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(F.relu(self.conv4(x)), 2)
        
        x = x.view(-1, self.num_flat_features(x))

        x = F.relu(self.fc1(x))
        x = self.do1(x)
        x = F.relu(self.fc2(x))
        x = self.do2(x)
        x = torch.sigmoid(self.fc3(x))

        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [0]:
# connect to google drive
offline_use = False

if offline_use:
  directory = os.path.dirname(os.path.abspath(__file__))
else:
  from google.colab import drive
  drive.mount('/content/drive', force_remount=True)
  directory = '/content/drive/My Drive/'

directory = os.path.join(directory, 'gan_data')

Mounted at /content/drive


In [0]:
real_img_directory = os.path.join(directory, 'real_img')
fake_img_directory = os.path.join(directory, 'fake_img')
os.makedirs(real_img_directory, exist_ok=True)
os.makedirs(fake_img_directory, exist_ok=True)

In [0]:
def read_img_file_names():
  if os.path.isdir(real_img_directory):
    real_photo_paths = [os.path.join(path, name) for path, subdirs, files in os.walk(real_img_directory) for name in files]
    if real_photo_paths == []:
      print('No real photos found.')
  else:
    print('No real photos directory found.')
    real_photo_paths = []

  if os.path.isdir(fake_img_directory):
    fake_photo_paths = [os.path.join(path, name) for path, subdirs, files in os.walk(fake_img_directory) for name in files]
    if fake_photo_paths == []:
      print('No fake photos found.')
  else:
    print('No fake photos directory found.')
    fake_photo_paths = []

  return real_photo_paths, fake_photo_paths

real_img_list, fake_img_list = read_img_file_names()

No fake photos found.


In [0]:
# print(training_photos)

In [0]:
mask_directory = os.path.join(directory, 'mask')

mask_paths = []
filtered_paths = []

n_subfolders = 1

if os.path.isdir(mask_directory):
  for i in range(n_subfolders):
    folder_path_2 = os.path.join(mask_directory, str(i))
    if not os.path.isdir(folder_path_2):
      continue
    for folder in os.listdir(folder_path_2):
      folder_path = os.path.join(folder_path_2, folder)

      if not os.path.isdir(folder_path):
        continue

      mask_path = ""
      filtered_path = ""

      for file_name in os.listdir(folder_path):
        file_path = os.path.join(folder_path, file_name)

        if not os.path.isfile(file_path):
          continue

        if "mask" in file_name:
          mask_path = file_path
        elif "combined" in file_name:
          filtered_path = file_path

      # if mask_path == "":
      #   print("No mask found in folder: {}".format(folder_path))
      # if filtered_path == "":
      #   print("No filtered image found in folder: {}".format(folder_path))

      if mask_path != "" and filtered_path != "":
        mask_paths.append(mask_path)
        filtered_paths.append(filtered_path)

  if mask_paths == []:
    print('No masks found.')
else:
  print('No mask directory found.')

masks_df = pd.DataFrame({'mask_path': mask_paths, 'filtered_path': filtered_paths})

In [0]:
import pathlib
import cv2
import pickle
from google.colab.patches import cv2_imshow

def load_img(file_path):
  extension = ''.join(pathlib.Path(file_path).suffixes)

  if extension == '.pickle':
    # load pickle
    with open(file_path, 'rb') as pickle_file:
      img = np.array(pickle.load(pickle_file))
  elif extension == '.npy':
    # load npy
    img = np.load(file_path)
  elif extension in ['.jpg', '.png', '.bmp']:
    # load img
    img = cv2.imread(file_path)
    img.resize(64, 64, 3)
    img = np.array(img)
  else:
    # unknown extension
    print('Image file has an unknown extension: {}'.format(file_path))

  # cv2_imshow(img)

  transposed_img = torchvision.transforms.ToTensor()(img).numpy()

  return transposed_img

def load_imgs(file_paths):
  x = [load_img(file_path) for file_path in file_paths]
  # torchvision.utils.make_grid(torch.Tensor(x))
  return x

In [0]:
from torchvision.utils import save_image

idx_img = 0
def save_img(directory, img):
  global idx_img
  file_path = os.path.join(directory, 'z{}.png'.format(idx_img))
  save_image(img, file_path)
  idx_img += 1
  return file_path

def save_imgs(directory, imgs_tensor):
  file_paths = []
  # torchvision.utils.make_grid(imgs_tensor)
  for img_tensor in imgs_tensor:
    img_tensor = img_tensor.detach()
    file_path = save_img(directory, img_tensor)
    file_paths.append(file_path)
  return file_paths

In [0]:
def train_classifier(classifier_network, training_imgs, targets, loss_function, classifier_optimizer):
  training_imgs = torch.tensor(training_imgs).float().cuda()
  targets = torch.tensor(targets).float().cuda()

  # print('training')
  # torchvision.utils.make_grid(training_imgs)

  out = cla.forward(training_imgs)
  loss = loss_function(out, targets)
  loss.backward()
  classifier_optimizer.step()

  print('Classifier loss: {}'.format(loss))

In [0]:
def train_generator(generator_network, classifier_network, masks, masked_imgs, loss_function, generator_optimizer):
  # print('masked')
  # torchvision.utils.make_grid(torch.Tensor(masked_imgs))

  masks = torch.tensor(masks).float().cuda()
  masked_imgs = torch.tensor(masked_imgs).float().cuda()
  targets = torch.ones([len(masks), 1]).float().cuda()

  generated_imgs = generator_network.forward(masked_imgs, masks)
  classifications = classifier_network.forward(generated_imgs)

  # print('generated imgs')
  # torchvision.utils.make_grid(torch.Tensor(generated_imgs.cpu()))

  loss = loss_function(classifications, targets)
  loss.backward()
  generator_optimizer.step()

  print('Generator loss: {}'.format(loss))

  return generated_imgs

In [0]:
n_epochs = 1900
n_repetitions = 10
n_img_sample = 50
n_mask_sample = 50

lr = 0.001
beta1 = 0.5

In [0]:
# print(fake_img_list)
fake_img_list = []

In [0]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('ConvTranspose') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.001)
    elif classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.01)
    # elif classname.find('BatchNorm') != -1:
    #     nn.init.normal_(m.weight.data, 1.0, 0.02)
    #     nn.init.constant_(m.bias.data, 0)

In [0]:
from random import shuffle

In [0]:
# gen = Generator().float().cuda()
# cla = Classifier().float().cuda()

# gen.apply(weights_init)
# cla.apply(weights_init)

loss_function = nn.BCELoss()
# cla_optimizer = torch.optim.SGD(cla.parameters(), lr=0.002)
# gen_optimizer = torch.optim.SGD(gen.parameters(), lr=0.002)

cla_optimizer = torch.optim.Adam(cla.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=0.98)
gen_optimizer = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=0.98)

for i_epoch in range(n_epochs):
  print('Running epoch {} out of {}'.format(i_epoch + 1, n_epochs))
  
  sample_real_imgs_fn = random.sample(real_img_list, min(len(real_img_list), n_img_sample))
  sample_fake_imgs_fn = random.sample(fake_img_list, min(len(fake_img_list), n_img_sample))
  sample_masks_fn_df = masks_df.sample(n = n_mask_sample, replace=True)

  real_imgs = load_imgs(sample_real_imgs_fn)
  fake_imgs = load_imgs(sample_fake_imgs_fn)
  masks = load_imgs(sample_masks_fn_df['mask_path'])
  masked_imgs = load_imgs(sample_masks_fn_df['filtered_path'])

  for i_repetition in range(n_repetitions):
    if len(real_imgs) > 0:
      shuffle(real_imgs)
      train_classifier(cla, real_imgs, [[1] for i in range(len(real_imgs))], loss_function, cla_optimizer)
    if len(fake_imgs) > 0:
      shuffle(fake_imgs)
      train_classifier(cla, fake_imgs, [[0] for i in range(len(fake_imgs))], loss_function, cla_optimizer)
    generated_imgs = train_generator(gen, cla, masks, masked_imgs, loss_function, gen_optimizer)
    fake_img_list += save_imgs(fake_img_directory, generated_imgs)

  n_img_sample = 200

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Classifier loss: 0.025787517428398132
Generator loss: 4.30033969678334e-06
Classifier loss: 0.10351395606994629
Classifier loss: 0.05170556902885437
Generator loss: 0.04000495374202728
Classifier loss: 0.05208441615104675
Classifier loss: 0.049865782260894775
Generator loss: 0.010923806577920914
Classifier loss: 0.06948782503604889
Classifier loss: 0.027047041803598404
Generator loss: 3.00409652709277e-07
Classifier loss: 0.06400252133607864
Classifier loss: 0.027450930327177048
Generator loss: 0.0
Classifier loss: 0.03434578329324722
Classifier loss: 0.02774358168244362
Generator loss: 0.0
Classifier loss: 0.07278061658143997
Classifier loss: 0.028069550171494484
Generator loss: 0.0
Classifier loss: 0.048233989626169205
Classifier loss: 0.025001730769872665
Generator loss: 1.6689307713591006e-08
Running epoch 75 out of 1900
Classifier loss: 0.09458958357572556
Classifier loss: 0.5375844240188599
Generator loss: 0.0
Class