# Spatial Transformer
In this lab scenario, you will finish a implementation of a module that allows networks to perform spatial transformations on both input images and feature maps. For details, you can refer to [the paper](https://arxiv.org/abs/1506.02025).


## Data Preparation
For training, we are going to use the MNIST dataset.

In [None]:
!pip install lightning

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl

In [None]:
BATCH_SIZE = 32
NUM_CLASSES = 10
IMAGE_HEIGHT = 28
IMAGE_WIDTH = 28
train_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),  # our input is an image
        torchvision.transforms.RandomAffine(
            45,
            (0.25, 0.25),
            scale=(0.5, 1.0),
            interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
        ),
        torchvision.transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
    ]
)

test_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),  # our input is an image
        torchvision.transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
    ]
)


DOWNLOAD_PATH = "~/torch_datasets/MNIST"
DATASET_TRAIN = torchvision.datasets.MNIST(
    root=DOWNLOAD_PATH, train=True, transform=train_transforms, download=True
)
DATASET_TEST = torchvision.datasets.MNIST(
    root=DOWNLOAD_PATH, train=False, transform=test_transforms, download=True
)

TRAIN_LOADER = torch.utils.data.DataLoader(
    DATASET_TRAIN, batch_size=BATCH_SIZE, shuffle=True
)
TEST_LOADER = torch.utils.data.DataLoader(
    DATASET_TEST, batch_size=BATCH_SIZE, shuffle=False
)

In [None]:
def show_images(images, num_rows=1):
    """
    Given a tensor of shape [BATCH, C, H, W]
    prints BATCH images splitting them evenly among num_rows
    """
    assert len(images.shape) == 4
    num_images = images.shape[0]
    row_len = num_images // num_rows
    assert row_len * num_rows == num_images
    _, axes = plt.subplots(num_rows, row_len, figsize=(12, 12))
    images = images.permute(0, 2, 3, 1).detach().numpy()

    def handle_img(img, axe):
        axe.axis("off")
        img = np.clip(img, 0, 1)

        axe.imshow(img)

    if num_images == 1:
        handle_img(images[0], axes)
    else:
        for i, img in enumerate(images):
            if num_rows == 1:
                handle_img(img, axes[i])
            else:
                r = i // row_len
                c = i % row_len
                handle_img(img, axes[r, c])

In [None]:
SAMPLE_TRAIN, _ = next(iter(TRAIN_LOADER))
SAMPLE_TRAIN = SAMPLE_TRAIN[:16]
show_images(SAMPLE_TRAIN, 4)

In [None]:
SAMPLE_IMGS, _ = next(iter(TEST_LOADER))
SAMPLE_IMGS = SAMPLE_IMGS[:16]
show_images(SAMPLE_IMGS, 4)

## Spatial Tools
Our goal is to train a model on the MNIST dataset that would be invariant to transformations like scale change and shifts.  
To achieve this we are going to use the spatial transformer.  

Briefly speaking, given a source image $I$ (or feature map) we want to produce a transformed image $I'$ (rotated, flipped, cropped, etc.) for further processing (for example by CNN). What is more, we want transformation parameters to depend on $I$ (we will produce them using a neural network that will input $I$).  
To produce such a transformed image, for each pixel coordinates $x'$, $y'$ in  $I'$ we will produce pixel coordinates $x$, $y$ in $I$ 
such that the value of the pixel $x', y'$ in $I$ will be the value of the pixel $x, y$ in $I$.
As our procedure will be able to produce coordinates with unknown values (between two pixels) we will use bilinear interpolation (in fact there is one more reason for this choice). 
For transforming $x', y'$ to $x, y$ we will restrict ourselves to the family of affine transformations.

To be more precise let $x', y'$ be the coordinates of some pixel in the output image scaled so that 
all coordinates are between $-1$ and $1$ (inclusive).  
To get the value of the pixel $x', y'$ in $I'$ we first calculate:

$$
\left(\begin{array}{c} 
x\\
y
\end{array}\right)
=
\left(\begin{array}{ccc} 
\theta_1 & \theta_2 & \theta_3\\ 
\theta_4 & \theta_5 & \theta_6
\end{array}\right)
\left(\begin{array}{c} 
x'\\ 
y'\\
1
\end{array}\right)
$$ 

Then we treat $x, y$ as the coordinates of the pixel in $I$, treating coordinates in $\{-1, 1\}\times\{-1, 1\}$ as corner pixels.  
Finally, we set the  value of the pixel $x', y'$ in $I'$ to either:
* $0$ if we landed outside of the image
* the result of the bilinear interpolation between values of pixels in $I$  at coordinates
    + $\lfloor x \rfloor, \lfloor y \rfloor$
    + $\lfloor x \rfloor, \lceil y \rceil$
    + $\lceil x \rceil, \lfloor y \rfloor$
    + $\lceil x \rceil, \lceil y \rceil$


If the image (or feature map) consists of many channels (colors), we apply the same transformation parameters to all channels.   


Before you start the implementation one more question.  
**Why do we use bilinear interpolation here instead of just picking the nearest pixel**?

Below your task is to finish the implementation of `SamplingGridGenerator` that for each pixel $p$ in the output image generates coordinates in the input image that will be the source of $p$ value.   
Assume that each coordinate in the input image is in $[-1, 1]$.  
Note that the produced coordinates can lie outside the image (we will take care of it later).  
**Please don't use build-in functions like `affine_grid` or `grid_sample`**


In [None]:
class SamplingGridGenerator(torch.nn.Module):
    def __init__(self, output_height, output_width):
        super().__init__()

        self.output_height = output_height
        self.output_width = output_width

    def forward(self, theta):
        """
        Given parameters of the transformation theta of shape [BATCH, 2, 3]
        returns a sampling_grid of shape [BATCH, output_height, output_width, 2]
        such that at sampling_grid[b, x', y'] are the coordinates of the pixel
        in the source image from which the value of pixel x', y' in
        the transformed image will be sampled.
        Assumes that each coordinate in the source image lies in the range [-1, 1].
        Note that transformation can point to pixels outside of the source image
        and this module does not clip values to lie inside the source image.
        """

        assert len(theta.shape) == 3  # [BATCH, 2, 3]
        assert theta.shape[1] == 2
        assert theta.shape[2] == 3

        # TODO {

        # }

        assert sampling_grid.shape == (
            theta.shape[0],
            self.output_height,
            self.output_width,
            2,
        )  # [BATCH, H', W', 2]
        return sampling_grid

In [None]:
# Here we can test the implementation
sgg = SamplingGridGenerator(4, 4)

# t1 should give sth like
# tensor([[-1.0000-1.0000j, -1.0000-0.3333j, -1.0000+0.3333j, -1.0000+1.0000j],
#         [-0.3333-1.0000j, -0.3333-0.3333j, -0.3333+0.3333j, -0.3333+1.0000j],
#         [ 0.3333-1.0000j,  0.3333-0.3333j,  0.3333+0.3333j,  0.3333+1.0000j],
#         [ 1.0000-1.0000j,  1.0000-0.3333j,  1.0000+0.3333j,  1.0000+1.0000j]])
t1 = sgg(torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]]))[0]
print("t1")
print(torch.view_as_complex(t1))

t2 = sgg(torch.tensor([[[0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]]))[0]
print("t2 = t1 with x, y swapped")
print(torch.view_as_complex(t2))
assert (t2[..., [1, 0]] == t1).all()


t3 = sgg(torch.tensor([[[1.0, 0.0, 1.0], [0.0, 1.0, 1.0]]]))[0]
print("t3 = t1 +1")
print(torch.view_as_complex(t3))
assert (t3 == t1 + 1).all()

Now implement `GridSampler` that given the sampling grid created by `SamplingGridGenerator` and source image will create the transformed image.  

In [None]:
class GridSampler(torch.nn.Module):
    """
    Given the source_image of shape [B, C, H, W]
    and sampling_grid of shape [B, H', W', 2]
    generates the transformed_image of shape
    [B, C, H', W'].
    transformed_image[b, c, x', y'] is the value of pixel
    at coordinates sampling_grid[b, x', y']
    in the source_image[b, c].
    Values {-1, 1}x{-1, 1} in sampling_grid correspond to corner pixels.
    Pixels outside of the source_image are assumed to have
    value 0.
    Values of unknown pixels inside are obtained
    using bilinear interpolation.
    """

    def __init__(self):
        super().__init__()

    def forward(self, source_image, sampling_grid):
        assert len(source_image.shape) == 4  # [B, C, H, W]
        batch_size = source_image.shape[0]
        input_height = source_image.shape[-2]
        input_width = source_image.shape[-1]

        # [B, H', W', 2]
        assert len(sampling_grid.shape) == 4
        assert sampling_grid.shape[0] == batch_size
        assert sampling_grid.shape[-1] == 2

        output_height = sampling_grid.shape[1]
        output_width = sampling_grid.shape[2]

        # TODO {

        # }

        assert transformed_image.shape[0] == source_image.shape[0]
        assert transformed_image.shape[1] == source_image.shape[1]
        assert transformed_image.shape[2] == output_height
        assert transformed_image.shape[3] == output_width
        return transformed_image

Let's check whether created modules work as we expect.

In [None]:
def visualize_transformation(transform_matrices, imgs):
    sgg = SamplingGridGenerator(IMAGE_HEIGHT, IMAGE_WIDTH)
    gs = GridSampler()
    sampling_grid = sgg(transform_matrices)
    res = gs(imgs, sampling_grid)
    show_images(res, 4)


identity = torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]] * SAMPLE_IMGS.shape[0])
flip_dim1 = torch.tensor([[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]] * SAMPLE_IMGS.shape[0])
flip_dim2 = torch.tensor([[[1.0, 0.0, 0.0], [0.0, -1.0, 0.0]]] * SAMPLE_IMGS.shape[0])
rotate_90 = torch.tensor([[[0.0, -1.0, 0.0], [1.0, 0.0, 0.0]]] * SAMPLE_IMGS.shape[0])
translate = torch.tensor([[[1.0, 0.0, 1.0], [0.0, 1.0, 0.5]]] * SAMPLE_IMGS.shape[0])
other = torch.tensor([[[1.5, 0.0, 0.75], [0.0, 1.5, 0.75]]] * SAMPLE_IMGS.shape[0])


visualize_transformation(identity, SAMPLE_IMGS)
visualize_transformation(flip_dim1, SAMPLE_IMGS)
visualize_transformation(flip_dim2, SAMPLE_IMGS)
visualize_transformation(rotate_90, SAMPLE_IMGS)
# test_transformation(translate, SAMPLE_IMGS)
visualize_transformation(other, SAMPLE_IMGS)

## Network

Finish the implementation of `NetWithSpatialTransformer`.
To get the parameters of the transformation use `LocNet` module that is already implemented.

In [None]:
class LocNet(torch.nn.Module):
    """
    Localisation network
    given the image generates parameters for the transformation
    """

    def __init__(self, in_channels=1):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=in_channels, out_channels=4, kernel_size=3),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Dropout(p=0.2),
            torch.nn.Linear(200, 32),
            torch.nn.ReLU(),
        )

        # we start with the identity
        self.last_matrix = torch.nn.Parameter(torch.zeros((32, 6)))
        self.last_bias = torch.nn.Parameter(
            torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0])
        )

    def forward(self, images):
        assert len(images.shape) == 4  # [B, C, H, W]
        res = self.layers(images)
        res = res @ self.last_matrix + self.last_bias
        res = res.reshape(images.shape[0], 2, 3)
        return res


class NetWithSpatialTransformer(torch.nn.Module):
    """ 
    Given the images first uses LocNet module to get
    transformation parameters.
    Then uses grid_generator to generate a sampling grid for each image 
    and grid_sampler to transform each image.
    Then processes the transformed images with layers.
    """
    def __init__(
        self, in_channels=1, image_height=IMAGE_HEIGHT, image_width=IMAGE_HEIGHT
    ):
        super().__init__()
        self.loc_net = LocNet(in_channels=in_channels)
        self.grid_generator = SamplingGridGenerator(
            output_height=image_height, output_width=image_width
        )
        self.grid_sampler = GridSampler()
        # MNIST is a very simple dataset. 
        # To strongly motivate spatial transformer we use a simple model there
        self.layers = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Dropout(),
            torch.nn.Linear(784, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 10),
        )

    def forward(self, images):
        assert len(images.shape) == 4  # [B, C, H, W]
        # TODO {

        # }
        return self.layers(sampled)

In [None]:
net = NetWithSpatialTransformer()
net(SAMPLE_IMGS).shape

## Training
Training loop is already implemented

In [None]:
class PLSpatialTransformer(pl.LightningModule):
    def __init__(self, model: torch.nn.Module):
        super().__init__()

        self.model = model
        self.criterion = torch.nn.CrossEntropyLoss()

    def train_dataloader(self):
        return TRAIN_LOADER

    def val_dataloader(self):
        return TEST_LOADER

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.criterion(logits, y)
        choosen = torch.argmax(logits, dim=-1)
        acc = (choosen == y).type(torch.float32).mean()

        # on-line metrics
        self.log("train/loss", loss.detach())
        self.log("train/acc", acc.detach())

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        choosen = torch.argmax(logits, dim=-1)
        acc = (choosen == y).type(torch.float32).mean()

        self.log("test/acc", acc.detach(), on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters())
        return optimizer

In [None]:
%load_ext tensorboard
!mkdir -p tb_logs

In [None]:
%tensorboard --logdir tb_logs

In [None]:
DEVICE = torch.device("cuda")

from pytorch_lightning.loggers import TensorBoardLogger

net = NetWithSpatialTransformer()
net.to(DEVICE)

plST = PLSpatialTransformer(model=net)

trainer = pl.Trainer(
    logger=TensorBoardLogger("tb_logs", name="my_model"),
    accelerator="gpu",
    max_epochs=2,
    check_val_every_n_epoch=1,
)

trainer.fit(plST)

## Inspection
Let's check what our spatial transformed learned.


In [None]:
net.to(DEVICE)
## First lets check unaltered images
imgs = SAMPLE_IMGS
show_images(imgs, 4)
visualize_transformation(net.loc_net(imgs.to(DEVICE)).cpu(), imgs)

In [None]:
## Now lets check what spatial transformer will do with altered images from train set
imgs = SAMPLE_TRAIN
show_images(imgs, 4)
visualize_transformation(net.loc_net(imgs.to(DEVICE)).cpu(), imgs)