In [1]:
import torch
import torch.nn as nn

In [None]:

class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(EncoderLayer, self).__init__()

        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

        self.feed = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*4),
            nn.SiLU(inplace=True),
            nn.Linear(embed_dim*4, embed_dim)
        )

        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # query, key, value
        y = self.attention(x, x, x)
        z = self.norm(x + y)

        w = self.feed(z)
        out = self.norm(w + z)
        return out

class Encoder(nn.Module):
    def __init__(self, embed_dim, num_heads, N):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([EncoderLayer(embed_dim, num_heads) for _ in range(N)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(DecoderLayer, self).__init__()

        self.self_attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.cross_attention = nn.MultiheadAttention(embed_dim, num_heads)


        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*4),
            nn.SiLU(inplace=True),
            nn.Linear(embed_dim*4, embed_dim)
        )

        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, object, image):
        # Self attention

        z = self.self_attention(object, object, object)
        object = self.norm(z + object)

        # query, key, value

        z = self.feed_forward(object)
        


        return object


class Decoder(nn.Module):
    def __init__(self, embed_dim, num_heads, N):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([DecoderLayer(embed_dim, num_heads) for _ in range(N)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
