In [None]:
import torch
from torch import nn
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.autograd import Variable
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid
import torchvision
import os

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size = 4, stride = 2, padding = 1),
                nn.LeakyReLU(0.2),
                nn.Conv2d(64, 64, kernel_size = 4, stride = 2, padding = 1),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(0.2),
                nn.Conv2d(64, 128, kernel_size = 4, stride = 2, padding = 1),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2),
                nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2),
                nn.Conv2d(256, 512, kernel_size = 4, stride = 2, padding = 1),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(0.2)
        )
        
        self.channel_wise = nn.Conv2d(512, 4000, kernel_size = 3)
        
        self.decoder = nn.Sequential(
                nn.ConvTranspose2d(4000, 512, kernel_size = 4, stride =2, padding = 1),
                nn.BatchNorm2d(512),
                nn.ReLU(),
                nn.ConvTranspose2d(512, 256, kernel_size = 4, stride = 2, padding = 1),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.ConvTranspose2d(256, 128, kernel_size = 4, stride = 2, padding = 1),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.ConvTranspose2d(128, 64, kernel_size = 4, stride = 2, padding = 1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.ConvTranspose2d(64, 3, kernel_size = 4, stride = 2, padding = 1),
                nn.Tanh(),
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.channel_wise(x)
        return self.decoder(x)
            

In [None]:
def test():
    imgs = torch.randn((1, 3, 128, 128))
    gen = Generator()
    preds = gen(imgs)
    print(preds.shape)
    
test()

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.classifier = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size = 4, stride = 2, padding = 1),
                nn.LeakyReLU(0.2),
                nn.Conv2d(64, 128, kernel_size = 4, stride = 2, padding = 1),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2),
                nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2),
                nn.Conv2d(256, 512, kernel_size = 4, stride = 2, padding = 1),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(0.2),
                nn.Conv2d(512, 1, kernel_size = 4, stride = 2, padding = 0),
                nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.classifier(x)

In [None]:
def test():
    imgs = torch.randn((1, 3, 64, 64))
    disc = Discriminator()
    preds = disc(imgs)
    print(preds.shape)
    
test()

In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02),
        torch.nn.init.normal_(m.bias, 0)

In [None]:
criterion1 = nn.BCELoss()
criterion2 = nn.MSELoss()

n_epochs = 20
img_channels = 3
lr = 0.0002
device = 'cuda'
beta_1 = 0.5
beta_2 = 0.999
steps = 500

In [None]:
gen = Generator().to(device)
disc = Discriminator().to(device)

gen_opt = torch.optim.Adam(gen.parameters(), lr = lr, betas = (beta_1, beta_2))
disc_opt = torch.optim.Adam(disc.parameters(), lr = lr, betas = (beta_1, beta_2))

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [None]:
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import numpy as np

class ImageDataset(Dataset):
    def __init__(self, root_dir, transform = None):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)
        self.transform = transform
        
    def __len__(self):
        return len(self.list_files)
    
    def mask(self, img):
#         x, y = np.random.randint(0, 64, 2)
#         x1, y1 = x+64, y+64
#         masked_part = img[:, x:x1, y:y1]
#         masked_img = img.clone()
#         masked_img[:, x:x1, y:y1] = 1

        x = (128 - 64) // 2
        masked_img = img.clone()
        masked_part = img[:, x:x+64, x: x+64]
        masked_img[:, x: x+64, x:x+64] = 1
        return masked_part, masked_img
    
    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = Image.open(img_path).convert('RGB')
        
        if self.transform :
            image = self.transform(image)
        masked_part, masked_img = self.mask(image)
        return image, masked_part, masked_img
        

In [None]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = ImageDataset('../horse2zebra/trainB/', transform = transform)
dataloader = DataLoader(dataset, batch_size = 16)

# for img, mask_part, mask_img in dataloader:
#         print(img.shape)
#         print(mask_part.shape)
#         save_image(img, "sat.png")
#         save_image(mask_part, "real.png")
#         save_image(mask_img, "real2.png")
#         break;

In [None]:
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    img = img.detach().cpu()
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
current_step = 0
gen_losses = 0
disc_losses = 0
# os.makedirs("context-encoders", exist_ok=True)

for epoch in range(n_epochs):
    for img, img_mask, masked_img in dataloader:
        img = img.to(device)
        img_mask = img_mask.to(device)
        masked_img = masked_img.to(device)
        
        disc_opt.zero_grad()
        fake_img_mask = gen(masked_img)
        disc_preds_fake = disc(fake_img_mask.detach())
        disc_fake_loss = criterion1(disc_preds_fake, torch.zeros_like(disc_preds_fake))
        disc_fake_loss.backward()
        disc_preds_real = disc(img_mask)
        disc_real_loss = criterion1(disc_preds_real, torch.ones_like(disc_preds_real))
        disc_real_loss.backward()
        disc_loss = disc_fake_loss + disc_real_loss
#         disc_loss.backward()
        disc_opt.step()
        disc_losses += disc_loss.item()
        
        gen_opt.zero_grad()
        fake_img_mask2 = gen(masked_img)
        gen_preds_fake = disc(fake_img_mask2)
        gen_fake_loss = criterion1(gen_preds_fake, torch.ones_like(gen_preds_fake))
        gen_mse_loss = criterion2(fake_img_mask2, img_mask)
        gen_loss = (1-0.999)*gen_fake_loss + 0.999*gen_mse_loss
        gen_loss.backward()
        gen_opt.step()
        gen_losses += gen_loss.item()

                
        if current_step % steps == 0 and current_step > 0 :
            print(f"Epochs: {epoch} Step: {current_step} Generator loss: {gen_losses / steps}, discriminator loss: {disc_losses / steps}")
            img_grid_real = torchvision.utils.make_grid(img[:5], nrow = 5)
            img_grid_mask = torchvision.utils.make_grid(masked_img[:5], nrow = 5)
            mask_img_gen = mask_img.clone()
            mask_img_gen[:5, :, 32:32+64, 32:32+64] = fake_img_mask2.data[:5]
            img_grid_fake = torchvision.utils.make_grid(fake_img_mask2.data[:5], nrow = 5)
            matplotlib_imshow(img_grid_real, one_channel=False)
            matplotlib_imshow(img_grid_mask,  one_channel=False)
            matplotlib_imshow(img_grid_fake,  one_channel=False)
            gen_losses = 0
            disc_losses = 0
        
        current_step += 1