In [None]:
pip install torch torchvision spacy pillow




In [None]:
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m71.7 MB/s[0m eta [36m0:00:00[0m
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [2]:
!unzip /content/Flickr8k_text.zip

Archive:  /content/Flickr8k_text.zip
  inflating: CrowdFlowerAnnotations.txt  
  inflating: ExpertAnnotations.txt   
  inflating: Flickr8k.lemma.token.txt  
   creating: __MACOSX/
  inflating: __MACOSX/._Flickr8k.lemma.token.txt  
  inflating: Flickr8k.token.txt      
  inflating: Flickr_8k.devImages.txt  
  inflating: Flickr_8k.testImages.txt  
  inflating: Flickr_8k.trainImages.txt  
  inflating: readme.txt              


In [3]:
import pandas as pd

def format_real_flickr_data():
    input_file = "Flickr8k.token.txt"
    output_file = "captions.txt"

    imgs = []
    caps = []

    print("Reading real Flickr8k data...")
    try:
        with open(input_file, "r") as f:
            for line in f:

                tokens = line.split("\t")
                if len(tokens) < 2:
                    continue


                img_id = tokens[0].split("#")[0]
                caption = tokens[1].strip()

                imgs.append(img_id)
                caps.append(caption)

        # Create Dataframe and save
        df = pd.DataFrame({"image": imgs, "caption": caps})
        df.to_csv(output_file, index=False)
        print(f"Success! Saved {len(df)} captions to {output_file}.")


    except FileNotFoundError:
        print(f"Error: Could not find {input_file}. Please download the dataset first.")

if __name__ == "__main__":
    format_real_flickr_data()


Reading real Flickr8k data...
Success! Saved 40460 captions to captions.txt.


In [4]:
import os
import spacy
import torch
import pandas as pd
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from typing import List, Tuple, Any


spacy_eng = spacy.load("en_core_web_sm")

class Vocabulary:

    def __init__(self, freq_threshold: int):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self) -> int:
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text: str) -> List[str]:
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list: List[str]):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text: str) -> List[int]:
        tokenized_text = self.tokenizer_eng(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

class FlickrDataset(Dataset):

    def __init__(self, root_dir: str, captions_file: str, transform: Any = None, freq_threshold: int = 5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        self.imgs = self.df["image"]
        self.captions = self.df["caption"]


        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        caption = self.captions[index]
        img_id = self.imgs[index]
        img_path = os.path.join(self.root_dir, img_id)


        img = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)


        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption += [self.vocab.stoi["<EOS>"]]

        return img, torch.tensor(numericalized_caption)

class MyCollate:

    def __init__(self, pad_idx: int):
        self.pad_idx = pad_idx

    def __call__(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)

        targets = [item[1] for item in batch]

        targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)

        return imgs, targets

def get_loader(
    root_folder: str,
    annotation_file: str,
    transform: Any,
    batch_size: int = 32,
    num_workers: int = 8,
    shuffle: bool = True,
    pin_memory: bool = True,
) -> Tuple[DataLoader, FlickrDataset]:

    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
    pad_idx = dataset.vocab.stoi["<PAD>"]

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )

    return loader, dataset


In [5]:
import torch
import torch.nn as nn
import torchvision.models as models
import math

class EncoderCNN(nn.Module):

    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)

        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        self.embed = nn.Linear(2048, embed_size)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()
        self.train_CNN = train_CNN

    def forward(self, images):
        with torch.set_grad_enabled(self.train_CNN):
            features = self.resnet(images)


        features = features.permute(0, 2, 3, 1)
        features = features.view(features.size(0), -1, features.size(3))


        features = self.embed(features)
        features = self.relu(features)

        return features

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) # [Max_Len, 1, D_Model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class DecoderTransformer(nn.Module):
    def __init__(self, embed_size, vocab_size, num_heads=4, num_layers=2, max_len=100):
        super(DecoderTransformer, self).__init__()

        self.embed = nn.Embedding(vocab_size, embed_size)
        self.pos_encoder = PositionalEncoding(embed_size, max_len)


        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        self.linear = nn.Linear(embed_size, vocab_size)
        self.max_len = max_len

    def forward(self, features, captions):

        # Embed captions
        embeddings = self.embed(captions) * math.sqrt(features.size(-1))


        embeddings = embeddings.permute(1, 0, 2)
        features = features.permute(1, 0, 2)

        # positional encoding
        embeddings = self.pos_encoder(embeddings)

        tgt_mask = self.generate_square_subsequent_mask(embeddings.size(0)).to(features.device)

        outputs = self.transformer_decoder(tgt=embeddings, memory=features, tgt_mask=tgt_mask)

        outputs = self.linear(outputs)

        return outputs.permute(1, 0, 2)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def caption_image(self, image_features, vocabulary, max_length=20):

        batch_size = image_features.size(0)
        start_token = vocabulary.stoi["<SOS>"]

        generated = torch.tensor([start_token]).unsqueeze(1).to(image_features.device)

        image_features = image_features.permute(1, 0, 2)

        result_caption = []

        for _ in range(max_length):

            tgt_emb = self.embed(generated) * math.sqrt(image_features.size(-1))
            tgt_emb = tgt_emb.permute(1, 0, 2)
            tgt_emb = self.pos_encoder(tgt_emb)

            # Mask
            mask = self.generate_square_subsequent_mask(tgt_emb.size(0)).to(image_features.device)

            # Decode
            out = self.transformer_decoder(tgt=tgt_emb, memory=image_features, tgt_mask=mask)

            # Get last token output
            last_token_out = out[-1, :, :]

            # Predict
            logits = self.linear(last_token_out)
            predicted_id = logits.argmax(1).item()

            result_caption.append(vocabulary.itos[predicted_id])

            if vocabulary.itos[predicted_id] == "<EOS>":
                break


            next_token = torch.tensor([predicted_id]).unsqueeze(1).to(image_features.device)
            generated = torch.cat((generated, next_token), dim=1)

        return result_caption

class ImageCaptionModel(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_heads, num_layers, train_CNN=False):
        super(ImageCaptionModel, self).__init__()
        self.encoder = EncoderCNN(embed_size, train_CNN)
        self.decoder = DecoderTransformer(embed_size, vocab_size, num_heads, num_layers)

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

    def caption_image(self, image, vocabulary, max_length=20):
        self.eval()
        with torch.no_grad():
            features = self.encoder(image.unsqueeze(0))
            return self.decoder.caption_image(features, vocabulary, max_length)

# Test
if __name__ == "__main__":
    embed_size = 256
    vocab_size = 1000
    model = ImageCaptionModel(embed_size, 256, vocab_size, num_heads=4, num_layers=2)

    img = torch.randn(2, 3, 224, 224)
    caps = torch.randint(0, 1000, (2, 20))

    out = model(img, caps)
    print(f"Transformer Output Shape: {out.shape}")


Downloading: "https://download.pytorch.org/models/resnet101-cd907fc2.pth" to /root/.cache/torch/hub/checkpoints/resnet101-cd907fc2.pth


100%|██████████| 171M/171M [00:00<00:00, 201MB/s]


Transformer Output Shape: torch.Size([2, 20, 1000])


In [6]:
import torch
import torchvision.transforms as transforms
from PIL import Image

def print_examples(model, device, dataset):
    """
    Helper to print predicted captions for a few images during training
    to manually verify progress.
    """
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    model.eval()


    test_img_dir = "images/"

    import os
    try:
        test_images = [f for f in os.listdir(test_img_dir) if f.endswith('.jpg')][:2]
    except FileNotFoundError:
        print("Image directory not found, skipping examples.")
        return

    for img_name in test_images:
        image = Image.open(os.path.join(test_img_dir, img_name)).convert("RGB")
        image_tensor = transform(image).unsqueeze(0).to(device)

        # Generate caption
        with torch.no_grad():
            caption = model.caption_image(image_tensor.squeeze(0), dataset.vocab)

        print(f"Image: {img_name}")
        print(f"Prediction: {' '.join(caption)}")
        print("-" * 20)

    model.train()

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])


In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


from data_loader import get_loader
from model import ImageCaptionModel
from utils import save_checkpoint, load_checkpoint, print_examples

#  Hyperparameters
embed_size = 256
hidden_size = 256
vocab_size = -1
num_heads = 4
num_layers = 2
learning_rate = 3e-4
num_epochs = 100
batch_size = 32
num_workers = 2
load_model = False
save_model = True
train_CNN = False
def train():

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")


    transform = transforms.Compose(
        [
            transforms.Resize((232, 232)),
            transforms.RandomCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )


    print("Loading Data...")
    loader, dataset = get_loader(
        root_folder="images/",
        annotation_file="captions.txt",
        transform=transform,
        batch_size=batch_size,
        num_workers=num_workers,
    )

    vocab_size = len(dataset.vocab)
    print(f"Vocabulary Size: {vocab_size}")

    model = ImageCaptionModel(
        embed_size=embed_size,
        hidden_size=hidden_size,
        vocab_size=vocab_size,
        num_heads=num_heads,
        num_layers=num_layers,
        train_CNN=train_CNN
    ).to(device)

    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])


    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    if load_model:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

    model.train()

    for epoch in range(num_epochs):
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")

        if save_model and epoch % 5 == 0:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            save_checkpoint(checkpoint)



        loop = tqdm(loader, leave=True)

        for idx, (imgs, captions) in enumerate(loop):
            imgs = imgs.to(device)
            captions = captions.to(device)


            outputs = model(imgs, captions[:, :-1])


            loss = criterion(
                outputs.reshape(-1, outputs.shape[2]),
                captions[:, 1:].reshape(-1)
            )

            optimizer.zero_grad()
            loss.backward()


            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            loop.set_postfix(loss=loss.item())

if __name__ == "__main__":
    train()


ModuleNotFoundError: No module named 'data_loader'

# Task
Refactor the existing image captioning notebook by organizing the `Vocabulary`, `FlickrDataset`, `MyCollate`, and `get_loader` into a data loader cell; `EncoderCNN`, `PositionalEncoding`, `DecoderTransformer`, and `ImageCaptionModel` into a model cell; and `print_examples`, `save_checkpoint`, and `load_checkpoint` into a utilities cell. Then, update the training cell by removing explicit imports for these components and execute all relevant cells.

## Move Data Loader Code

### Subtask:
Create a new code cell above the training cell and move the `Vocabulary`, `FlickrDataset`, `MyCollate`, and `get_loader` definitions into it. This cell will effectively act as the 'data_loader'.


**Reasoning**:
To fulfill the subtask, I need to create a new code cell directly above the training cell (`FHlswN3E6d3R`) and populate it with the `Vocabulary`, `FlickrDataset`, `MyCollate`, and `get_loader` definitions from cell `oUdVWXpEvyzf`.



In [10]:
import os
import spacy
import torch
import pandas as pd
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from typing import List, Tuple, Any


spacy_eng = spacy.load("en_core_web_sm")

class Vocabulary:

    def __init__(self, freq_threshold: int):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self) -> int:
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text: str) -> List[str]:
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list: List[str]):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text: str) -> List[int]:
        tokenized_text = self.tokenizer_eng(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

class FlickrDataset(Dataset):

    def __init__(self, root_dir: str, captions_file: str, transform: Any = None, freq_threshold: int = 5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        self.imgs = self.df["image"]
        self.captions = self.df["caption"]


        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        caption = self.captions[index]
        img_id = self.imgs[index]
        img_path = os.path.join(self.root_dir, img_id)


        img = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)


        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption += [self.vocab.stoi["<EOS>"]]

        return img, torch.tensor(numericalized_caption)

class MyCollate:

    def __init__(self, pad_idx: int):
        self.pad_idx = pad_idx

    def __call__(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)

        targets = [item[1] for item in batch]

        targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)

        return imgs, targets

def get_loader(
    root_folder: str,
    annotation_file: str,
    transform: Any,
    batch_size: int = 32,
    num_workers: int = 8,
    shuffle: bool = True,
    pin_memory: bool = True,
) -> Tuple[DataLoader, FlickrDataset]:

    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
    pad_idx = dataset.vocab.stoi["<PAD>"]

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )

    return loader, dataset

In [11]:
import torch
import torch.nn as nn
import torchvision.models as models
import math

class EncoderCNN:

    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)

        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        self.embed = nn.Linear(2048, embed_size)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()
        self.train_CNN = train_CNN

    def forward(self, images):
        with torch.set_grad_enabled(self.train_CNN):
            features = self.resnet(images)


        features = features.permute(0, 2, 3, 1)
        features = features.view(features.size(0), -1, features.size(3))


        features = self.embed(features)
        features = self.relu(features)

        return features

class PositionalEncoding:

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) # [Max_Len, 1, D_Model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class DecoderTransformer:
    def __init__(self, embed_size, vocab_size, num_heads=4, num_layers=2, max_len=100):
        super(DecoderTransformer, self).__init__()

        self.embed = nn.Embedding(vocab_size, embed_size)
        self.pos_encoder = PositionalEncoding(embed_size, max_len)


        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        self.linear = nn.Linear(embed_size, vocab_size)
        self.max_len = max_len

    def forward(self, features, captions):

        # Embed captions
        embeddings = self.embed(captions) * math.sqrt(features.size(-1))


        embeddings = embeddings.permute(1, 0, 2)
        features = features.permute(1, 0, 2)

        # positional encoding
        embeddings = self.pos_encoder(embeddings)

        tgt_mask = self.generate_square_subsequent_mask(embeddings.size(0)).to(features.device)

        outputs = self.transformer_decoder(tgt=embeddings, memory=features, tgt_mask=tgt_mask)

        outputs = self.linear(outputs)

        return outputs.permute(1, 0, 2)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def caption_image(self, image_features, vocabulary, max_length=20):

        batch_size = image_features.size(0)
        start_token = vocabulary.stoi["<SOS>"]

        generated = torch.tensor([start_token]).unsqueeze(1).to(image_features.device)

        image_features = image_features.permute(1, 0, 2)

        result_caption = []

        for _ in range(max_length):

            tgt_emb = self.embed(generated) * math.sqrt(image_features.size(-1))
            tgt_emb = tgt_emb.permute(1, 0, 2)
            tgt_emb = self.pos_encoder(tgt_emb)

            # Mask
            mask = self.generate_square_subsequent_mask(tgt_emb.size(0)).to(image_features.device)

            # Decode
            out = self.transformer_decoder(tgt=tgt_emb, memory=image_features, tgt_mask=mask)

            # Get last token output
            last_token_out = out[-1, :, :]

            # Predict
            logits = self.linear(last_token_out)
            predicted_id = logits.argmax(1).item()

            result_caption.append(vocabulary.itos[predicted_id])

            if vocabulary.itos[predicted_id] == "<EOS>":
                break


            next_token = torch.tensor([predicted_id]).unsqueeze(1).to(image_features.device)
            generated = torch.cat((generated, next_token), dim=1)

        return result_caption

class ImageCaptionModel:
    def __init__(self, embed_size, hidden_size, vocab_size, num_heads, num_layers, train_CNN=False):
        super(ImageCaptionModel, self).__init__()
        self.encoder = EncoderCNN(embed_size, train_CNN)
        self.decoder = DecoderTransformer(embed_size, vocab_size, num_heads, num_layers)

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

    def caption_image(self, image, vocabulary, max_length=20):
        self.eval()
        with torch.no_grad():
            features = self.encoder(image.unsqueeze(0))
            return self.decoder.caption_image(features, vocabulary, max_length)

In [12]:
import torch
import torchvision.transforms as transforms
from PIL import Image

def print_examples(model, device, dataset):
    """
    Helper to print predicted captions for a few images during training
    to manually verify progress.
    """
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    model.eval()


    test_img_dir = "images/"

    import os
    try:
        test_images = [f for f in os.listdir(test_img_dir) if f.endswith('.jpg')][:2]
    except FileNotFoundError:
        print("Image directory not found, skipping examples.")
        return

    for img_name in test_images:
        image = Image.open(os.path.join(test_img_dir, img_name)).convert("RGB")
        image_tensor = transform(image).unsqueeze(0).to(device)

        # Generate caption
        with torch.no_grad():
            caption = model.caption_image(image_tensor.squeeze(0), dataset.vocab)

        print(f"Image: {img_name}")
        print(f"Prediction: {' '.join(caption)}")
        print("-" * 20)

    model.train()

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

In [14]:
import torch
import torch.nn as nn
import torchvision.models as models
import math

class EncoderCNN(nn.Module):

    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)

        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        self.embed = nn.Linear(2048, embed_size)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()
        self.train_CNN = train_CNN

    def forward(self, images):
        with torch.set_grad_enabled(self.train_CNN):
            features = self.resnet(images)


        features = features.permute(0, 2, 3, 1)
        features = features.view(features.size(0), -1, features.size(3))


        features = self.embed(features)
        features = self.relu(features)

        return features

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) # [Max_Len, 1, D_Model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class DecoderTransformer(nn.Module):
    def __init__(self, embed_size, vocab_size, num_heads=4, num_layers=2, max_len=100):
        super(DecoderTransformer, self).__init__()

        self.embed = nn.Embedding(vocab_size, embed_size)
        self.pos_encoder = PositionalEncoding(embed_size, max_len)


        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        self.linear = nn.Linear(embed_size, vocab_size)
        self.max_len = max_len

    def forward(self, features, captions):

        # Embed captions
        embeddings = self.embed(captions) * math.sqrt(features.size(-1))


        embeddings = embeddings.permute(1, 0, 2)
        features = features.permute(1, 0, 2)

        # positional encoding
        embeddings = self.pos_encoder(embeddings)

        tgt_mask = self.generate_square_subsequent_mask(embeddings.size(0)).to(features.device)

        outputs = self.transformer_decoder(tgt=embeddings, memory=features, tgt_mask=tgt_mask)

        outputs = self.linear(outputs)

        return outputs.permute(1, 0, 2)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def caption_image(self, image_features, vocabulary, max_length=20):

        batch_size = image_features.size(0)
        start_token = vocabulary.stoi["<SOS>"]

        generated = torch.tensor([start_token]).unsqueeze(1).to(image_features.device)

        image_features = image_features.permute(1, 0, 2)

        result_caption = []

        for _ in range(max_length):

            tgt_emb = self.embed(generated) * math.sqrt(image_features.size(-1))
            tgt_emb = tgt_emb.permute(1, 0, 2)
            tgt_emb = self.pos_encoder(tgt_emb)

            # Mask
            mask = self.generate_square_subsequent_mask(tgt_emb.size(0)).to(image_features.device)

            # Decode
            out = self.transformer_decoder(tgt=tgt_emb, memory=image_features, tgt_mask=mask)

            # Get last token output
            last_token_out = out[-1, :, :]

            # Predict
            logits = self.linear(last_token_out)
            predicted_id = logits.argmax(1).item()

            result_caption.append(vocabulary.itos[predicted_id])

            if vocabulary.itos[predicted_id] == "<EOS>":
                break


            next_token = torch.tensor([predicted_id]).unsqueeze(1).to(image_features.device)
            generated = torch.cat((generated, next_token), dim=1)

        return result_caption

class ImageCaptionModel(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_heads, num_layers, train_CNN=False):
        super(ImageCaptionModel, self).__init__()
        self.encoder = EncoderCNN(embed_size, train_CNN)
        self.decoder = DecoderTransformer(embed_size, vocab_size, num_heads, num_layers)

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

    def caption_image(self, image, vocabulary, max_length=20):
        self.eval()
        with torch.no_grad():
            features = self.encoder(image.unsqueeze(0))
            return self.decoder.caption_image(features, vocabulary, max_length)

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


# from data_loader import get_loader
# from model import ImageCaptionModel
# from utils import save_checkpoint, load_checkpoint, print_examples

#  Hyperparameters
embed_size = 256
hidden_size = 256
vocab_size = -1
num_heads = 4
num_layers = 2
learning_rate = 3e-4
num_epochs = 100
batch_size = 32
num_workers = 2
load_model = False
save_model = True
train_CNN = False
def train():

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")


    transform = transforms.Compose(
        [
            transforms.Resize((232, 232)),
            transforms.RandomCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )


    print("Loading Data...")
    loader, dataset = get_loader(
        root_folder="images/",
        annotation_file="captions.txt",
        transform=transform,
        batch_size=batch_size,
        num_workers=num_workers,
    )

    vocab_size = len(dataset.vocab)
    print(f"Vocabulary Size: {vocab_size}")

    model = ImageCaptionModel(
        embed_size=embed_size,
        hidden_size=hidden_size,
        vocab_size=vocab_size,
        num_heads=num_heads,
        num_layers=num_layers,
        train_CNN=train_CNN
    ).to(device)

    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])


    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    if load_model:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

    model.train()

    for epoch in range(num_epochs):
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")

        if save_model and epoch % 5 == 0:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            save_checkpoint(checkpoint)



        loop = tqdm(loader, leave=True)

        for idx, (imgs, captions) in enumerate(loop):
            imgs = imgs.to(device)
            captions = captions.to(device)


            outputs = model(imgs, captions[:, :-1])


            loss = criterion(
                outputs.reshape(-1, outputs.shape[2]),
                captions[:, 1:].reshape(-1)
            )

            optimizer.zero_grad()
            loss.backward()


            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            loop.set_postfix(loss=loss.item())

if __name__ == "__main__":
    train()

Using device: cuda
Loading Data...
Vocabulary Size: 12

--- Epoch 1/100 ---
=> Saving checkpoint


Epoch [1/100]: 100%|██████████| 1/1 [00:00<00:00,  1.96it/s, loss=2.34]



--- Epoch 2/100 ---


Epoch [2/100]: 100%|██████████| 1/1 [00:00<00:00,  3.59it/s, loss=1.27]



--- Epoch 3/100 ---


Epoch [3/100]: 100%|██████████| 1/1 [00:00<00:00,  3.66it/s, loss=0.683]



--- Epoch 4/100 ---


Epoch [4/100]: 100%|██████████| 1/1 [00:00<00:00,  3.61it/s, loss=0.375]



--- Epoch 5/100 ---


Epoch [5/100]: 100%|██████████| 1/1 [00:00<00:00,  3.70it/s, loss=0.224]



--- Epoch 6/100 ---
=> Saving checkpoint


Epoch [6/100]: 100%|██████████| 1/1 [00:00<00:00,  3.31it/s, loss=0.169]



--- Epoch 7/100 ---


Epoch [7/100]: 100%|██████████| 1/1 [00:00<00:00,  3.57it/s, loss=0.143]



--- Epoch 8/100 ---


Epoch [8/100]: 100%|██████████| 1/1 [00:00<00:00,  3.69it/s, loss=0.138]



--- Epoch 9/100 ---


Epoch [9/100]: 100%|██████████| 1/1 [00:00<00:00,  3.54it/s, loss=0.127]



--- Epoch 10/100 ---


Epoch [10/100]: 100%|██████████| 1/1 [00:00<00:00,  3.49it/s, loss=0.129]



--- Epoch 11/100 ---
=> Saving checkpoint


Epoch [11/100]: 100%|██████████| 1/1 [00:00<00:00,  3.54it/s, loss=0.128]



--- Epoch 12/100 ---


Epoch [12/100]: 100%|██████████| 1/1 [00:00<00:00,  3.52it/s, loss=0.134]



--- Epoch 13/100 ---


Epoch [13/100]: 100%|██████████| 1/1 [00:00<00:00,  3.74it/s, loss=0.122]



--- Epoch 14/100 ---


Epoch [14/100]: 100%|██████████| 1/1 [00:00<00:00,  3.59it/s, loss=0.14]



--- Epoch 15/100 ---


Epoch [15/100]: 100%|██████████| 1/1 [00:00<00:00,  3.69it/s, loss=0.117]



--- Epoch 16/100 ---
=> Saving checkpoint


Epoch [16/100]: 100%|██████████| 1/1 [00:00<00:00,  2.58it/s, loss=0.141]



--- Epoch 17/100 ---


Epoch [17/100]: 100%|██████████| 1/1 [00:00<00:00,  2.82it/s, loss=0.114]



--- Epoch 18/100 ---


Epoch [18/100]: 100%|██████████| 1/1 [00:00<00:00,  2.64it/s, loss=0.11]



--- Epoch 19/100 ---


Epoch [19/100]: 100%|██████████| 1/1 [00:00<00:00,  3.03it/s, loss=0.122]



--- Epoch 20/100 ---


Epoch [20/100]: 100%|██████████| 1/1 [00:00<00:00,  3.68it/s, loss=0.138]



--- Epoch 21/100 ---
=> Saving checkpoint


Epoch [21/100]: 100%|██████████| 1/1 [00:00<00:00,  3.65it/s, loss=0.106]



--- Epoch 22/100 ---


Epoch [22/100]: 100%|██████████| 1/1 [00:00<00:00,  3.66it/s, loss=0.127]



--- Epoch 23/100 ---


Epoch [23/100]: 100%|██████████| 1/1 [00:00<00:00,  3.70it/s, loss=0.116]



--- Epoch 24/100 ---


Epoch [24/100]: 100%|██████████| 1/1 [00:00<00:00,  3.61it/s, loss=0.119]



--- Epoch 25/100 ---


Epoch [25/100]: 100%|██████████| 1/1 [00:00<00:00,  3.72it/s, loss=0.11]



--- Epoch 26/100 ---
=> Saving checkpoint


Epoch [26/100]: 100%|██████████| 1/1 [00:00<00:00,  3.48it/s, loss=0.132]



--- Epoch 27/100 ---


Epoch [27/100]: 100%|██████████| 1/1 [00:00<00:00,  3.54it/s, loss=0.112]



--- Epoch 28/100 ---


Epoch [28/100]: 100%|██████████| 1/1 [00:00<00:00,  3.42it/s, loss=0.131]



--- Epoch 29/100 ---


Epoch [29/100]: 100%|██████████| 1/1 [00:00<00:00,  3.76it/s, loss=0.129]



--- Epoch 30/100 ---


Epoch [30/100]: 100%|██████████| 1/1 [00:00<00:00,  3.78it/s, loss=0.107]



--- Epoch 31/100 ---
=> Saving checkpoint


Epoch [31/100]: 100%|██████████| 1/1 [00:00<00:00,  3.52it/s, loss=0.138]



--- Epoch 32/100 ---


Epoch [32/100]: 100%|██████████| 1/1 [00:00<00:00,  3.74it/s, loss=0.124]



--- Epoch 33/100 ---


Epoch [33/100]: 100%|██████████| 1/1 [00:00<00:00,  3.70it/s, loss=0.121]



--- Epoch 34/100 ---


Epoch [34/100]: 100%|██████████| 1/1 [00:00<00:00,  3.33it/s, loss=0.123]



--- Epoch 35/100 ---


Epoch [35/100]: 100%|██████████| 1/1 [00:00<00:00,  3.78it/s, loss=0.126]



--- Epoch 36/100 ---
=> Saving checkpoint


Epoch [36/100]: 100%|██████████| 1/1 [00:00<00:00,  3.55it/s, loss=0.111]



--- Epoch 37/100 ---


Epoch [37/100]: 100%|██████████| 1/1 [00:00<00:00,  3.45it/s, loss=0.115]



--- Epoch 38/100 ---


Epoch [38/100]: 100%|██████████| 1/1 [00:00<00:00,  3.65it/s, loss=0.12]



--- Epoch 39/100 ---


Epoch [39/100]: 100%|██████████| 1/1 [00:00<00:00,  3.61it/s, loss=0.118]



--- Epoch 40/100 ---


Epoch [40/100]: 100%|██████████| 1/1 [00:00<00:00,  3.07it/s, loss=0.135]



--- Epoch 41/100 ---
=> Saving checkpoint


Epoch [41/100]: 100%|██████████| 1/1 [00:00<00:00,  3.32it/s, loss=0.133]



--- Epoch 42/100 ---


Epoch [42/100]: 100%|██████████| 1/1 [00:00<00:00,  3.33it/s, loss=0.119]



--- Epoch 43/100 ---


Epoch [43/100]: 100%|██████████| 1/1 [00:00<00:00,  3.65it/s, loss=0.123]



--- Epoch 44/100 ---


Epoch [44/100]: 100%|██████████| 1/1 [00:00<00:00,  3.64it/s, loss=0.117]



--- Epoch 45/100 ---


Epoch [45/100]: 100%|██████████| 1/1 [00:00<00:00,  3.56it/s, loss=0.133]



--- Epoch 46/100 ---
=> Saving checkpoint


Epoch [46/100]: 100%|██████████| 1/1 [00:00<00:00,  3.51it/s, loss=0.113]



--- Epoch 47/100 ---


Epoch [47/100]: 100%|██████████| 1/1 [00:00<00:00,  3.44it/s, loss=0.132]



--- Epoch 48/100 ---


Epoch [48/100]: 100%|██████████| 1/1 [00:00<00:00,  3.74it/s, loss=0.11]



--- Epoch 49/100 ---


Epoch [49/100]: 100%|██████████| 1/1 [00:00<00:00,  3.73it/s, loss=0.116]



--- Epoch 50/100 ---


Epoch [50/100]: 100%|██████████| 1/1 [00:00<00:00,  3.71it/s, loss=0.122]



--- Epoch 51/100 ---
=> Saving checkpoint


Epoch [51/100]: 100%|██████████| 1/1 [00:00<00:00,  3.42it/s, loss=0.113]



--- Epoch 52/100 ---


Epoch [52/100]: 100%|██████████| 1/1 [00:00<00:00,  3.68it/s, loss=0.139]



--- Epoch 53/100 ---


Epoch [53/100]: 100%|██████████| 1/1 [00:00<00:00,  3.34it/s, loss=0.118]



--- Epoch 54/100 ---


Epoch [54/100]: 100%|██████████| 1/1 [00:00<00:00,  3.72it/s, loss=0.108]



--- Epoch 55/100 ---


Epoch [55/100]: 100%|██████████| 1/1 [00:00<00:00,  3.66it/s, loss=0.131]



--- Epoch 56/100 ---
=> Saving checkpoint


Epoch [56/100]: 100%|██████████| 1/1 [00:00<00:00,  3.43it/s, loss=0.114]



--- Epoch 57/100 ---


Epoch [57/100]: 100%|██████████| 1/1 [00:00<00:00,  3.63it/s, loss=0.117]



--- Epoch 58/100 ---


Epoch [58/100]: 100%|██████████| 1/1 [00:00<00:00,  3.66it/s, loss=0.113]



--- Epoch 59/100 ---


Epoch [59/100]: 100%|██████████| 1/1 [00:00<00:00,  3.58it/s, loss=0.122]



--- Epoch 60/100 ---


Epoch [60/100]: 100%|██████████| 1/1 [00:00<00:00,  3.66it/s, loss=0.116]



--- Epoch 61/100 ---
=> Saving checkpoint


Epoch [61/100]: 100%|██████████| 1/1 [00:00<00:00,  3.44it/s, loss=0.123]



--- Epoch 62/100 ---


Epoch [62/100]: 100%|██████████| 1/1 [00:00<00:00,  3.47it/s, loss=0.126]



--- Epoch 63/100 ---


Epoch [63/100]: 100%|██████████| 1/1 [00:00<00:00,  3.59it/s, loss=0.124]



--- Epoch 64/100 ---


Epoch [64/100]: 100%|██████████| 1/1 [00:00<00:00,  3.55it/s, loss=0.132]



--- Epoch 65/100 ---


Epoch [65/100]: 100%|██████████| 1/1 [00:00<00:00,  3.68it/s, loss=0.122]



--- Epoch 66/100 ---
=> Saving checkpoint


Epoch [66/100]: 100%|██████████| 1/1 [00:00<00:00,  3.31it/s, loss=0.125]



--- Epoch 67/100 ---


Epoch [67/100]: 100%|██████████| 1/1 [00:00<00:00,  3.54it/s, loss=0.128]



--- Epoch 68/100 ---


Epoch [68/100]: 100%|██████████| 1/1 [00:00<00:00,  3.60it/s, loss=0.128]



--- Epoch 69/100 ---


Epoch [69/100]: 100%|██████████| 1/1 [00:00<00:00,  3.70it/s, loss=0.119]



--- Epoch 70/100 ---


Epoch [70/100]: 100%|██████████| 1/1 [00:00<00:00,  3.52it/s, loss=0.105]



--- Epoch 71/100 ---
=> Saving checkpoint


Epoch [71/100]: 100%|██████████| 1/1 [00:00<00:00,  3.41it/s, loss=0.115]



--- Epoch 72/100 ---


Epoch [72/100]: 100%|██████████| 1/1 [00:00<00:00,  3.62it/s, loss=0.115]



--- Epoch 73/100 ---


Epoch [73/100]: 100%|██████████| 1/1 [00:00<00:00,  3.68it/s, loss=0.107]



--- Epoch 74/100 ---


Epoch [74/100]: 100%|██████████| 1/1 [00:00<00:00,  3.54it/s, loss=0.114]



--- Epoch 75/100 ---


Epoch [75/100]: 100%|██████████| 1/1 [00:00<00:00,  3.73it/s, loss=0.123]



--- Epoch 76/100 ---
=> Saving checkpoint


Epoch [76/100]: 100%|██████████| 1/1 [00:00<00:00,  3.39it/s, loss=0.12]



--- Epoch 77/100 ---


Epoch [77/100]: 100%|██████████| 1/1 [00:00<00:00,  3.55it/s, loss=0.116]



--- Epoch 78/100 ---


Epoch [78/100]: 100%|██████████| 1/1 [00:00<00:00,  3.64it/s, loss=0.117]



--- Epoch 79/100 ---


Epoch [79/100]: 100%|██████████| 1/1 [00:00<00:00,  3.64it/s, loss=0.114]



--- Epoch 80/100 ---


Epoch [80/100]: 100%|██████████| 1/1 [00:00<00:00,  3.53it/s, loss=0.115]



--- Epoch 81/100 ---
=> Saving checkpoint


Epoch [81/100]: 100%|██████████| 1/1 [00:00<00:00,  3.43it/s, loss=0.124]



--- Epoch 82/100 ---


Epoch [82/100]: 100%|██████████| 1/1 [00:00<00:00,  3.53it/s, loss=0.126]



--- Epoch 83/100 ---


Epoch [83/100]: 100%|██████████| 1/1 [00:00<00:00,  3.67it/s, loss=0.135]



--- Epoch 84/100 ---


Epoch [84/100]: 100%|██████████| 1/1 [00:00<00:00,  3.66it/s, loss=0.106]



--- Epoch 85/100 ---


Epoch [85/100]: 100%|██████████| 1/1 [00:00<00:00,  3.60it/s, loss=0.121]



--- Epoch 86/100 ---
=> Saving checkpoint


Epoch [86/100]: 100%|██████████| 1/1 [00:00<00:00,  3.40it/s, loss=0.117]



--- Epoch 87/100 ---


Epoch [87/100]: 100%|██████████| 1/1 [00:00<00:00,  3.59it/s, loss=0.116]



--- Epoch 88/100 ---


Epoch [88/100]: 100%|██████████| 1/1 [00:00<00:00,  3.58it/s, loss=0.12]



--- Epoch 89/100 ---


Epoch [89/100]: 100%|██████████| 1/1 [00:00<00:00,  3.54it/s, loss=0.119]



--- Epoch 90/100 ---


Epoch [90/100]: 100%|██████████| 1/1 [00:00<00:00,  3.75it/s, loss=0.115]



--- Epoch 91/100 ---
=> Saving checkpoint


Epoch [91/100]: 100%|██████████| 1/1 [00:00<00:00,  3.10it/s, loss=0.113]



--- Epoch 92/100 ---


Epoch [92/100]: 100%|██████████| 1/1 [00:00<00:00,  3.62it/s, loss=0.114]



--- Epoch 93/100 ---


Epoch [93/100]: 100%|██████████| 1/1 [00:00<00:00,  3.65it/s, loss=0.129]



--- Epoch 94/100 ---


Epoch [94/100]: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s, loss=0.112]



--- Epoch 95/100 ---


Epoch [95/100]: 100%|██████████| 1/1 [00:00<00:00,  3.64it/s, loss=0.111]



--- Epoch 96/100 ---
=> Saving checkpoint


Epoch [96/100]: 100%|██████████| 1/1 [00:00<00:00,  3.37it/s, loss=0.117]



--- Epoch 97/100 ---


Epoch [97/100]: 100%|██████████| 1/1 [00:00<00:00,  3.54it/s, loss=0.127]



--- Epoch 98/100 ---


Epoch [98/100]: 100%|██████████| 1/1 [00:00<00:00,  3.59it/s, loss=0.121]



--- Epoch 99/100 ---


Epoch [99/100]: 100%|██████████| 1/1 [00:00<00:00,  3.63it/s, loss=0.12]



--- Epoch 100/100 ---


Epoch [100/100]: 100%|██████████| 1/1 [00:00<00:00,  2.80it/s, loss=0.113]


In [18]:

torch.save(trained_model.state_dict(), "my_best_model.pth")
print("Model saved to disk!")


Model saved to disk!


In [None]:

trained_model = ImageCaptionModel(embed_size, vocab_size, num_heads, num_layers).to(device)

trained_model.load_state_dict(torch.load("my_best_model.pth"))
print("Model loaded!")


#deployment

In [23]:
print("Installing deployment tools...")
!pip install -q fastapi uvicorn pyngrok python-multipart nest-asyncio prometheus-fastapi-instrumentator

import uvicorn
import nest_asyncio
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from prometheus_fastapi_instrumentator import Instrumentator
from pyngrok import ngrok
from PIL import Image
import io
import torch
import torchvision.transforms as transforms

# Import necessary components from previously defined cells
# (Assuming get_loader, Vocabulary, ImageCaptionModel are defined in prior executed cells)
# To make dataset globally available for prediction

# Initialize dataset and vocab (assuming get_loader is available from an executed cell)
# Re-using transform definition from training for consistency
transform_deploy = transforms.Compose(
    [
        transforms.Resize((232, 232)),
        transforms.RandomCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)
_, dataset = get_loader(
    root_folder="images/", # Make sure 'images/' directory and 'captions.txt' are correctly set up
    annotation_file="captions.txt",
    transform=transform_deploy,
    batch_size=1, # Batch size 1 for single image prediction
    shuffle=False,
    num_workers=0 # No workers needed for single image prediction
)

app = FastAPI(
    title="Image Captioning API",
    description="Karpathy-style Image Captioning deployed from Colab",
    version="1.0"
)


Instrumentator().instrument(app).expose(app)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


deploy_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

@app.get("/")
def home():
    return {"status": "online", "model": "Transformer-ResNet101"}

@app.post("/predict")
async def predict(file: UploadFile = File(...)):

    try:

        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")


        img_tensor = deploy_transform(image).unsqueeze(0).to(device)


        caption_list = trained_model.caption_image(img_tensor.squeeze(0), dataset.vocab)


        caption_text = " ".join(caption_list)


        caption_text = caption_text.replace("<SOS>", "").replace("<EOS>", "").strip()

        return {
            "filename": file.filename,
            "caption": caption_text
        }

    except Exception as e:
        return JSONResponse(status_code=500, content={"error": str(e)})


NGROK_AUTH_TOKEN = "37nEPCc9Sa1UOCH5a6tA91p7e8H_2V6u4Cje5VLoHZyejxGWf"
if NGROK_AUTH_TOKEN != "37nEPCc9Sa1UOCH5a6tA91p7e8H_2V6u4Cje5VLoHZyejxGWf":
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)


    public_url = ngrok.connect(8000).public_url
    print(f"\n API IS LIVE! Public URL: {public_url}")
    print(f" Monitoring Metrics: {public_url}/metrics")
    print(f" Interactive Docs: {public_url}/docs")


    nest_asyncio.apply()
    uvicorn.run(app, port=8000)
else:
    print("⚠️ Please paste your ngrok token in the code above to generate a public URL.")

Installing deployment tools...
⚠️ Please paste your ngrok token in the code above to generate a public URL.
