In [None]:
class ImageClassificationModel(torch.nn.Module):
    def __init__(self, cnn_model, image_size, vocab_size, max_seq_length, dropout=0.1, embed_size=512):
        super(ImageClassificationModel, self).__init__()
        self.encoder = EncoderCNN(cnn_model)
        self.decoder = DecoderRNN(cnn_model, image_size, vocab_size, embed_size, max_seq_length, dropout=dropout)
        
    def forward(self, image, captions):
        image_features = self.encoder(image)
        output = self.decoder(image_features, captions)
        return output
    
    def generate_caption(self, image, max_seq_length):
        image_features = self.encoder(image)
        batch_size = image_features.shape[0]
        tgt = torch.ones(batch_size, 1).long().to(image.device)
        for i in range(max_seq_length):
            output = self.decoder(image_features, tgt)
            output = torch.argmax(output, dim=2)
            tgt = torch.cat((tgt, output[:, -1].unsqueeze(1)), dim=1)
        return tgt
    
    def generate_caption_beam_search(self, image, max_seq_length, beam_size=3):
        image_features = self.encoder(image)
        batch_size = image_features.shape[0]
        tgt = torch.ones(batch_size, 1).long().to(image.device)
        for i in range(max_seq_length):
            output = self.decoder(image_features, tgt)
            output = torch.argmax(output, dim=2)
            tgt = torch.cat((tgt, output[:, -1].unsqueeze(1)), dim=1)
        return tgt
    
    def generate_caption_greedy_search(self, image, max_seq_length):
        image_features = self.encoder(image)
        batch_size = image_features.shape[0]
        tgt = torch.ones(batch_size, 1).long().to(image.device)
        for i in range(max_seq_length):
            output = self.decoder(image_features, tgt)
            output = torch.argmax(output, dim=2)
            tgt = torch.cat((tgt, output[:, -1].unsqueeze(1)), dim=1)
        return tgt
    
    def generate_caption_sampling(self, image, max_seq_length):
        image_features = self.encoder(image)
        batch_size = image_features.shape[0]
        tgt = torch.ones(batch_size, 1).long().to(image.device)
        for i in range(max_seq_length):
            output = self.decoder(image_features, tgt)
            output = torch.argmax(output, dim=2)
            tgt = torch.cat((tgt, output[:, -1].unsqueeze(1)), dim=1)
        return tgt
    
    
    # class DecoderRNN(torch.nn.Module):
#     def __init__(self, cnn_model, image_size, vocab_size, embed_size, max_seq_length, dropout=0.1, num_layers=6):
#         super(DecoderRNN, self).__init__()
#         self.token_embed = torch.nn.Embedding(vocab_size, embed_size) # Token embedding
#         self.position_embed = PositionalEncoding(embed_size, dropout=dropout, max_len=max_seq_length)
#         self.decoder = torch.nn.Transformer(d_model=embed_size, nhead=8, num_encoder_layers=num_layers, num_decoder_layers=num_layers, dropout=dropout)
#         self.fc = torch.nn.Linear(embed_size, vocab_size) # Fully connected layer to output the predicted token
#         self.dropout = torch.nn.Dropout(dropout)
#         
#         # Image features from CNN Encoder
#         self.encoder = cnn_model
#         
#         
#     def forward(self, image, captions, tgt_mask=None):
#         # image = self.embed(image)
#         batch_size, seq_len = captions.shape[0], captions.shape[1]
#         
#         # Get Image Features from Encoder
#         with torch.no_grad():
#             image_features = self.encoder(image)
#         image_features = image_features.view(image_features.shape[0], -1) # [seq_len, batch_size, embed_size]
#         
#         # Decoder
#         tgt = self.token_embed(captions) * torch.sqrt(torch.tensor([self.token_embed.embedding_dim]).to(image.device))
#         # Add Positional Encoding to the token embeddings
#         tgt = self.position_embed(tgt)
#         tgt = tgt.permute(1, 0, 2) # [seq_len, batch_size, embed_size]
#         encoder_memory = self.encoder(image)
#         tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)
#         output = self.decoder(tgt, encoder_memory, tgt_mask=tgt_mask)
#         output = output.permute(1, 0, 2) # [batch_size, seq_len, embed_size]
#         output = self.fc(output)
#         return output, encoder_memory
    