In [None]:
import os
import glob 
import json
import math
import PIL
import numpy as np
from imgaug import augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import torch
from torch.utils.data import Dataset, Dataloader
from torchvision import transforms
import torch.nn.functional as F
import torchvision.transforms.functional as TF

In [None]:
class ContentOrientedDataset(Dataset):
    def __init__(self, root='', crop_size=256, 
        normalize=False, **kwargs):
        super().__init__()

        self.data_dir = root 
        img_extensions = ['.jpg', '.png']
        self.imgs = []
        for ext in img_extensions:
            self.imgs += glob.glob(os.path.join(self.data_dir, f'images/**/*{ext}'), recursive=True)
        self.crop_size = crop_size
        self.image_dims = (3, self.crop_size, self.crop_size)
        self.normalize = normalize
        
        json_file_path = os.path.join(self.data_dir, "face_coords.json")
        with open(json_file_path, 'r') as json_file:
            self.face_coords = json.load(json_file)

    def _augment(self, img, face_masks, structure_masks):
        """
        Apply augmentations 
        """
        SCALE_MIN = 0.75
        SCALE_MAX = 0.95
        H, W, _ = img.shape # slightly confusing
        shortest_side_length = min(H,W)
        minimum_scale_factor = float(self.crop_size) / float(shortest_side_length)
        scale_low = max(minimum_scale_factor, SCALE_MIN)
        scale_high = max(scale_low, SCALE_MAX)
        scale = np.random.uniform(scale_low, scale_high)

        self.augmentations = iaa.Sequential([iaa.Fliplr(0.5), # horizontally flip 50% of the images
                                             iaa.Resize((math.ceil(scale * H), math.ceil(scale * W))), # resize
                                             iaa.size.CropToFixedSize(self.crop_size,self.crop_size)])
        
        masks = np.dstack( [face_masks, structure_masks])
        masks = SegmentationMapsOnImage(masks, shape=(H,W,2))
        img, masks = self.augmentations(image=img, segmentation_maps=masks)
        masks = masks.get_arr()
        face_masks, structure_masks = masks[:,:,0], masks[:,:,1]
        
        return img, face_masks, structure_masks

    def _transforms(self, img, face_mask, structure_mask):
        """
        Create and apply transforms
        """
        to_tensor = transforms.ToTensor()
        transforms_list = [to_tensor]
        if self.normalize is True:
            transforms_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
        img = transforms.Compose(transforms_list)(img)

        face_mask, structure_mask = to_tensor(face_mask).type(torch.int8), to_tensor(structure_mask).type(torch.int8)
    
        return img, face_mask, structure_mask

    def get_face_mask(self, idx, shape): 
        """
        Create the face mask 
        """
        H,W,_ = shape
        mask = np.zeros((H, W, 1), dtype=np.int32)
        coords = self.face_coords[os.path.relpath(self.imgs[idx], os.path.join(self.data_dir,"images"))]
        for coord in coords:
             mask[coord[1]:coord[3],coord[0]:coord[2]] = 1 
        return mask 

    def get_structure_mask(self, idx):
        """
        Create the structure mask
        """
        rel_path = os.path.relpath(self.imgs[idx], os.path.join(self.data_dir,"images"))
        rel_path = os.path.splitext(rel_path)[0]+".png"
        mask = PIL.Image.open(os.path.join(self.data_dir, "structure_masks", rel_path))
        mask = np.expand_dims(np.array(mask), axis=-1)
        mask = mask / 255 
        mask = mask.astype(np.int8)
        return mask

    def __getitem__(self, idx):
        img_path = self.imgs[idx]
        filesize = os.path.getsize(img_path)
        img = PIL.Image.open(img_path)
        img = img.convert('RGB') 
        img = np.array(img)
        H, W, _ = img.shape # slightly confusing
        bpp = filesize * 8. / (H * W)
        face_mask = self.get_face_mask(idx, img.shape)
        structure_mask = self.get_structure_mask(idx)
        img, face_mask, structure_mask = self._augment(img, face_mask, structure_mask)    
        img, face_mask, structure_mask = self._transforms(img, face_mask, structure_mask)
        return [img, face_mask, structure_mask], bpp



        

In [None]:
class LaplacianLoss(nn.Module):
    def __init__(self):
        super(LaplacianLoss, self).__init__()

    def forward(self, img1, img2):
        if img1.shape[1] == 3:
            img1 = TF.rgb_to_grayscale(img1)
        if img2.shape[1] == 3:
            img2 = TF.rgb_to_grayscale(img2)
        laplacian1 = self.laplacian(img1)
        laplacian2 = self.laplacian(img2)
        laplacian_loss = F.mse_loss(laplacian1, laplacian2, reduction="none")
        return laplacian_loss

    def laplacian(self, img):
        laplacian_kernel = torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=torch.float32, device=img.device).unsqueeze(0).unsqueeze(0)
        img = F.pad(img, (1, 1, 1, 1), mode='reflect')
        laplacian = F.conv2d(img, laplacian_kernel)
        return laplacian
    
class ContentOrientedLoss(nn.Module):
    def __init__(self, args, discriminator):
        super(ContentOrientedLoss, self).__init__()
        self.args = args
        self.vgg_perceptual_loss = None
        self.laplacian = LaplacianLoss()
        self.Discriminator = discriminator

    def forward(self, orig_imgs, recon_imgs, face_masks, structure_masks):
        weighted_rate = self.rate_loss()
        if self.args.normalize_input_image is True:
            # [-1.,1.] -> [0.,1.]
            orig_imgs = (orig_imgs * 0.5) + 0.5 
            recon_imgs = (recon_imgs * 0.5) + 0.5
        content_oiented_loss = self.content_oriented_loss(orig_imgs, recon_imgs, face_masks, structure_masks)
        rate_content_oriented_loss = weighted_rate + content_oiented_loss
        # compute discriminator loss
        face_masks, structure_masks, texture_masks = self.process_masks(face_masks, structure_masks)
        replaced_recon_imgs = self.real_value_replacement(orig_imgs, recon_imgs, texture_masks)
        D_loss = self.gan_loss(orig_imgs, replaced_recon_imgs, "discriminator_loss") 
        return rate_content_oriented_loss, D_loss

    def rate_loss(self):
        return torch.zeros((0,))

    def content_oriented_loss(self, original_img, reconstructed_img, face_mask, structure_mask):
        face_mask, structure_mask, texture_mask = self.process_masks(face_mask, structure_mask)
        L_tex = self.loss_texture(original_img, reconstructed_img, texture_mask)
        L_struc = self.loss_structure(original_img, reconstructed_img, texture_mask)
        L_face = self.loss_face(original_img, reconstructed_img, face_mask)
        w_L_tex = L_tex
        w_L_struc = L_struc * self.args.epsilon
        w_L_face = L_face * self.args.gamma
        return  L_tex + L_struc + L_face

    def loss_face (self, original_img, reconstructed_img, face_mask): 
        return torch.mean(face_mask * F.mse_loss(original_img, reconstructed_img, reduction="none"))
    
    def loss_structure (self, original_img, reconstructed_img, structure_mask):
        return torch.mean(structure_mask * self.laplacian(original_img, reconstructed_img))
    
    def loss_texture (self, original_img, reconstructed_img, texture_mask): 
        replaced_reconstruction = self.real_value_replacement(original_img, reconstructed_img, texture_mask)
        loss = self.args.alpha * torch.mean(texture_mask * F.l1_loss(original_img, reconstructed_img, reduction="none")) + self.args.beta * self.lpips(original_img, replaced_reconstruction) + self.args.delta * self.gan_loss(original_img, replaced_reconstruction, "generator_loss")
        return loss

    def lpips(self, original_img, reconstructed_img):
        LPIPS_loss = self.vgg_perceptual_loss.forward(original_img, reconstructed_img, normalize=True)
        return torch.mean(LPIPS_loss)
    
    def discriminator_forward(self, orig_imgs, recon_imgs, train_generator):
        """ Train on gen/real batches simultaneously. """
        # Alternate between training discriminator and compression models
        if train_generator is False:
            recon_imgs = recon_imgs.detach()

        D_in = torch.cat([orig_imgs, recon_imgs], dim=0)

        D_out = self.Discriminator(D_in)
        D_out = torch.squeeze(D_out)
        D_real, D_gen = torch.chunk(D_out, 2, dim=0)
        return D_out

    def gan_loss(self, original_img, replaced_reconstruction, intermediates, mode, training, writeout, step_counter):
        disc_out = self.discriminator_forward(original_img, replaced_reconstruction, intermediates, train_generator=(mode=="generator_loss"))
        D_loss, G_loss = gan_loss(gan_loss_type=self.args.gan_loss_type, disc_out=disc_out, mode='generator_discriminator_loss')
        if mode == 'generator_discriminator_loss':
            loss = [D_loss, G_loss]
        elif mode == 'generator_loss':
            loss = G_loss 
        else:
            loss = D_loss
        return loss

    def real_value_replacement(self, original_img, reconstructed_img, texture_mask):
        return (1-texture_mask) * original_img + texture_mask * reconstructed_img

    def process_masks(self, face_mask, structure_mask):
        structure_mask = ~face_mask & structure_mask
        texture_mask = torch.ones_like(face_mask) & ~face_mask & ~structure_mask
        return face_mask, structure_mask, texture_mask


In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__() 
        # Encoder layers
        self.encoder = nn.Sequential(
            nn.Flatten(),            # Flatten the input
            nn.Linear(256*256*3, 256),
            nn.ReLU()
        )
        # Decoder layers
        self.decoder = nn.Sequential(
            nn.Linear(256, 256*256*3),
            nn.Sigmoid()  # Sigmoid activation for pixel values (assuming image in [0, 1] range)
        )
        # Discriminator layers
        self.discriminator = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256*256*3, 256),
            nn.Sigmoid()  # Sigmoid activation for pixel values (assuming image in [0, 1] range)
        )

    def forward(self, x):
        x_encoded = self.encoder(x)
        x_decoded = self.decoder(x_encoded)
        return x_decoded

In [None]:
num_epochs = 10
device = torch.device
dataset = ContentOrientedDataset(root='', crop_size=256, normalize=False)
dataloader = DataLoader(dataset, batch_size=4)
model = Model()
loss = ContentOrientedLoss(args, discriminator=args.discriminator)
optimizer_encoder_decoder = torch.optim.Adam([Model.encoder.parameters(), Model.decoder.parameters()], lr=args.learning_rate)
optimizer_discriminator = torch.optim.Adam([Model.discriminator.parameters()], lr=args.learning_rate)
phase = "ED"
for epoch in range(num_epochs):
    for [orig_imgs, face_masks, structure_masks], bpp in dataloader:
        recon_imgs = model(orig_imgs)
        compression_loss, D_loss = loss(orig_imgs, recon_imgs, face_masks, structure_masks)
        if phase=="ED":
            compression_loss.backward()
            optimizer_encoder_decoder.step()
            optimizer_encoder_decoder.zero_grad()
            phase = "D"
        else: 
            D_loss.backward()
            optimizer_discriminator.step()
            optimizer_discriminator.zero_grad()
            phase = "ED"
