In [1]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.transforms import v2
import cv2
from torchvision.models import (densenet121, DenseNet121_Weights,
                                densenet161, DenseNet161_Weights,
                                resnet50, ResNet50_Weights,
                                resnet152, ResNet152_Weights, 
                                vgg19, VGG19_Weights)
from pycocotools.coco import COCO
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import CocoCaptions

  check_for_updates()


In [2]:
os.getcwd()

'c:\\Users\\Srijan\\Desktop\\Srijan\\seq2seq-demo\\image_captioning\\cnn_lstm_attention'

In [3]:
train_root_img = "C:\\Users\\Srijan\\Desktop\\Srijan\\seq2seq-demo\\image_captioning\\COCO2014\\train2014"
val_root_img = "C:\\Users\\Srijan\\Desktop\\Srijan\\seq2seq-demo\\image_captioning\\COCO2014\\val2014"
train_captions = "C:\\Users\\Srijan\\Desktop\\Srijan\\seq2seq-demo\\image_captioning\\COCO2014\\annotations_trainval2014\\annotations\\captions_train2014.json"
val_captions = "C:\\Users\\Srijan\\Desktop\\Srijan\\seq2seq-demo\\image_captioning\\COCO2014\\annotations_trainval2014\\annotations\\captions_val2014.json"

In [4]:
trans_album = {
    "train": A.Compose([
        A.Resize(224, 224, interpolation=cv2.INTER_AREA),
        A.RandomRotate90(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        A.pytorch.ToTensorV2()], p=1.
    ),
    "test": A.Compose([
        A.Resize(224, 224, interpolation=cv2.INTER_AREA),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        A.pytorch.ToTensorV2()], p=1.
    )
}

trans_v2 = v2.Compose([
    v2.Resize((224, 224)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
        )
])

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
img0 = Image.open(os.path.join(train_root_img, os.listdir(train_root_img)[0])).convert('RGB')

In [7]:
resnet152_net = resnet152(weights = ResNet152_Weights.DEFAULT)
resnet152_net = nn.Sequential(*list(resnet152_net.children())[:-2]).to(device)
resnet152_dim = 2048

In [8]:
img0_trans_album = trans_album["train"](image = np.array(img0, dtype = np.float32))["image"].to(device).unsqueeze(0)
img0_trans_v2 = trans_v2(img0).to(device).unsqueeze(0)
img0_trans_album.size(), img0_trans_v2.size()

(torch.Size([1, 3, 224, 224]), torch.Size([1, 3, 224, 224]))

In [9]:
img0_res152 = resnet152_net(img0_trans_album)
img0_res152.size()

torch.Size([1, 2048, 7, 7])

In [10]:
img0_res152 = img0_res152.permute(0, 2, 3, 1)
img0_res152.size()

torch.Size([1, 7, 7, 2048])

In [11]:
img0_res152 = img0_res152.view(img0_res152.size(0), -1, img0_res152.size(-1))
img0_res152.size()

torch.Size([1, 49, 2048])

In [12]:
type(trans_album["train"]), type(trans_v2)

(albumentations.core.composition.Compose,
 torchvision.transforms.v2._container.Compose)

In [13]:
type(trans_album["train"]) == A.core.composition.Compose, type(trans_v2) == v2._container.Compose

(True, True)

In [14]:
def get_coco_dataloader(
    transform,
    root: str,
    annFile: str,
    batch_size: int = 32,
    num_workers: int = 4
):
    """
    Create a DataLoader for COCO Captions using torchvision's built-in dataset.
    
    Args:
        root: Path to the COCO images directory
        annFile: Path to the annotations json file
        batch_size: Number of samples per batch
        num_workers: Number of worker processes for data loading
    """
    # Define transforms
    transform = v2.Compose([
        v2.Resize((224, 224)),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
            )
    ])
    # Create dataset
    dataset = CocoCaptions(
        root=root,
        annFile=annFile,
        transform=transform
    )
    
    # Create dataloader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return dataloader

In [15]:
val_captions

'C:\\Users\\Srijan\\Desktop\\Srijan\\seq2seq-demo\\image_captioning\\COCO2014\\annotations_trainval2014\\annotations\\captions_val2014.json'

In [16]:
root = val_root_img
annFile = val_captions
dataloader = get_coco_dataloader(root = root, annFile=annFile, transform=trans_v2)

# Each item will be (image, captions) where captions is a list of strings
for images, captions in dataloader:
    # images: tensor of shape [batch_size, 3, 224, 224]
    # captions: list of lists, where each inner list contains 5 captions for one image
    print(images.size())
    print(captions)
    break


loading annotations into memory...
Done (t=0.27s)
creating index...
index created!
torch.Size([32, 3, 224, 224])
[['Room with a couch, tv, dining table surrounded by chairs and two doors ', 'A player runs for the ball during a tennis match.', 'A doughnut shop sign hanging off the side of a building.', 'There are several multicolored sun umbrellas and this boy is holding one', 'A man riding a motorcycle approaching a man wearing camouflage clothing.', 'A man and two girls sitting on a couch with a dog.', 'A policeman, cameraman, and reporter stand near a police checkpoint.', 'People are horseback riding as a man is taking a picture.', 'A Penn tennis bill resting on a tennis racquet', 'The top of the head of a man sitting in front of disorganized computer desk', 'a stove top with a tea kettle with steam pouring out of it.', 'a girl smiling sitting at a table in front of several display items.', 'A group of men play frisbee in a field.', 'The young girl is sitting at the table eating a pi

In [17]:
len(captions), len(captions[0])

(5, 32)

In [18]:
[caption for caption in captions[2]]

['The living room is clean and empty of people.',
 'A woman standing on a tennis court with a racket in her hand.',
 'Two different types of signs hanging off a building.',
 'A man in a yellow shirt takes down colorful umbrellas outside.',
 'A man riding an old motorcycle beside an Army worker',
 'Some kids are relaxing with their dad on a coach',
 'A bunch of police officers on a city street corner,',
 'a lady on a horse and people taking a photo',
 'A blue tennis racket has a yellow tennis ball on it.',
 'A man with blonde hair in front of his computer.',
 'Smoke is rising from a pan on the back of the stove top.',
 'The woman is sitting at the table with the deserts. ',
 'A group of men on a field playing frisbee.',
 'A little girl eating a piece of cake. ',
 'A train track junction with a train on one of the tracks.',
 'A man who is attempting to hit a tennis ball.',
 'A cold dog curled up and going to sleep.',
 'Still shots of a man trying to hit a shuttlecock.',
 'a train is pass

In [19]:
captions[0][0], captions[1][0], captions[2][0], captions[3][0], captions[4][0]

('Room with a couch, tv, dining table surrounded by chairs and two doors ',
 'A couch, table, chairs and a tv are featured in a den with blue carpeting.',
 'The living room is clean and empty of people.',
 'Someone recently redid their front room in blues and browns',
 'An empty living room with many pieces of furniture. ')

In [20]:
# import matplotlib.pyplot as plt

In [21]:
# image = images[0]
# print(type(image))
# mean = torch.tensor([0.485, 0.456, 0.406])
# std = torch.tensor([0.229, 0.224, 0.225])
# image = (image.permute(1, 2, 0) * std.view(1, 1, 3) + mean.view(1, 1, 3)).clamp(0, 1).byte().numpy()
# image = Image.fromarray(image, mode="RGB")

In [22]:
# Image.fromarray(images[0].permute(1, 2, 0).numpy(), mode='RGB').show()
# print("COCO Caption: " + ", ".join(captions[0]))


In [23]:
batch0_enc = resnet152_net(images.to(device))
batch0_enc.size()

torch.Size([32, 2048, 7, 7])

In [24]:
batch0_enc = batch0_enc.permute(0, 2, 3, 1)
batch0_enc.size()

torch.Size([32, 7, 7, 2048])

In [25]:
batch0_enc = batch0_enc.view(batch0_enc.size(0), -1, batch0_enc.size(-1))
batch0_enc.size()

torch.Size([32, 49, 2048])

In [26]:
max(len(caption) for sublist in captions for caption in sublist)

89

In [27]:
max_len_caption, max_indices = max(
    ((caption, (i, j)) for i, sublist in enumerate(captions) for j, caption in enumerate(sublist)),
    key=lambda x: len(x[0])
)
len(max_len_caption), max_len_caption, max_indices

(89,
 'A bus with passengers who are getting out of bus with their luggage at their destination.',
 (0, 30))

In [28]:
max([len(caption) for caption in captions[0]]) - 1

88

In [30]:
batch0_enc.dim()

3