In [4]:
import torch
import numpy as np
import os
from torch import nn
from torch.utils.data import DataLoader
from preprocessing.dataset import CIFAR10_custom
from preprocessing.transforms import CompressedToTensor, ZigZagOrder, ChooseAC, FlattenZigZag
from torchvision.transforms import Compose

In [5]:
download_path = os.path.join('data', 'cifar10')

In [6]:
transform = Compose([CompressedToTensor(),
                     ZigZagOrder(),
                     ChooseAC(5)])


cifar_compressed = CIFAR10_custom(root=download_path, train=True, transform=transform, target_transform=None, download = False, compression=None)

cifar_compressed_test = CIFAR10_custom(root=download_path, train=False, transform=transform, target_transform=None, download = False, compression=None)

In [7]:
cifar_compressed.to_ycbcr(in_place=True)
cifar_compressed.compress(in_place=True)

In [8]:
batch_size = 16
train_loader = DataLoader(cifar_compressed, batch_size=batch_size, shuffle=True)

In [9]:
next(iter(train_loader))[0].shape

torch.Size([16, 3, 16, 6])

In [10]:
net = nn.Sequential(
    nn.Linear(6, 248),
    nn.TransformerEncoderLayer(248, 8, dim_feedforward=1024, dropout=0)
)

In [11]:
img1 = cifar_compressed[0][0].unsqueeze(0)
img2 = cifar_compressed[1][0].unsqueeze(0)
img = torch.cat((img1, img2), 0).to(torch.float32)

In [12]:
print(img.shape)
img = img.permute(0, 2, 1, 3).reshape(2 ,-1, 6)
print(img.shape)

torch.Size([2, 3, 16, 6])
torch.Size([2, 48, 6])


In [13]:
img

tensor([[[ 37., -13.,  -7.,   3.,  -7.,  -4.],
         [ 50.,   3.,   3.,   0.,   1.,   1.],
         [ 69.,  -2.,  -2.,   0.,  -1.,  -1.],
         [ 42.,   4.,   7.,   0.,  -5.,   1.],
         [ 49.,   0.,   0.,   1.,   1.,   0.],
         [ 70.,   0.,   0.,   0.,  -1.,   0.],
         [ 44.,  -2.,   8.,   0.,  -2.,   4.],
         [ 50.,  -1.,  -1.,   1.,   1.,   0.],
         [ 70.,   0.,   0.,   0.,  -1.,   1.],
         [ 45.,   8.,   9.,   6.,  -5.,  -2.],
         [ 50.,  -1.,   0.,   0.,   0.,   1.],
         [ 70.,   1.,   0.,   0.,   0.,   0.],
         [ 50.,  12.,   3.,  -2.,  -5.,  -1.],
         [ 49.,  -1.,   0.,   0.,   1.,   0.],
         [ 70.,   0.,   0.,   0.,   0.,   0.],
         [ 55., -11., -23.,  -2.,  16.,   5.],
         [ 48.,   0.,   1.,   2.,   0.,   1.],
         [ 71.,   0.,   0.,  -1.,  -1.,  -1.],
         [ 72.,   8., -36.,   3.,   2.,  -1.],
         [ 48.,  -1.,   1.,   2.,  -1.,   0.],
         [ 71.,   2.,   3.,  -1.,   1.,   0.],
         [ 51

In [14]:
output = net(img)

In [15]:
output.shape

torch.Size([2, 48, 248])

In [26]:
total_elements = 3 * 4 * 2 * 2  # equals 48

# Create a tensor from 1 to total_elements (48 in this case)
input_patches = torch.arange(1, total_elements + 1, dtype=torch.float).view(3, 4, 2, 2)

In [27]:
input_patches

tensor([[[[ 1.,  2.],
          [ 3.,  4.]],

         [[ 5.,  6.],
          [ 7.,  8.]],

         [[ 9., 10.],
          [11., 12.]],

         [[13., 14.],
          [15., 16.]]],


        [[[17., 18.],
          [19., 20.]],

         [[21., 22.],
          [23., 24.]],

         [[25., 26.],
          [27., 28.]],

         [[29., 30.],
          [31., 32.]]],


        [[[33., 34.],
          [35., 36.]],

         [[37., 38.],
          [39., 40.]],

         [[41., 42.],
          [43., 44.]],

         [[45., 46.],
          [47., 48.]]]])

In [18]:
# Example input with 3 channels, 4 patches, each patch is 2x2
# Simulated input (randomly generated for illustration)
# Assume each "4" now stands for different patches, each 2x2 in size

# Reshape input to flatten each 2x2 patch into a single vector per patch
# Flattening each patch
flattened_patches = input_patches.view(3, 4, -1)



# Optionally, combine channel and spatial information into one dimension
# This would result in a [4, 6] shape (4 patches, each 6 values from 2x2x3)
flattened_patches = flattened_patches.permute(1, 0, 2).reshape(4, -1)
print(flattened_patches)

print(flattened_patches.shape)
# Output: torch.Size([4, 12])

tensor([[ 1.,  2.,  3.,  4., 17., 18., 19., 20., 33., 34., 35., 36.],
        [ 5.,  6.,  7.,  8., 21., 22., 23., 24., 37., 38., 39., 40.],
        [ 9., 10., 11., 12., 25., 26., 27., 28., 41., 42., 43., 44.],
        [13., 14., 15., 16., 29., 30., 31., 32., 45., 46., 47., 48.]])
torch.Size([4, 12])


In [19]:
class LinearProjection(nn.Module):
    def __init__(self, ac, channels, patch_num, d_model = 248):
        super(LinearProjection, self).__init__()
        self.dct = ac + 1
        self.channels = channels
        self.d_model = d_model
        self.patch_num = patch_num


        self.projection = nn.Linear(self.dct * channels, d_model)

    def init_weights(self, init_fn):
        init_fn(self.projection.weight)

    def forward(self, X):
        batch_size = X.shape[:-3]
        permutate_dim = (0, 2, 1, 3) if batch_size else (1, 0, 2)

        X = X.permute(permutate_dim).reshape(*batch_size, self.patch_num, -1)
        return self.projection(X)

In [86]:
class VisionTransformer(nn.Module):
    def __init__(self,
                 ac: int,
                 channels: int,
                 patch_num: int,
                 num_classes: int,
                 d_model: int = 248,
                 nhead: int = 8,
                 dim_feedforward: int = 1024,
                 dropout: int = 0.1,
                 activation = nn.ReLU,
                 ntransformers: int = 4,
                 add_cls_token: bool = True,
                 learnable_positional: bool = True):
        super(VisionTransformer, self).__init__()

        self.add_cls_token = add_cls_token
        self.learnable_positional = learnable_positional
        self.activation = activation

        self.linear_projection = LinearProjection(ac=ac, channels=channels, patch_num=patch_num, d_model=d_model)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if add_cls_token else None

        if learnable_positional:
            if add_cls_token:
                self.positional = nn.Parameter(torch.zeros(1, 1 + patch_num, d_model))
            else:
                self.positional = nn.Parameter(torch.zeros(1, patch_num, d_model))
        else:
            self.positional = None


    def _concat_cls_token(self, X):
        if not self.add_cls_token:
            return X

        batch_size = X.shape[:-2]
        if batch_size:
            cls_token = self.cls_token.expand(*batch_size, -1, -1)
            return torch.cat((cls_token, X), dim=1)
        cls_token = self.cls_token.squeeze(0)
        return torch.cat((cls_token, X), dim = 0)


    def _with_positional(self, X):
        return X + self.positional if self.learnable_positional else X if X.shape[:-2] else X.unsqueeze(0) # prevents transformer encoder from receiving unbatched input


    def forward(self, X):
        X = self.linear_projection(X)
        X = self._concat_cls_token(X)
        return self._with_positional(X)

In [87]:
batched, tgt = next(iter(train_loader))
batched = batched.to(torch.float32)
unbatched = torch.rand(3, 16, 6)

In [88]:
vit = VisionTransformer(ac=5,
                        channels=3,
                        patch_num=16,
                        num_classes=10,
                        d_model=248,
                        nhead=8,
                        dim_feedforward=1024,
                        dropout=0.1,
                        activation=nn.ReLU,
                        add_cls_token=False,
                        learnable_positional=False)

In [89]:
vit(batched).shape

torch.Size([16, 16, 248])

In [90]:
vit.linear_projection(unbatched).shape

torch.Size([16, 248])

In [91]:
vit(unbatched).shape

torch.Size([1, 16, 248])