In [8]:
import torch
import numpy as np
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision.models import vgg19
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import os
from google.colab.patches import cv2_imshow
from google.colab import drive
drive.mount('/content/drive')
torch.manual_seed(0)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


<torch._C.Generator at 0x7b0954451c50>

In [9]:
def load_image(image_path, size=(576, 576)):
    """
    Loads a single image and converts it to a tensor of shape [1, 3, H, W]
    """
    transform = transforms.Compose([
        transforms.Resize(size),        # Resize to fixed size
        transforms.ToTensor(),          # Converts to [C, H, W], values in [0, 1]
        # transforms.Normalize([0.5]*3, [0.5]*3)  # Optional if using Tanh output
    ])

    image = Image.open(image_path).convert('RGB')  # Ensure 3 channels
    image = transform(image)                      # [3, H, W]

    return image

def tensor_create(image):

    tensor = image.unsqueeze(0)                   # Add batch dimension → [1, 3, H, W]

    return tensor

In [10]:
images_edited = []

images_real = []

images_edited_mask = []

images_real_mask = []

dup = 20

for _ in range(dup):
    for i in range(13):

        real_image = load_image("/content/drive/MyDrive/AIClothes/Pants/OriginalImages/PantsOn/PantsOn_" + ("00" if len(str(i+1)) == 1 else "0") + str(i+1) + ".jpeg")

        edited_image = load_image("/content/drive/MyDrive/AIClothes/Pants/FinalInput/EditedPantsOn_" + ("00" if len(str(i+1)) == 1 else "0") + str(i+1) + ".jpg")

        mask_image = load_image("/content/drive/MyDrive/AIClothes/Pants/Annotations/PantsOn/PantsOn_" + ("00" if len(str(i+1)) == 1 else "0") + str(i+1) + "_gtFine_labelIds.png")

        #Resize the image to be comparable

        mask_tensor = tensor_create(mask_image*255)

        image_tensor_edited = tensor_create(edited_image)

        image_tensor_edited_mask = tensor_create(edited_image*(mask_image*255))

        image_tensor_real = tensor_create(real_image)

        image_tensor_real_mask = tensor_create(real_image * (mask_image*255))

        images_edited.append(image_tensor_edited)

        images_edited_mask.append(image_tensor_edited_mask)

        images_real.append(image_tensor_real)

        images_real_mask.append(image_tensor_real_mask)

In [11]:
class ContractingBlock(nn.Module):
  '''
  ContractingBlock Class
  Performs two convolutions followed by a max pool operation.
  Values:
      input_channels: the number of channels to expect from a given input
  '''
  def __init__(self, input_channels, use_dropout=False, use_bn=True):
      super(ContractingBlock, self).__init__()
      self.conv1 = nn.Conv2d(input_channels, input_channels * 2, kernel_size=3, stride=2, padding=1)
      self.activation = nn.LeakyReLU(0.2)
      if use_bn:
          self.batchnorm = nn.InstanceNorm2d(input_channels * 2)
      self.use_bn = use_bn
      if use_dropout:
          self.dropout = nn.Dropout()
      self.use_dropout = use_dropout

  def forward(self, x):
      '''
      Function for completing a forward pass of ContractingBlock:
      Given an image tensor, completes a contracting block and returns the transformed tensor.
      Parameters:
          x: image tensor of shape (batch size, channels, height, width)
      '''
      x = self.conv1(x)
      if self.use_bn:
          x = self.batchnorm(x)
      if self.use_dropout:
          x = self.dropout(x)
      x = self.activation(x)
      return x

class ExpandingBlock(nn.Module):
  '''
  ExpandingBlock Class:
  Performs an upsampling, a convolution, a concatenation of its two inputs,
  followed by two more convolutions with optional dropout
  Values:
      input_channels: the number of channels to expect from a given input
  '''
  def __init__(self, input_channels, use_dropout=False, use_bn=True):
      super(ExpandingBlock, self).__init__()
      #self.upsample = nn.ConvTranspose2d(input_channels, input_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
      self.upsample = nn.ConvTranspose2d(input_channels, input_channels, kernel_size=2, stride=2)
      self.conv1 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=3, padding=1)
      if use_bn:
          self.batchnorm = nn.InstanceNorm2d(input_channels // 2, affine=True)
      self.use_bn = use_bn
      self.activation = nn.ReLU()
      if use_dropout:
          self.dropout = nn.Dropout()
      self.use_dropout = use_dropout

  def forward(self, x, skip_con_x):
      '''
      Function for completing a forward pass of ExpandingBlock:
      Given an image tensor, completes an expanding block and returns the transformed tensor.
      Parameters:
          x: image tensor of shape (batch size, channels, height, width)
          skip_con_x: the image tensor from the contracting path (from the opposing block of x)
                  for the skip connection
      '''
      x = self.upsample(x)
      x = self.conv1(x)
      if self.use_bn:
          x = self.batchnorm(x)
      if self.use_dropout:
          x = self.dropout(x)
      x = torch.cat([x, skip_con_x], axis=1)
      x = self.conv1(x)
      if self.use_bn:
          x = self.batchnorm(x)
      if self.use_dropout:
          x = self.dropout(x)
      x = self.activation(x)
      return x

class FeatureMapBlock(nn.Module):
  '''
  FeatureMapBlock Class
  The final layer of a U-Net -
  maps each pixel to a pixel with the correct number of output dimensions
  using a 1x1 convolution.
  Values:
      input_channels: the number of channels to expect from a given input
      output_channels: the number of channels to expect for a given output
  '''
  def __init__(self, input_channels, output_channels):
      super(FeatureMapBlock, self).__init__()
      self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)

  def forward(self, x):
      '''
      Function for completing a forward pass of FeatureMapBlock:
      Given an image tensor, returns it mapped to the desired number of channels.
      Parameters:
          x: image tensor of shape (batch size, channels, height, width)
      '''
      x = self.conv(x)
      return x

class UNet(nn.Module):
  '''
  UNet Class
  A series of 4 contracting blocks followed by 4 expanding blocks to
  transform an input image into the corresponding paired image, with an upfeature
  layer at the start and a downfeature layer at the end.
  Values:
      input_channels: the number of channels to expect from a given input
      output_channels: the number of channels to expect for a given output
  '''
  def __init__(self, input_channels, output_channels, hidden_channels=32):
      super(UNet, self).__init__()
      self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
      self.contract1 = ContractingBlock(hidden_channels) # this layer could have dropout
      self.contract2 = ContractingBlock(hidden_channels * 2) # this layer could have dropout
      self.contract3 = ContractingBlock(hidden_channels * 4) # this layer could have dropout
      self.contract4 = ContractingBlock(hidden_channels * 8)
      self.contract5 = ContractingBlock(hidden_channels * 16)
      self.contract6 = ContractingBlock(hidden_channels * 32)
      self.expand0 = ExpandingBlock(hidden_channels * 64)
      self.expand1 = ExpandingBlock(hidden_channels * 32)
      self.expand2 = ExpandingBlock(hidden_channels * 16)
      self.expand3 = ExpandingBlock(hidden_channels * 8)
      self.expand4 = ExpandingBlock(hidden_channels * 4)
      self.expand5 = ExpandingBlock(hidden_channels * 2)
      self.downfeature = FeatureMapBlock(hidden_channels, output_channels)
      self.sigmoid = torch.nn.Sigmoid()

  def forward(self, x):
      '''
      Function for completing a forward pass of UNet:
      Given an image tensor, passes it through U-Net and returns the output.
      Parameters:
          x: image tensor of shape (batch size, channels, height, width)
      '''
      x0 = self.upfeature(x)
      x1 = self.contract1(x0)
      x2 = self.contract2(x1)
      x3 = self.contract3(x2)
      x4 = self.contract4(x3)
      x5 = self.contract5(x4)
      x6 = self.contract6(x5)
      x7 = self.expand0(x6, x5)
      x8 = self.expand1(x7, x4)
      x9 = self.expand2(x8, x3)
      x10 = self.expand3(x9, x2)
      x11 = self.expand4(x10, x1)
      x12 = self.expand5(x11, x0)
      xn = self.downfeature(x12)
      return self.sigmoid(xn)

In [12]:
gen = UNet(3, 3).to('cpu')
gen_opt = torch.optim.Adam(gen.parameters(), lr=0.0002)
img = torch.randn(1, 3, 576, 576)
out = gen(img)
out.shape

torch.Size([1, 3, 576, 576])

In [13]:
class Discriminator(nn.Module):

    def __init__(self, input_channels, hidden_channels=32):
        super(Discriminator, self).__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels, use_bn=True)
        self.contract2 = ContractingBlock(hidden_channels * 2)
        self.contract3 = ContractingBlock(hidden_channels * 4)
        self.contract4 = ContractingBlock(hidden_channels * 8)
        self.final = nn.Conv2d(hidden_channels * 16, 1, kernel_size=(1,1))

    def forward(self, x, y):
        x = torch.cat([x, y], axis=1)
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        xn = self.final(x4)
        return xn

In [14]:
import torch.nn.functional as F
# New parameters
adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion_l1 = nn.L1Loss()
recon_criterion_l2 = nn.MSELoss()
l1_recon = 100
l2_recon = 10
lambda_perc = 1

n_epochs = 30
input_dim = 3
real_dim = 3
display_step = 200
batch_size = 4
lr = 0.0002
target_shape = 448
device = 'cpu'

In [15]:
# transform = transforms.Compose([
#     transforms.ToTensor(),
# ])

# import torchvision
# dataset = torchvision.datasets.ImageFolder("maps", transform=transform)

In [16]:
gen = UNet(input_dim, real_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(input_dim + real_dim).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.xavier_normal_(m.weight)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.xavier_normal_(m.weight)
        torch.nn.init.constant_(m.bias, 0)

# Feel free to change pretrained to False if you're training the model from scratch
pretrained = False
if pretrained:
    loaded_state = torch.load("/content/drive/MyDrive/AIClothes/PreTrainedModels/pix2pix_15000.pth")
    gen.load_state_dict(loaded_state["gen"])
    gen_opt.load_state_dict(loaded_state["gen_opt"])
    disc.load_state_dict(loaded_state["disc"])
    disc_opt.load_state_dict(loaded_state["disc_opt"])
else:
    gen = gen.apply(weights_init)
    disc = disc.apply(weights_init)

In [17]:
# VGG19 feature extractor for specific layers
class VGG19Features(nn.Module):
    def __init__(self, layer_ids=[12, 21, 30]):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features
        self.layer_ids = layer_ids
        for param in self.vgg.parameters():
            param.requires_grad = False  # Freeze weights

    def forward(self, x):
        features = []
        for i, layer in enumerate(self.vgg):
            x = layer(x)
            if i in self.layer_ids:
                features.append(x)
        return features

# Preprocessing for VGG19
preprocess = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def prepare_image(img):
    # Assuming img is [0, 1], shape [batch, 3, H, W]
    img = img.clamp(0, 1)  # Ensure valid range
    img = preprocess(img)  # Normalize for VGG
    return img

# Perceptual loss function
def perceptual_loss(generated, target, vgg_extractor, weights=[1.0, 1.0, 1.0]):
    gen_features = vgg_extractor(prepare_image(generated))
    target_features = vgg_extractor(prepare_image(target))

    loss = 0
    for gen_f, tar_f, w in zip(gen_features, target_features, weights):
        loss += w * torch.mean(torch.abs(gen_f - tar_f))  # L1 loss per layer
    return loss / len(weights)

# Example usage
vgg_extractor = VGG19Features(layer_ids=[12, 21, 30]).eval()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:08<00:00, 67.4MB/s]


In [18]:
# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED CLASS: get_gen_loss
def get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion_l1, recon_criterion_l2, l1_recon, l2_recon, lambda_perc, vgg_extractor):
    '''
    Return the loss of the generator given inputs.
    Parameters:
        gen: the generator; takes the condition and returns potential images
        disc: the discriminator; takes images and the condition and
          returns real/fake prediction matrices
        real: the real images (e.g. maps) to be used to evaluate the reconstruction
        condition: the source images (e.g. satellite imagery) which are used to produce the real images
        adv_criterion: the adversarial loss function; takes the discriminator
                  predictions and the true labels and returns a adversarial
                  loss (which you aim to minimize)
        recon_criterion: the reconstruction loss function; takes the generator
                    outputs and the real images and returns a reconstructuion
                    loss (which you aim to minimize)
        lambda_recon: the degree to which the reconstruction loss should be weighted in the sum
    '''
    # Steps: 1) Generate the fake images, based on the conditions.
    gen_img = gen(condition)
    #        2) Evaluate the fake images and the condition with the discriminator.
    disc_fake_pred = disc(gen_img, condition)
    #        3) Calculate the adversarial and reconstruction losses.
    adv = adv_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    recon_l1 = recon_criterion_l1(real, gen_img)
    recon_l2 = recon_criterion_l2(real, gen_img)
    #        4) Add the two losses, weighting the reconstruction loss appropriately.

    #### START CODE HERE ####
    if lambda_perc != 0:

        gen_loss = adv + (recon_l1*l1_recon) + (recon_l2*l2_recon) + (perceptual_loss(gen_img, real, vgg_extractor)*lambda_perc)

    else:

        gen_loss = adv + (recon_l1*l1_recon) + (recon_l2*l2_recon)

    #### END CODE HERE ####
    return gen_loss

In [None]:
#from skimage import color
import numpy as np


mean_generator_loss = 0
mean_discriminator_loss = 0
#dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
cur_step = 0

for epoch in range(n_epochs):
    # Dataloader returns the batches
    for i in range(len(images_edited)):
        condition = images_edited_mask[i]/1.

        real = images_real_mask[i]/1.

        cur_batch_size = len(condition)
        condition = condition.to(device)
        real = real.to(device)

        ### Update discriminator ###
        disc_opt.zero_grad() # Zero out the gradient before backpropagation
        with torch.no_grad():
            fake = gen(condition)
        disc_fake_hat = disc(fake.detach(), condition) # Detach generator
        disc_fake_loss = adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat))
        disc_real_hat = disc(real, condition)
        disc_real_loss = adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        disc_loss.backward(retain_graph=True) # Update gradients
        disc_opt.step() # Update optimizer

        ### Update generator ###
        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion_l1, recon_criterion_l2, l1_recon, l2_recon, lambda_perc, vgg_extractor)
        gen_loss.backward() # Update gradients
        gen_opt.step() # Update optimizer

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step
        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ### Visualization code ###
        # if cur_step % display_step == 0:
        #     if cur_step > 0:
        #         print(f"Epoch {epoch}: Step {cur_step}: Generator (U-Net) loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
        #     else:
        #         print("Pretrained initial state")
        #     show_tensor_images(condition, size=(input_dim, target_shape, target_shape))
        #     show_tensor_images(real, size=(real_dim, target_shape, target_shape))
        #     show_tensor_images(fake, size=(real_dim, target_shape, target_shape))
        #     mean_generator_loss = 0
        #     mean_discriminator_loss = 0
        #     # You can change save_model to True if you'd like to save the model
        #     if save_model:
        #         torch.save({'gen': gen.state_dict(),
        #             'gen_opt': gen_opt.state_dict(),
        #             'disc': disc.state_dict(),
        #             'disc_opt': disc_opt.state_dict()
        #         }, f"pix2pix_{cur_step}.pth")
        cur_step += 1

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))  # 1 row, 3 columns

    # Display each image in a subplot
    axes[0].imshow(fake[0].detach().cpu().permute(1, 2, 0))
    axes[0].set_title("GAN Output")
    axes[0].axis('off')  # Hide axes

    axes[1].imshow(condition[0].detach().cpu().permute(1, 2, 0))
    axes[1].set_title("Input Image")
    axes[1].axis('off')

    axes[2].imshow(real[0].detach().cpu().permute(1, 2, 0))
    axes[2].set_title("Real Image")
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

    print(epoch)