## Human Pose Generation through Deformable GANs
- by Siarohin et al.

In [1]:
import os
import torch
from torch.utils.data import DataLoader

# your modules
from datasets.market_dataset import MarketPoseDataset
from models.generator import DeformableGenerator
from models.discriminator import PatchDiscriminator
from train import train_model
from infer import infer_model
from evaluate import evaluate_model
from utils.visualization import tensor_to_image, save_comparison_grid
from utils.checkpoint import load_checkpoint, load_models_from_checkpoint

In [2]:
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
# ===== Cell 2: Configuration & Hyperparameters =====

# Data paths
images_dir = "data/1/Market-1501-v15.09.15/bounding_box_train"
pose_maps_dir = "data/processed/pose_maps"
train_pairs = "data/splits/train_pairs.csv"
val_pairs = "data/splits/val_pairs.csv"
test_pairs = val_pairs  # or a separate test set

# Training hyperparams
batch_size = 16
num_epochs = 20
lr = 2e-4
betas = (0.5, 0.999)
lambda_nn = 1.0
lambda_perceptual = 0.1
lambda_offset = 0.1
use_perceptual = True

# Output dirs
ckpt_dir = "outputs/ckpts"
sample_dir = "outputs/samples"
infer_out_dir = "outputs/infer"

os.makedirs(ckpt_dir, exist_ok=True)
os.makedirs(sample_dir, exist_ok=True)
os.makedirs(infer_out_dir, exist_ok=True)

In [4]:
from torch.utils.data import Subset
import numpy as np

train_ds = MarketPoseDataset(images_dir, pose_maps_dir, train_pairs)
val_ds = MarketPoseDataset(images_dir, pose_maps_dir, val_pairs)
test_ds = MarketPoseDataset(images_dir, pose_maps_dir, test_pairs)

# Assuming train_dataset and val_dataset are already created
reduce_factor = 0.25  # use 25% of training data

num_train = len(train_ds)
num_val = len(val_ds)

subset_train_indices = np.random.choice(num_train, int(num_train * reduce_factor), replace=False)
subset_val_indices = np.random.choice(num_val, int(num_val * reduce_factor), replace=False)

train_dataset_small = Subset(train_ds, subset_train_indices)
val_dataset_small = Subset(val_ds, subset_val_indices)


train_loader = DataLoader(train_dataset_small, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset_small, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4)

print(f"Train set size: {len(train_ds)}, Val set size: {len(val_ds)}, Test set size: {len(test_ds)}")

Skipping missing pair: 0696_c6s2_046643_03.jpg, 0696_c6s2_054793_02.jpg
Skipping missing pair: 0432_c5s1_105273_03.jpg, 0432_c5s1_105323_05.jpg
Skipping missing pair: 0105_c6s1_017301_04.jpg, 0105_c5s1_037501_01.jpg
Skipping missing pair: 0635_c6s2_033343_01.jpg, 0635_c5s2_040055_02.jpg
Skipping missing pair: 0158_c6s1_037651_01.jpg, 0158_c5s1_037801_01.jpg
Skipping missing pair: 0810_c4s4_049385_03.jpg, 0810_c3s3_030553_03.jpg
Skipping missing pair: 1428_c6s3_090842_03.jpg, 1428_c6s3_090867_04.jpg
Skipping missing pair: 0605_c6s2_071193_01.jpg, 0605_c6s2_071268_02.jpg
Skipping missing pair: 0139_c6s1_023576_01.jpg, 0139_c1s1_023301_02.jpg
Skipping missing pair: 0673_c1s3_058901_04.jpg, 0673_c6s2_038818_03.jpg
Skipping missing pair: 0696_c6s2_055093_06.jpg, 0696_c6s2_054868_03.jpg
Skipping missing pair: 0706_c1s4_007706_02.jpg, 0706_c5s2_068752_03.jpg
Skipping missing pair: 0973_c6s3_002092_03.jpg, 0973_c2s2_137652_01.jpg
Skipping missing pair: 0637_c6s2_021768_04.jpg, 0637_c5s2_028105

In [5]:
pose_ch = train_ds[0]['src_pose'].shape[0]
netG = DeformableGenerator(in_ch=3, pose_ch=pose_ch).to(device)
netD = PatchDiscriminator(in_ch_img=3, pose_ch=pose_ch).to(device)

print(netG)
print(netD)

DeformableGenerator(
  (enc1): DownsampleBlock(
    (conv): ConvBlock(
      (main): Sequential(
        (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): ReLU(inplace=True)
      )
    )
  )
  (enc2): DownsampleBlock(
    (conv): ConvBlock(
      (main): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): ReLU(inplace=True)
      )
    )
  )
  (enc3): DownsampleBlock(
    (conv): ConvBlock(
      (main): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): ReLU(inplace=True)
      )
    )
  )
  (enc4): DownsampleBlock(
    (conv): ConvBlock(
      (main): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), st

In [6]:
ckpt = torch.load("outputs/ckpts/checkpoint_latest.pth")
netG.load_state_dict(ckpt['netG'])
netD.load_state_dict(ckpt['netD'])

  ckpt = torch.load("outputs/ckpts/checkpoint_latest.pth")


<All keys matched successfully>

In [None]:
train_model(
    netG, netD,
    train_loader, val_loader,
    device=device,
    num_epochs=10,
    lr=lr/3,
    betas=betas,
    lambda_nn=lambda_nn,
    lambda_p=lambda_perceptual,
    lambda_off=lambda_offset,
    use_perceptual=use_perceptual,
    checkpoint_dir=ckpt_dir,
    sample_dir=sample_dir,
    num_gen_updates=2
)

In [12]:
from utils.checkpoint import save_checkpoint 
checkpoint_dir="outputs/ckpts"

ckpt = {
            'epoch': 7,
            'netG': netG.state_dict(),
            'netD': netD.state_dict(),
            'optimG': "",
            'optimD': "",
        }

save_checkpoint(ckpt, checkpoint_dir, filename="checkpoint_latest.pth")

Saved checkpoint: outputs/ckpts/checkpoint_latest.pth


In [15]:
infer_model(netG, test_loader, device=device, output_dir=infer_out_dir)

Inference done; results saved to outputs/infer
