In [1]:
# Show all Jupyter output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"


In [2]:
import sys
sys.path.append('..')


import torch
import torch.utils.data as data
from ignite.engine import Engine

from src.datasets import WarpDataset
from src.loss import PerPixelCrossEntropyLoss
from src.nets import Discriminator
from src.warping_module import WarpingModule



In [3]:
clothing_dir = "test_resources/clothing_segmentation"
body_dir = "test_resources/body_segmentation"

In [4]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [5]:
dataset = WarpDataset(clothing_seg_dir=clothing_dir, body_seg_dir=body_dir)
data_loader = torch.utils.data.DataLoader(dataset)

In [6]:
# Loss Warp function
reconstruction_loss = PerPixelCrossEntropyLoss()
adversarial_loss = torch.nn.BCELoss()
# later: warp_loss = reconstruction_loss(i, t) + adversarial_loss(d, c)
# warp_loss.backward()



In [7]:
generator = WarpingModule()
discriminator = Discriminator()


generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

WarpingModule(
  (pose_encoder): PoseEncoder(
    (down1): UNetDown(
      (model): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): LeakyReLU(negative_slope=0.2)
      )
    )
    (down2): UNetDown(
      (model): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (down3): UNetDown(
      (model): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2)
        (3): Dropout(p=0.5)
      )
    )
    (down4): UNetDown(
      (model): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bia

Discriminator(
  (model): Sequential(
    (0): Conv2d(19, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): ZeroPad2d(padding=(1, 0, 1, 0), value=0.0)
    (9): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1))
  )
)

In [8]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.99))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.99))