# Colab-co-mod-gan-pytorch
Original tensorflow version: [zsyzzsoft/co-mod-gan](https://github.com/zsyzzsoft/co-mod-gan)

Pytorch version: [zengxianyu/co-mod-gan-pytorch](https://github.com/zengxianyu/co-mod-gan-pytorch)

My fork: [styler00dollar/Colab-co-mod-gan-pytorch](https://github.com/styler00dollar/Colab-co-mod-gan-pytorch)

In [None]:
!nvidia-smi

In [None]:
#@title setup
!git clone https://github.com/zengxianyu/co-mod-gan-pytorch
!mkdir "/content/output"
!mkdir "/content/input"
#@title download models
%cd /content/co-mod-gan-pytorch/
!sh download/ffhq512.sh
!sh download/ffhq1024.sh
!sh download/places512.sh
# install ninja
%cd /content
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force

Models: 
```
co-mod-gan-places2-050000.pth
co-mod-gan-ffhq-9-025000.pth # 512px
co-mod-gan-ffhq-10-025000.pth # 1024px
```

In [None]:
#@title resize/invert mask color if needed
import cv2
image = cv2.imread("/content/input.png")
image = cv2.resize(image, (512,512), cv2.INTER_NEAREST)
cv2.imwrite("/content/input.png", image)

image = cv2.imread("/content/mask.png", cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (512,512), cv2.INTER_NEAREST)
# invert mask if needed
# white = original area
# black = inpainting
image = 255-image
cv2.imwrite("/content/mask.png", image)

In [None]:
%cd /content/co-mod-gan-pytorch
!python test.py -i /content/input.png -m /content/mask.png -o /content/output.png -c checkpoints/co-mod-gan-places2-050000.pth

In [None]:
#@title batch_process folder
#@markdown Image areas are marked with green
%cd /content/co-mod-gan-pytorch
import argparse
import numpy as np
import torch
from co_mod_gan import Generator
from PIL import Image
import glob
import cv2
from tqdm import tqdm
import os

output_path = "/content/output" #@param
rootdir = "/content/input" #@param

files = glob.glob(rootdir + '/**/*.png', recursive=True)

device = "cuda"

net = Generator()
net.load_state_dict(torch.load("checkpoints/co-mod-gan-places2-050000.pth"))
net.eval()
net = net.to(device)

for f in tqdm(files):
  images = cv2.imread(f)
  images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB)
  masks = 1-np.all(images == [0,255,0], axis=-1).astype(int)
  masks = torch.from_numpy(masks).unsqueeze(0).unsqueeze(0).float()
  images = (torch.from_numpy(images).unsqueeze(0).permute(0,3,1,2)/255)*2-1

  latents_in = torch.randn(1, 512)

  images = images.to(device)
  masks = masks.to(device)
  latents_in = latents_in.to(device)

  result = net(images, masks, [latents_in], truncation=None)
  result = result.detach().cpu().numpy()
  result = (result+1)/2
  result = (result[0].transpose((1,2,0)))*255
  Image.fromarray(result.clip(0,255).astype(np.uint8)).save(os.path.join(output_path, os.path.basename(f)))

In [None]:
#@title delelete folders and recreate them
%cd /content/
!sudo rm -rf /content/input
!sudo rm -rf /content/output
!mkdir /content/input
!mkdir /content/output