# Dataset creation

In [2]:
import numpy as np
import argparse
import cv2

In [3]:
def config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--training_image", type=str, default="../data/TI/strebelle.png", help="Path to training image")
    parser.add_argument("--name", type=str, default="strebelle", help="Name of dataset")
    parser.add_argument("--edgelength", type=int, default=250, help="Input batch size")
    parser.add_argument("--stride", type=int, default=32, help="Height/Width of the input image to the network")
    parser.add_argument("--output_dir", type=str, default="output/", help="Path to output the training images slices")
    return parser.parse_args()

# param = config()

# Jupyter notebook
param = {
    "training_image": "../data/TI/strebelle.png",
    "name": "strebelle",
    "edgelength": 64,
    "stride": 16,
    "output_dir": "output/"
}

print(f"Parameters for dataset creation:")
for k, v in param.items():
    print(f"{k}:{v}")

Parameters for dataset creation:
training_image:../data/TI/strebelle.png
name:strebelle
edgelength:64
stride:16
output_dir:output/


In [3]:
import os
from tqdm import tqdm
from skimage.io import imsave
from skimage.util import view_as_windows

ti = cv2.imread(param["training_image"], cv2.COLOR_BGR2GRAY)
win_ti = view_as_windows(ti/255,
                        window_shape=(64, 64),
                        step=1)


os.makedirs('output', exist_ok=True)

for i, batch_ti in tqdm(enumerate(win_ti)):
    for j, t in enumerate(batch_ti):
        imsave(f'output/real/strebelle_patch_{i}_{j}.png', t)





















































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































































187it [04:17,  1.38s/it]


In [4]:
import torch.utils.data as data
from torch import Tensor
from os import listdir
from os.path import join
import numpy as np
import h5py

In [5]:
import torchvision.transforms as transforms
import torch

transform_funcs = transforms.Compose([
    transforms.GaussianBlur(kernel_size=(3, 3)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.3),
    transforms.Normalize([0.5], [0.5], [0.5])])

# WGAN-GP

In [4]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
from torch.autograd import grad
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
import shutil
import statistics as st
import matplotlib.image as mpimg
import pylab

%pylab inline
pylab.rcParams['figure.figsize'] = (10, 10)

class Generator(nn.Module):
    def __init__(self, dim_in, dim=64):
        super(Generator,self).__init__()
        def genblock(dim_in, dim_out):
            block = nn.Sequential( nn.ConvTranspose2d(in_channels = dim_in, 
                                                      out_channels = dim_out,
                                                      kernel_size = 5, 
                                                      stride=2, 
                                                      padding=2,
                                                      output_padding = 1,
                                                      bias = False),
                                    nn.BatchNorm2d(dim_out),
                                    nn.ReLU()
                                    )
            return block
        def genimg(dim_in):
            block = nn.Sequential( nn.ConvTranspose2d(in_channels = dim_in, 
                                                      out_channels = 1,
                                                      kernel_size = 5, 
                                                      stride=2, 
                                                      padding=2,
                                                      output_padding = 1,
                                                      ),
                                    nn.Tanh()
                                    )
            return block
        
        self.prepare = nn.Sequential(nn.Linear(dim_in, dim*8*4*4, bias=False),
                                     nn.BatchNorm1d(dim*8*4*4),
                                     nn.ReLU())
        
        self.generate = nn.Sequential(genblock(dim*8, dim*4),
                                      genblock(dim*4, dim*2),
                                      genblock(dim*2, dim),
                                      genimg(dim))
    def forward(self, x):
        x = self.prepare(x)
        x = x.view(x.size(0), -1, 4, 4)
        x = self.generate(x)
        return x

    
class Critic(nn.Module):
    def __init__(self, dim_in, dim=64):
        super(Critic, self).__init__()
        
        def critic_block(dim_in , dim_out):
            block = nn.Sequential(nn.Conv2d(in_channels = dim_in, 
                                            out_channels = dim_out,
                                            kernel_size = 5, 
                                            stride=2, 
                                            padding=2),
                                    nn.InstanceNorm2d(dim_out, affine= True),
                                    nn.LeakyReLU(0.2))
            return block
        
        print(dim_in)
        
        self.analyze = nn.Sequential(nn.Conv2d(in_channels = dim_in, 
                                               out_channels = dim,
                                               kernel_size = 5, 
                                               stride=2, 
                                               padding=2),
                                     nn.LeakyReLU(0.2),
                                     critic_block(dim,dim*2),
                                     critic_block(dim*2,dim*4),
                                     critic_block(dim*4, dim*8),
                                     nn.Conv2d(in_channels=dim*8, 
                                               out_channels=1,
                                               kernel_size=4))
    def forward(self,x):
        x = self.analyze(x)
        x =x.view(-1)
        return x

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib


`%matplotlib` prevents importing * from pylab and numpy
  warn("pylab import has clobbered these variables: %s"  % clobbered +


In [5]:
def gradient_penalty(x,y,f):
    shape =[x.size(0)] + [1] * (x.dim() -1)
    alpha = torch.rand(shape)
    z = x+ alpha *(y-x)
    z = Variable(z,requires_grad=True)
    z=z
    o=f(z)
    g = grad(o,z, grad_outputs=torch.ones(o.size()), create_graph=True)[0].view(z.size(0), -1)
    gp = ((g.norm(p=2, dim=1))**2).mean()
    return gp
#%%
def save_checkpoint(state, save_path, is_best=False, max_keep=None):
    # save checkpoint
    torch.save(state, save_path)

    # deal with max_keep
    save_dir = os.path.dirname(save_path)
    list_path = os.path.join(save_dir, 'latest_checkpoint')

    save_path = os.path.basename(save_path)
    if os.path.exists(list_path):
        with open(list_path) as f:
            ckpt_list = f.readlines()
            ckpt_list = [save_path + '\n'] + ckpt_list
    else:
        ckpt_list = [save_path + '\n']

    if max_keep is not None:
        for ckpt in ckpt_list[max_keep:]:
            ckpt = os.path.join(save_dir, ckpt[:-1])
            if os.path.exists(ckpt):
                os.remove(ckpt)
        ckpt_list[max_keep:] = []

    with open(list_path, 'w') as f:
        f.writelines(ckpt_list)

    # copy best
    if is_best:
        shutil.copyfile(save_path, os.path.join(save_dir, 'best_model.ckpt'))
#%%
def load_checkpoint(ckpt_dir_or_file, map_location=None, load_best=False):
    if os.path.isdir(ckpt_dir_or_file):
        if load_best:
            ckpt_path = os.path.join(ckpt_dir_or_file, 'best_model.ckpt')
        else:
            with open(os.path.join(ckpt_dir_or_file, 'latest_checkpoint')) as f:
                ckpt_path = os.path.join(ckpt_dir_or_file, f.readline()[:-1])
    else:
        ckpt_path = ckpt_dir_or_file
    ckpt = torch.load(ckpt_path, map_location=map_location)
    print(' [*] Loading checkpoint from %s succeed!' % ckpt_path)
    return ckpt

In [6]:
epochs = 100
batch_size = 180
n_critic=5
lr=0.0002
z_dim = 100


transf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Grayscale(num_output_channels=1),
        transforms.Normalize([0.5], [0.5])])
data = torchvision.datasets.ImageFolder('output/', transform = transf)
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size,
                                         shuffle=True, num_workers=4)
C = Critic(1)
G = Generator(z_dim)
print("==================================================================\nGenerator : ")
print(G)
print("==================================================================\nCritic")
print(C)
start_epoch=0
G_opt = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5,0.999))
C_opt = torch.optim.Adam(C.parameters(), lr=lr, betas=(0.5,0.999))

checkpoint = './checkpoints/wgan_gp'
save_dir = './sample_images/wgan_gp'

if not isinstance(checkpoint, (list, tuple)):
    paths = [checkpoint]
    for path in paths:
        if not os.path.isdir(path):
            os.makedirs(path)
if not isinstance(save_dir, (list, tuple)):
    paths = [save_dir]
    for path in paths:
        if not os.path.isdir(path):
            os.makedirs(path)
try:
    ckpt = load_checkpoint(checkpoint)
    start_epoch = ckpt['epoch']
    C.load_state_dict(ckpt['D'])
    G.load_state_dict(ckpt['G'])
    C_opt.load_state_dict(ckpt['d_optimizer'])
    G_opt.load_state_dict(ckpt['g_optimizer'])
except:
    print(' [*] No checkpoint!')
    start_epoch = 0

z_sample = Variable(torch.randn(100, z_dim))

1
Generator : 
Generator(
  (prepare): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=False)
    (1): BatchNorm1d(8192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (generate): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=

In [None]:
import time
for epoch in range(start_epoch, epochs):
    start_time = time.time()
    C_loss= []
    G_loss=[]
    G.train()
    for i, (images, _) in enumerate(dataloader):
        step = epoch * len(dataloader) + i + 1
        images = Variable(images)
        batch = images.size(0)
        images = images
        z = Variable(torch.randn(batch, z_dim))
        z = z
        
        generated = G(z)
        real_criticized = C(images)
        fake_criticized = C(generated)
        
        em_distance = real_criticized.mean() - fake_criticized.mean()
        grad_penalty = gradient_penalty(images.data, generated.data, C)
        
        CriticLoss = -em_distance + grad_penalty*10
        C_loss.append(CriticLoss.item())
        C.zero_grad()
        CriticLoss.backward()
        C_opt.step()
        
        if step % n_critic == 0:
            z = Variable(torch.randn(batch, z_dim))
            z = z
            generated = G(z)
            fake_criticized = C(generated)
            GenLoss = -fake_criticized.mean()
            G_loss.append(GenLoss.item())
            C.zero_grad()
            G.zero_grad()
            GenLoss.backward()
            G_opt.step()
            print("Epoch {} : {}/{} :: {} mins".format(epoch+1, i+1, len(dataloader), (time.time()-start_time)/60), end='\r')         
    print("Epoch {} completed".format(epoch+1))
    G.eval()
    fake_gen_images = (G(z_sample).data +1)/2.0
    torchvision.utils.save_image(fake_gen_images, save_dir+'/Epoch '+str(epoch+1)+".jpg",nrow=10)
    save_checkpoint({'epoch': epoch + 1,
                     'D': C.state_dict(),
                     'G': G.state_dict(),
                   'd_optimizer': C_opt.state_dict(),
                   'g_optimizer': G_opt.state_dict()},
                  '%s/Epoch_(%d).ckpt' % (checkpoint, epoch + 1),
                  max_keep=2)