In [3]:
from pathlib import Path 
print(Path.cwd().parents[0]) 

/home/amzad/Desktop/stable_diffusion


In [10]:
import torch
from torch import nn
import math

class SelfAttention(nn.Module):
    def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
        self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads

    def forward(self, x, causal_mask=False):
        input_shape = x.shape 
        batch_size, sequence_length, d_embed = input_shape 
        interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
        q, k, v = self.in_proj(x).chunk(3, dim=-1)
        q = q.view(interim_shape).transpose(1, 2)
        k = k.view(interim_shape).transpose(1, 2)
        v = v.view(interim_shape).transpose(1, 2)
        weight = q @ k.transpose(-1, -2)
        if causal_mask:
            mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
            weight.masked_fill_(mask, -torch.inf)
        weight /= math.sqrt(self.d_head)
        weight = nn.functional.softmax(weight, dim=-1)
        output = weight @ v
        output = output.transpose(1, 2)
        output = output.reshape(input_shape)
        output = self.out_proj(output)
        return output


class VAE_attention(nn.Module):
    """
    VAE_attention is an attention mechanism used in Variational Autoencoders (VAEs).
    It applies Group Normalization followed by a multi-head attention mechanism.
    """

    def __init__(self, channels, num_heads):
        """
        Initializes the VAE_attention.

        Parameters:
        channels (int): Number of input and output channels.
        num_heads (int): Number of attention heads.
        """
        super(VAE_attention, self).__init__()
        self.groupnorm = nn.GroupNorm(32, channels)
        self.in_proj = nn.Linear(channels, 3 * channels, bias=True)
        self.out_proj = nn.Linear(channels, channels, bias=False)
        self.attention = nn.MultiheadAttention(channels, num_heads)

    def forward(self, x):
        """
        Forward pass of the VAE_attention.

        Parameters:
        x (torch.Tensor): Input tensor with shape (Batch_Size, Features, Height, Width).

        Returns:
        torch.Tensor: Output tensor after applying attention mechanism.
        """
        x = self.groupnorm(x)  # Apply Group Normalization

        n, c, h, w = x.shape  # Get the shape of the input tensor
        key, query, value = self.in_proj(x).chunk(3, dim=-1)  
        print("key.shape:", key.shape )
        print("query.shape:", query.shape )
        print("value.shape:", value.shape )

        key = key.view(n, h * w, c).transpose(1, 2)  # Reshape key tensor
        query = query.view(n, h * w, c).transpose(1, 2)
        value = value.view(n, h * w, c).transpose(1, 2)
        out, _ = self.attention(query, key, value)  
        out = out.transpose(1, 2).view(n, c, h, w)  
        out = self.out_proj(out)  
        return out 

def test_vae_attention():
    torch.manual_seed(0)
    batch_size = 2
    channels = 64
    height = 64
    width = 4
    num_heads = 2
    x = torch.randn(batch_size, channels, height, width)

    vae_attn = VAE_attention(channels, num_heads)
    output = vae_attn(x)

    print("Output:\n", output)

    assert output.shape == x.shape, "Output shape is incorrect!"

if __name__ == "__main__":
    test_vae_attention()



RuntimeError: mat1 and mat2 shapes cannot be multiplied (8192x4 and 64x192)