# Load Pretrained CNN Encoder

In [1]:
import torch
import torchvision.models as models
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

# Load pretrained ResNet model
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)  # Load ResNet-50
        for param in resnet.parameters():
            param.requires_grad_(False)  # Freeze ResNet weights
        
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])  # Remove last FC layer
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)  # Fully connected layer
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.fc(features)
        features = self.bn(features)
        return features


# Define RNN Decoder

In [2]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, captions):
        embeddings = self.embed(captions[:, :-1])  # Exclude <end> token
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        lstm_out, _ = self.lstm(embeddings)
        outputs = self.fc(lstm_out)
        return outputs


# **Step 2.1
Loading features.pkl into PyTorch**

In [7]:
import pickle
import torch

# Load precomputed features
with open("E:\\Project\\features.pkl", "rb") as f:
    features = pickle.load(f)  # Dictionary: {image_name: feature_vector}

# Convert features to PyTorch tensors
features_torch = {k: torch.tensor(v) for k, v in features.items()}

# Check one example
sample_key = list(features_torch.keys())[0]
print(f"Image: {sample_key}")
print(f"Feature Shape: {features_torch[sample_key].shape}")


Image: 1000268201_693b08cb0e.jpg
Feature Shape: torch.Size([2048])
