In [1]:
# Codeblock 1
import torch
import torch.nn as nn

In [2]:
# Codeblock 2
BATCH_SIZE   = 1      #(1)

IMAGE_SIZE   = 384    #(2)
IN_CHANNELS  = 3      #(3)

SEQ_LENGTH   = 30     #(4)
VOCAB_SIZE   = 10000  #(5)

EMBED_DIM          = 768  #(6)
PATCH_SIZE         = 16   #(7)
NUM_PATCHES        = (IMAGE_SIZE//PATCH_SIZE) ** 2  #(8)
NUM_ENCODER_BLOCKS = 12   #(9)
NUM_DECODER_BLOCKS = 4    #(10)
NUM_HEADS          = 12   #(11)
HIDDEN_DIM         = EMBED_DIM * 4  #(12)
DROP_PROB          = 0.1  #(13)

In [5]:
# Codeblock 3
class Patcher(nn.Module):
    def __init__(self):
        super().__init__()

        #(1)
        self.unfold = nn.Unfold(kernel_size=PATCH_SIZE, stride=PATCH_SIZE)

        #(2)
        self.linear_projection = nn.Linear(in_features=IN_CHANNELS*PATCH_SIZE*PATCH_SIZE, 
                                           out_features=EMBED_DIM)
        
    def forward(self, images):
        #print(f'images\t\t: {images.size()}')
        
        images = self.unfold(images)  #(3)
        #print(f'after unfold\t: {images.size()}')
        
        images = images.permute(0, 2, 1)  #(4)
        #print(f'after permute\t: {images.size()}')
        
        features = self.linear_projection(images)  #(5)
        #print(f'after lin proj\t: {features.size()}')
        
        return features

In [4]:
# Codeblock 4
patcher  = Patcher()

images   = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
features = patcher(images)

images		: torch.Size([1, 3, 384, 384])
after unfold	: torch.Size([1, 768, 576])
after permute	: torch.Size([1, 576, 768])
after lin proj	: torch.Size([1, 576, 768])


In [8]:
# Codeblock 5
class LearnableEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.learnable_embedding = nn.Parameter(torch.randn(size=(NUM_PATCHES, EMBED_DIM)), 
                                                requires_grad=True)
        
    def forward(self):
        pos_embed = self.learnable_embedding
        #print(f'learnable embedding\t: {pos_embed.size()}')
        
        return pos_embed

In [7]:
# Codeblock 6
learnable_embedding = LearnableEmbedding()

pos_embed = learnable_embedding()

learnable embedding	: torch.Size([576, 768])


In [11]:
# Codeblock 7a
class EncoderBlock(nn.Module):
    def __init__(self):
        super().__init__()
        
        #(1)
        self.self_attention = nn.MultiheadAttention(embed_dim=EMBED_DIM,
                                                    num_heads=NUM_HEADS, 
                                                    batch_first=True,  #(2)
                                                    dropout=DROP_PROB)
        
        self.layer_norm_0 = nn.LayerNorm(EMBED_DIM)  #(3)
        
        self.ffn = nn.Sequential(  #(4)
            nn.Linear(in_features=EMBED_DIM, out_features=HIDDEN_DIM),
            nn.GELU(), 
            nn.Dropout(p=DROP_PROB), 
            nn.Linear(in_features=HIDDEN_DIM, out_features=EMBED_DIM),
        )
        
        self.layer_norm_1 = nn.LayerNorm(EMBED_DIM)  #(5)
        
# Codeblock 7b
    def forward(self, features):  #(1)
        
        residual = features  #(2)
        #print(f'features & residual\t: {residual.size()}')
        
        #(3)
        features, self_attn_weights = self.self_attention(query=features, 
                                                          key=features, 
                                                          value=features)
        #print(f'after self attention\t: {features.size()}')
        #print(f"self attn weights\t: {self_attn_weights.shape}")
        
        features = self.layer_norm_0(features + residual)  #(4)
        #print(f'after norm\t\t: {features.size()}')
        

        residual = features
        #print(f'\nfeatures & residual\t: {residual.size()}')
        
        features = self.ffn(features)  #(5)
        #print(f'after ffn\t\t: {features.size()}')
        
        features = self.layer_norm_1(features + residual)
        #print(f'after norm\t\t: {features.size()}')
        
        return features

In [10]:
# Codeblock 8
encoder_block = EncoderBlock()

features = torch.randn(BATCH_SIZE, NUM_PATCHES, EMBED_DIM)
features = encoder_block(features)

features & residual	: torch.Size([1, 576, 768])
after self attention	: torch.Size([1, 576, 768])
self attn weights	: torch.Size([1, 576, 576])
after norm		: torch.Size([1, 576, 768])

features & residual	: torch.Size([1, 576, 768])
after ffn		: torch.Size([1, 576, 768])
after norm		: torch.Size([1, 576, 768])


In [14]:
# Codeblock 9
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.patcher = Patcher()  #(1)
        self.learnable_embedding = LearnableEmbedding()  #(2)

        #(3)
        self.encoder_blocks = nn.ModuleList(EncoderBlock() for _ in range(NUM_ENCODER_BLOCKS))
    
    def forward(self, images):  #(4)
        #print(f'images\t\t\t: {images.size()}')
        
        features = self.patcher(images)  #(5)
        #print(f'after patcher\t\t: {features.size()}')
        
        features = features + self.learnable_embedding()  #(6)
        #print(f'after learn embed\t: {features.size()}')
        
        for i, encoder_block in enumerate(self.encoder_blocks):
            features = encoder_block(features)  #(7)
            #print(f"after encoder block #{i}\t: {features.shape}")

        return features

In [13]:
# Codeblock 10
encoder = Encoder()

images = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
features = encoder(images)

images			: torch.Size([1, 3, 384, 384])
after patcher		: torch.Size([1, 576, 768])
after learn embed	: torch.Size([1, 576, 768])
after encoder block #0	: torch.Size([1, 576, 768])
after encoder block #1	: torch.Size([1, 576, 768])
after encoder block #2	: torch.Size([1, 576, 768])
after encoder block #3	: torch.Size([1, 576, 768])
after encoder block #4	: torch.Size([1, 576, 768])
after encoder block #5	: torch.Size([1, 576, 768])
after encoder block #6	: torch.Size([1, 576, 768])
after encoder block #7	: torch.Size([1, 576, 768])
after encoder block #8	: torch.Size([1, 576, 768])
after encoder block #9	: torch.Size([1, 576, 768])
after encoder block #10	: torch.Size([1, 576, 768])
after encoder block #11	: torch.Size([1, 576, 768])


In [29]:
# Codeblock 11
class EncoderTorch(nn.Module):
    def __init__(self):
        super().__init__()
        self.patcher = Patcher()
        self.learnable_embedding = LearnableEmbedding()
        
        #(1)
        encoder_block = nn.TransformerEncoderLayer(d_model=EMBED_DIM, 
                                                   nhead=NUM_HEADS, 
                                                   dim_feedforward=HIDDEN_DIM, 
                                                   dropout=DROP_PROB, 
                                                   batch_first=True)
        
        #(2)
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer=encoder_block, 
                                                    num_layers=NUM_ENCODER_BLOCKS)
    
    def forward(self, images):
        #print(f'images\t\t\t: {images.size()}')
        
        features = self.patcher(images)
        #print(f'after patcher\t\t: {features.size()}')
        
        features = features + self.learnable_embedding()
        #print(f'after learn embed\t: {features.size()}')
        
        features = self.encoder_blocks(features)  #(3)
        #print(f'after encoder blocks\t: {features.size()}')

        return features

In [16]:
# Codeblock 12
encoder_torch = EncoderTorch()

images = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
features = encoder_torch(images)

images			: torch.Size([1, 3, 384, 384])
after patcher		: torch.Size([1, 576, 768])
after learn embed	: torch.Size([1, 576, 768])
after encoder blocks	: torch.Size([1, 576, 768])


In [19]:
# Codeblock 13
class SinusoidalEmbedding(nn.Module):
    def forward(self):
        pos = torch.arange(SEQ_LENGTH).reshape(SEQ_LENGTH, 1)
        #print(f"pos\t\t: {pos.shape}")
        
        i = torch.arange(0, EMBED_DIM, 2)
        denominator = torch.pow(10000, i/EMBED_DIM)
        #print(f"denominator\t: {denominator.shape}")
        
        even_pos_embed = torch.sin(pos/denominator)  #(1)
        odd_pos_embed  = torch.cos(pos/denominator)  #(2)
        #print(f"even_pos_embed\t: {even_pos_embed.shape}")
        
        stacked = torch.stack([even_pos_embed, odd_pos_embed], dim=2)  #(3)
        #print(f"stacked\t\t: {stacked.shape}")

        pos_embed = torch.flatten(stacked, start_dim=1, end_dim=2)  #(4)
        #print(f"pos_embed\t: {pos_embed.shape}")
        
        return pos_embed

In [18]:
# Codeblock 14
sinusoidal_embedding = SinusoidalEmbedding()
pos_embed = sinusoidal_embedding()

pos		: torch.Size([30, 1])
denominator	: torch.Size([384])
even_pos_embed	: torch.Size([30, 384])
stacked		: torch.Size([30, 384, 2])
pos_embed	: torch.Size([30, 768])


In [20]:
# Codeblock 15
def create_mask(seq_length):
    mask = torch.tril(torch.ones((seq_length, seq_length)))  #(1)
    mask[mask == 0] = -float('inf')  #(2)
    mask[mask == 1] = 0  #(3)
    return mask

In [21]:
# Codeblock 16
mask_example = create_mask(seq_length=7)
mask_example

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0.]])

In [24]:
# Codeblock 17a
class DecoderBlock(nn.Module):
    def __init__(self):
        super().__init__()
        
        #(1)
        self.self_attention = nn.MultiheadAttention(embed_dim=EMBED_DIM, 
                                                    num_heads=NUM_HEADS, 
                                                    batch_first=True, 
                                                    dropout=DROP_PROB)
        #(2)
        self.layer_norm_0 = nn.LayerNorm(EMBED_DIM)
        
        #(3)
        self.cross_attention = nn.MultiheadAttention(embed_dim=EMBED_DIM, 
                                                     num_heads=NUM_HEADS, 
                                                     batch_first=True, 
                                                     dropout=DROP_PROB)

        #(4)
        self.layer_norm_1 = nn.LayerNorm(EMBED_DIM)
        
        #(5)       
        self.ffn = nn.Sequential(
            nn.Linear(in_features=EMBED_DIM, out_features=HIDDEN_DIM),
            nn.GELU(), 
            nn.Dropout(p=DROP_PROB), 
            nn.Linear(in_features=HIDDEN_DIM, out_features=EMBED_DIM),
        )
        
        #(6)
        self.layer_norm_2 = nn.LayerNorm(EMBED_DIM)
        
# Codeblock 17b
    def forward(self, features, captions, attn_mask):  #(1)
        #print(f"attn_mask\t\t: {attn_mask.shape}")
        residual = captions
        #print(f"captions & residual\t: {captions.shape}")
        
        #(2)
        captions, self_attn_weights = self.self_attention(query=captions, 
                                                          key=captions, 
                                                          value=captions, 
                                                          attn_mask=attn_mask)
        #print(f"after self attention\t: {captions.shape}")
        #print(f"self attn weights\t: {self_attn_weights.shape}")
        
        captions = self.layer_norm_0(captions + residual)
        #print(f"after norm\t\t: {captions.shape}")
        
        
        #print(f"\nfeatures\t\t: {features.shape}")
        residual = captions
        #print(f"captions & residual\t: {captions.shape}")
        
        #(3)
        captions, cross_attn_weights = self.cross_attention(query=captions, 
                                                            key=features, 
                                                            value=features)
        #print(f"after cross attention\t: {captions.shape}")
        #print(f"cross attn weights\t: {cross_attn_weights.shape}")
        
        captions = self.layer_norm_1(captions + residual)
        #print(f"after norm\t\t: {captions.shape}")
        
        residual = captions
        #print(f"\ncaptions & residual\t: {captions.shape}")
        
        captions = self.ffn(captions)  #(4)
        #print(f"after ffn\t\t: {captions.shape}")
        
        captions = self.layer_norm_2(captions + residual)
        #print(f"after norm\t\t: {captions.shape}")
        
        return captions

In [23]:
# Codeblock 18
decoder_block = DecoderBlock()

features = torch.randn(BATCH_SIZE, NUM_PATCHES, EMBED_DIM)  #(1)
captions = torch.randn(BATCH_SIZE, SEQ_LENGTH, EMBED_DIM)   #(2)
look_ahead_mask = create_mask(seq_length=SEQ_LENGTH)  #(3)

captions = decoder_block(features, captions, look_ahead_mask)

attn_mask		: torch.Size([30, 30])
captions & residual	: torch.Size([1, 30, 768])
after self attention	: torch.Size([1, 30, 768])
self attn weights	: torch.Size([1, 30, 30])
after norm		: torch.Size([1, 30, 768])

features		: torch.Size([1, 576, 768])
captions & residual	: torch.Size([1, 30, 768])
after cross attention	: torch.Size([1, 30, 768])
cross attn weights	: torch.Size([1, 30, 576])
after norm		: torch.Size([1, 30, 768])

captions & residual	: torch.Size([1, 30, 768])
after ffn		: torch.Size([1, 30, 768])
after norm		: torch.Size([1, 30, 768])


In [30]:
# Codeblock 19a
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

        #(1)
        self.embedding = nn.Embedding(num_embeddings=VOCAB_SIZE,
                                      embedding_dim=EMBED_DIM)

        #(2)
        self.sinusoidal_embedding = SinusoidalEmbedding()

        #(3)
        self.decoder_blocks = nn.ModuleList(DecoderBlock() for _ in range(NUM_DECODER_BLOCKS))

        #(4)
        self.linear = nn.Linear(in_features=EMBED_DIM, 
                                out_features=VOCAB_SIZE)
        
# Codeblock 19b
    def forward(self, features, captions, attn_mask):  #(1)
        #print(f"features\t\t: {features.shape}")
        #print(f"captions\t\t: {captions.shape}")
        
        captions = self.embedding(captions)  #(2)
        #print(f"after embedding\t\t: {captions.shape}")
        
        captions = captions + self.sinusoidal_embedding()  #(3)
        #print(f"after sin embed\t\t: {captions.shape}")
        
        for i, decoder_block in enumerate(self.decoder_blocks):
            captions = decoder_block(features, captions, attn_mask)  #(4)
            #print(f"after decoder block #{i}\t: {captions.shape}")
        
        captions = self.linear(captions)  #(5)
        #print(f"after linear\t\t: {captions.shape}")
        
        return captions

In [26]:
# Codeblock 20
decoder = Decoder()

features = torch.randn(BATCH_SIZE, NUM_PATCHES, EMBED_DIM)
captions = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))  #(1)

captions = decoder(features, captions, look_ahead_mask)

features		: torch.Size([1, 576, 768])
captions		: torch.Size([1, 30])
after embedding		: torch.Size([1, 30, 768])
after sin embed		: torch.Size([1, 30, 768])
after decoder block #0	: torch.Size([1, 30, 768])
after decoder block #1	: torch.Size([1, 30, 768])
after decoder block #2	: torch.Size([1, 30, 768])
after decoder block #3	: torch.Size([1, 30, 768])
after linear		: torch.Size([1, 30, 10000])


In [31]:
# Codeblock 21
class DecoderTorch(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=VOCAB_SIZE,
                                      embedding_dim=EMBED_DIM)
        
        self.sinusoidal_embedding = SinusoidalEmbedding()
        
        #(1)
        decoder_block = nn.TransformerDecoderLayer(d_model=EMBED_DIM, 
                                                   nhead=NUM_HEADS, 
                                                   dim_feedforward=HIDDEN_DIM, 
                                                   dropout=DROP_PROB, 
                                                   batch_first=True)
        
        #(2)
        self.decoder_blocks = nn.TransformerDecoder(decoder_layer=decoder_block, 
                                                    num_layers=NUM_DECODER_BLOCKS)
        
        self.linear = nn.Linear(in_features=EMBED_DIM, 
                                out_features=VOCAB_SIZE)
        
    def forward(self, features, captions, tgt_mask):
        #print(f"features\t\t: {features.shape}")
        #print(f"captions\t\t: {captions.shape}")
        
        captions = self.embedding(captions)
        #print(f"after embedding\t\t: {captions.shape}")
        
        captions = captions + self.sinusoidal_embedding()
        #print(f"after sin embed\t\t: {captions.shape}")
        
        #(3)
        captions = self.decoder_blocks(tgt=captions, 
                                       memory=features, 
                                       tgt_mask=tgt_mask)
        #print(f"after decoder blocks\t: {captions.shape}")
        
        captions = self.linear(captions)
        #print(f"after linear\t\t: {captions.shape}")
        
        return captions

In [28]:
# Codeblock 22
decoder_torch = DecoderTorch()

features = torch.randn(BATCH_SIZE, NUM_PATCHES, EMBED_DIM)
captions = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))

captions = decoder_torch(features, captions, look_ahead_mask)

features		: torch.Size([1, 576, 768])
captions		: torch.Size([1, 30])
after embedding		: torch.Size([1, 30, 768])
after sin embed		: torch.Size([1, 30, 768])
after decoder blocks	: torch.Size([1, 30, 768])
after linear		: torch.Size([1, 30, 10000])


In [32]:
# Codeblock 23
class EncoderDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()  #EncoderTorch()  #(1)
        self.decoder = Decoder()  #DecoderTorch()  #(2)
        
    def forward(self, images, captions, look_ahead_mask):  #(3)
        print(f"images\t\t\t: {images.shape}")
        print(f"captions\t\t: {captions.shape}")
        
        features = self.encoder(images)
        print(f"after encoder\t\t: {features.shape}")
        
        captions = self.decoder(features, captions, look_ahead_mask)
        print(f"after decoder\t\t: {captions.shape}")
        
        return captions

In [33]:
# Codeblock 24
encoder_decoder = EncoderDecoder()

images = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)  #(1)
captions = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))  #(2)

captions = encoder_decoder(images, captions, look_ahead_mask)

images			: torch.Size([1, 3, 384, 384])
captions		: torch.Size([1, 30])
after encoder		: torch.Size([1, 576, 768])
after decoder		: torch.Size([1, 30, 10000])
