## Open notebook in:
| Colab                                 |  Gradient                                                                                                                                         |
|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nicolepcx/transformers-the-definitive-guide/blob/master/CH06/CH06_KV_compression.ipynb)                                             | [![Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://console.paperspace.com//github.com/Nicolepcx/transformers-the-definitive-guide/blob/main/CH06/CH06_KV_compression.ipynb)|             

# About this notebook


In this notebook, you explore an optimized approach to handle high-resolution image generation using a key-value token compression mechanism within attention modules. The `AttentionKVCompress` class demonstrates how to effectively downsample tensors, thereby reducing computational load while maintaining model performance. By implementing various sampling methods such as convolutional downsampling, average pooling, and uniform token selection, we achieve significant reductions in the number of tokens processed during training and inference. This technique is pivotal in enabling efficient scaling to ultra-high resolutions without compromising on quality. This method was introduced in the paper: ["PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation"](https://arxiv.org/abs/2403.04692)

#Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [7]:
class AttentionKVCompress(nn.Module):
    def __init__(self, dim, sr_ratio):
        super(ExampleModel, self).__init__()
        self.sr = nn.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio)
        self.norm = nn.LayerNorm(dim)

    def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
        if sampling is None or scale_factor == 1:
            return tensor, tensor.shape[1]

        B, N, C = tensor.shape
        print(f"Original Tensor Shape: {tensor.shape}")

        if sampling == 'uniform_every':
            tensor = tensor[:, ::scale_factor]
            print(f"Shape after 'uniform_every': {tensor.shape}")
            return tensor, int(N // scale_factor)

        tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
        print(f"Shape after Reshape and Permute: {tensor.shape}")
        new_H, new_W = int(H / scale_factor), int(W / scale_factor)
        new_N = new_H * new_W

        if sampling == 'ave':
            tensor = F.interpolate(
                tensor, scale_factor=1 / scale_factor, mode='nearest'
            ).permute(0, 2, 3, 1)
        elif sampling == 'uniform':
            tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
            print(f"Shape after 'uniform' downsampling: {tensor.shape}")
        elif sampling == 'conv':
            tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
            tensor = self.norm(tensor)
            print(f"Shape after 'conv' downsampling: {tensor.shape}")
        else:
            raise ValueError

        tensor = tensor.reshape(B, new_N, C).contiguous()
        print(f"Final Shape after Reshape: {tensor.shape}")
        return tensor, new_N


# Create example tensor
B, H, W, C = 1, 4, 4, 1  # Batch size, Height, Width, Channels
input_tensor = torch.arange(1, 17).view(B, H * W, C).float()  # Creating a tensor [1, 16, 1]

compressed_KV = AttentionKVCompress(dim=C, sr_ratio=2)

# Uniform Every
print("=== Uniform Every Downsampling ===")
output_tensor, new_N = model.downsample_2d(input_tensor, H, W, scale_factor=2, sampling='uniform_every')
print(f"New number of tokens: {new_N}\n")

# Average Pooling
print("=== Average Pooling Downsampling ===")
output_tensor, new_N = model.downsample_2d(input_tensor, H, W, scale_factor=2, sampling='ave')
print(f"New number of tokens: {new_N}\n")

# Uniform
print("=== Uniform Downsampling ===")
output_tensor, new_N = model.downsample_2d(input_tensor, H, W, scale_factor=2, sampling='uniform')
print(f"New number of tokens: {new_N}\n")

# Convolution
print("=== Convolution Downsampling ===")
output_tensor, new_N = model.downsample_2d(input_tensor, H, W, scale_factor=2, sampling='conv')
print(f"New number of tokens: {new_N}\n")


=== Uniform Every Downsampling ===
Original Tensor Shape: torch.Size([1, 16, 1])
Shape after 'uniform_every': torch.Size([1, 8, 1])
New number of tokens: 8

=== Average Pooling Downsampling ===
Original Tensor Shape: torch.Size([1, 16, 1])
Shape after Reshape and Permute: torch.Size([1, 1, 4, 4])
Final Shape after Reshape: torch.Size([1, 4, 1])
New number of tokens: 4

=== Uniform Downsampling ===
Original Tensor Shape: torch.Size([1, 16, 1])
Shape after Reshape and Permute: torch.Size([1, 1, 4, 4])
Shape after 'uniform' downsampling: torch.Size([1, 2, 2, 1])
Final Shape after Reshape: torch.Size([1, 4, 1])
New number of tokens: 4

=== Convolution Downsampling ===
Original Tensor Shape: torch.Size([1, 16, 1])
Shape after Reshape and Permute: torch.Size([1, 1, 4, 4])
Shape after 'conv' downsampling: torch.Size([1, 4, 1])
Final Shape after Reshape: torch.Size([1, 4, 1])
New number of tokens: 4

