In [99]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tvtf
import torchvision.transforms.functional as TF

import copy
from tqdm import tqdm
import matplotlib.pyplot as plt
import math
from datasets import MNIST_rot
from train_vit import VisionTransformer

import models
import g_selfatt.groups as groups

In [100]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [101]:
data_mean = (0.1307,)
data_stddev = (0.3081,)

transform_train = tvtf.Compose([
    tvtf.RandomRotation(degrees=(-180, 180)),  # random rotation
    tvtf.ToTensor(),
    tvtf.Normalize(data_mean, data_stddev)
])
transform_test = tvtf.Compose(
    [
        tvtf.ToTensor(),
        tvtf.Normalize(data_mean, data_stddev),
    ]
)

train_set = MNIST_rot(root="../data", stage="train", download=True, transform=transform_train, data_fraction=1, only_3_and_8=False)
validation_set = MNIST_rot(root="../data", stage="validation", download=True, transform=transform_test, data_fraction=1, only_3_and_8=False)
test_set = MNIST_rot(root="../data", stage="test", download=True, transform=transform_test, data_fraction=1, only_3_and_8=False)

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=16,
    shuffle=True,
    num_workers=4,
)
val_loader = torch.utils.data.DataLoader(
    validation_set,
    batch_size=128,
    shuffle=True,
    num_workers=4,
)
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=128,
    shuffle=False,
    num_workers=4,
)
img_loader = torch.utils.data.DataLoader(  # single element for visualization purposes
    test_set,
    batch_size=1,
    shuffle=False,
    num_workers=4,
)

In [102]:
data, target = next(iter(train_loader))

OSError: [Errno 12] Cannot allocate memory

In [None]:
def get_transforms(images, n_rotations=4, flips=True):
    """ Returns all transformations of the input images """

    B, C, H, W = images.shape
    T = 2*n_rotations if flips else n_rotations  # number of transformations

    # initialize empty transforms tensor
    transforms = torch.empty(size=(B, T, C, H, W))
    transforms[:, 0,...] = images
    idx = 1

    # remember all orientations that need to be flipped
    orientations = [images] if flips else []

    # rotations
    for i in range(1, n_rotations):
        angle = i * (360 / n_rotations)
        rotated_images = TF.rotate(images, angle)  # B, C, H, W
        transforms[:, idx,...] = rotated_images
        idx += 1

        if flips:
            orientations.append(rotated_images)

    # flips
    for transform in orientations:
        flipped_image = TF.hflip(transform)
        transforms[:, idx, ...] = flipped_image
        idx += 1

    return transforms  # B, T, C, H, W

In [None]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Args:
        x: Tensor representing the image of shape [B, C, H, W]
        patch_size: Number of pixels per dimension of the patches (integer)
        flatten_channels: If True, the patches will be returned in a flattened format
                           as a feature vector instead of a image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H', W', C, p_H, p_W]
    x = x.flatten(1, 2)  # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2, 4)  # [B, H'*W', C*p_H*p_W]
    return x


class EquivariantViT(nn.Module):
    def __init__(self, patch_size=7, num_patches=16, num_channels=1, n_rotations=4, flips=True, n_embd=1):
        super().__init__()
        self.patch_size = patch_size
        self.num_channels = num_channels  
        self.n_rotations = n_rotations
        self.flips = flips
        self.n_embd = n_embd
        self.num_patches_x = int(math.sqrt(num_patches))
        # below can be more intricate, but for now we just use a linear layer
        self.project = nn.Linear(num_channels*patch_size**2, n_embd)  # to project the patches to their embedding space
        self.gevit = models.GroupTransformer(
            group=groups.SE2(num_elements=8),
            in_channels=1,
            num_channels=20,
            block_sizes=[2, 3],
            expansion_per_block=1,
            crop_per_layer=[0, 0, 0, 0, 0],
            image_size=self.num_patches_x,
            num_classes=10,
            dropout_rate_after_maxpooling=0.0,
            maxpool_after_last_block=False,
            normalize_between_layers=False,
            patch_size=None,
            num_heads=9,
            norm_type="LayerNorm",
            activation_function="Swish",
            attention_dropout_rate=0.0,
            value_dropout_rate=0.01,
            whitening_scale=1.41421356,
        )

    def forward(self, x):
        # get the patches
        x = img_to_patch(x, self.patch_size, flatten_channels=False)  # B, num_patches, C, patch_size, patch_size
        B, num_patches, C, patch_size, _ = x.shape

        # get all transformations for the patches
        x = x.view(B*num_patches, C, patch_size, patch_size)
        x = get_transforms(x, n_rotations=self.n_rotations, flips=self.flips)

        T = x.shape[1]  # number of transformations

        # flatten and project all patches
        x = x.view(B*num_patches*T, C*patch_size*patch_size)
        x = self.project(x)

        # combine the transformations for the patches to make it invariant
        x = x.view(B, num_patches, T, self.n_embd)  # TODO check this
        x = x.mean(dim=2)
        
        # reshape to image grid
        x = x.view(B, self.num_patches_x, self.num_patches_x, self.n_embd).permute(0, 3, 1, 2)

        # print(x.shape)
        # pass through the GEViT to get predictions
        x = self.gevit(x)

        return x

In [None]:
ms = 4
gevit = models.GroupTransformer(
    group=groups.SE2(num_elements=8),
    in_channels=1,
    num_channels=20,
    block_sizes=[2, 3],
    expansion_per_block=1,
    crop_per_layer=[0, 0, 0, 0, 0],
    image_size=ms,
    num_classes=10,
    dropout_rate_after_maxpooling=0.0,
    maxpool_after_last_block=False,
    normalize_between_layers=False,
    patch_size=3,
    num_heads=9,
    norm_type="LayerNorm",
    activation_function="Swish",
    attention_dropout_rate=0.0,
    value_dropout_rate=0.01,
    whitening_scale=1.41421356,
)
gevit.eval()
gevit(get_transforms(torch.randn((1, 1, ms, ms)).float(), n_rotations=4, flips=False).squeeze(0))

tensor([[  0.9188,  -1.2069,  12.8977, -12.0133,  26.3354,  13.2836,   5.6998,
          19.4729,  -2.1033,  -3.6100],
        [  0.9188,  -1.2069,  12.8978, -12.0133,  26.3354,  13.2837,   5.6998,
          19.4729,  -2.1033,  -3.6100],
        [  0.9188,  -1.2069,  12.8978, -12.0133,  26.3354,  13.2837,   5.6998,
          19.4729,  -2.1033,  -3.6100],
        [  0.9188,  -1.2069,  12.8977, -12.0133,  26.3354,  13.2836,   5.6998,
          19.4729,  -2.1033,  -3.6100]], grad_fn=<ViewBackward0>)

In [None]:
n_data = get_transforms(data, n_rotations=4, flips=False)
n_data.shape, data.shape


(torch.Size([1, 4, 1, 28, 28]), torch.Size([1, 1, 28, 28]))

tensor([[ 32.9556,  47.2671,  46.2578,  10.6949,  32.2529,  63.4244, -36.7843,
          42.2023, -32.5435,   6.3517],
        [ 38.2196,  47.0096,  52.6853,  12.7539,  33.3660,  65.1099, -36.1047,
          41.1652, -29.1102,   8.9967],
        [ 43.5203,  44.1402,  55.9121,  17.1286,  34.9502,  70.4931, -35.9536,
          35.6416, -28.5321,  12.1842],
        [ 37.3506,  46.3114,  51.6953,  14.4644,  34.9085,  68.2118, -37.1745,
          38.4139, -32.6094,   9.5347]], grad_fn=<ViewBackward0>)

In [None]:
model = EquivariantViT(n_rotations=8)

In [None]:
n_data.shape

torch.Size([1, 4, 1, 28, 28])

In [None]:
n_data.squeeze(0).shape

torch.Size([4, 1, 28, 28])

In [None]:
model.eval()
model(get_transforms(torch.randn((1, 1, 28, 28)), n_rotations=8, flips=False).squeeze(0))

tensor([[ -5.2963,  -4.4130,  24.1830,  19.6565,  -4.5242,   0.1885,   0.7335,
         -12.1216,  20.6063, -46.3707],
        [ -5.3243,  -3.4945,  21.3550,  18.0383,  -5.8006,  -0.6816,  -0.1165,
         -11.7457,  19.1885, -47.7863],
        [ -5.2606,  -4.4213,  24.2381,  19.6619,  -4.5029,   0.1828,   0.6442,
         -12.1724,  20.6091, -46.3625],
        [ -5.3128,  -3.4906,  21.3652,  18.0476,  -5.8155,  -0.6852,  -0.1165,
         -11.7535,  19.1261, -47.7776],
        [ -5.3017,  -4.4111,  24.1761,  19.6539,  -4.5245,   0.1856,   0.7149,
         -12.0858,  20.6099, -46.3162],
        [ -5.3103,  -3.4895,  21.3567,  18.0471,  -5.8096,  -0.6905,  -0.1123,
         -11.7461,  19.1613, -47.7241],
        [ -5.3368,  -4.4045,  24.1211,  19.6481,  -4.5456,   0.1920,   0.8062,
         -12.0287,  20.6096, -46.3183],
        [ -5.3231,  -3.4937,  21.3463,  18.0368,  -5.7955,  -0.6865,  -0.1129,
         -11.7390,  19.2220, -47.7317]], grad_fn=<ViewBackward0>)

In [None]:
gevit.eval()
gevit(get_transforms(torch.randn((1, 1, 7, 7)), n_rotations=4, flips=False).squeeze(0))

tensor([[-57.7284, -14.8040,  45.7791, -59.9496, -27.1457,  12.3237,   8.8680,
          96.6409,   9.8924, -76.2570],
        [-55.4390, -16.7818,  48.2747, -57.6844, -25.0714,  11.1756,   7.9872,
          91.1289,   7.6447, -75.3762],
        [-63.3238, -19.3804,  58.6454, -69.4326, -28.6111,  10.5632,   3.4687,
         102.9273,  11.9772, -87.4186],
        [-65.4590, -19.8870,  51.3665, -65.6245, -32.2740,  12.9232,   8.0115,
         103.0499,  13.4209, -84.9431]], grad_fn=<ViewBackward0>)

In [None]:
group_transformer.eval()
model.eval()
model(n_data.squeeze(0))

tensor([[ 3.9370e-01,  1.4722e-01,  1.0389e-01,  6.5286e-02,  6.8986e-01,
          1.3243e-01, -1.5173e-01, -3.2953e-01,  5.7994e-01, -1.1655e-02],
        [ 1.8799e+00,  1.7593e-01, -1.3142e-01,  2.8236e-01,  2.5801e+00,
         -3.2251e+00, -5.0657e-01, -1.3560e+00,  1.3613e+00, -7.9269e-01],
        [-3.2089e-03,  1.4081e-01,  1.8851e-02,  6.0775e-02,  4.8662e-01,
          1.8629e-01, -3.5152e-02, -2.4624e-01,  3.8574e-01, -3.6133e-02],
        [ 1.7409e+00, -2.9496e-02, -3.4795e-02,  4.4164e-01,  2.7304e+00,
         -2.2282e+00, -6.4203e-01, -1.2745e+00,  1.3482e+00, -4.6910e-01]],
       grad_fn=<ViewBackward0>)

In [None]:
n_data.squeeze()[1,:,...]

tensor([[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242

In [None]:
TF.rotate(torch.randn(1, 1,  28, 28), 90).shape

torch.Size([1, 1, 28, 28])

In [None]:
class EquivariantViT(nn.Module):
    def __init__(self, patch_size=7, num_patches=16, num_channels=1, n_rotations=4, flips=True, n_embd=1):
        super().__init__()
        self.patch_size = patch_size
        self.num_channels = num_channels  
        self.n_rotations = n_rotations
        self.flips = flips
        self.n_embd = n_embd
        self.num_patches_x = int(math.sqrt(num_patches))
        # below can be more intricate, but for now we just use a linear layer
        self.project = nn.Linear(num_channels*patch_size**2, n_embd)  # to project the patches to their embedding space
        self.gevit = models.GroupTransformer(
            group=groups.SE2(num_elements=8),
            in_channels=1,
            num_channels=20,
            block_sizes=[2, 3],
            expansion_per_block=1,
            crop_per_layer=[0, 0, 0, 0, 0],
            image_size=self.num_patches_x,
            num_classes=10,
            dropout_rate_after_maxpooling=0.0,
            maxpool_after_last_block=False,
            normalize_between_layers=False,
            patch_size=None,
            num_heads=9,
            norm_type="LayerNorm",
            activation_function="Swish",
            attention_dropout_rate=0.0,
            value_dropout_rate=0.01,
            whitening_scale=1.41421356,
        )

    def forward(self, x):
        # get the patches
        x = img_to_patch(x, self.patch_size, flatten_channels=False)  # B, num_patches, C, patch_size, patch_size
        B, num_patches, C, patch_size, _ = x.shape

        # get all transformations for the patches
        x = x.view(B*num_patches, C, patch_size, patch_size)
        x = get_transforms(x, n_rotations=self.n_rotations, flips=self.flips)

        T = x.shape[1]  # number of transformations

        # flatten and project all patches
        x = x.view(B*num_patches*T, C*patch_size*patch_size)
        x = self.project(x)

        # combine the transformations for the patches to make it invariant
        x = x.view(B, num_patches, T, self.n_embd)  # TODO check this
        x = x.mean(dim=2)
        
        # reshape to image grid
        x = x.view(B, self.num_patches_x, self.num_patches_x, self.n_embd).permute(0, 3, 1, 2)

        # print(x.shape)
        # pass through the GEViT to get predictions
        x = self.gevit(x)

        return x

In [None]:
def train(model, n_epochs=5):
    # model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
    criterion = nn.CrossEntropyLoss()
    # best_val_acc = evaluate(model)
    print(type(model).__name__)
    # print(f"Starting validaitons accuracy: {best_val_acc}")
    best_model_state = None

    for epoch in tqdm(range(n_epochs)):
        epoch_losses = []
        for images, targets in train_loader:
            # images = images.to(device)
            # targets = targets.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            print(loss.item())
            optimizer.step()
            epoch_losses.append(loss.item())

        # validate and store best model state
        # val_acc = evaluate(model)
        # if val_acc > best_val_acc:
        #     best_val_acc = val_acc
        #     best_model_state = copy.deepcopy(model.state_dict())

        # log epoch loss
        # print(f"Epoch {epoch+1}: loss {sum(epoch_losses)/len(epoch_losses):.4f}, validation accuracy {val_acc}")

    # Load best model state into the original model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

train(model)

EquivariantViT


  0%|          | 0/5 [00:00<?, ?it/s]


OSError: [Errno 12] Cannot allocate memory