In [2]:
import gc
import glob
import os
import pdb
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from numpy import random
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler, TensorDataset
from torchvision.utils import save_image

from models import Unet

torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def random_cutout(image, label, max_h=50, max_w=50):
    """
    Apply random cutout to both image and label.
    Args:
        image: The input image tensor.
        label: The label tensor.
        max_h: Maximum height of the cutout box.
        max_w: Maximum width of the cutout box.
    Returns:
        image: Image after cutout.
        label: Label after cutout.
    """
    _, h, w = image.shape
    cutout_height = random.randint(10, max_h)
    cutout_width = random.randint(10, max_w)

    # Randomly choose the position for the cutout
    top = random.randint(0, h - cutout_height)
    left = random.randint(0, w - cutout_width)

    # Apply the cutout to the image and label (set to 0)
    image[:, top:top + cutout_height, left:left + cutout_width] = 0
    label[:, top:top + cutout_height, left:left + cutout_width] = 0

    return image, label


def adjust_brightness(image, label, brightness_factor=0.2):
    """
    Adjust the brightness of the image and label.
    Args:
        image: The input image tensor.
        label: The label tensor.
        brightness_factor: Factor by which brightness is adjusted.
    Returns:
        image: Image after brightness adjustment.
        label: Label after brightness adjustment.
    """
    image = TF.adjust_brightness(image, 1 + (random.random() * 2 - 1) * brightness_factor)
    # Note: Brightness doesn't affect the label, so we leave it unchanged
    return image, label

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF

import torch
import torchvision.transforms.functional as TF

def random_jitter(image, max_jitter=0.1):
    """
    Apply random jittering to an image by adding noise to its pixel values.
    Args:
        image: The input image tensor.
        max_jitter: The maximum amount of jitter to apply to each pixel.
    Returns:
        image: The image with jitter applied.
    """
    # Generate random noise with a normal distribution
    noise = torch.randn_like(image) * max_jitter  # Gaussian noise
    image = image + noise

    # Clip the values to be in the valid range [0, 1] for images
    image = torch.clamp(image, 0, 1)

    return image


def motion_blur(image, kernel_size=5, angle=45):
    """
    Apply motion blur to an image using a convolution with a motion blur kernel.
    Args:
        image: The input image tensor.
        kernel_size: The size of the blur kernel.
        angle: The angle of the motion.
    Returns:
        image: The image after motion blur.
    """
    # Create motion blur kernel
    kernel = torch.zeros((kernel_size, kernel_size))

    # Define the direction of the blur (this could be any angle, here we use horizontal motion)
    center = kernel_size // 2
    angle_rad = torch.tensor(angle * torch.pi / 180)  # Convert to radians

    # Apply a simple horizontal motion blur
    for i in range(kernel_size):
        kernel[center, i] = 1

    # Normalize the kernel
    kernel = kernel / kernel.sum()

    # Reshape kernel for convolution (batch size, channels, kernel size)
    kernel = kernel.unsqueeze(0).unsqueeze(0)  # Shape (1, 1, kernel_size, kernel_size)

    # Apply the kernel using convolution
    blurred_image = F.conv2d(image.unsqueeze(0), kernel, padding=kernel_size//2)
    return blurred_image.squeeze(0)

class FaceMapDataset(Dataset):
    #def __init__(self, data_file="data/dolensek_facemap_softlabels_224.pt", 
    #def __init__(self, data_file="data\dolensek_facemap_softlabels_224_TEST_DIF_KP.pt", 
    def __init__(self, data_file="data\dolensek_facemap_softlabels_224_WITH_KP_REGIONS.pt", 
                 transform=None, 
                 rotation_degrees=(15, 30),  # Rotation angle range from 15 to 30 degrees
                 zoom_range=(0.8, 1.5),  # Zoom range from 0.8 (zoom out) to 1.5 (zoom in)
                 blur_radius=(1, 2),  # Tuple for Gaussian blur radius range
                 cutout_prob=0.2,  # Probability of applying cutout
                 brightness_prob=0.2,  # Probability of applying brightness adjustment
                 brightness_factor=0.5,  # Max factor for brightness adjustment
                 motion_blur_prob=0.2,  # Probability of applying motion blur
                 motion_blur_kernel_size=5,  # Size of the motion blur kernel
                 motion_blur_angle=45,  # Angle of the motion blur
                 jitter_prob=0.2,  # Probability of applying random jitter
                 jitter_max=0.1):  # Maximum jitter value (standard deviation)
        super().__init__()
        self.transform = transform
        self.rotation_degrees = rotation_degrees
        self.zoom_range = zoom_range
        self.blur_radius = blur_radius
        self.cutout_prob = cutout_prob
        self.brightness_prob = brightness_prob
        self.brightness_factor = brightness_factor
        self.motion_blur_prob = motion_blur_prob
        self.motion_blur_kernel_size = motion_blur_kernel_size
        self.motion_blur_angle = motion_blur_angle
        self.jitter_prob = jitter_prob
        self.jitter_max = jitter_max
        self.data, _, self.targets = torch.load(data_file)

    #def __len__(self):
    #    return len(self.data) * 5  # Return length * 5 for augmented versions
    def __len__(self):
        return len(self.data) * 10  # Return length * 10 for augmented versions

    def __getitem__(self, index):
        # Ensure the index stays within bounds by using modulo with the original dataset size
        base_index = index % len(self.data)  # This will prevent out-of-bounds errors
        aug_type = index // len(self.data)   # This will determine which augmentation to apply

        # Load the original image and label
        image, label = self.data[base_index].clone(), self.targets[base_index].clone()

        # Apply the augmentation based on the `aug_type`
        if self.transform is not None:
            if aug_type == 1:  # Flipping
                image = image.flip([2])
                label = label.flip([2])
            elif aug_type == 2:  # Rotation
                angle = random.uniform(-self.rotation_degrees[1], self.rotation_degrees[1])
                image = TF.rotate(image, angle)
                label = TF.rotate(label, angle)
            elif aug_type == 3:  # Zooming
                scale_factor = random.uniform(self.zoom_range[0], self.zoom_range[1])
                image = self.zoom(image, scale_factor)
                label = self.zoom(label, scale_factor)
            elif aug_type == 4:  # Gaussian Blur
                radius = (torch.rand(1).item() * (self.blur_radius[1] - self.blur_radius[0])
                          + self.blur_radius[0])
                image = TF.gaussian_blur(image, kernel_size=int(radius))
                # Do not apply blur to the label

            # Apply random cutout with probability
            if random.random() < self.cutout_prob:
                image, label = random_cutout(image, label)

            # Apply random brightness adjustment with probability
            if random.random() < self.brightness_prob:
                image, _ = adjust_brightness(image, label, self.brightness_factor)
                # Note that the label is not being adjusted, only the image

            # Apply motion blur with probability
            if random.random() < self.motion_blur_prob:
                image = motion_blur(image, self.motion_blur_kernel_size, self.motion_blur_angle)

            # Apply random jittering with probability
            if random.random() < self.jitter_prob:
                image = random_jitter(image, self.jitter_max)

        return image, label

    def zoom(self, img, scale_factor):
        # Calculate new dimensions
        _, h, w = img.shape
        new_h, new_w = int(h * scale_factor), int(w * scale_factor)

        # Resize and center-crop back to the original size
        img = TF.resize(img, [new_h, new_w])
        img = TF.center_crop(img, [h, w])
        return img


### Make dataset
dataset = FaceMapDataset(transform="test")

x = dataset[0][0]
dim = x.shape[-1]
print(f"Using {dim} size of images")
N = len(dataset)

# Randomization
indices = np.random.permutation(N)
train_indices = indices[:int(0.6 * N)]
valid_indices = indices[int(0.6 * N):int(0.8 * N)]
test_indices = indices[int(0.8 * N):]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
test_sampler = SubsetRandomSampler(test_indices)


batch_size = 4
# Initialize loss and metrics
loss_fun = torch.nn.MSELoss(reduction="sum")

# Initiliaze input dimensions
num_train = len(train_sampler)
num_valid = len(valid_sampler)
num_test = len(test_sampler)
print(f"Num. train = {num_train}, Num. val = {num_valid}, Num. test = {num_test}")

# Initialize dataloaders
loader_train = DataLoader(
    dataset=dataset,
    drop_last=False,
    num_workers=0,
    batch_size=batch_size,
    pin_memory=True,
    sampler=train_sampler,
)
loader_valid = DataLoader(
    dataset=dataset,
    drop_last=True,
    num_workers=0,
    batch_size=batch_size,
    pin_memory=True,
    sampler=valid_sampler,
)
loader_test = DataLoader(
    dataset=dataset,
    drop_last=True,
    num_workers=0,
    batch_size=1,
    pin_memory=True,
    sampler=test_sampler,
)

nValid = len(loader_valid)
nTrain = len(loader_train)
nTest = len(loader_test)

### hyperparam
lr = 5e-4
num_epochs = 300

model = Unet(nhid=8) # change nhid=2,4,6,8 (8 is original)

model = model.to(device)
nParam = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of parameters: {nParam / 1e6:.2f} M")

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
minLoss = 1e6
convIter = 0
patience = 100
train_loss = []
valid_loss = []

for epoch in range(num_epochs):
    tr_loss = 0
    for i, (inputs, labels) in enumerate(loader_train):
        inputs = inputs.to(device)
        labels = labels.to(device)
        scores, _ = model(inputs)
        loss = loss_fun((scores), ((labels)))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{nTrain}], Loss: {loss.item():.4f}")
        tr_loss += loss.item()
    train_loss.append(tr_loss / (i + 1))

    with torch.no_grad():
        val_loss = 0
        for i, (inputs, labels) in enumerate(loader_valid):
            inputs = inputs.to(device)
            labels = labels.to(device)
            scores, fmap = model(inputs)
            loss = loss_fun((scores), ((labels)))
            val_loss += loss.item()
        val_loss = val_loss / (i + 1)

        valid_loss.append(val_loss)

        print(f"Val. loss: {val_loss:.4f}")

        labels = labels.squeeze().detach().cpu().numpy()
        scores = scores.squeeze().detach().cpu().numpy()
        img = inputs.squeeze().detach().cpu().numpy()
        fmap = inputs.mean(1).squeeze().detach().cpu().numpy()

        plt.clf()
        plt.figure(figsize=(16, 12))
        for i in range(batch_size):
            plt.subplot(batch_size, 3, 3 * i + 1)
            plt.imshow(labels[i])
            plt.subplot(batch_size, 3, 3 * i + 2)
            plt.imshow(scores[i] * img[i] )
            plt.subplot(batch_size, 3, 3 * i + 3)
            plt.imshow(fmap[i], cmap = "gray")
        plt.tight_layout()
        plt.savefig(f"logs/epoch_{epoch:03d}.jpg")
        plt.close()  # prevent 'fail to allocate bitmap' error at epoch 262
        gc.collect()

        if minLoss > val_loss:
            convEpoch = epoch
            minLoss = val_loss
            convIter = 0
            torch.save(model.state_dict(), "models/best_model.pt")
        else:
            convIter += 1

        if convIter == patience:
            print(f"Converged at epoch {convEpoch + 1} with val. loss {minLoss:.4f}")
            break

plt.clf()
plt.plot(train_loss, label="Training")
plt.plot(valid_loss, label="Valid")
plt.plot(convEpoch, valid_loss[convEpoch], "x", label="Final Model")
plt.legend()
plt.tight_layout()
plt.savefig("loss_curve.pdf")

### Load best model for inference
with torch.no_grad():
    val_loss = 0
    for i, (inputs, labels) in enumerate(loader_test):
        inputs = inputs.to(device)
        labels = labels.to(device)
        scores, fmap = model(inputs)
        loss = loss_fun(scores, labels)
        val_loss += loss.item()

        img = inputs.squeeze().detach().cpu().numpy()
        pred = scores.squeeze().detach().cpu().numpy()
        labels_np = labels.squeeze().cpu().numpy()
        fmap_mean = fmap.mean(1).squeeze().cpu().numpy()
        fmap_each = fmap.squeeze().cpu().numpy()

        # Extract individual normalized feature maps
        fmap_1 = fmap_each[0]
        fmap_2 = fmap_each[1]
        fmap_3 = fmap_each[2]
        fmap_4 = fmap_each[3]
        fmap_5 = fmap_each[4]
        fmap_6 = fmap_each[5]
        fmap_7 = fmap_each[6]
        fmap_8 = fmap_each[7]

        # Plotting code
        plt.clf()
        plt.figure(figsize=(12, 9))

        # Display the main images
        plt.subplot(3, 4, 1)
        plt.imshow(img, cmap="gray")
        plt.title("Input Image")

        plt.subplot(3, 4, 2)
        plt.imshow(labels_np)
        plt.title("Ground Truth")

        plt.subplot(3, 4, 3)
        plt.imshow(pred)
        plt.title("Prediction")

        plt.subplot(3, 4, 4)
        plt.imshow(fmap_mean)
        plt.title("Normalized Feature Map Mean")

        # Display each individual normalized feature map
        plt.subplot(3, 4, 5)
        plt.imshow(fmap_1)
        plt.title("Feature Map 1")

        plt.subplot(3, 4, 6)
        plt.imshow(fmap_2)
        plt.title("Feature Map 2")

        plt.subplot(3, 4, 7)
        plt.imshow(fmap_3)
        plt.title("Feature Map 3")

        plt.subplot(3, 4, 8)
        plt.imshow(fmap_4)
        plt.title("Feature Map 4")

        plt.subplot(3, 4, 9)
        plt.imshow(fmap_5)
        plt.title("Feature Map 5")

        plt.subplot(3, 4, 10)
        plt.imshow(fmap_6)
        plt.title("Feature Map 6")

        plt.subplot(3, 4, 11)
        plt.imshow(fmap_7)
        plt.title("Feature Map 7")

        plt.subplot(3, 4, 12)
        plt.imshow(fmap_8)
        plt.title("Feature Map 8")

        plt.tight_layout()
        plt.savefig(f"preds/test_{i:03d}.jpg")
        plt.close()
        gc.collect()

    val_loss = val_loss / (i + 1)
    print(f"Test loss: {val_loss:.4f}")


  def __init__(self, data_file="data\dolensek_facemap_softlabels_224_WITH_KP_REGIONS.pt",


Using 224 size of images
Num. train = 1722, Num. val = 574, Num. test = 574
Number of parameters: 0.03 M
Epoch [1/300], Step [1/431], Loss: 65291.2656
Epoch [1/300], Step [2/431], Loss: 64794.7109
Epoch [1/300], Step [3/431], Loss: 64672.9766
Epoch [1/300], Step [4/431], Loss: 64634.2891
Epoch [1/300], Step [5/431], Loss: 51295.1328
Epoch [1/300], Step [6/431], Loss: 64608.3828
Epoch [1/300], Step [7/431], Loss: 64379.6367
Epoch [1/300], Step [8/431], Loss: 64525.8359
Epoch [1/300], Step [9/431], Loss: 60864.6328
Epoch [1/300], Step [10/431], Loss: 61569.9688
Epoch [1/300], Step [11/431], Loss: 64539.2422
Epoch [1/300], Step [12/431], Loss: 65527.0547
Epoch [1/300], Step [13/431], Loss: 57806.1094
Epoch [1/300], Step [14/431], Loss: 64470.6641
Epoch [1/300], Step [15/431], Loss: 64355.1523
Epoch [1/300], Step [16/431], Loss: 55663.0117
Epoch [1/300], Step [17/431], Loss: 64426.1953
Epoch [1/300], Step [18/431], Loss: 64359.5742
Epoch [1/300], Step [19/431], Loss: 58902.4922
Epoch [1/30

<Figure size 640x480 with 0 Axes>