This repo contains a PyTorch implementation of the method of the paper Deep Supervised Image Retargeting (Available [here](https://cic.tju.edu.cn/faculty/zhangjiawan/Jiawan_Zhang_files/paper/meiyijing2021.pdf)) along with the evaluation outputs. Specifically, suppose x,y, y_pred denote original, Ground Truth retargeted, and this method's retargeted image respectively and M denotes a function measuring a certain metric similarity; then we'd like to have M(y, y_pred)> M(x, y). Say the model function is G, i.e y_pred= G(x). Consider, for instance, the extreme cases of

 1. G(x)= y. In this case, M(y,y_pred)= M(y,y)= infinity, i.e we always beat M(x,y), so perfect map G.

 2. G(x)=x. In this case, M(y,y_pred)= M(y,x)= M(x,y), and we can never have M(y, y_pred)> M(x, y), so worst map G, and the fact that M(x, G(x))= M(x,x)= infinity is not an indication that we retargeted the image successfully.

 **To be able to use this code,** please download the TIRed [dataset](https://github.com/TIReD2020/TIReD), and search for the line **Path = "/path/to/TIReD"**, and
 **change** the path to yours.

 The paper presents 7 different models depending on the choice of the loss function and calls the best one "ours" (the one that combines all the error functions below). Thus, please **keep loss_mode = "ours"** (search that line) if the goal the best performing model; otherwise, **change** ours to any of "no_Lcon", "no_LP", "no_Ltv", "no_Lm_tv", "no_L1", "no_LSSIM" for benchmarking purposes.



 We trained the model on the TIReD dataset provided by the authors and available for download [here](https://github.com/TIReD2020/TIReD), containing the following, where train_A and test_A contain the original input images and train_B and test_B contain Ground truth retargeted images.


**To run this code,** you first need to

1. pip install torch torchvision piq opencv-python scikit-image matplotlib tqdm pillow

2. Download the TIReD dataset [here](https://github.com/TIReD2020/TIReD). Say its path is /path/to/TIReD/  (please **modify** the line Path = "/path/to/TIReD" to have your path) . Then the folder looks like:

In [2]:
# /path/to/TIReD/
#   AVA/
#     train/train_A/
#     train/train_B/
#     train/train_B_mask/
#     test/test_A/
#     test/test_B/
#   COCO/
#     train_224_224/train_A/
#     train_224_224/train_B/
#     train_224_224/train_B_mask/
#     train_300_300/train_A/
#     train_300_300/train_B/
#     test/test_A/
#     test/test_B/
#   HKU-IS/
#     train/train_A/
#     train/train_B/
#     train/train_B_mask/
#     test/test_A/
#     test/test_B/
#   Waterloo Exploration/
#     train/train_A/
#     train/train_B/
#     train/train_B_mask/
#     test/test_A/
#     test/test_B/


3. Choose which model you'd like to use from "ours", "no_Lcon", "no_LP", "no_Ltv", "no_Lm_tv", "no_L1", "no_LSSIM". To do so, modify the line loss_mode = "ours"; otherwise, **"ours"** is the **best performing model**.

4. Finally, run python deep_supervised_image_retargeting.py. This will

    a. train the model

    b.  save the best checkpoint (based on highest validation SSIM metric) to mrgan_<loss_mode>_best.pth

    c. use best checkpoint to evaluate the model on the test datasets

    d. print the PSNR, SSIM, FSIM, VIF metrics for input vs GT and output vs GT

    e. save input vs GT vs our output images to ./mrgan_examples/<dataset_name>/*.png

The evaluation outputs are also available in this repo.

In [2]:
import os
from typing import Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split
import torchvision.transforms as T
from torchvision import models
from PIL import Image

import piq
from tqdm.auto import tqdm

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

Models: "ours", "no_Lcon", "no_LP", "no_Ltv", "no_Lm_tv", "no_L1", "no_LSSIM"


In [3]:
loss_mode= "ours"
use_Lmtv=(loss_mode!="no_Lm_tv")
use_L1=(loss_mode!="no_L1")
use_Lcon=(loss_mode!="no_Lcon")
use_LP=(loss_mode!="no_LP")
use_Ltv=(loss_mode!="no_Ltv")
use_Lssim=(loss_mode!="no_LSSIM")

Dataset: (original image, retargeted image, mask)

In [4]:
class TIRed(Dataset):
  def __init__(self, A, B, mask=None, image_size=224, normalize=True):
    self.A=A
    self.B=B
    self.mask=mask
    self.image_size=image_size

    files=sorted(os.listdir(A))
    self.data=[]
    for f in files:
      mp=None
      if mask is not None and os.path.isdir(mask):
        mp=os.path.join(mask, f)
      self.data.append((os.path.join(A, f), os.path.join(B, f), mp))

    #normalize
    if normalize:
      self.transform = T.Compose([ T.Resize((image_size, image_size)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    else:
      self.transform = T.Compose([T.Resize((image_size, image_size)), T.ToTensor()])
    self.maskTransform = T.Compose([T.Resize((image_size, image_size), interpolation=Image.NEAREST), T.ToTensor()])

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    AA, BB, MM = self.data[idx]

    if MM is not None:
      m = self.maskTransform(Image.open(MM).convert("L"))
    else:
      m = torch.zeros(1, self.image_size, self.image_size)

    return (self.transform(Image.open(AA).convert("RGB")), self.transform(Image.open(BB).convert("RGB")), m)


We now implement the Generator and Discriminator

We need this convolution layer pretty often.
Each convolutional layer con
tains the activation layer and a batch normalization after con
volution. We use LeakyReLU in the encoder, and ReLU in the decoder.

In [5]:
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, ksize=3, stride=1,
                 padding=None, dilation=1, bn=True, act="relu"):
        super().__init__()
        if padding is None:
            padding= (ksize // 2) * dilation

        self.conv= nn.Conv2d(in_channels, out_channels, ksize, stride=stride, padding=padding, dilation=dilation, bias=not bn)
        if bn:
          self.bn= nn.BatchNorm2d(out_channels)
        else:
          self.bn =None

        if act == "relu":
            self.activation= nn.ReLU(inplace=True)
        elif act == "leaky_relu":
            self.activation= nn.LeakyReLU(0.2, inplace=True)
        else:
            self.activation= None

    def forward(self, x):
        x= self.conv(x)
        if self.bn is not None:
            x= self.bn(x)
        if self.activation is not None:
            x= self.activation(x)
        return x

The encoder needs to be followed by a ResNet module that contains six residual blocks, so we first construct

ResNet(x)= ReLU(x+ F(x))

In [6]:
class ResNet(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.Fx = nn.Sequential(Conv(channels, channels, ksize=3, act="relu"), Conv(channels, channels, ksize=3, act=None),)
        self.activation = nn.ReLU()

    def forward(self, x):
        y=self.Fx(x)
        return self.activation(y + x)

From shallow to deep, these blocks contain 5, 5, 4, 4 and 2 convolution layers respectively, including standard 3x3 convolutions, dilated convolutions, and 1x1 convolutions. The first four blocks end with convolution of a stride 2. In each block, the feature maps of the first convolution layer and the dilated convolution layer(s) are concatenated and we send it to the next layer.

In [7]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, num_convs, rates=(1,2,3), downsample=True):

        super().__init__()
        self.downsample= downsample

        self.conv1= Conv(in_channels, out_channels, ksize=3, act="leaky_relu")

        self.atrous1= Conv(out_channels, out_channels, ksize=3, dilation=rates[1], act="leaky_relu")

        if num_convs == 5:
          self.atrous2 = Conv(out_channels, out_channels, ksize=3, dilation=rates[2], act="leaky_relu")
        else:
          self.atrous2= None

        if num_convs >= 4:
          self.conv1x1= Conv(out_channels * 2, out_channels, ksize=1, act="leaky_relu")
        else:
          self.conv1x1= None

        if downsample and num_convs >= 4 :
          self.down_conv = Conv(out_channels, out_channels, ksize=3, stride=2, act="leaky_relu")
        else:
           self.down_conv= None

    def forward(self, x):
        y1= self.conv1(x)
        y= self.atrous1(y1)
        if self.atrous2 is not None:
            y= self.atrous2(y)

        if self.conv1x1 is not None:
            y_cat= torch.cat([y1, y], dim=1)
            y= self.conv1x1(y_cat)
        skip= y

        if self.down_conv is not None:
            y= self.down_conv(y)

        return y, skip


The decoder is completely symmetric to the encoder and uses skip-connections to fuse multi-scale features. We use a resize-convolution scheme instead of naive deconvolution layers

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        self.conv = nn.Sequential(Conv(in_channels + skip_channels, out_channels, ksize=3, act="relu"), Conv(out_channels, out_channels, ksize=3, act="relu"),)
    def forward(self, x, skip):
        x= self.upsample(x)
        x= torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x


Generator is: U-Net with a 6-block ResNet bottleneck and the 5,5,4,4,2 encoder.

5 encoder blocks with channel sizes: 64, 128, 256, 512, 512

6 residual blocks

4 decoder blocks

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, channels=64):
        super().__init__()

        c1= channels  # 64
        c2= channels * 2 # 128
        c3= channels * 4  # 256
        c4= channels * 8  # 512
        c5= channels * 8  # 512

        #  5 encoder blocks
        self.enc1= Encoder(in_channels, c1, num_convs=5, rates=(1, 2, 3), downsample=True)
        self.enc2= Encoder(c1, c2, num_convs=5, rates=(1, 2, 3), downsample=True)
        self.enc3= Encoder(c2, c3, num_convs=4, rates=(1, 2, 3), downsample=True)
        self.enc4= Encoder(c3, c4, num_convs=4, rates=(1, 2, 3), downsample=True)
        self.enc5= Encoder(c4, c5, num_convs=2, rates=(1, 2, 3), downsample=False)

        #  bottleneck: 6 residual blocks
        self.bottleneck= nn.Sequential(ResNet(c5), ResNet(c5), ResNet(c5), ResNet(c5), ResNet(c5), ResNet(c5),)

        # 4 decoder blocks
        self.dec1= Decoder(in_channels=c5, skip_channels=c4, out_channels=c4)
        self.dec2= Decoder(in_channels=c4, skip_channels=c3, out_channels=c3)
        self.dec3= Decoder(in_channels=c3, skip_channels=c2, out_channels=c2)
        self.dec4= Decoder(in_channels=c2, skip_channels=c1, out_channels=c1)

        # 3x3 conv to RGB
        self.final_conv = nn.Conv2d(c1, out_channels, kernel_size=3, padding=1)

    def forward(self, x, return_skips: bool=False):
        y1, s1= self.enc1(x)
        y2, s2= self.enc2(y1)
        y3, s3= self.enc3(y2)
        y4, s4= self.enc4(y3)
        y5, s5= self.enc5(y4)

        y= self.bottleneck(y5)

        y= self.dec1(y, s4)
        y= self.dec2(y, s3)
        y= self.dec3(y, s2)
        y = self.dec4(y, s1)

        y= self.final_conv(y)

        if return_skips:
          return y, [s1,s2,s3,s4,s5]
        return y


PatchGAN-style discriminator.

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=6, channels=64):
        super().__init__()

        self.block1= Conv(in_channels, channels, ksize=4, stride=2, padding=1, bn=False, act="leaky_relu")

        self.block2= Conv(channels, channels * 2, ksize=4, stride=2, padding=1, bn=True, act="leaky_relu")

        self.block3 = Conv(channels * 2, channels * 4, ksize=4, stride=2, padding=1, bn=True, act="leaky_relu")

        self.block4 = Conv(channels * 4, channels * 8, ksize=4, stride=2, padding=1, bn=True, act="leaky_relu")

        self.final_conv= nn.Conv2d(channels * 8, 1, kernel_size=4, stride=1, padding=1)

    def forward(self, x):
        y= self.block1(x)
        y= self.block2(y)
        y= self.block3(y)
        y= self.block4(y)
        y= self.final_conv(y)
        return y


Next, we implement the loss functions

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
IMAGENET_MEAN= torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)

def denorm_imagenet(x):
    return x * IMAGENET_STD + IMAGENET_MEAN


Pretrained VGG19 network's outputs at different blocks

We will then use these features and compare them with ecoder features via loss functions

block 1: 0-3

block 2: 4-8

block 3: 9-17

block 4: 18-26

block 5: 27-35

In [None]:
class VGG19Features(nn.Module):
    def __init__(self):
        super().__init__()
        VGG = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)

        self.features = VGG.features.eval()
        for p in self.features.parameters():
            p.requires_grad_(False)

        self.block_end = [3, 8, 17, 26, 35]

    def forward(self, x):
        features = []
        y= x
        last= 0
        for end in self.block_end:
            for i in range(last, end + 1):
                y= self.features[i](y)
            features.append(y)
            last= end + 1
        return features


In [None]:
vgg19= VGG19Features().to(device)
vgg19.eval()


Gradient

In [None]:
def gradient(l):
    dy = l[:, :, 1:, :]  - l[:, :, :-1, :]
    dx = l[:, :, :, 1:]  - l[:, :, :, :-1]
    return dy, dx

L_con: Content loss (Equation 1). Compares encoder skip features and VGG features of ground truth

In [None]:
def L_con(skip_features, vgg_features, alpha=0.1):
    n = min(len(skip_features), len(vgg_features))
    if n <= 1:
        return 0.0
    loss= 0.0
    for i in range(1, n):
        loss += F.l1_loss(skip_features[i], vgg_features[i])
    return alpha * loss / (n - 1)


L_tv: Total variation loss (Equation 2). Computes difference between gradients of generated image and ground truth image

In [None]:
def L_tv(y_gen, y_gt):
    dy_f, dx_f= gradient(y_gen)
    dy_g, dx_g= gradient(y_gt)
    return (dy_f - dy_g).abs().mean() + (dx_f - dx_g).abs().mean()


L_mtv: Masked TV loss (eq 3). Multiplies gradient differences by the edge mask

In [None]:
def L_mtv(y_gen, y_gt, mask):
    dy_f, dx_f= gradient(y_gen)
    dy_g, dx_g= gradient(y_gt)

    m_dy= mask[:, :, 1:, :]
    m_dx= mask[:, :, :, 1:]

    m_dy= m_dy.expand_as(dy_f)
    m_dx= m_dx.expand_as(dx_f)

    return ((dy_f - dy_g).abs() * m_dy).mean() + ((dx_f - dx_g).abs() * m_dx).mean()


L1: L1 reconstruction loss (eq 5). L1 distance between generated image and ground truth retargeted image.

In [None]:
def L1(y_gen, y_gt):
    return F.l1_loss(y_gen, y_gt)


L_P: Perceptual loss (eq 6). Compares VGG19 features of GT and generated images.

In [None]:
def L_P(y_gen, y_gt, beta4=1.0, beta5=2.0):
    features_gt= vgg19(y_gt)
    features_gen= vgg19(y_gen)

    f4_gt, f5_gt= features_gt[3], features_gt[4]
    f4_gen, f5_gen= features_gen[3], features_gen[4]

    l4= F.mse_loss(f4_gen, f4_gt)
    l5= F.mse_loss(f5_gen, f5_gt)

    return beta4 * l4 + beta5 * l5


L_SSIM: SSIM + SSIM loss (eq 7). Finds SSIM between generated and GT images. Higher similarity means smaller loss.

In [None]:
def gaussian(size, sigma, channels, device):
    coordinates = torch.arange(size,device=device).float()
    coordinates -= size // 2
    gaussian1d = torch.exp(-(coordinates ** 2) / (2 * sigma ** 2))
    gaussian1d /= gaussian1d.sum()
    gaussian2d= gaussian1d[:, None] @ gaussian1d[None, :]
    gaussian2d= gaussian2d[None, None, :, :]
    #(channels, 1, H, W)
    gaussian2d= gaussian2d.expand(channels, 1, size, size)
    return gaussian2d

In [None]:
def ssim(image1, image2, size=11, sigma=1.5):
  C1, C2=0.01**2, 0.03**2
  b, c, _, _=image1.shape
  g=gaussian(size, sigma, c, image1.device)

  mu1=F.conv2d(image1, g, padding=size//2, groups=c)
  mu2=F.conv2d(image2, g, padding=size//2, groups=c)

  mu1_sq, mu2_sq=mu1**2, mu2**2
  mu1mu2=mu1*mu2

  sigma1=F.conv2d(image1*image1, g, padding=size//2, groups=c)-mu1_sq
  sigma2=F.conv2d(image2*image2, g, padding=size//2, groups=c)-mu2_sq
  sigma12=F.conv2d(image1*image2, g, padding=size//2, groups=c)-mu1mu2

  top=(2*mu1mu2+C1)*(2*sigma12+C2)
  bottom=(mu1_sq+mu2_sq+C1)*(sigma1+sigma2+C2)
  ssim0=top/bottom
  return ssim0.mean()

In [None]:
def L_SSIM(y_gen, y_gt):
  return 1.0-ssim(torch.clamp(denorm_imagenet(y_gen), 0.0, 1.0), torch.clamp(denorm_imagenet(y_gt), 0.0, 1.0))

L_D: binary cross entropy on the discriminator outputs.

L_GAN: GAN losses (eq 4). Generator tries to fool D

In [None]:
bce = nn.BCEWithLogitsLoss()

def L_D(d_real, d_gen):
    return bce(d_real, torch.ones_like(d_real)) + bce(d_gen, torch.zeros_like(d_gen))

def L_GAN(d_gen):
    return bce(d_gen, torch.ones_like(d_gen))


L_G: Full generator loss (eq 8). Combines the loss functions above

In [None]:
Alpha_Lcon= 0.1
Lambda_L1= 1000.0
Lambda_P= 10.0
Lambda_Tv= 0.01
Lambda_Mtv= 0.003
Lambda_Ssim= 70.0

def L_G(y_gen, y_gt, mask, d_gen, skip_feats):
  zero=y_gen.new_zeros(())

  L_adv=L_GAN(d_gen)

  if use_Lcon:
    vgg_gt=vgg19(y_gt)
    Lcon_val=L_con(skip_feats, vgg_gt, alpha=Alpha_Lcon)
  else:
    Lcon_val=zero

  if use_LP:
    Lp_val=L_P(y_gen, y_gt)
  else:
    Lp_val=zero

  if use_L1:
    L1_val=L1(y_gen, y_gt)
  else:
    L1_val=zero

  if use_Ltv:
    Ltv_val=L_tv(y_gen, y_gt)
  else:
    Ltv_val=zero

  if use_Lmtv and (mask is not None):
    Lmtv_val=L_mtv(y_gen, y_gt, mask)
  else:
    Lmtv_val=zero

  if use_Lssim:
    Lssim_val=L_SSIM(y_gen, y_gt)
  else:
    Lssim_val=zero

  LG_val=L_adv+ Lcon_val +Lambda_P*Lp_val +Lambda_L1*L1_val +Lambda_Tv*Ltv_val +Lambda_Mtv*Lmtv_val+Lambda_Ssim*Lssim_val

  parts={
    "L_adv":float(L_adv.item()),
    "L_con":float(Lcon_val.item()),
    "L_P":float(Lp_val.item()),
    "L_1":float(L1_val.item()),
    "L_tv":float(Ltv_val.item()),
    "L_m_tv":float(Lmtv_val.item()),
    "L_SSIM":float(Lssim_val.item()),
    "L_total":float(LG_val.item()),
  }
  return LG_val, parts



Training and Evaluation

Have to put Path to TIRed dataset to run the code

In [None]:
Path = "/path/to/TIReD"

TIRed contains 4 subfolders: AVA, COCO, HKU-IS, Waterloo Exploration

Each subfolder contains subfolders:

train subfolder containing subfolders: tain_A: original image, tain_B: ground truth retargeted image

test subfolder containing subfolder: test_A: original image, test_B: ground truth retargeted image

In [None]:
ava_train=TIRed(
  os.path.join(Path, "AVA", "train", "train_A"),
  os.path.join(Path, "AVA", "train", "train_B"),
  os.path.join(Path, "AVA", "train", "train_B_mask"),
)

coco224_train=TIRed(
  os.path.join(Path, "COCO", "train_224_224", "train_A"),
  os.path.join(Path, "COCO", "train_224_224", "train_B"),
  os.path.join(Path, "COCO", "train_224_224", "train_B_mask"),
)

coco300_train=TIRed(
  os.path.join(Path, "COCO", "train_300_300", "train_A"),
  os.path.join(Path, "COCO", "train_300_300", "train_B"),
  None,
)

hku_train=TIRed(
  os.path.join(Path, "HKU-IS", "train", "train_A"),
  os.path.join(Path, "HKU-IS", "train", "train_B"),
  os.path.join(Path, "HKU-IS", "train", "train_B_mask"),
)

waterloo_train=TIRed(
  os.path.join(Path, "Waterloo Exploration", "train", "train_A"),
  os.path.join(Path, "Waterloo Exploration", "train", "train_B"),
  os.path.join(Path, "Waterloo Exploration", "train", "train_B_mask"),
)

full_train_dataset=ConcatDataset(
  [ava_train, coco224_train, coco300_train, hku_train, waterloo_train]
)
print("Total train images:", len(full_train_dataset))


Split into train and validation

In [None]:
n_val= int(0.05 * len(full_train_dataset))
n_train= len(full_train_dataset) - n_val

train_ds, val_ds= random_split(full_train_dataset, [n_train, n_val], generator=torch.Generator().manual_seed(123),)

batch_size= 8
num_workers= 8

train_loader= DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True,)

val_loader= DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True,)

Test dataset

In [None]:
test_loaders: Dict[str, DataLoader] = {}

# AVA
ava_A= os.path.join(Path, "AVA", "test", "test_A")
ava_B= os.path.join(Path, "AVA", "test", "test_B")
ava_test= TIRed(ava_A, ava_B, None, image_size=224)
test_loaders["AVA"] = DataLoader(ava_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True,)
print("AVA:", len(ava_test), "test images")

# COCO
coco_A= os.path.join(Path, "COCO", "test", "test_A")
coco_B= os.path.join(Path, "COCO", "test", "test_B")
coco_test= TIRed(coco_A, coco_B, None, image_size=224)
test_loaders["COCO"] = DataLoader(coco_test,  batch_size=batch_size, shuffle=False,  num_workers=num_workers, pin_memory=True,)
print("COCO:", len(coco_test), "test images")

# HKU-IS
hku_A= os.path.join(Path, "HKU-IS", "test", "test_A")
hku_B= os.path.join(Path, "HKU-IS", "test", "test_B")
hku_test= TIRed(hku_A, hku_B, None, image_size=224)
test_loaders["HKU-IS"] = DataLoader( hku_test, batch_size=batch_size,  shuffle=False, num_workers=num_workers, pin_memory=True,)
print("HKU-IS:", len(hku_test), "test images")

# Waterloo
wat_A= os.path.join(Path, "Waterloo Exploration", "test", "test_A")
wat_B= os.path.join(Path, "Waterloo Exploration", "test", "test_B")
wat_test= TIRed(wat_A, wat_B, None, image_size=224)
test_loaders["Waterloo"] = DataLoader(wat_test,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True,)
print("Waterloo:", len(wat_test), "test images")

# all_test_loader
from torch.utils.data import ConcatDataset as _Concat

all_test_dataset= _Concat([dl.dataset for dl in test_loaders.values()])
all_test_loader= DataLoader(all_test_dataset, batch_size=batch_size,shuffle=False, num_workers=num_workers, pin_memory=True,)


Model

In [None]:
torch.backends.cudnn.benchmark = True
print("Using device:", device)

G= Generator().to(device)
D= Discriminator().to(device)

opt_G= torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D= torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

vgg19.to(device)
vgg19.eval()

Helper for training and evaluation

In [None]:
@torch.no_grad()
def evaluate(loader, G, device, max_batches=None):
    G.eval()
    sums={
        "psnr_in": 0.0,
        "psnr_out": 0.0,
        "ssim_in": 0.0,
        "ssim_out": 0.0,
        "fsim_in": 0.0,
        "fsim_out": 0.0,
        "vif_in": 0.0,
        "vif_out": 0.0,
    }
    n_batches = 0

    for i, (x, y, m) in enumerate(loader):
        if (max_batches is not None) and (i >= max_batches):
            break

        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        y_pred = G(x)

        x_dn = torch.clamp(denorm_imagenet(x), 0.0, 1.0)
        y_dn = torch.clamp(denorm_imagenet(y), 0.0, 1.0)
        y_pred_dn = torch.clamp(denorm_imagenet(y_pred), 0.0, 1.0)

        sums["psnr_in"]+= piq.psnr(x_dn, y_dn, data_range=1.0).item()
        sums["psnr_out"]+= piq.psnr(y_pred_dn, y_dn, data_range=1.0).item()

        sums["ssim_in"]+= piq.ssim(x_dn, y_dn, data_range=1.0).item()
        sums["ssim_out"]+= piq.ssim(y_pred_dn, y_dn, data_range=1.0).item()

        sums["fsim_in"]+= piq.fsim(x_dn, y_dn, data_range=1.0).item()
        sums["fsim_out"]+= piq.fsim(y_pred_dn, y_dn, data_range=1.0).item()

        sums["vif_in"] += piq.vif_p(x_dn, y_dn, data_range=1.0).item()
        sums["vif_out"] += piq.vif_p(y_pred_dn, y_dn, data_range=1.0).item()

        n_batches += 1

    if n_batches == 0:
        return {k: float("nan") for k in sums}

    return {k: v / n_batches for k, v in sums.items()}


def print_metrics(name, metrics):
    print(
        f"{name:10s} | "
        f"PSNR in/out: {metrics['psnr_in']:.3f} / {metrics['psnr_out']:.3f}  | "
        f"SSIM in/out: {metrics['ssim_in']:.4f} / {metrics['ssim_out']:.4f}  | "
        f"FSIM in/out: {metrics['fsim_in']:.4f} / {metrics['fsim_out']:.4f}  | "
        f"VIF  in/out: {metrics['vif_in']:.4f} / {metrics['vif_out']:.4f}"
    )

Training

In [None]:
epochs=50
best_val_ssim=-1.0

best_path=f"mrgan_{loss_mode}_best.pth"

train_history={"G":[], "D":[]}
val_history=[]

for epoch in range(1, epochs+1):
  G.train()
  D.train()

  error_G=0.0
  error_D=0.0
  steps=0

  data=tqdm(train_loader, desc=f"[{loss_mode}] Epoch {epoch}/{epochs}", ncols=120)

  for x, y, m in data:
    x=x.to(device)
    y=y.to(device)
    m=m.to(device)

    # Update D
    with torch.no_grad():
      y_fake_temp=G(x)

    d_real=D(torch.cat([x, y], dim=1))
    d_fake=D(torch.cat([x, y_fake_temp], dim=1))

    loss_D=L_D(d_real, d_fake)
    opt_D.zero_grad(set_to_none=True)
    loss_D.backward()
    opt_D.step()

    # Update G
    y_fake, skip_feats=G(x, return_skips=True)
    d_fake=D(torch.cat([x, y_fake], dim=1))

    loss_G, parts_G=L_G(y_fake, y, m, d_fake, skip_feats)
    opt_G.zero_grad(set_to_none=True)
    loss_G.backward()
    opt_G.step()

    steps+=1
    error_G+=loss_G.item()
    error_D+=loss_D.item()

    if steps%50==0:
      data.set_postfix({
        "G_loss":f"{error_G/steps:.3f}",
        "D_loss":f"{error_D/steps:.3f}",
      })

  avg_G=error_G/max(1, steps)
  avg_D=error_D/max(1, steps)
  train_history["G"].append(avg_G)
  train_history["D"].append(avg_D)

  val_metrics=evaluate(val_loader, G, device)
  val_history.append(val_metrics)
  val_ssim_out=val_metrics["ssim_out"]

  print(
    f"\nEpoch {epoch:3d}/{epochs} ({loss_mode}): "
    f"G_loss={avg_G:.4f}, D_loss={avg_D:.4f}, "
    f"Val PSNR_out={val_metrics['psnr_out']:.3f} dB, "
    f"Val SSIM_out={val_ssim_out:.4f}"
  )

  if val_ssim_out>best_val_ssim:
    best_val_ssim=val_ssim_out
    torch.save(
      {
        "epoch":epoch,
        "G_state":G.state_dict(),
        "D_state":D.state_dict(),
        "opt_G_state":opt_G.state_dict(),
        "opt_D_state":opt_D.state_dict(),
        "val_metrics":val_metrics,
        "loss_mode":loss_mode,
      },
      best_path,
    )
    print(f"  -> saved new best {best_path} with SSIM_out = {best_val_ssim:.4f}")


Outputs

In [None]:
best=torch.load(best_path, map_location=device)
G.load_state_dict(best["G_state"])
print(
  f"\nLoaded best model from epoch {best['epoch']} "
  f"(val SSIM_out = {best['val_metrics']['ssim_out']:.4f}, mode = {loss_mode})"
)

G.eval()

print("\nPer-dataset TIReD test metrics (PSNR / SSIM / FSIM / VIF)")
for name, loader in test_loaders.items():
  m=evaluate(loader, G, device)
  print_metrics(name, m)

m_all=evaluate(all_test_loader, G, device)
print("\nOverall TIReD test set")
print_metrics("TIReD_all", m_all)


@torch.no_grad()
def save_outputs(G, loaders_dict, device, out_root="mrgan_examples", per_dataset=8):
  G.eval()
  os.makedirs(out_root, exist_ok=True)

  for name, loader in loaders_dict.items():
    save_dir=os.path.join(out_root, name)
    os.makedirs(save_dir, exist_ok=True)

    print(f"Saving outputs for {name} into {save_dir} ...")
    saved=0

    for x, y, m in loader:
      if saved>=per_dataset:
        break

      x=x.to(device)
      y=y.to(device)
      y_pred=G(x)

      bs=x.size(0)
      for i in range(bs):
        if saved>=per_dataset:
          break

        x_i=torch.clamp(
          denorm_imagenet(x[i:i+1]), 0.0, 1.0
        )[0].cpu().permute(1,2,0).numpy()
        y_i=torch.clamp(
          denorm_imagenet(y[i:i+1]), 0.0, 1.0
        )[0].cpu().permute(1,2,0).numpy()
        y_pred_i=torch.clamp(
          denorm_imagenet(y_pred[i:i+1]), 0.0, 1.0
        )[0].cpu().permute(1,2,0).numpy()

        fig, axes=plt.subplots(1,3, figsize=(9,3))
        axes[0].imshow(x_i);      axes[0].set_title("Input");         axes[0].axis("off")
        axes[1].imshow(y_i);      axes[1].set_title("GT Retargeted"); axes[1].axis("off")
        axes[2].imshow(y_pred_i); axes[2].set_title("MRGAN Output");  axes[2].axis("off")

        fname=os.path.join(save_dir, f"{name}_{saved:03d}_{loss_mode}.png")
        fig.tight_layout()
        fig.savefig(fname, dpi=150, bbox_inches="tight")
        plt.close(fig)

        saved+=1

    print(f"  -> saved {saved} examples for {name}\n")


save_outputs(G, test_loaders, device,
             out_root="mrgan_examples",
             per_dataset=8)
