In [None]:
!wget -q https://www.dropbox.com/s/5ji7jl7httso9ny/person_images.zip
!wget -q https://raw.githubusercontent.com/sizhky/deep-fake-util/main/random_warp.py
!pip install torch_summary
!pip install torch_snippets
!mkdir cropped_faces_personA
!mkdir cropped_faces_personB

In [None]:
from torch_snippets import *
from random_warp import get_training_data
from torchsummary import summary

In [None]:
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascades_frontalface_default.xml')

def crop_face(img):
  gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  face = face_cascade.detectMultiScale(gray, 1.3, 5)
  if(len(face) > 0):
    for (x,y,w,h) in face:
        img2 = img[y:(y+h),x:(x+w),:]
    img2 = cv2.resize(img2,(256,256))
    return img2, True
  else:
    return img, False

def crop_images(folders):
  images = Glob(folders + "/*.jpg")
  for i in range(len(images)):
    img = read(images[i], 1)
    image, face_detected = crop_face(img)
    if face_detected == False:
      continue
    else:
      cv2.imwrite('cropped_faces_'+folders+'/'+str(i)+ '.jpg',cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

In [None]:
class ImageDataset(Dataset):
  def __init__(self, personA, personB):
    self.itemsA = np.concatenate([read(f,1)[None] for f in personA])/255.
    self.itemsB = np.concatenate([read(f,1)[None] for f in personB])/255.
    self.itemA += self.itemsB.mean(axis = (0,1,2)) - self.itemsA.mean(axis = (0,1,2))
  
  def __len__(self):
    return len(self.itemsA)
  
  def __getitem__(self, index):
    a, b = choose(self.itemsA), choose(self.itemsB)
    return a, b

  def collate_fn(self, batch):
    a, b = list(zip(*batch))
    imsA, target_A = get_training_data(a, len(a))
    imsB, target_B = get_training_data(b, len(b))
    imsA, target_A, imsB, target_B = [torch.Tensor(i).permute(0,3,1,2).to(device) for i in [imsA, target_A, imsB, target_B]]
    return imsA, target_A, imsB, target_B

In [None]:
def conv_layer(input_features, output_features):
  return nn.Sequential(
      nn.Conv2d(input_features, output_features, kernel_size = 5, stride = 2, padding = 2),
      nn.ReLU()
  )

def up_scale(input_features, output_features):
  return nn.Sequential(
      nn.ConvTranspose2d(input_features, output_features, kernel_size = 2, stride = 2, padding = 0),
      nn.ReLU()
  )

class Reshape(nn.Module):
  def forward(self, input):
    output = input.view(-1, 1024, 4, 4)
    return output

class AutoEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = nn.Sequential(
        conv_layer(3,128),
        conv_layer(128, 256),
        conv_layer(256, 512),
        conv_layer(512, 1024),
        nn.Flatten(),
        nn.Linear(1024 * 4 * 4, 1024),
        nn.Linear(1024, 1024 * 4 * 4),
        Reshape(),
        up_scale(1024, 512)
    )

    self.decoderA = nn.Sequential(
        up_scale(512, 256),
        up_scale(256, 128),
        up_scale(128, 64),
        nn.Conv2d(64, 3, kernel_size = 3, padding = 1),
        nn.Sigmoid()
    )

    self.decoderB = nn.Sequential(
        up_scale(512, 256),
        up_scale(256, 128),
        up_scale(128, 64),
        nn.Conv2d(64, 3, kernel_size = 3, padding = 1),
        nn.Sigmoid()
    )

  def forward(self, x, select = "A"):
    if select == "A":
      out = self.encoder(x)
      out = self.decoderA(out)
      return out
    else:
      out = self.encoder(x)
      out = self.decoderB(out)
      return out


In [None]:
def train_batch(model, data, criterion, optA, optB):
  optA, optB = optimizers 
  optA.zero_grad()
  optB.zero_grad()
  imsA, target_A, imsB, target_B = data
  _imsA = model(imsA)
  _imsB = model(imsB)
  loss_A = criterion(_imsA, target_A)
  loss_B = criterion(_imsB, target_B)
  loss_A.backward()
  loss_B.backward()
  optA.step()
  optB.step()

  return loss_A.item(), loss_B.item()

model = AutoEncoder().to(device)

optimizerA = optim.Adam([{'params': model.encoder.parameters()}, {'params': model.decoderA.parameters()}], lr=5e-5, betas=(0.5, 0.999))
optimizerB = optim.Adam([{'params': model.encoder.parameters()}, {'params': model.decoderB.parameters()}], lr=5e-5, betas=(0.5, 0.999))
criterion = nn.L1Loss()

n_epochs = 1000
log = Report(n_epochs)
for ex in range(n_epochs):
  N = len(dataloader)
  for bx, data in enumerate(tqdm(dataloader)):
    lossA, lossB = train_batch(model, data , criterion, optimizerA, optimizerB)
    log.record(ex + (1 + bx)/N, lossA = lossA, lossB = lossB, end = '/r')
  log.report_avgs(ex+1)
  if (ex+1)%100 == 0:
      state = {'state': model.state_dict(),'epoch': ex }
      torch.save(state, 'autoencoder.pth')
  if (ex+1)%100 == 0:
      bs = 5
      a,b,A,B = data
      line('A to B')
      _a = model(a[:bs], 'A')
      _b = model(a[:bs], 'B')
      x = torch.cat([A[:bs],_a,_b])
      subplots(x, nc=bs, figsize=(bs*2, 5))
      line('B to A')
      _a = model(b[:bs], 'A')
      _b = model(b[:bs], 'B')
      x = torch.cat([B[:bs],_a,_b])
      subplots(x, nc=bs, figsize=(bs*2, 5))
log.plot_epochs()