In [137]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim
from torchvision.datasets import MNIST

In [149]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)  
        att2 = self.decoder_att(decoder_hidden)  
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) 
        alpha = self.softmax(att)  
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) 
        return attention_weighted_encoding, alpha

In [156]:
class DecoderWithAttention(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim, dropout):
        super(DecoderWithAttention, self).__init__()
        self.encoder_dim = encoder_dim
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim)
        self.fc = nn.Linear(decoder_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.fc.out_features        
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  
        num_pixels = encoder_out.size(1)
        
        caption_lengths, sort_ind = caption_lengths.sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]
        
        h, c = torch.zeros(batch_size, decoder_dim).to(device), torch.zeros(batch_size, decoder_dim).to(device)        
        predictions = torch.zeros(batch_size, max(caption_lengths), vocab_size).to(device)
    
        for t in range(max(caption_lengths)):
            batch_size_t = sum([l > t for l in caption_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
            
            # Fix: Apply embedding layer to the encoded captions indices
            embedded = self.embedding(encoded_captions[:batch_size_t, t])
            
            # Now concatenate the embedded vector with attention_weighted_encoding
            lstm_input = torch.cat([embedded, attention_weighted_encoding], dim=1)
            
            h, c = self.lstm(lstm_input, (h[:batch_size_t], c[:batch_size_t]))
            preds = self.fc(self.dropout(h))
            predictions[:batch_size_t, t, :] = preds
        return predictions, encoded_captions, caption_lengths, sort_ind
    

In [157]:
class MNISTCaptioningModel(nn.Module):
    def __init__(self, embed_dim, decoder_dim, attention_dim, vocab_size, encoder_dim=512, dropout=0.5):
        super(MNISTCaptioningModel, self).__init__()
        # Use a simple CNN as the encoder for MNIST
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, encoder_dim, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))  # Global average pooling
        )
        # Fix: Changed DecoderwithAttention to DecoderWithAttention
        self.decoder = DecoderWithAttention(attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim, dropout)

    def forward(self, images, encoded_captions, caption_lengths):
        encoder_out = self.encoder(images)  
        encoder_out = encoder_out.view(encoder_out.size(0), -1, encoder_out.size(1))  # (batch_size, 1, encoder_dim)
        predictions, encoded_captions, caption_lengths, sort_ind = self.decoder(encoder_out, encoded_captions, caption_lengths)
        return predictions, encoded_captions, caption_lengths, sort_ind

In [158]:
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


vocab = ["<start>", "<end>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "This", "is", "the", "digit"]
vocab_size = len(vocab)
word_to_idx = {word:idx for idx, word in enumerate(vocab)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

In [159]:
def encode_caption(caption):
    return [word_to_idx[word] for word in caption.split()]

In [160]:
captions = ["This is the digit 1", "This is the digit 2", "This is the digit 3", "This is the digit 4", "This is the digit 6", "This is the digit 7", "This is the digit 8", "This is the digit 5", "This is the digit 0", "This is the digit 9"]

encoded_captions = [encode_caption(caption) for caption in captions]
max_caption_length = max(len(caption) for caption in encoded_captions)
padded_captions = [caption + [word_to_idx["<end>"]] * (max_caption_length - len(caption)) for caption in encoded_captions]
encoded_captions = torch.tensor(padded_captions)

In [161]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embed_dim = 256
decoder_dim = 512
attention_dim = 512
model = MNISTCaptioningModel(embed_dim, decoder_dim, attention_dim, vocab_size).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=word_to_idx["<end>"])  # Ignore padding index
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [162]:
train_dataset = MNIST(root="./data", train=True, transform=transform, download=True)

# Create DataLoader
batch_size = 32
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [125]:
"""train_dataset = datasets.MNIST(root ='./data',train = True,download = True, transform = transform)
test_dataset = datasets.MNIST(root ='./data',train = False,download = True, transform = transform)

train_loader = DataLoader(dataset = train_dataset, batch_size = 128, shuffle = True)
test_loader = DataLoader(dataset = test_dataset, batch_size = 128, shuffle = False)

device = torch.device("cude" if torch.cuda.is_available() else "cpu")"""


'train_dataset = datasets.MNIST(root =\'./data\',train = True,download = True, transform = transform)\ntest_dataset = datasets.MNIST(root =\'./data\',train = False,download = True, transform = transform)\n\ntrain_loader = DataLoader(dataset = train_dataset, batch_size = 128, shuffle = True)\ntest_loader = DataLoader(dataset = test_dataset, batch_size = 128, shuffle = False)\n\ndevice = torch.device("cude" if torch.cuda.is_available() else "cpu")'

In [None]:
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)

        # Use labels as captions (e.g., "This is the digit 5")
        target_captions = [f"This is the digit {label}" for label in labels.tolist()]
        encoded_targets = [encode_caption(caption) for caption in target_captions]
        padded_targets = [caption + [word_to_idx["<end>"]] * (max_caption_length - len(caption)) for caption in encoded_targets]
        encoded_targets = torch.tensor(padded_targets).to(device)


        predictions, _, _, _ = model(images, encoded_targets, torch.tensor([max_caption_length] * batch_size).to(device))
        predictions = predictions.view(-1, vocab_size)
        targets = encoded_targets.view(-1)
        loss = criterion(predictions, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

In [39]:
def predict_caption(model, image):
    model.eval()
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        encoder_out = model.encoder(image)
        encoder_out = encoder_out.view(1, -1, encoder_out.size(1))

        
        caption = [word_to_idx["<start>"]]
        h, c = torch.zeros(1, decoder_dim).to(device), torch.zeros(1, decoder_dim).to(device)

        for _ in range(max_caption_length):
            encoded_caption = torch.tensor(caption).unsqueeze(0).to(device)
            predictions, _, _, _ = model.decoder(encoder_out, encoded_caption, torch.tensor([len(caption)]).to(device))
            next_word_idx = predictions.argmax(2)[:, -1].item()
            caption.append(next_word_idx)
            if next_word_idx == word_to_idx["<end>"]:
                break


        caption_words = [idx_to_word[idx] for idx in caption if idx not in [word_to_idx["<start>"], word_to_idx["<end>"]]]
        return " ".join(caption_words)


In [None]:
test_image, test_label = dataset[0]
predicted_caption = predict_caption(model, test_image)
print(f"Predicted Caption: {predicted_caption}")
print(f"Actual Label: {test_label}")