In [19]:
from torchvision.models import resnet50, ResNet50_Weights
from torchinfo import summary
import torch


model = resnet50(weights=ResNet50_Weights)

summary(model, input_size=(1, 3, 224, 224))




Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 256, 56, 56]          --
│    └─Bottleneck: 2-1                   [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           4,096
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│ 

In [20]:
import sys
from pathlib import Path

project_root = Path().resolve().parents[0]
sys.path.append(str(project_root))

In [21]:
from src.models.components.resnet_encoder import ResNet50Encoder

from src.models.components.lstm_decoder import LSTMDecoder

In [22]:
Batch = 4
max_length = 10
vocab_size = 1000

encoder = ResNet50Encoder(embed_size=512)
decoder = LSTMDecoder(
    embed_size=512,
    hidden_size=512,
    vocab_size=vocab_size,
    num_layers=1,
)

images = torch.randn(Batch, 3, 224, 224)
captions = torch.randint(0, vocab_size, (Batch, max_length))

lengths = torch.full((Batch,), max_length, dtype=torch.long)


features = encoder(images)
outputs = decoder(features, captions, lengths)

print("features:", features.shape)
print("outputs:", outputs.shape)


features: torch.Size([4, 512])
outputs: torch.Size([4, 11, 1000])
