In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import gc
import pdb

import gc
import glob
import os
import pdb
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from numpy import random
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler, TensorDataset
from torchvision.utils import save_image

from models import Unet


def random_cutout(image, label, max_h=50, max_w=50):
    """
    Apply random cutout to both image and label.
    Args:
        image: The input image tensor.
        label: The label tensor.
        max_h: Maximum height of the cutout box.
        max_w: Maximum width of the cutout box.
    Returns:
        image: Image after cutout.
        label: Label after cutout.
    """
    _, h, w = image.shape
    cutout_height = random.randint(10, max_h)
    cutout_width = random.randint(10, max_w)

    # Randomly choose the position for the cutout
    top = random.randint(0, h - cutout_height)
    left = random.randint(0, w - cutout_width)

    # Apply the cutout to the image and label (set to 0)
    image[:, top:top + cutout_height, left:left + cutout_width] = 0
    label[:, top:top + cutout_height, left:left + cutout_width] = 0

    return image, label


def adjust_brightness(image, label, brightness_factor=0.2):
    """
    Adjust the brightness of the image and label.
    Args:
        image: The input image tensor.
        label: The label tensor.
        brightness_factor: Factor by which brightness is adjusted.
    Returns:
        image: Image after brightness adjustment.
        label: Label after brightness adjustment.
    """
    image = TF.adjust_brightness(image, 1 + (random.random() * 2 - 1) * brightness_factor)
    # Note: Brightness doesn't affect the label, so we leave it unchanged
    return image, label

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF

import torch
import torchvision.transforms.functional as TF

def random_jitter(image, max_jitter=0.1):
    """
    Apply random jittering to an image by adding noise to its pixel values.
    Args:
        image: The input image tensor.
        max_jitter: The maximum amount of jitter to apply to each pixel.
    Returns:
        image: The image with jitter applied.
    """
    # Generate random noise with a normal distribution
    noise = torch.randn_like(image) * max_jitter  # Gaussian noise
    image = image + noise

    # Clip the values to be in the valid range [0, 1] for images
    image = torch.clamp(image, 0, 1)

    return image


def motion_blur(image, kernel_size=5, angle=45):
    """
    Apply motion blur to an image using a convolution with a motion blur kernel.
    Args:
        image: The input image tensor.
        kernel_size: The size of the blur kernel.
        angle: The angle of the motion.
    Returns:
        image: The image after motion blur.
    """
    # Create motion blur kernel
    kernel = torch.zeros((kernel_size, kernel_size))

    # Define the direction of the blur (this could be any angle, here we use horizontal motion)
    center = kernel_size // 2
    angle_rad = torch.tensor(angle * torch.pi / 180)  # Convert to radians

    # Apply a simple horizontal motion blur
    for i in range(kernel_size):
        kernel[center, i] = 1

    # Normalize the kernel
    kernel = kernel / kernel.sum()

    # Reshape kernel for convolution (batch size, channels, kernel size)
    kernel = kernel.unsqueeze(0).unsqueeze(0)  # Shape (1, 1, kernel_size, kernel_size)

    # Apply the kernel using convolution
    blurred_image = F.conv2d(image.unsqueeze(0), kernel, padding=kernel_size//2)
    return blurred_image.squeeze(0)

class FaceMapDataset(Dataset):
    def __init__(self, data_file="data/dolensek_facemap_softlabels_224.pt", 
                 transform=None, 
                 rotation_degrees=(15, 30),  # Rotation angle range from 15 to 30 degrees
                 zoom_range=(0.8, 1.5),  # Zoom range from 0.8 (zoom out) to 1.5 (zoom in)
                 blur_radius=(1, 2),  # Tuple for Gaussian blur radius range
                 cutout_prob=0.2,  # Probability of applying cutout
                 brightness_prob=0.2,  # Probability of applying brightness adjustment
                 brightness_factor=0.5,  # Max factor for brightness adjustment
                 motion_blur_prob=0.2,  # Probability of applying motion blur
                 motion_blur_kernel_size=5,  # Size of the motion blur kernel
                 motion_blur_angle=45,  # Angle of the motion blur
                 jitter_prob=0.2,  # Probability of applying random jitter
                 jitter_max=0.1):  # Maximum jitter value (standard deviation)
        super().__init__()
        self.transform = transform
        self.rotation_degrees = rotation_degrees
        self.zoom_range = zoom_range
        self.blur_radius = blur_radius
        self.cutout_prob = cutout_prob
        self.brightness_prob = brightness_prob
        self.brightness_factor = brightness_factor
        self.motion_blur_prob = motion_blur_prob
        self.motion_blur_kernel_size = motion_blur_kernel_size
        self.motion_blur_angle = motion_blur_angle
        self.jitter_prob = jitter_prob
        self.jitter_max = jitter_max
        self.data, _, self.targets = torch.load(data_file)

    #def __len__(self):
    #    return len(self.data) * 5  # Return length * 5 for augmented versions
    def __len__(self):
        return len(self.data) * 10  # Return length * 10 for augmented versions

    def __getitem__(self, index):
        # Ensure the index stays within bounds by using modulo with the original dataset size
        base_index = index % len(self.data)  # This will prevent out-of-bounds errors
        aug_type = index // len(self.data)   # This will determine which augmentation to apply

        # Load the original image and label
        image, label = self.data[base_index].clone(), self.targets[base_index].clone()

        # Apply the augmentation based on the `aug_type`
        if self.transform is not None:
            if aug_type == 1:  # Flipping
                image = image.flip([2])
                label = label.flip([2])
            elif aug_type == 2:  # Rotation
                angle = random.uniform(-self.rotation_degrees[1], self.rotation_degrees[1])
                image = TF.rotate(image, angle)
                label = TF.rotate(label, angle)
            elif aug_type == 3:  # Zooming
                scale_factor = random.uniform(self.zoom_range[0], self.zoom_range[1])
                image = self.zoom(image, scale_factor)
                label = self.zoom(label, scale_factor)
            elif aug_type == 4:  # Gaussian Blur
                radius = (torch.rand(1).item() * (self.blur_radius[1] - self.blur_radius[0])
                          + self.blur_radius[0])
                image = TF.gaussian_blur(image, kernel_size=int(radius))
                # Do not apply blur to the label

            # Apply random cutout with probability
            if random.random() < self.cutout_prob:
                image, label = random_cutout(image, label)

            # Apply random brightness adjustment with probability
            if random.random() < self.brightness_prob:
                image, _ = adjust_brightness(image, label, self.brightness_factor)
                # Note that the label is not being adjusted, only the image

            # Apply motion blur with probability
            if random.random() < self.motion_blur_prob:
                image = motion_blur(image, self.motion_blur_kernel_size, self.motion_blur_angle)

            # Apply random jittering with probability
            if random.random() < self.jitter_prob:
                image = random_jitter(image, self.jitter_max)

        return image, label

    def zoom(self, img, scale_factor):
        # Calculate new dimensions
        _, h, w = img.shape
        new_h, new_w = int(h * scale_factor), int(w * scale_factor)

        # Resize and center-crop back to the original size
        img = TF.resize(img, [new_h, new_w])
        img = TF.center_crop(img, [h, w])
        return img


# Assuming these are your Unet and convBlock definitions from before
class convBlock(nn.Module):
    def __init__(self, inCh, nhid, nOp, twod=True, pool=True,
                    ker=3, padding=1, pooling=2):
        super(convBlock, self).__init__()

        if twod:
            self.enc1 = nn.Conv2d(inCh, nhid, kernel_size=ker, padding=1)
            self.enc2 = nn.Conv2d(nhid, nOp, kernel_size=ker, padding=1)
            self.bn = nn.BatchNorm2d(inCh)

            if pool:
                self.scale = nn.AvgPool2d(kernel_size=pooling)
            else:
                self.scale = nn.Upsample(scale_factor=pooling)
        else:
            self.enc1 = nn.Conv3d(inCh, nhid, kernel_size=ker, padding=1)
            self.enc2 = nn.Conv3d(nhid, nOp, kernel_size=ker, padding=1)
            self.bn = nn.BatchNorm3d(inCh)

            if pool:
                self.scale = nn.AvgPool3d(kernel_size=pooling)
            else:
                self.scale = nn.Upsample(scale_factor=pooling)

        self.pool = pool
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.scale(x)
        x = self.bn(x)
        x = self.act(self.enc1(x))
        x = self.act(self.enc2(x))
        return x


class Unet(nn.Module):
    def __init__(self, nhid=8, ker=3, inCh=1, h=224, w=224):
        super(Unet, self).__init__()
        ### U-net Encoder with 3 downsampling layers
        self.uEnc11 = nn.Conv2d(inCh, nhid, kernel_size=ker, padding=1)
        self.uEnc12 = nn.Conv2d(nhid, nhid, kernel_size=ker, padding=1)

        self.uEnc2 = convBlock(nhid, 2 * nhid, 2 * nhid, pool=True)
        self.uEnc3 = convBlock(2 * nhid, 4 * nhid, 4 * nhid, pool=True)
        self.uEnc4 = convBlock(4 * nhid, 8 * nhid, 8 * nhid, pool=True)  # unhashed
        self.uEnc5 = convBlock(8 * nhid, 16 * nhid, 16 * nhid, pool=True)  # unhashed

        ### U-net decoder 
        self.dec5 = convBlock(16 * nhid, 8 * nhid, 8 * nhid, pool=False)  # unhashed
        self.dec4 = convBlock(16 * nhid, 4 * nhid, 4 * nhid, pool=False)  # unhashed
        self.dec3 = convBlock(4 * nhid, 2 * nhid, 2 * nhid, pool=False, pooling=2)
        self.dec2 = convBlock(4 * nhid, nhid, nhid, pool=False, pooling=2)

        self.dec11 = nn.Conv2d(2 * nhid, nhid, kernel_size=ker, padding=1)
        self.dec12 = nn.Conv2d(nhid, 1, kernel_size=ker, padding=1)

        self.act = nn.ReLU()

        self.h = h
        self.w = w

    def encoder(self, x_in):
        ### Unet Encoder
        x = []
        x.append(self.act(self.uEnc12(self.act(self.uEnc11(x_in)))))
        x.append(self.uEnc2(x[-1]))
        x.append(self.uEnc3(x[-1]))
        return x

    def decoder(self, x_enc):
        x = self.dec3(x_enc[-1])
        x = torch.cat((x, x_enc[-2]), dim=1)
        x = self.dec2(x)
        x = torch.cat((x, x_enc[-3]), dim=1)
        fmap = self.act(self.dec11(x))
        x = self.dec12(fmap)
        return x, fmap

    def forward(self, x):
        # Unet encoder result
        x_enc = self.encoder(x)
        # Outputs for MSE
        xHat, fmap = self.decoder(x_enc)
        return xHat, fmap


# Ensemble Model
class UNetEnsemble(nn.Module):
    def __init__(self, num_models=5, nhid=8, ker=3, inCh=1, h=224, w=224):
        super(UNetEnsemble, self).__init__()
        self.models = nn.ModuleList([Unet(nhid=nhid, ker=ker, inCh=inCh, h=h, w=w).to(device) for _ in range(num_models)])

    def forward(self, x):
        predictions = []
        feature_maps = []

        # Forward pass through all models in the ensemble
        for model in self.models:
            pred, fmap = model(x)  # Assuming each model returns (prediction, feature_maps)
            predictions.append(pred)
            feature_maps.append(fmap)

        # Average the predictions from all models
        ensemble_prediction = torch.mean(torch.stack(predictions), dim=0)

        return ensemble_prediction, feature_maps



### Make dataset
dataset = FaceMapDataset(transform="test")

x = dataset[0][0]
dim = x.shape[-1]
print(f"Using {dim} size of images")
N = len(dataset)

# Randomization
indices = np.random.permutation(N)
train_indices = indices[:int(0.6 * N)]
valid_indices = indices[int(0.6 * N):int(0.8 * N)]
test_indices = indices[int(0.8 * N):]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
test_sampler = SubsetRandomSampler(test_indices)


batch_size = 4
# Initialize loss and metrics
loss_fun = torch.nn.MSELoss(reduction="sum")

# Initiliaze input dimensions
num_train = len(train_sampler)
num_valid = len(valid_sampler)
num_test = len(test_sampler)
print(f"Num. train = {num_train}, Num. val = {num_valid}, Num. test = {num_test}")

# Initialize dataloaders
loader_train = DataLoader(
    dataset=dataset,
    drop_last=False,
    num_workers=0,
    batch_size=batch_size,
    pin_memory=True,
    sampler=train_sampler,
)
loader_valid = DataLoader(
    dataset=dataset,
    drop_last=True,
    num_workers=0,
    batch_size=batch_size,
    pin_memory=True,
    sampler=valid_sampler,
)
loader_test = DataLoader(
    dataset=dataset,
    drop_last=True,
    num_workers=0,
    batch_size=1,
    pin_memory=True,
    sampler=test_sampler,
)

nValid = len(loader_valid)
nTrain = len(loader_train)
nTest = len(loader_test)

# Training Loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ensemble_model = UNetEnsemble(num_models=5, nhid=8, ker=3, inCh=1, h=224, w=224).to(device)
optimizer = torch.optim.Adam(ensemble_model.parameters(), lr=0.001)
loss_fun = nn.MSELoss()

# Assuming `loader_train` is the DataLoader for training
num_epochs = 10  # Set the number of epochs
for epoch in range(num_epochs):
    ensemble_model.train()
    tr_loss = 0
    for i, (inputs, labels) in enumerate(loader_train):  # Assuming loader_train is your DataLoader
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass through the ensemble model
        ensemble_prediction, feature_maps = ensemble_model(inputs)

        # Compute loss
        loss = loss_fun(ensemble_prediction, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tr_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {tr_loss / (i + 1):.4f}")


# Inference and Feature Map Visualization
ensemble_model.eval()
with torch.no_grad():
    val_loss = 0
    for i, (inputs, labels) in enumerate(loader_test):  # Assuming loader_test is your test DataLoader
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass through the ensemble model
        ensemble_prediction, feature_maps = ensemble_model(inputs)

        # Compute loss
        loss = loss_fun(ensemble_prediction, labels)
        val_loss += loss.item()

        # Convert outputs to numpy for plotting
        img = inputs.squeeze().detach().cpu().numpy()
        pred = ensemble_prediction.squeeze().detach().cpu().numpy()
        labels_np = labels.squeeze().cpu().numpy()

        # Visualize feature maps from all models in the ensemble
        for j, fmap in enumerate(feature_maps):
            fmap_mean = fmap.mean(1).squeeze().cpu().numpy()  # Take the mean over the channels (axis=1)

            # Plotting
            plt.clf()
            plt.figure(figsize=(12, 9))

            plt.subplot(3, 4, 1)
            plt.imshow(img, cmap="gray")
            plt.title("Input Image")

            plt.subplot(3, 4, 2)
            plt.imshow(labels_np)
            plt.title("Ground Truth")

            plt.subplot(3, 4, 3)
            plt.imshow(pred)
            plt.title("Ensemble Prediction")

            plt.subplot(3, 4, 4)
            plt.imshow(fmap_mean)
            plt.title(f"Feature Map Mean (Model {j+1})")

            plt.tight_layout()
            plt.savefig(f"preds/test_{i:03d}_model_{j+1}.jpg")
            plt.close()

        gc.collect()

    val_loss = val_loss / (i + 1)
    print(f"Test loss: {val_loss:.4f}")


Using 224 size of images
Num. train = 1722, Num. val = 574, Num. test = 574
Epoch [1/10], Train Loss: 0.0000
Epoch [2/10], Train Loss: 0.0000
Epoch [3/10], Train Loss: 0.0000
Epoch [4/10], Train Loss: 0.0000
Epoch [5/10], Train Loss: 0.0000
Epoch [6/10], Train Loss: 0.0000
Epoch [7/10], Train Loss: 0.0000
Epoch [8/10], Train Loss: 0.0000
Epoch [9/10], Train Loss: 0.0000
Epoch [10/10], Train Loss: 0.0000
Test loss: 0.0000


<Figure size 640x480 with 0 Axes>