# Imports

In [None]:
import os
import random

import matplotlib.pyplot as plt
import numpy as np

plt.rcParams['savefig.bbox'] = 'tight'

import albumentations as A
import graphviz
import torch
from albumentations.pytorch import ToTensorV2
from PIL import Image, ImageDraw
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision.utils import make_grid
from tqdm import tqdm

%pip install torchview
from torchview import draw_graph

graphviz.set_jupyter_format('png');

In [None]:
# Set random seed for reproducibility
seed = 0
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

# U-Net Model Architecture

Here is some useful information about the U-Net model.
* [U-Net: Convolutional Networks for Biomedical Image Segmentation, Ronneberger et al., 2015](https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28)
* [3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation](https://link.springer.com/chapter/10.1007/978-3-319-46723-8_49)
* [The U-Net (actually) explained in 10 minutes](https://www.youtube.com/watch?v=NhdzGfB1q74)

In [None]:
# Code inspired by: https://towardsdatascience.com/cook-your-first-u-net-in-pytorch-b3297a844cf3


class UNet(nn.Module):
    def __init__(self, n_in=3, n_out=1, n_feat=8, kernel_size=3):
        super(UNet, self).__init__()

        # Encoder
        self.encoder_1 = nn.Sequential(
            nn.Conv2d(
                in_channels=n_in,
                out_channels=n_feat,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=n_feat,
                out_channels=n_feat,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
        )
        self.downsample_1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder_2 = nn.Sequential(
            nn.Conv2d(
                in_channels=n_feat,
                out_channels=n_feat * 2,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=n_feat * 2,
                out_channels=n_feat * 2,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
        )
        self.downsample_2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder_3 = nn.Sequential(
            nn.Conv2d(
                in_channels=n_feat * 2,
                out_channels=n_feat * 4,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=n_feat * 4,
                out_channels=n_feat * 4,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
        )
        self.downsample_3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder_4 = nn.Sequential(
            nn.Conv2d(
                in_channels=n_feat * 4,
                out_channels=n_feat * 8,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=n_feat * 8,
                out_channels=n_feat * 8,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
        )
        self.downsample_4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(
                in_channels=n_feat * 8,
                out_channels=n_feat * 16,
                kernel_size=3,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=n_feat * 16,
                out_channels=n_feat * 16,
                kernel_size=3,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
        )

        # Decoder
        self.upsample_1 = nn.ConvTranspose2d(
            in_channels=n_feat * 16,
            out_channels=n_feat * 8,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.decoder_1 = nn.Sequential(
            nn.Conv2d(
                in_channels=n_feat * 16,
                out_channels=n_feat * 8,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=n_feat * 8,
                out_channels=n_feat * 8,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
        )

        self.upsample_2 = nn.ConvTranspose2d(
            in_channels=n_feat * 8,
            out_channels=n_feat * 4,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.decoder_2 = nn.Sequential(
            nn.Conv2d(
                in_channels=n_feat * 8,
                out_channels=n_feat * 4,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=n_feat * 4,
                out_channels=n_feat * 4,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
        )

        self.upsample_3 = nn.ConvTranspose2d(
            in_channels=n_feat * 4,
            out_channels=n_feat * 2,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.decoder_3 = nn.Sequential(
            nn.Conv2d(
                in_channels=n_feat * 4,
                out_channels=n_feat * 2,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=n_feat * 2,
                out_channels=n_feat * 2,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
        )

        self.upsample_4 = nn.ConvTranspose2d(
            in_channels=n_feat * 2,
            out_channels=n_feat,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.decoder_4 = nn.Sequential(
            nn.Conv2d(
                in_channels=n_feat * 2,
                out_channels=n_feat,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=n_feat,
                out_channels=n_feat,
                kernel_size=kernel_size,
                stride=1,
                padding='same',
            ),
            nn.ReLU(inplace=True),
        )

        # Output layer
        self.output = nn.Conv2d(n_feat, n_out, stride=1, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc_1 = self.encoder_1(x)
        ds_1 = self.downsample_1(enc_1)

        enc_2 = self.encoder_2(ds_1)
        ds_2 = self.downsample_2(enc_2)

        enc_3 = self.encoder_3(ds_2)
        ds_3 = self.downsample_3(enc_3)

        enc_4 = self.encoder_4(ds_3)
        ds_4 = self.downsample_4(enc_4)

        bottle = self.bottleneck(ds_4)

        # Decoder
        us_1 = self.upsample_1(bottle)
        conc_1 = torch.cat([enc_4, us_1], dim=1)
        dec_1 = self.decoder_1(conc_1)

        us_2 = self.upsample_2(dec_1)
        conc_2 = torch.cat([enc_3, us_2], dim=1)
        dec_2 = self.decoder_2(conc_2)

        us_3 = self.upsample_3(dec_2)
        conc_3 = torch.cat([enc_2, us_3], dim=1)
        dec_3 = self.decoder_3(conc_3)

        us_4 = self.upsample_4(dec_3)
        conc_4 = torch.cat([enc_1, us_4], dim=1)
        dec_4 = self.decoder_4(conc_4)

        # Output layer
        out = self.output(dec_4)

        return out

In [None]:
model = UNet()
model_graph = draw_graph(model, input_size=(1, 3, 128, 128), device=torch.device('cuda'), graph_dir='TB')
model_graph.resize_graph(1.5)
model_graph.visual_graph

# Data, Dataset and DataLoader

## Download and unzip data

In [None]:
if os.getenv("COLAB_RELEASE_TAG"):
   !gdown 1uzaiglG7aOJae9EFy6Q0SSnWaNBoLNt5
else:
   !wget -O shapes_dataset.zip https://drive.usercontent.google.com/download?id=1uzaiglG7aOJae9EFy6Q0SSnWaNBoLNt5&export=download

In [None]:
!unzip -q shapes_dataset.zip

In [None]:
!rm -rf *__MACOSX*

## Custom Dataset for Images and Masks

In [None]:
class ImagePairPreloadedDataset(Dataset):
    def __init__(self, images_np, masks_np, transforms=None):
        self.images_np = images_np
        self.masks_np = masks_np
        self.transforms = transforms

    def __getitem__(self, index):
        x = self.images_np[index]
        y = self.masks_np[index]
        y[y == 0] = 0.0
        y[y == 255] = 1.0
        if self.transforms is not None:
            transformed = self.transforms(image=x, mask=y)
            x = transformed['image']
            y = transformed['mask']
        return x, y

    def __len__(self):
        return len(self.images_np)

## Training Data

### Preload data for faster processing

In [None]:
train_image_path = './shapes_dataset/train/images'
train_mask_path = './shapes_dataset/train/masks'

train_list = os.listdir(train_image_path)
train_list_chosen = sorted([img for img in train_list if img.endswith('png')])

train_image_list_paths = [os.path.join(train_image_path, img) for img in train_list_chosen]
train_mask_list_paths = [os.path.join(train_mask_path, img.replace('image', 'mask')) for img in train_list_chosen]

train_images_np = [read_image(train_image).numpy().transpose((1, 2, 0)).astype(np.float32) for train_image in train_image_list_paths]
train_masks_np = [read_image(train_mask).numpy().transpose((1, 2, 0)).astype(np.float32) for train_mask in train_mask_list_paths]

print('Total training images:', len(train_images_np))
print('Total training masks:', len(train_masks_np))

### Select data augmentations

In [None]:
train_transform = A.Compose(
    [
        A.HorizontalFlip(),
        A.VerticalFlip(),
        A.RandomRotate90(),
        A.ChannelShuffle(),
        A.ToFloat(max_value=255),
        ToTensorV2(transpose_mask=True),
    ]
)

train_seg_dataset = ImagePairPreloadedDataset(train_images_np, train_masks_np, transforms=train_transform)
train_seg_dataloader = DataLoader(train_seg_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(10, 10))
fig.subplots_adjust(hspace=0.1, wspace=0.05)

for i, x in enumerate([0, 1, 2, 3]):
    transformed_images = train_transform(image=train_images_np[x], mask=train_masks_np[x])

    axs[i, 0].imshow(train_images_np[x] / 255.0)
    axs[i, 0].set_title('Image' if i == 0 else '')
    axs[i, 0].set_xticks([])
    axs[i, 0].set_yticks([])

    axs[i, 1].imshow(transformed_images['image'].numpy().transpose((1, 2, 0)))
    axs[i, 1].set_title('Image Augmented' if i == 0 else '')
    axs[i, 1].set_xticks([])
    axs[i, 1].set_yticks([])

    axs[i, 2].imshow(train_masks_np[x] / 255.0, cmap='gray')
    axs[i, 2].set_title('Mask' if i == 0 else '')
    axs[i, 2].set_xticks([])
    axs[i, 2].set_yticks([])

    axs[i, 3].imshow(transformed_images['mask'].numpy().transpose((1, 2, 0)), cmap='gray')
    axs[i, 3].set_title('Mask Augmented' if i == 0 else '')
    axs[i, 3].set_xticks([])
    axs[i, 3].set_yticks([])

## Validation Data

In [None]:
val_image_path = './shapes_dataset/val/images'
val_mask_path = './shapes_dataset/val/masks'

val_list = os.listdir(val_image_path)
val_list_chosen = sorted([img for img in val_list if img.endswith('png')])

val_image_list_paths = [os.path.join(val_image_path, img) for img in val_list_chosen]
val_mask_list_paths = [os.path.join(val_mask_path, img.replace('image', 'mask')) for img in val_list_chosen]

val_images_np = [read_image(val_image).numpy().transpose((1, 2, 0)).astype(np.float32) for val_image in val_image_list_paths]
val_masks_np = [read_image(val_mask).numpy().transpose((1, 2, 0)).astype(np.float32) for val_mask in val_mask_list_paths]

print('Total validation images:', len(val_images_np))
print('Total validation masks:', len(val_masks_np))

val_transform = A.Compose(
    [
        A.ToFloat(max_value=255),
        ToTensorV2(transpose_mask=True),
    ]
)

val_seg_dataset = ImagePairPreloadedDataset(val_images_np, val_masks_np, transforms=val_transform)
val_seg_dataloader = DataLoader(val_seg_dataset, batch_size=64, num_workers=2, pin_memory=True)

# Loss function

## Binary Cross Entropy Loss
<img src='https://miro.medium.com/v2/resize:fit:4800/format:webp/1*QohPFy6wBfbjK5VUWxlyoA.png' alt='Binary Cross Entropy Loss' width='400'/>

## Dice Score Coefficient
<img src='https://miro.medium.com/v2/resize:fit:720/format:webp/1*tSqwQ9tvLmeO9raDqg3i-w.png' alt='Dice Score Coefficient' width='400'/>

In [None]:
class BCEDiceLoss(nn.Module):
    def __init__(self, eps=1e-9, weight_dice=1.0, weight_bce=1.0):
        super(BCEDiceLoss, self).__init__()
        self.eps = eps
        self.weight_dice = weight_dice
        self.weight_bce = weight_bce
        self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, y_pred, y_true):
        # Take mean Binary Cross Entropy per image and then per batch
        bce = self.bce_loss(y_pred, y_true).mean(dim=(-1, -2)).mean()

        # Sigmoid is needed to transform the model output into binary
        y_pred = torch.sigmoid(y_pred)

        # Intersection |Y_pred ∩ Y_true| of the predicted mask and the real mask
        intersection = y_pred * y_true
        intersection_sum_per_image = intersection.sum(dim=(-1, -2))  # Intersection sum per image

        # Denominator sum |Y_pred| + |Y_true|
        sum_of_sum_per_image = y_pred.sum(dim=(-1, -2)) + y_true.sum(dim=(-1, -2))  # Sum the predicted masks and the real masks

        # Dice score coefficient per image
        dice_per_image = (2 * intersection_sum_per_image + self.eps) / (sum_of_sum_per_image + self.eps)
        dice_loss_per_image = 1 - dice_per_image

        # Mean of dice scores across batch
        dice = dice_loss_per_image.mean()

        return bce * self.weight_bce + dice * self.weight_dice, dice, bce

# Training loop

In [None]:
def plot_batch(image, mask, output, save_path):
    fig, axs = plt.subplots(3, 1, figsize=(16, 8))
    fig.subplots_adjust(hspace=0.2, wspace=0.075)

    axs = axs.ravel()

    image_grid = make_grid(image, nrow=16, pad_value=0)
    axs[0].imshow(np.transpose(image_grid, (1, 2, 0)))
    axs[0].set_title('Input')
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    mask_grid = make_grid(mask, nrow=16, pad_value=1)
    axs[1].imshow(np.transpose(mask_grid, (1, 2, 0)))
    axs[1].set_title('Real Mask')
    axs[1].set_xticks([])
    axs[1].set_yticks([])

    output_grid = make_grid(output, nrow=16, pad_value=1)
    axs[2].imshow(np.transpose(output_grid, (1, 2, 0)))
    axs[2].set_title('Predicted Mask')
    axs[2].set_xticks([])
    axs[2].set_yticks([])

    plt.savefig(save_path, bbox_inches='tight', pad_inches=0.05, dpi=100, format='png')
    plt.show()

In [None]:
epochs = 12
width = 16

checkpoint_dir = f'checkpoints_seg_width_{width}_epochs_{epochs}'
os.makedirs(checkpoint_dir, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(n_feat=width, kernel_size=3)
model.to(device)

criterion = BCEDiceLoss(weight_dice=1.0, weight_bce=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)

for epoch in range(epochs):
    model.train()
    train_running_loss = 0.0

    train_progress_bar = tqdm(train_seg_dataloader, desc=f'Epoch {epoch+1}/{epochs}', unit='batch')

    for x, y in train_progress_bar:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        out = model(x)

        loss, dice, bce = criterion(out, y)
        loss.backward()
        optimizer.step()

        train_running_loss += loss.item()
        train_progress_bar.set_postfix(loss=loss.item(), dice_loss=dice.item(), bce_loss=bce.item())

    val_running_loss = 0.0
    model.eval()

    val_progress_bar = tqdm(val_seg_dataloader, desc=f'Epoch {epoch+1}/{epochs}', unit='batch')

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for vx, vy in val_progress_bar:
            vx = vx.to(device)
            vy = vy.to(device)

            vout = model(vx)

            vloss, vdice, vbce = criterion(vout, vy)
            val_running_loss += vloss.item()
            val_progress_bar.set_postfix(loss=vloss.item(), dice_loss=vdice.item(), bce_loss=vbce.item())

    print(
        '\nDice+BCE Loss train: {}, validation: {}'.format(
            train_running_loss / len(train_seg_dataloader),
            val_running_loss / len(val_seg_dataloader),
        )
    )

    scheduler.step(val_running_loss / len(val_seg_dataloader))

    checkpoint = {
        'epoch': epoch,
        'model': model,
        'optimizer': optimizer,
        'lr_sched': scheduler,
    }

    if epoch + 1 > 45 or epoch < 15:
        torch.save(checkpoint, os.path.join(checkpoint_dir, f'checkpoint_{epoch+1}.pth'))

    plot_batch(
        vx.detach().cpu(),
        vy.detach().cpu(),
        torch.sigmoid(vout.detach().cpu()),
        os.path.join(checkpoint_dir, f'sample_{epoch}.png'),
    )

# Test with custom image

In [None]:
def draw_circle(draw_image, centre, radius, colour):
    draw_image.ellipse(
        (
            centre[0] - radius,
            centre[1] - radius,
            centre[0] + radius,
            centre[1] + radius,
        ),
        fill=colour,
    )


def draw_ellipse(draw_image, centre, radius_1, radius_2, colour):
    draw_image.ellipse(
        (
            centre[0] - radius_1,
            centre[1] - radius_2,
            centre[0] + radius_1,
            centre[1] + radius_2,
        ),
        fill=colour,
    )


def draw_rectangle(draw_image, top_left, bottom_right, colour):
    draw_image.rectangle([top_left, bottom_right], fill=colour)


def draw_square(draw_image, centre, radius, rotation, colour):
    draw_image.regular_polygon((centre, radius), n_sides=4, rotation=rotation, fill=colour)


def draw_pentagon(draw_image, centre, radius, rotation, colour):
    draw_image.regular_polygon((centre, radius), n_sides=5, rotation=rotation, fill=colour)


def draw_hexagon(draw_image, centre, radius, rotation, colour):
    draw_image.regular_polygon((centre, radius), n_sides=6, rotation=rotation, fill=colour)


def draw_triangle(draw_image, draw_mask, centre, radius, rotation, colour):
    draw_image.regular_polygon((centre, radius), n_sides=3, rotation=rotation, fill=colour)
    draw_mask.regular_polygon((centre, radius), n_sides=3, rotation=rotation, fill=1)

In [None]:
def plot_two_images(image_1, image_2, title_1, title_2, cmap_1, cmap_2):
    fig, axs = plt.subplots(1, 2, figsize=(7, 5))
    axs = axs.ravel()

    axs[0].imshow(image_1, cmap=cmap_1)
    axs[0].set_title(title_1)
    axs[0].axis('off')

    axs[1].imshow(image_2, cmap=cmap_2)
    axs[1].set_title(title_2)
    axs[1].axis('off')

    plt.show()

## Create a custom image

In [None]:
image_size = (128, 128)

image = Image.new('RGB', image_size, (255, 255, 255))
mask = Image.new('L', image_size, 0)

draw_image = ImageDraw.Draw(image)
draw_mask = ImageDraw.Draw(mask)

# Draw shapes
draw_circle(draw_image, centre=(40, 40), radius=40, colour=(200, 1, 1))
draw_ellipse(draw_image, centre=(90, 100), radius_1=20, radius_2=10, colour=(1, 200, 1))
draw_rectangle(draw_image, top_left=(10, 10), bottom_right=(90, 50), colour=(1, 1, 200))
draw_square(draw_image, centre=(67, 67), radius=50, rotation=30, colour=(200, 200, 1))
draw_pentagon(draw_image, centre=(27, 87), radius=20, rotation=30, colour=(100, 200, 1))
draw_hexagon(draw_image, centre=(27, 97), radius=20, rotation=30, colour=(200, 100, 1))

# Draw triangle
draw_triangle(
    draw_image,
    draw_mask,
    centre=(70, 70),
    radius=60,
    rotation=95,
    colour=(100, 100, 200),
)

image_np = np.array(image)
mask_np = np.array(mask)

plot_two_images(image_np, mask_np, 'Image', 'Mask', None, 'gray')

## Test performance

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.load(os.path.join('checkpoints_seg_width_16_epochs_12', 'checkpoint_12.pth'))['model']
model.to(device)
model.eval()

image_pt = torch.tensor(image_np.transpose((2, 0, 1)) / 255.0, dtype=torch.float32).to(device).unsqueeze(0)
predicted_mask = model(image_pt)
predicted_mask_np = torch.sigmoid(predicted_mask.cpu().detach()).numpy().squeeze()

plot_two_images(mask_np, predicted_mask_np, 'Mask', 'Predicted', 'gray', 'gray')

dice = 2 * np.sum(mask_np * predicted_mask_np) / (np.sum(mask_np) + np.sum(predicted_mask_np))
print('Dice Score Coefficient', dice)

# Visualising features and filters

In [None]:
activation = {}


def get_activation(name):
    def hook(model, x, y):
        activation[name] = y.detach()

    return hook

In [None]:
model.encoder_1.register_forward_hook(get_activation('encoder_1'))
model.encoder_2.register_forward_hook(get_activation('encoder_2'))
model.encoder_3.register_forward_hook(get_activation('encoder_3'))
model.encoder_4.register_forward_hook(get_activation('encoder_4'))
model.bottleneck.register_forward_hook(get_activation('bottleneck'))
model.decoder_1.register_forward_hook(get_activation('decoder_1'))
model.decoder_2.register_forward_hook(get_activation('decoder_2'))
model.decoder_3.register_forward_hook(get_activation('decoder_3'))
model.decoder_4.register_forward_hook(get_activation('decoder_4'))
model.output.register_forward_hook(get_activation('output'))
predicted_mask = model(image_pt)

In [None]:
def plot_features(features, layer_title, columns=8, cmap='Reds'):
    rows = features.size(0) // columns

    width = columns * 2
    height = rows * 2.22

    fig, axs = plt.subplots(rows, columns, figsize=(width, height))
    fig.subplots_adjust(hspace=0.01, wspace=0.01)

    axs = axs.ravel() if len(fig.get_axes()) > 1 else [axs]

    for idx in range(len(axs)):
        axs[idx].imshow(features[idx].cpu().detach().numpy(), cmap=cmap)
        axs[idx].set_xticks([])
        axs[idx].set_yticks([])
    fig.suptitle(
        'Block: ' + layer_title + ' | Feature dimensions: ' + str(features.cpu().detach().numpy().shape),
        fontsize=2 * (columns) if columns > 1 else 12,
    )
    fig.tight_layout()
    plt.show()

In [None]:
def plot_filters(filters, layer_title, inner_layer_title, cmap='Reds', has_been_transposed=False):
    rows = filters.size(1)
    columns = filters.size(0)

    width = columns
    height = rows

    fig, axs = plt.subplots(rows, columns, figsize=(width, height))
    fig.subplots_adjust(hspace=0.01, wspace=0.01)

    for r in range(rows):
        for c in range(columns):
            if rows == 1:
                curr_axs = axs[c]
            elif columns == 1:
                curr_axs = axs[r]
            else:
                curr_axs = axs[r, c]

            curr_axs.imshow(filters[c][r].cpu().detach().numpy(), cmap=cmap)
            curr_axs.set_xticks([])
            curr_axs.set_yticks([])
    filter_dims = str(filters.cpu().detach().numpy().shape) if not has_been_transposed else str(filters.transpose(1, 0).cpu().detach().numpy().shape)
    fig.suptitle(
        'Block: ' + layer_title + ' | Inner layer: ' + inner_layer_title + ' | Filter dimensions: ' + filter_dims,
        fontsize=2 * (columns) if columns < 16 else columns,
        y=0.99,
    )
    fig.tight_layout()
    plt.show()

## Encoder features

In [None]:
layer = 'encoder_1'
features = activation[layer].squeeze()
plot_features(features, layer_title=layer, columns=8, cmap='Reds')

In [None]:
layer = 'encoder_2'
features = activation[layer].squeeze()
plot_features(features, layer_title=layer, columns=8, cmap='Reds')

In [None]:
layer = 'encoder_3'
features = activation[layer].squeeze()
plot_features(features, layer_title=layer, columns=16, cmap='Reds')

In [None]:
layer = 'encoder_4'
features = activation[layer].squeeze()
plot_features(features, layer_title=layer, columns=16, cmap='Reds')

In [None]:
layer = 'bottleneck'
features = activation[layer].squeeze()
plot_features(features, layer_title=layer, columns=32, cmap='Reds')

## Decoder features

In [None]:
layer = 'decoder_1'
features = activation[layer].squeeze()
plot_features(features, layer_title=layer, columns=16, cmap='Blues')

In [None]:
layer = 'decoder_2'
features = activation[layer].squeeze()
plot_features(features, layer_title=layer, columns=16, cmap='Blues')

In [None]:
layer = 'decoder_3'
features = activation[layer].squeeze()
plot_features(features, layer_title=layer, columns=8, cmap='Blues')

In [None]:
layer = 'decoder_4'
features = activation[layer].squeeze()
plot_features(features, layer_title=layer, columns=8, cmap='Blues')

In [None]:
layer = 'output'
features = activation[layer].squeeze(0)
plot_features(features, layer_title=layer, columns=1, cmap='Blues')

In [None]:
layer = 'output'
features = torch.sigmoid(activation[layer].squeeze(0))
plot_features(features, layer_title=layer + ' + sigmoid', columns=1, cmap='Blues')

## Filters

In [None]:
w_encoder_1_1 = model.encoder_1[0].weight.detach()
w_encoder_1_2 = model.encoder_1[2].weight.detach()
w_encoder_2_1 = model.encoder_2[0].weight.detach()
w_encoder_2_2 = model.encoder_2[2].weight.detach()
w_encoder_3_1 = model.encoder_3[0].weight.detach()
w_encoder_3_2 = model.encoder_3[2].weight.detach()
w_encoder_4_1 = model.encoder_4[0].weight.detach()
w_encoder_4_2 = model.encoder_4[2].weight.detach()
w_bottleneck_1 = model.bottleneck[0].weight.detach()
w_bottleneck_2 = model.bottleneck[2].weight.detach()

w_decoder_1_1 = model.decoder_1[0].weight.detach()
w_decoder_1_2 = model.decoder_1[2].weight.detach()
w_decoder_2_1 = model.decoder_2[0].weight.detach()
w_decoder_2_2 = model.decoder_2[2].weight.detach()
w_decoder_3_1 = model.decoder_3[0].weight.detach()
w_decoder_3_2 = model.decoder_3[2].weight.detach()
w_decoder_4_1 = model.decoder_4[0].weight.detach()
w_decoder_4_2 = model.decoder_4[2].weight.detach()
w_output = model.output.weight.detach()

In [None]:
plot_filters(w_encoder_1_1, 'encoder_1', 'conv_1', cmap='gray')

In [None]:
plot_filters(w_encoder_1_2, 'encoder_1', 'conv_2', cmap='gray')

In [None]:
plot_filters(w_encoder_2_1, 'encoder_2', 'conv_1', cmap='gray')

In [None]:
plot_filters(w_decoder_4_2, 'decoder_4', 'conv_2', cmap='gray')

In [None]:
plot_filters(w_output.transpose(1, 0), 'output', 'conv', cmap='gray', has_been_transposed=True)