In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import os
import csv
import random
from transformers import BertTokenizer

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Define transformations for image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class ImageCaptioningDataset(Dataset):
    def __init__(self, image_dir, caption_file, transform=None, tokenizer=None, sample_size=None):
        self.image_dir = image_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.image_ids = []
        self.captions = []

        # Read CSV with comma separator
        with open(caption_file, 'r', encoding='utf-8') as f:
            reader = csv.reader(f)
            next(reader)  # skip header
            for row in reader:
                if len(row) != 2:
                    continue  # skip malformed lines
                image_id, caption = row
                self.image_ids.append(image_id)
                self.captions.append(caption)
        
        # If sample_size is provided, randomly sample from the data
        if sample_size:
            sampled_indices = random.sample(range(len(self.image_ids)), min(sample_size, len(self.image_ids)))
            self.image_ids = [self.image_ids[i] for i in sampled_indices]
            self.captions = [self.captions[i] for i in sampled_indices]
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        caption = self.captions[idx]
        image_path = os.path.join(self.image_dir, image_id)
        
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image {image_path} not found.")
        
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        # Tokenize the caption using the BertTokenizer
        if self.tokenizer:
            tokenized_caption = self.tokenizer(
                caption,
                padding='max_length',
                truncation=True,
                max_length=50,
                return_tensors='pt'
            ).input_ids.squeeze(0)  # Convert to 1D tensor
        else:
            tokenized_caption = caption  # If no tokenizer is passed, return the caption as is
        
        return image, tokenized_caption

# Paths
image_dir = 'C:/Users/ajais/Downloads/LIS 640 Project/Dataset/Images'
caption_file = 'C:/Users/ajais/Downloads/LIS 640 Project/Dataset/captions.txt'

# Compute total number of lines (excluding header) to get total dataset size
with open(caption_file, 'r', encoding='utf-8') as f:
    total_lines = sum(1 for _ in f) - 1  # subtract 1 for header

# Use 10% of the dataset
sample_size = int(total_lines * 0.1)

# Initialize Dataset with tokenizer
dataset = ImageCaptioningDataset(
    image_dir=image_dir,
    caption_file=caption_file,
    transform=transform,
    tokenizer=tokenizer,  # Pass tokenizer here
    sample_size=sample_size
)

# Split dataset into training (70%), validation (15%), and test (15%)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Initialize DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, pin_memory=True)


In [2]:
import torchvision.models as models
import torch.nn as nn

class CNNEncoder(nn.Module):
    def __init__(self, embed_size):
        super(CNNEncoder, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad = False  # Freeze ResNet weights

        modules = list(resnet.children())[:-1]  # Remove the last FC layer
        self.resnet = nn.Sequential(*modules)
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        features = self.resnet(images)  # Output shape: (batch_size, 2048, 1, 1)
        features = features.view(features.size(0), -1)  # Shape: (batch_size, 2048)
        features = self.fc(features)  # Shape: (batch_size, embed_size)
        features = self.bn(features)
        return features


In [3]:
import torch
import torch.nn as nn

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, hidden_dim, num_layers, max_len=50):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = nn.Parameter(torch.randn(1, max_len, embed_size))  # Learnable positional encoding
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads, dim_feedforward=hidden_dim)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(embed_size, vocab_size)
    
    def forward(self, tgt, memory):
        # tgt: (batch_size, seq_len)
        # memory: (batch_size, seq_len, embed_size)
        
        # Add positional encoding to the target embeddings
        tgt_emb = self.embedding(tgt) + self.positional_encoding[:, :tgt.size(1), :]
        tgt_emb = tgt_emb.permute(1, 0, 2)  # (seq_len, batch_size, embed_size)
        
        # Ensure memory is in the correct shape
        memory = memory.permute(1, 0, 2)  # (seq_len, batch_size, embed_size)
        
        # Pass through the transformer decoder
        output = self.transformer_decoder(tgt_emb, memory)
        
        # Output layer to predict the next word in the sequence
        output = self.fc(output)
        
        # Return the output in shape (batch_size, seq_len, vocab_size)
        return output.permute(1, 0, 2)  # (batch_size, seq_len, vocab_size)


In [4]:
# Define model parameters
vocab_size = tokenizer.vocab_size
embed_size = 256
num_heads = 4
hidden_dim = 256
num_layers = 2

# Instantiate encoder and decoder
encoder = CNNEncoder(embed_size)
decoder = TransformerDecoder(vocab_size, embed_size, num_heads, hidden_dim, num_layers)



In [5]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.optim as optim
from torchmetrics.functional import accuracy

class LitImageCaptioningModel(pl.LightningModule):
    def __init__(self, encoder, decoder, tokenizer, lr=1e-4):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tokenizer = tokenizer
        self.criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
        self.lr = lr

    def forward(self, images, captions):    
        with torch.no_grad():
            features = self.encoder(images)  # (batch_size, embed_size)
            seq_len = captions.shape[1]
            features = features.unsqueeze(1).repeat(1, seq_len, 1)  # (batch_size, seq_len, embed_size)
        
        outputs = self.decoder(captions, features)
        return outputs

    def training_step(self, batch, batch_idx):
        images, captions = batch
        outputs = self(images, captions[:, :-1])
        targets = captions[:, 1:]

        loss = self.criterion(
            outputs.reshape(-1, outputs.shape[-1]),
            targets.reshape(-1)
        )

        preds = outputs.argmax(dim=-1)
        mask = targets != self.tokenizer.pad_token_id
        correct = (preds == targets) & mask
        acc = correct.sum().float() / mask.sum().float()

        self.log("train_loss", loss, on_step=False, on_epoch=True)
        self.log("train_acc", acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, captions = batch
        inputs = captions[:, :-1]
        targets = captions[:, 1:]

        outputs = self(images, inputs)
        outputs = outputs[:, : targets.shape[1], :]

        loss = self.criterion(
            outputs.reshape(-1, outputs.shape[-1]),
            targets.reshape(-1)
        )

        preds = outputs.argmax(dim=-1)
        mask = targets != self.tokenizer.pad_token_id
        correct = (preds == targets) & mask
        acc = correct.sum().float() / mask.sum().float()

        self.log("val_loss", loss, on_step=False, on_epoch=True)
        self.log("val_acc", acc, on_step=False, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        images, captions = batch
        inputs = captions[:, :-1]
        targets = captions[:, 1:]

        outputs = self(images, inputs)
        outputs = outputs[:, : targets.shape[1], :]

        loss = self.criterion(
            outputs.reshape(-1, outputs.shape[-1]),
            targets.reshape(-1)
        )

        preds = outputs.argmax(dim=-1)
        mask = targets != self.tokenizer.pad_token_id
        correct = (preds == targets) & mask
        acc = correct.sum().float() / mask.sum().float()

        self.log("test_loss", loss, on_step=False, on_epoch=True)
        self.log("test_acc", acc, on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

# Logger
logger = TensorBoardLogger("lightning_logs", name="image_captioning")

# Trainer
trainer = pl.Trainer(
    max_epochs=10,
    accelerator="cpu",
    logger=logger,
    log_every_n_steps=5,
    enable_progress_bar=True
)

# Instantiate the model
lit_model = LitImageCaptioningModel(encoder, decoder, tokenizer)

# Train
trainer.fit(lit_model, train_dataloader, val_dataloader)

# Test
trainer.test(lit_model, dataloaders=test_dataloader)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | encoder   | CNNEncoder         | 24.0 M | train
1 | decoder   | TransformerDecoder | 17.0 M | train
2 | criterion | CrossEntropyLoss   | 0      | train
---------------------------------------------------------
17.5 M    Trainable params
23.5 M    Non-trainable params
41.0 M    Total params
164.090   Total estimated model params size (MB)
187       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\ajais\anaconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
c:\Users\ajais\anaconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
c:\Users\ajais\anaconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 2.533141613006592, 'test_acc': 0.6731557846069336}]