Importing modules

In [1]:
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50
from PIL import Image
import pickle
import numpy as np
from torch.nn.utils.rnn import pack_padded_sequence
import torch.nn as nn
import torch.optim as optim

Load pre-trained ResNet model for feature extraction

In [2]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = resnet50(pretrained=True)
        modules = list(resnet.children())[:-1]  # Remove the last fully connected layer
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features

LSTM-based Decoder for caption generation

In [3]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, captions, lengths):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True, enforce_sorted=False)
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        return outputs


Load Image and Preprocess

In [4]:
def load_image(image_path, transform=None):
    image = Image.open(image_path).convert('RGB')
    if transform is not None:
        image = transform(image).unsqueeze(0)  # Add batch dimension
    return image

Transformations

In [5]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

Example Usage

In [51]:
if __name__ == "__main__":
    embed_size = 256
    hidden_size = 512
    vocab_size = 5000  # Assume a vocabulary size
    num_layers = 1

    encoder = EncoderCNN(embed_size)
    decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

    # Load an example image
    image_path = "/content/Cat.jpg.jpg"  # Replace with your image path
    image = load_image(image_path, transform)

    # Extract features
    encoder.eval()  # Set to evaluation mode

# Extract features
    with torch.no_grad():
        features = encoder(image)

    print("Extracted Features Shape:", features.shape)

Extracted Features Shape: torch.Size([1, 256])
