<a href="https://colab.research.google.com/github/ChetanKrishnaPeela/CodSoft/blob/main/CodSoft_AI_Task3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pack_padded_sequence
from PIL import Image
import numpy as np

class Vocabulary:
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
        self.add_word('<pad>')
        self.add_word('<start>')
        self.add_word('<end>')
        self.add_word('<unk>')

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        return self.word2idx.get(word, self.word2idx['<unk>'])

    def __len__(self):
        return len(self.word2idx)

class ImageCaptioningModel(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(ImageCaptioningModel, self).__init__()
        self.cnn = models.resnet50(pretrained=True)
        self.cnn.fc = nn.Linear(self.cnn.fc.in_features, embed_size)
        for param in self.cnn.parameters():
            param.requires_grad = False

        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=embed_size, nhead=8), num_layers
        )
        self.fc = nn.Linear(embed_size, vocab_size)
        self.embed_size = embed_size

    def forward(self, images, captions, lengths):
        features = self.cnn(images)
        features = features.unsqueeze(1)

        embeddings = self.embedding(captions)

        outputs = self.transformer_decoder(embeddings.transpose(0, 1), features.transpose(0, 1))
        outputs = self.fc(outputs.transpose(0, 1))

        return outputs

    def caption_image(self, image, vocab, max_length=20):
        self.eval()
        with torch.no_grad():
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            image = transform(image).unsqueeze(0)

            features = self.cnn(image).unsqueeze(1)

            caption = []
            input_token = torch.tensor([vocab('<start>')]).unsqueeze(0)
            for _ in range(max_length):
                embeddings = self.embedding(input_token)
                output = self.transformer_decoder(embeddings.transpose(0, 1), features.transpose(0, 1))
                output = self.fc(output.transpose(0, 1))
                predicted = output.argmax(2)[:, -1].item()
                caption.append(predicted)
                if predicted == vocab('<end>'):
                    break
                input_token = torch.tensor([predicted]).unsqueeze(0)

            return [vocab.idx2word[idx] for idx in caption if idx not in [vocab('<start>'), vocab('<end>')]]

class MockDataset(torch.utils.data.Dataset):
    def __init__(self, vocab):
        self.vocab = vocab
        self.images = [torch.randn(3, 224, 224) for _ in range(10)]
        self.captions = [
            ['<start>', 'a', 'dog', 'is', 'running', '<end>'],
            ['<start>', 'a', 'cat', 'is', 'sleeping', '<end>']
        ] * 5
        for caption in self.captions:
            for word in caption:
                vocab.add_word(word)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        caption = [self.vocab(word) for word in self.captions[idx]]
        return image, torch.tensor(caption), len(caption)

def train_model(model, dataset, num_epochs=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab('<pad>'))
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        model.train()
        for image, caption, length in torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True):
            image, caption = image.to(device), caption.to(device)
            optimizer.zero_grad()
            outputs = model(image, caption[:, :-1], length)
            loss = criterion(outputs.view(-1, len(dataset.vocab)), caption[:, 1:].contiguous().view(-1))
            loss.backward()
            optimizer.step()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

if __name__ == '__main__':
    vocab = Vocabulary()
    dataset = MockDataset(vocab)

    model = ImageCaptioningModel(embed_size=256, hidden_size=512, vocab_size=len(vocab), num_layers=3)

    train_model(model, dataset)

    sample_image = Image.fromarray((torch.randn(3, 224, 224).numpy() * 255).astype(np.uint8).transpose(1, 2, 0))
    caption = model.caption_image(sample_image, vocab)
    print('Generated Caption:', ' '.join(caption))

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 89.7MB/s]


Epoch [1/5], Loss: 0.4145
Epoch [2/5], Loss: 0.0748
Epoch [3/5], Loss: 0.0239
Epoch [4/5], Loss: 0.0055
Epoch [5/5], Loss: 0.0015
Generated Caption: a cat is sleeping
