In [3]:

import torch
import torch.nn as nn

# Self-Attention Module
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.query = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.key = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.value = nn.Conv2d(in_dim, in_dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, H, W = x.size()
        q = self.query(x).view(B, -1, H*W).permute(0, 2, 1)
        k = self.key(x).view(B, -1, H*W)
        attn = torch.bmm(q, k)
        attn = torch.softmax(attn, dim=-1)
        v = self.value(x).view(B, -1, H*W)
        out = torch.bmm(v, attn.permute(0, 2, 1)).view(B, C, H, W)
        return self.gamma * out + x

# Simple Generator using Attention
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 64, 4, 1, 0),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            SelfAttention(64),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

# Test Generator output
noise = torch.randn(4, 100, 1, 1)
G = Generator()
output = G(noise)
print("Output shape:", output.shape)


Output shape: torch.Size([4, 3, 8, 8])
