In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# ResNet-50 and RNN

In [5]:
import os
import ssl
import nltk
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from collections import Counter
import requests
from io import BytesIO

# ---------------------------------------------------
# Optional: disable SSL verification if SSL errors occur
ssl._create_default_https_context = ssl._create_unverified_context

# Download NLTK data (for tokenization)
nltk.download('punkt')

# ---------------------------------------------------
# 1. Vocabulary and Tokenization
# ---------------------------------------------------
class Vocabulary:
    """A simple vocabulary wrapper."""
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        # Keep punctuation tokens if they appear in the dataset
        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.stoi = {v: k for k, v in self.itos.items()}
    
    def __len__(self):
        return len(self.itos)
    
    @staticmethod
    def tokenizer(text):
        # We'll NOT remove punctuation here so that the model can learn punctuation tokens
        text = text.lower().strip()
        return nltk.tokenize.word_tokenize(text)  # keeps punctuation as separate tokens
    
    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4  # starting index after special tokens
        for sentence in sentence_list:
            tokens = self.tokenizer(sentence)
            frequencies.update(tokens)
            for token in tokens:
                # Only add a token to the vocab if its frequency == freq_threshold
                if frequencies[token] == self.freq_threshold:
                    # if it is not already in stoi
                    if token not in self.stoi:
                        self.stoi[token] = idx
                        self.itos[idx] = token
                        idx += 1
    
    def numericalize(self, text):
        tokenized_text = self.tokenizer(text)
        return [self.stoi.get(token, self.stoi["<unk>"]) for token in tokenized_text]

# ---------------------------------------------------
# 2. Hyperparameters and Transforms
# ---------------------------------------------------
embed_size    = 256
hidden_size   = 512
num_layers    = 1
learning_rate = 1e-3
# We'll train longer for a better model
num_epochs    = 10

batch_size    = 16
# Raise the frequency threshold to reduce <unk>
freq_threshold = 5

max_seq_length = 30  # allowing a bit longer sentences

# ImageNet-based transformations
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225]
    )
])

# ---------------------------------------------------
# 3. Loading the CSV into a DataFrame
# ---------------------------------------------------
csv_file = "/kaggle/input/flickr8k/captions.txt"  # your CSV with columns: image, caption
img_dir  = "/kaggle/input/flickr8k/Images"

df = pd.read_csv(csv_file)
print("Number of captions:", len(df))

# ---------------------------------------------------
# 4. Dataset and DataLoader (with input/target split)
# ---------------------------------------------------
class CaptionDatasetFromDF(Dataset):
    def __init__(self, dataframe, img_dir, vocabulary, transform=None, max_seq_length=30):
        self.df = dataframe
        self.img_dir = img_dir
        self.transform = transform
        self.vocab = vocabulary
        self.max_seq_length = max_seq_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['image']
        caption  = self.df.iloc[idx]['caption']

        # Load and transform image
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)

        # Numericalize the caption
        numericalized_caption = [self.vocab.stoi["<start>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<end>"])

        # Truncate if over max length
        if len(numericalized_caption) > self.max_seq_length:
            numericalized_caption = numericalized_caption[:self.max_seq_length]
            numericalized_caption[-1] = self.vocab.stoi["<end>"]

        # Split into input & target
        input_caption  = numericalized_caption[:-1]  # all but last
        target_caption = numericalized_caption[1:]   # all but first

        return image, torch.tensor(input_caption), torch.tensor(target_caption)

def collate_fn(data):
    """
    Creates mini-batches by padding input and target captions to
    the length of the longest caption in the batch.
    """
    images, input_captions, target_captions = zip(*data)

    # Stack images
    images = torch.stack(images, 0)

    # Pad input and target
    padded_input  = pad_sequence(input_captions, batch_first=True, padding_value=0)
    padded_target = pad_sequence(target_captions, batch_first=True, padding_value=0)

    # The lengths are for reference only (not strictly needed with current code)
    lengths = [len(cap) for cap in input_captions]

    return images, padded_input, padded_target, lengths

# ---------------------------------------------------
# 5. Model Definitions
# ---------------------------------------------------
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        # Load a ResNet50
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1]  # Remove final FC
        self.resnet = nn.Sequential(*modules)

        # Optionally freeze the CNN
        for param in self.resnet.parameters():
            param.requires_grad = train_CNN

        # Map 2048 -> embed_size
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn     = nn.BatchNorm1d(embed_size, momentum=0.01)
        
    def forward(self, images):
        features = self.resnet(images)                # (batch, 2048, 1, 1)
        features = features.view(features.size(0), -1)  # (batch, 2048)
        features = self.linear(features)                 # (batch, embed_size)
        features = self.bn(features)
        return features

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, max_seq_length=30):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm  = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seq_length = max_seq_length

        self.vocab_size = vocab_size

    def forward(self, features, captions):
        # features: (batch, embed_size)
        # captions: (batch, seq_len)
        embeddings = self.embed(captions)  # (batch, seq_len, embed_size)
        # Insert the image features at t=0
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)  # (batch, seq_len+1, embed_size)

        outputs, _ = self.lstm(embeddings)      # (batch, seq_len+1, hidden_size)
        outputs = self.linear(outputs)          # (batch, seq_len+1, vocab_size)
        outputs = outputs[:, 1:, :]            # drop the first time step => (batch, seq_len, vocab_size)
        return outputs

    def sample_greedy(self, features):
        """(Optional) Greedy search as an alternative to beam search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)  # (batch, 1, embed_size)

        for _ in range(self.max_seq_length):
            hiddens, states = self.lstm(inputs)
            outputs = self.linear(hiddens.squeeze(1))  # (batch, vocab_size)
            predicted = outputs.argmax(dim=1)          # (batch)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted).unsqueeze(1)

        sampled_ids = torch.stack(sampled_ids, 1)
        return sampled_ids

    def sample_beam_search(self, features, vocab, beam_size=3):
        """
        Beam search for improved decoding.
        Returns the best predicted sequence of token IDs (excluding <start>).
        """
        device = features.device
        start_id = vocab.stoi["<start>"]
        end_id   = vocab.stoi["<end>"]

        # Each element: (sequence_of_ids, log_prob, hidden_state, cell_state)
        sequences = [([start_id], 0.0, None, None)]

        for _ in range(self.max_seq_length):
            all_candidates = []
            for seq, log_prob, hidden, cell in sequences:
                # If the last token was <end>, just keep adding it as is
                if seq[-1] == end_id:
                    all_candidates.append((seq, log_prob, hidden, cell))
                    continue

                # The last token
                last_token = torch.tensor([seq[-1]]).to(device)
                # embed that token
                token_embed = self.embed(last_token).unsqueeze(1)  # shape: (1,1,embed_size)

                if hidden is None and cell is None:
                    # We'll treat the image features as the initial hidden
                    # so we pass the image feature + token as the first input
                    # We can do a single-step LSTM pass
                    init_input = torch.cat((features.unsqueeze(1), token_embed), dim=1)  # shape (1,2,embed_size)
                    lstm_out, (hidden_next, cell_next) = self.lstm(init_input)
                    # last step output
                    lstm_out = lstm_out[:, -1, :]
                else:
                    lstm_out, (hidden_next, cell_next) = self.lstm(token_embed, (hidden, cell))

                scores = self.linear(lstm_out.squeeze(1))   # shape (1, vocab_size)
                log_probs = torch.log_softmax(scores, dim=1)

                # get top beam_size expansions
                topk_log_probs, topk_ids = torch.topk(log_probs, beam_size, dim=1)
                for i in range(beam_size):
                    candidate_seq = seq + [topk_ids[0, i].item()]
                    candidate_log_prob = log_prob + topk_log_probs[0, i].item()
                    all_candidates.append((candidate_seq, candidate_log_prob, hidden_next, cell_next))

            # sort all candidates in descending order of log_prob
            ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
            # take top beam_size
            sequences = ordered[:beam_size]

        # choose best sequence
        best_seq, best_logprob, _, _ = sequences[0]
        # remove the <start> token
        return best_seq[1:]  # these are token IDs

# Full model wrapper
class ImageCaptioningModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(ImageCaptioningModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

# ---------------------------------------------------
# 6. Training function with .reshape
# ---------------------------------------------------
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0

    for images, input_captions, target_captions, _ in dataloader:
        images          = images.to(device)
        input_captions  = input_captions.to(device)
        target_captions = target_captions.to(device)

        optimizer.zero_grad()

        # Forward pass
        features = model.encoder(images)
        outputs  = model.decoder(features, input_captions)
        # outputs:  (batch, seq_len, vocab_size)
        # targets:  (batch, seq_len)

        # Flatten
        outputs_reshaped = outputs.reshape(-1, outputs.size(2))
        targets_reshaped = target_captions.reshape(-1)

        loss = criterion(outputs_reshaped, targets_reshaped)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

# ---------------------------------------------------
# 7. Multi-Sentence Postprocessing
# ---------------------------------------------------
def tokens_to_multisentence(token_ids, vocab, max_sentences=3):
    """
    Convert a list of token IDs to up to 'max_sentences' sentences, 
    splitting on '.' or <end> token. Returns a string with multiple sentences.
    """
    sentences = []
    current_sentence = []

    end_token_id  = vocab.stoi["<end>"]

    for token_id in token_ids:
        if token_id == end_token_id:
            # end token => finish up
            if current_sentence:
                sentences.append(" ".join(current_sentence))
            break

        word = vocab.itos.get(token_id, "<unk>")
        if word == "<unk>":
            # Optionally skip or replace <unk>
            word = "something"

        current_sentence.append(word)

        # If we see a '.' we treat it as a sentence boundary
        if word == ".":
            sentences.append(" ".join(current_sentence))
            current_sentence = []

        # If we already have max_sentences
        if len(sentences) >= max_sentences:
            break

    # If there's leftover text that didn't end with '.', we can finalize it
    if current_sentence and len(sentences) < max_sentences:
        sentences.append(" ".join(current_sentence))

    # Join with a period + space or newlines, whichever you prefer
    final_output = ". ".join(sentences)
    # Optionally ensure punctuation at the end
    if not final_output.endswith("."):
        final_output += "."

    return final_output

# ---------------------------------------------------
# 8. Prediction with Beam Search
# ---------------------------------------------------
def predict_caption_from_url(model, url, transform, vocab, device, beam_size=3, max_sentences=3):
    model.eval()
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        features = model.encoder(image)
        # We'll do beam search
        token_ids = model.decoder.sample_beam_search(features, vocab, beam_size=beam_size)

    # Convert token IDs to multi-sentence text
    caption_text = tokens_to_multisentence(token_ids, vocab, max_sentences=max_sentences)
    return caption_text

# ---------------------------------------------------
# 9. Execution: Build, Train, Predict
# ---------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 9.1: Build the vocabulary
captions_list = df['caption'].tolist()
vocab = Vocabulary(freq_threshold)
vocab.build_vocabulary(captions_list)
print("Vocabulary size:", len(vocab))

# 9.2: Dataset & DataLoader
dataset = CaptionDatasetFromDF(df, img_dir, vocab, transform=transform, max_seq_length=max_seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# 9.3: Model Instantiation
encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderRNN(embed_size, hidden_size, len(vocab), num_layers, max_seq_length).to(device)
model   = ImageCaptioningModel(encoder, decoder).to(device)

# 9.4: Loss + Optimizer
criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<pad>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 9.5: Training
print("Starting training ...")
for epoch in range(num_epochs):
    epoch_loss = train(model, dataloader, criterion, optimizer, device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# 9.6: Save the model
torch.save(model.state_dict(), "image_captioning_model.pth")
print("Model saved!")

# 9.7: Prediction from URL
test_url = input("Enter an image URL for caption prediction: ").strip()
predicted_caption = predict_caption_from_url(
    model, 
    test_url, 
    transform, 
    vocab, 
    device,
    beam_size=3,      # you can experiment with beam sizes
    max_sentences=3   # we want up to 3 sentences
)
print("Predicted Multi-Sentence Caption:", predicted_caption)


[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Number of captions: 40455
Vocabulary size: 2984


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 240MB/s]


Starting training ...
Epoch [1/10], Loss: 3.0775
Epoch [2/10], Loss: 2.4914
Epoch [3/10], Loss: 2.2401
Epoch [4/10], Loss: 2.0401
Epoch [5/10], Loss: 1.8669
Epoch [6/10], Loss: 1.7148
Epoch [7/10], Loss: 1.5836
Epoch [8/10], Loss: 1.4687
Epoch [9/10], Loss: 1.3687
Epoch [10/10], Loss: 1.2843
Model saved!


Enter an image URL for caption prediction:  https://cdn.prod.website-files.com/5aba1faad4eb88cb5f8c0b57/5acd3f675e259d1e064464ae_imgonline-com-ua-CompressToSize-OOr5aNN1GIUCTB2.jpg


Predicted Multi-Sentence Caption: three people are standing on a grassy hillside .


In [9]:
test_url = input("Enter an image URL for caption prediction: ").strip()
predicted_caption = predict_caption_from_url(
    model, 
    test_url, 
    transform, 
    vocab, 
    device,
    beam_size=3,      # you can experiment with beam sizes
    max_sentences=3   # we want up to 3 sentences
)
print("Predicted Multi-Sentence Caption:", predicted_caption)

Enter an image URL for caption prediction:  https://hips.hearstapps.com/hmg-prod/images/gettyimages-180680638-676f621f720bc.jpg?crop=0.8888888888888888xw:1xh;center,top&resize=1200:*


Predicted Multi-Sentence Caption: three dogs run through the grass .


In [10]:
# https://hips.hearstapps.com/hmg-prod/images/gettyimages-180680638-676f621f720bc.jpg?crop=0.8888888888888888xw:1xh;center,top&resize=1200:*

# MobileNet-V3 and GRU

In [2]:
import os
import ssl
import nltk
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from collections import Counter
import requests
from io import BytesIO

# ---------------------------------------------------
# Optional: disable SSL verification if SSL errors occur
ssl._create_default_https_context = ssl._create_unverified_context

# Download NLTK data (for tokenization)
nltk.download('punkt')

# ---------------------------------------------------
# 1. Vocabulary and Tokenization
# ---------------------------------------------------
class Vocabulary:
    """A simple vocabulary wrapper."""
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.stoi = {v: k for k, v in self.itos.items()}
    
    def __len__(self):
        return len(self.itos)
    
    @staticmethod
    def tokenizer(text):
        # Keep punctuation so the model can learn multi-sentence outputs
        text = text.lower().strip()
        return nltk.tokenize.word_tokenize(text)
    
    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4  # starting index after special tokens
        for sentence in sentence_list:
            tokens = self.tokenizer(sentence)
            frequencies.update(tokens)
            for token in tokens:
                if frequencies[token] == self.freq_threshold:
                    if token not in self.stoi:
                        self.stoi[token] = idx
                        self.itos[idx] = token
                        idx += 1
    
    def numericalize(self, text):
        tokenized_text = self.tokenizer(text)
        return [self.stoi.get(token, self.stoi["<unk>"]) for token in tokenized_text]

# ---------------------------------------------------
# 2. Hyperparameters and Transforms
# ---------------------------------------------------
embed_size     = 256
hidden_size    = 512
num_layers     = 1
learning_rate  = 1e-3
num_epochs     = 10  # train longer for better results
batch_size     = 16
freq_threshold = 5   # raise threshold to reduce <unk>
max_seq_length = 30  # allow slightly longer sequences

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225]
    )
])

# ---------------------------------------------------
# 3. Load CSV with image, caption columns
# ---------------------------------------------------
csv_file = "/kaggle/input/flickr8k/captions.txt"
img_dir  = "/kaggle/input/flickr8k/Images"

df = pd.read_csv(csv_file)
print("Number of captions:", len(df))

# ---------------------------------------------------
# 4. Dataset and DataLoader
# ---------------------------------------------------
class CaptionDatasetFromDF(Dataset):
    def __init__(self, dataframe, img_dir, vocabulary, transform=None, max_seq_length=30):
        self.df = dataframe
        self.img_dir = img_dir
        self.transform = transform
        self.vocab = vocabulary
        self.max_seq_length = max_seq_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['image']
        caption  = self.df.iloc[idx]['caption']

        # Load and transform the image
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)

        # Numericalize the caption
        numericalized_caption = [self.vocab.stoi["<start>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<end>"])

        # Truncate if needed
        if len(numericalized_caption) > self.max_seq_length:
            numericalized_caption = numericalized_caption[:self.max_seq_length]
            numericalized_caption[-1] = self.vocab.stoi["<end>"]

        # Create input & target
        input_caption  = numericalized_caption[:-1]
        target_caption = numericalized_caption[1:]

        return image, torch.tensor(input_caption), torch.tensor(target_caption)

def collate_fn(data):
    images, input_captions, target_captions = zip(*data)

    images = torch.stack(images, 0)
    padded_input  = pad_sequence(input_captions, batch_first=True, padding_value=0)
    padded_target = pad_sequence(target_captions, batch_first=True, padding_value=0)

    lengths = [len(cap) for cap in input_captions]  # optional

    return images, padded_input, padded_target, lengths

# ---------------------------------------------------
# 5. Model Definitions
# ---------------------------------------------------
#
# -- New CNN: MobileNetV3 Large --
#
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        # Use MobileNet V3 Large as a feature extractor
        backbone = models.mobilenet_v3_large(pretrained=True)
        
        # Remove the final classifier; keep only the feature layers
        self.backbone = backbone.features  # (batch, 960, H', W')
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))  # convert H'xW' -> 1x1

        # Optionally freeze layers
        for param in self.backbone.parameters():
            param.requires_grad = train_CNN

        # Linear to map from 960 -> embed_size
        self.linear = nn.Linear(960, embed_size)
        self.bn     = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        x = self.backbone(images)           # shape: (batch, 960, H', W')
        x = self.global_pool(x)             # (batch, 960, 1, 1)
        x = x.view(x.size(0), -1)           # (batch, 960)
        x = self.linear(x)                  # (batch, embed_size)
        x = self.bn(x)
        return x

#
# -- New RNN: GRU for text decoding --
#
class DecoderGRU(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, max_seq_length=30):
        super(DecoderGRU, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.gru   = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seq_length = max_seq_length

        self.vocab_size = vocab_size

    def forward(self, features, captions):
        """
        features: (batch, embed_size)
        captions: (batch, seq_len)
        """
        embeddings = self.embed(captions)  # (batch, seq_len, embed_size)
        # Insert image features at t=0
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)  
        # (batch, seq_len+1, embed_size)

        outputs, _ = self.gru(embeddings)    # (batch, seq_len+1, hidden_size)
        outputs = self.linear(outputs)       # (batch, seq_len+1, vocab_size)
        outputs = outputs[:, 1:, :]         # drop the first time step
        return outputs

    # Optional: A simpler greedy approach (if you want)
    def sample_greedy(self, features):
        sampled_ids = []
        inputs = features.unsqueeze(1)  # (batch, 1, embed_size)
        h = None

        for _ in range(self.max_seq_length):
            out, h = self.gru(inputs, h)      # (batch, 1, hidden_size)
            out = self.linear(out.squeeze(1)) # (batch, vocab_size)
            predicted = out.argmax(dim=1)     # (batch)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted).unsqueeze(1)

        sampled_ids = torch.stack(sampled_ids, 1)
        return sampled_ids

    def sample_beam_search(self, features, vocab, beam_size=3):
        """
        Beam Search for better decoding.
        Returns the best predicted sequence of token IDs (excluding <start>).
        """
        device = features.device
        start_id = vocab.stoi["<start>"]
        end_id   = vocab.stoi["<end>"]

        # Each item is (seq, log_prob, hidden)
        sequences = [([start_id], 0.0, None)]

        for _ in range(self.max_seq_length):
            all_candidates = []
            for seq, log_prob, hidden in sequences:
                if seq[-1] == end_id:
                    # If already ended, keep as-is
                    all_candidates.append((seq, log_prob, hidden))
                    continue

                last_token = torch.tensor([seq[-1]]).to(device)
                token_embed = self.embed(last_token).unsqueeze(1)  # (1,1,embed_size)

                if hidden is None:
                    # treat image features as initial hidden; let's do 1 step
                    # We can do a dummy forward pass with: 
                    #   out, new_hidden = self.gru(torch.cat((features.unsqueeze(1), token_embed), 1))
                    # but that implies 2 time steps at once. Alternatively, we do:
                    out, new_hidden = self.gru(token_embed, None)
                    # Combine with features as if it's the first token if desired.
                    # But simpler is to just pass token_embed alone each step.
                else:
                    out, new_hidden = self.gru(token_embed, hidden)

                out = self.linear(out.squeeze(1))  # (1, vocab_size)
                log_probs = torch.log_softmax(out, dim=1)

                topk_log_probs, topk_ids = torch.topk(log_probs, beam_size, dim=1)
                for i in range(beam_size):
                    candidate_seq = seq + [topk_ids[0, i].item()]
                    candidate_log_prob = log_prob + topk_log_probs[0, i].item()
                    all_candidates.append((candidate_seq, candidate_log_prob, new_hidden))

            # Re-sort
            ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
            sequences = ordered[:beam_size]

        best_seq, best_logprob, _ = sequences[0]
        return best_seq[1:]  # remove the <start> token

class ImageCaptioningModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(ImageCaptioningModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

# ---------------------------------------------------
# 6. Training function
# ---------------------------------------------------
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    for images, input_captions, target_captions, _ in dataloader:
        images = images.to(device)
        input_captions = input_captions.to(device)
        target_captions = target_captions.to(device)

        optimizer.zero_grad()

        features = model.encoder(images)
        outputs  = model.decoder(features, input_captions)
        # (batch, seq_len, vocab_size) vs (batch, seq_len)

        # Flatten
        outputs_reshaped = outputs.reshape(-1, outputs.size(2))
        targets_reshaped = target_captions.reshape(-1)

        loss = criterion(outputs_reshaped, targets_reshaped)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

# ---------------------------------------------------
# 7. Multi-Sentence Postprocessing
# ---------------------------------------------------
def tokens_to_multisentence(token_ids, vocab, max_sentences=3):
    """
    Converts token IDs to multiple sentences, splitting on '.' or <end>.
    """
    sentences = []
    current_sentence = []
    end_token_id = vocab.stoi["<end>"]

    for token_id in token_ids:
        if token_id == end_token_id:
            if current_sentence:
                sentences.append(" ".join(current_sentence))
            break

        word = vocab.itos.get(token_id, "<unk>")
        if word == "<unk>":
            # skip or replace <unk>
            word = "something"

        current_sentence.append(word)

        if word == ".":
            sentences.append(" ".join(current_sentence))
            current_sentence = []

        if len(sentences) >= max_sentences:
            break

    if current_sentence and len(sentences) < max_sentences:
        sentences.append(" ".join(current_sentence))

    final_output = ". ".join(sentences)
    if not final_output.endswith("."):
        final_output += "."

    return final_output

# ---------------------------------------------------
# 8. Prediction with Beam Search
# ---------------------------------------------------
def predict_caption_from_url(model, url, transform, vocab, device, beam_size=3, max_sentences=3):
    model.eval()
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        features = model.encoder(image)
        token_ids = model.decoder.sample_beam_search(features, vocab, beam_size=beam_size)

    # Convert token IDs to a multi-sentence string
    caption_text = tokens_to_multisentence(token_ids, vocab, max_sentences=max_sentences)
    return caption_text

# ---------------------------------------------------
# 9. Execution: Build, Train, Predict
# ---------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 9.1: Build the vocabulary
captions_list = df['caption'].tolist()
vocab = Vocabulary(freq_threshold)
vocab.build_vocabulary(captions_list)
print("Vocabulary size:", len(vocab))

# 9.2: Dataset & DataLoader
dataset = CaptionDatasetFromDF(df, img_dir, vocab, transform=transform, max_seq_length=max_seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# 9.3: Instantiate the MobileNetV3 + GRU model
encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderGRU(embed_size, hidden_size, len(vocab), num_layers, max_seq_length).to(device)
model   = ImageCaptioningModel(encoder, decoder).to(device)

# 9.4: Loss + Optimizer
criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<pad>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 9.5: Training
print("Starting training ...")
for epoch in range(num_epochs):
    epoch_loss = train(model, dataloader, criterion, optimizer, device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# 9.6: Save the model
torch.save(model.state_dict(), "image_captioning_mobilenet_gru.pth")
print("Model saved!")

# 9.7: Prediction from URL
test_url = input("Enter an image URL for caption prediction: ").strip()
predicted_caption = predict_caption_from_url(
    model, 
    test_url, 
    transform, 
    vocab, 
    device,
    beam_size=3,      # can be tuned
    max_sentences=3   # request up to 3 sentences
)
print("Predicted Multi-Sentence Caption:", predicted_caption)


[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Number of captions: 40455


Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth


Vocabulary size: 2984


100%|██████████| 21.1M/21.1M [00:00<00:00, 169MB/s]


Starting training ...
Epoch [1/10], Loss: 2.9724
Epoch [2/10], Loss: 2.4165
Epoch [3/10], Loss: 2.1717
Epoch [4/10], Loss: 1.9876
Epoch [5/10], Loss: 1.8380
Epoch [6/10], Loss: 1.7148
Epoch [7/10], Loss: 1.6178
Epoch [8/10], Loss: 1.5368
Epoch [9/10], Loss: 1.4673
Epoch [10/10], Loss: 1.4130
Model saved!


Enter an image URL for caption prediction:  https://hips.hearstapps.com/hmg-prod/images/gettyimages-180680638-676f621f720bc.jpg?crop=0.8888888888888888xw:1xh;center,top&resize=1200:*


Predicted Multi-Sentence Caption: a man and a woman are sitting on a bench outside .


In [6]:
test_url = input("Enter an image URL for caption prediction: ").strip()
predicted_caption = predict_caption_from_url(
    model, 
    test_url, 
    transform, 
    vocab, 
    device,
    beam_size=3,      # can be tuned
    max_sentences=3   # request up to 3 sentences
)
print("Predicted Multi-Sentence Caption:", predicted_caption)

Enter an image URL for caption prediction:  https://images.pexels.com/photos/1054655/pexels-photo-1054655.jpeg?cs=srgb&dl=pexels-hsapir-1054655.jpg&fm=jpg


Predicted Multi-Sentence Caption: a man and a woman are sitting on a bench outside .


In [7]:
test_url = input("Enter an image URL for caption prediction: ").strip()
predicted_caption = predict_caption_from_url(
    model, 
    test_url, 
    transform, 
    vocab, 
    device,
    beam_size=3,      # can be tuned
    max_sentences=3   # request up to 3 sentences
)
print("Predicted Multi-Sentence Caption:", predicted_caption)

Enter an image URL for caption prediction:  https://plus.unsplash.com/premium_photo-1664474619075-644dd191935f?fm=jpg&q=60&w=3000&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8MXx8aW1hZ2V8ZW58MHx8MHx8fDA%3D


Predicted Multi-Sentence Caption: a man and a woman are sitting on a bench outside .


# Better version

In [8]:
import os
import ssl
import nltk
import torch
import random
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from collections import Counter
import requests
from io import BytesIO

# -------------------------------------------------------------------------
# 0. (Optional) Fix random seeds for reproducibility
# -------------------------------------------------------------------------
def set_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # For cudnn reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_random_seed(42)

# -------------------------------------------------------------------------
# Optional: disable SSL verification if SSL errors occur
# -------------------------------------------------------------------------
ssl._create_default_https_context = ssl._create_unverified_context

# Download NLTK data (for tokenization)
nltk.download('punkt')

# -------------------------------------------------------------------------
# 1. Vocabulary and Tokenization
# -------------------------------------------------------------------------
class Vocabulary:
    """A simple vocabulary wrapper."""
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        # Keep punctuation tokens
        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.stoi = {v: k for k, v in self.itos.items()}
    
    def __len__(self):
        return len(self.itos)
    
    @staticmethod
    def tokenizer(text):
        text = text.lower().strip()
        return nltk.tokenize.word_tokenize(text)
    
    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4  # starting index after special tokens
        for sentence in sentence_list:
            tokens = self.tokenizer(sentence)
            frequencies.update(tokens)
            for token in tokens:
                if frequencies[token] == self.freq_threshold:
                    if token not in self.stoi:
                        self.stoi[token] = idx
                        self.itos[idx] = token
                        idx += 1
    
    def numericalize(self, text):
        tokenized_text = self.tokenizer(text)
        return [self.stoi.get(token, self.stoi["<unk>"]) for token in tokenized_text]

# -------------------------------------------------------------------------
# 2. Hyperparameters and Data Augmentations
# -------------------------------------------------------------------------
embed_size     = 256
hidden_size    = 512
num_layers     = 1
learning_rate  = 1e-3
num_epochs     = 15  # increased for better training
batch_size     = 16
freq_threshold = 5
max_seq_length = 30

# Data Augmentation + Normalization
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop((224,224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# For inference, no augmentations
inference_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# -------------------------------------------------------------------------
# 3. Load CSV with image, caption columns
# -------------------------------------------------------------------------
csv_file = "/kaggle/input/flickr8k/captions.txt"
img_dir  = "/kaggle/input/flickr8k/Images"

df = pd.read_csv(csv_file)
print("Number of captions:", len(df))

# -------------------------------------------------------------------------
# 4. Dataset and DataLoader
# -------------------------------------------------------------------------
class CaptionDatasetFromDF(Dataset):
    def __init__(self, dataframe, img_dir, vocabulary, transform=None, max_seq_length=30):
        self.df = dataframe
        self.img_dir = img_dir
        self.transform = transform
        self.vocab = vocabulary
        self.max_seq_length = max_seq_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['image']
        caption  = self.df.iloc[idx]['caption']

        # Load image
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Numericalize
        numericalized_caption = [self.vocab.stoi["<start>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<end>"])

        # Truncate if needed
        if len(numericalized_caption) > self.max_seq_length:
            numericalized_caption = numericalized_caption[:self.max_seq_length]
            numericalized_caption[-1] = self.vocab.stoi["<end>"]

        input_caption  = numericalized_caption[:-1]
        target_caption = numericalized_caption[1:]

        return image, torch.tensor(input_caption), torch.tensor(target_caption)

def collate_fn(data):
    images, input_captions, target_captions = zip(*data)

    images = torch.stack(images, 0)
    padded_input  = pad_sequence(input_captions, batch_first=True, padding_value=0)
    padded_target = pad_sequence(target_captions, batch_first=True, padding_value=0)

    lengths = [len(cap) for cap in input_captions]
    return images, padded_input, padded_target, lengths

# -------------------------------------------------------------------------
# 5. Model Definitions (MobileNetV3 + GRU with partial fine-tuning)
# -------------------------------------------------------------------------
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        backbone = models.mobilenet_v3_large(pretrained=True)
        # Keep the feature extraction layers
        self.backbone = backbone.features  
        self.global_pool = nn.AdaptiveAvgPool2d((1,1))
        # Unfreeze only the last few layers for partial fine-tuning:
        for name, param in self.backbone.named_parameters():
            # Example: unfreeze last stage
            if "12" in name or "13" in name or "14" in name:  # mobilenet_v3_large has final layers around 14
                param.requires_grad = True
            else:
                param.requires_grad = False

        # Map from 960 -> embed_size
        self.linear = nn.Linear(960, embed_size)
        self.bn     = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        x = self.backbone(images)     # (batch, 960, H, W)
        x = self.global_pool(x)       # (batch, 960, 1, 1)
        x = x.view(x.size(0), -1)     # (batch, 960)
        x = self.linear(x)            # (batch, embed_size)
        x = self.bn(x)
        return x

class DecoderGRU(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, dropout=0.3, num_layers=1, max_seq_length=30):
        super(DecoderGRU, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        # Add dropout to reduce overfitting
        self.gru   = nn.GRU(embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout if num_layers>1 else 0)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seq_length = max_seq_length
        self.vocab_size = vocab_size

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        # Insert the image features at t=0
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        outputs, _ = self.gru(embeddings)
        outputs = self.linear(outputs)
        outputs = outputs[:, 1:, :]  # drop first time step
        return outputs

    def sample_beam_search(self, features, vocab, beam_size=3):
        """
        Beam Search decoding.
        """
        device = features.device
        start_id = vocab.stoi["<start>"]
        end_id   = vocab.stoi["<end>"]
        sequences = [([start_id], 0.0, None)]

        for _ in range(self.max_seq_length):
            all_candidates = []
            for seq, log_prob, hidden in sequences:
                if seq[-1] == end_id:
                    all_candidates.append((seq, log_prob, hidden))
                    continue

                last_token = torch.tensor([seq[-1]]).to(device)
                token_embed = self.embed(last_token).unsqueeze(1)
                
                if hidden is None:
                    out, new_hidden = self.gru(token_embed, None)
                else:
                    out, new_hidden = self.gru(token_embed, hidden)

                out = self.linear(out.squeeze(1))  
                log_probs = torch.log_softmax(out, dim=1)

                topk_log_probs, topk_ids = torch.topk(log_probs, beam_size, dim=1)
                for i in range(beam_size):
                    candidate_seq = seq + [topk_ids[0, i].item()]
                    candidate_log_prob = log_prob + topk_log_probs[0, i].item()
                    all_candidates.append((candidate_seq, candidate_log_prob, new_hidden))

            ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
            sequences = ordered[:beam_size]

        best_seq, best_logprob, _ = sequences[0]
        return best_seq[1:]  # remove <start>

    # Optional: top-k sampling for variety
    def sample_topk(self, features, vocab, k=5, temperature=1.0):
        device = features.device
        start_id = vocab.stoi["<start>"]
        end_id   = vocab.stoi["<end>"]
        
        sampled_ids = []
        hidden = None

        token_id = torch.tensor([start_id], device=device)
        inputs   = self.embed(token_id).unsqueeze(1)
        
        for _ in range(self.max_seq_length):
            if hidden is None:
                out, hidden = self.gru(inputs)
            else:
                out, hidden = self.gru(inputs, hidden)

            out = self.linear(out.squeeze(1))  # (1, vocab_size)
            # Scale by temperature
            out = out / temperature
            probs = torch.softmax(out, dim=1)
            # top-k sampling
            topk_probs, topk_ids = torch.topk(probs, k, dim=1)
            topk_probs = topk_probs.squeeze(0)
            topk_ids   = topk_ids.squeeze(0)

            # pick 1 from topk by random
            chosen_idx = torch.multinomial(topk_probs, 1)
            chosen_id  = topk_ids[chosen_idx]

            sampled_ids.append(chosen_id.item())
            if chosen_id.item() == end_id:
                break

            # next input
            inputs = self.embed(chosen_id).unsqueeze(1)

        return sampled_ids

class ImageCaptioningModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(ImageCaptioningModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

# -------------------------------------------------------------------------
# 6. Training loop
# -------------------------------------------------------------------------
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    for images, input_captions, target_captions, _ in dataloader:
        images          = images.to(device)
        input_captions  = input_captions.to(device)
        target_captions = target_captions.to(device)

        optimizer.zero_grad()
        features = model.encoder(images)
        outputs  = model.decoder(features, input_captions)
        # (batch, seq_len, vocab_size) vs. (batch, seq_len)

        outputs_reshaped = outputs.reshape(-1, outputs.size(2))
        targets_reshaped = target_captions.reshape(-1)

        loss = criterion(outputs_reshaped, targets_reshaped)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

# -------------------------------------------------------------------------
# 7. Multi-Sentence Postprocessing
# -------------------------------------------------------------------------
def tokens_to_multisentence(token_ids, vocab, max_sentences=3):
    sentences = []
    current_sentence = []
    end_token_id = vocab.stoi["<end>"]

    for token_id in token_ids:
        if token_id == end_token_id:
            if current_sentence:
                sentences.append(" ".join(current_sentence))
            break

        word = vocab.itos.get(token_id, "<unk>")
        if word == "<unk>":
            word = "something"

        current_sentence.append(word)

        if word == ".":
            sentences.append(" ".join(current_sentence))
            current_sentence = []

        if len(sentences) >= max_sentences:
            break

    if current_sentence and len(sentences) < max_sentences:
        sentences.append(" ".join(current_sentence))

    final_output = ". ".join(sentences)
    if not final_output.endswith("."):
        final_output += "."

    return final_output

# -------------------------------------------------------------------------
# 8. Prediction from URL (Beam Search by default)
# -------------------------------------------------------------------------
def predict_caption_from_url(model, url, vocab, device, max_sentences=3, beam_size=3, 
                             decode_method="beam_search"):
    model.eval()

    # Download image
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert("RGB")
    # Use the inference transform (no augmentation)
    image = inference_transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        features = model.encoder(image)

        if decode_method == "beam_search":
            token_ids = model.decoder.sample_beam_search(features, vocab, beam_size=beam_size)
        else:
            # top-k sampling (k=5, temperature=1.0) as example
            token_ids = model.decoder.sample_topk(features, vocab, k=5, temperature=1.0)

    caption_text = tokens_to_multisentence(token_ids, vocab, max_sentences=max_sentences)
    return caption_text

# -------------------------------------------------------------------------
# 9. Execution: Build, Train, Predict
# -------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 9.1: Build the vocabulary
captions_list = df['caption'].tolist()
vocab = Vocabulary(freq_threshold)
vocab.build_vocabulary(captions_list)
print("Vocabulary size:", len(vocab))

# 9.2: Dataset & DataLoader
#     We'll use data augmentations for the training transform
dataset = CaptionDatasetFromDF(df, img_dir, vocab, transform=train_transform, max_seq_length=max_seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# 9.3: Instantiate model
encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderGRU(embed_size, hidden_size, len(vocab), dropout=0.3, num_layers=num_layers, max_seq_length=max_seq_length).to(device)
model   = ImageCaptioningModel(encoder, decoder).to(device)

# 9.4: Loss & Optimizer
criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<pad>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 9.5: Training
print("Starting training ...")
for epoch in range(num_epochs):
    epoch_loss = train(model, dataloader, criterion, optimizer, device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# 9.6: Save the model
torch.save(model.state_dict(), "image_captioning_mobilenetv3_gru.pth")
print("Model saved!")

# 9.7: Inference
test_url = input("Enter an image URL for caption prediction: ").strip()
predicted_caption = predict_caption_from_url(
    model, 
    test_url, 
    vocab, 
    device, 
    max_sentences=3, 
    beam_size=3, 
    decode_method="beam_search"  # or "topk" for top-k sampling
)
print("Predicted Multi-Sentence Caption:", predicted_caption)


[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Number of captions: 40455
Vocabulary size: 2984
Starting training ...




Epoch [1/15], Loss: 2.9992
Epoch [2/15], Loss: 2.4396
Epoch [3/15], Loss: 2.1936
Epoch [4/15], Loss: 2.0047
Epoch [5/15], Loss: 1.8498


KeyboardInterrupt: 

# Vit Gpt2

In [3]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("comp646")

In [4]:
!pip uninstall -y peft

Found existing installation: peft 0.14.0
Uninstalling peft-0.14.0:
  Successfully uninstalled peft-0.14.0


In [5]:
!pip uninstall -y transformers
!pip install --no-cache-dir transformers==4.31.0 accelerate

Found existing installation: transformers 4.47.0
Uninstalling transformers-4.47.0:
  Successfully uninstalled transformers-4.47.0
Collecting transformers==4.31.0
  Downloading transformers-4.31.0-py3-none-any.whl.metadata (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.9/116.9 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.31.0)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m98.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m250.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tokenizers, t

In [2]:
#####################################
# 0. Clean Up & Install (Optional)
#####################################
# If you installed "peft" or older versions that conflict, remove them:
# !pip uninstall -y peft transformers
# Reinstall Transformers 4.31.0 (or similar) + Accelerate
# !pip install --no-cache-dir transformers==4.31.0 accelerate

#####################################
# 1. Imports and Setup
#####################################
import os
import ssl
import requests
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from io import BytesIO

# Optional if you run into SSL issues
ssl._create_default_https_context = ssl._create_unverified_context

# Hugging Face Transformers
from transformers import (
    VisionEncoderDecoderModel,
    ViTImageProcessor,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    default_data_collator,
    set_seed,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(42)

#####################################
# 2. Config & Paths
#####################################
csv_file = "/kaggle/input/flickr8k/captions.txt"
img_dir  = "/kaggle/input/flickr8k/Images"

encoder_model_name = "google/vit-base-patch16-224-in21k"
decoder_model_name = "gpt2"

# Updated hyperparameters
EPOCHS              = 3
BATCH_SIZE          = 2  # reduce from 8 to 2
LEARNING_RATE       = 5e-5
MAX_SEQ_LEN         = 32
FREEZE_ENCODER      = True   # freeze ViT
USE_GRAD_CHECKPOINT = False  # set True if you need more memory savings
USE_FP16            = True   # half precision -> significantly lowers memory usage

#####################################
# 3. Load the Captions CSV
#####################################
df = pd.read_csv(csv_file)
print("Number of captions:", len(df))
print(df.head())

#####################################
# 4. Image Processor & Tokenizer
#####################################
feature_extractor = ViTImageProcessor.from_pretrained(encoder_model_name)
tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

#####################################
# 5. Dataset
#####################################
class ImageCaptionDataset(Dataset):
    def __init__(self, dataframe, img_dir, feature_extractor, tokenizer, max_target_length=32, transforms=None):
        self.df = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.max_target_length = max_target_length
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row["image"]
        caption  = str(row["caption"])

        # load image
        path = os.path.join(self.img_dir, img_name)
        with Image.open(path).convert("RGB") as image:
            if self.transforms is not None:
                image = self.transforms(image)

        pixel_values = self.feature_extractor(image, return_tensors="pt").pixel_values.squeeze()

        # tokenize caption
        labels = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_target_length
        ).input_ids

        # replace pad_token_id with -100
        labels = [(lbl if lbl != self.tokenizer.pad_token_id else -100) for lbl in labels]
        
        return {
            "pixel_values": pixel_values,
            "labels": torch.tensor(labels, dtype=torch.long),
        }

# Basic transform: resize to 224x224
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
])

# create dataset
dataset = ImageCaptionDataset(
    df,
    img_dir,
    feature_extractor=feature_extractor,
    tokenizer=tokenizer,
    max_target_length=MAX_SEQ_LEN,
    transforms=train_transforms
)

# small train-val split
train_size = int(0.95 * len(dataset))
val_size   = len(dataset) - train_size
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])
print("Train size:", len(train_ds), "Val size:", len(val_ds))

#####################################
# 6. ViT + GPT-2 Model
#####################################
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_model_name, 
    decoder_model_name
)
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id           = tokenizer.eos_token_id
model.config.pad_token_id           = tokenizer.pad_token_id

# Generation settings
model.config.max_length           = MAX_SEQ_LEN
model.config.early_stopping       = True
model.config.no_repeat_ngram_size = 2

# (Optional) Freeze the entire ViT encoder to save memory
if FREEZE_ENCODER:
    for param in model.encoder.parameters():
        param.requires_grad = False

# (Optional) Gradient checkpointing
model.decoder.gradient_checkpointing = USE_GRAD_CHECKPOINT

model.to(device)

#####################################
# 7. Seq2SeqTrainer Setup
#####################################
def collate_fn(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels       = torch.stack([item["labels"] for item in batch])
    return {"pixel_values": pixel_values, "labels": labels}

training_args = Seq2SeqTrainingArguments(
    output_dir="vit-gpt2-checkpoints",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    # Turn off evaluation to prevent OOM mid-training
    evaluation_strategy="no",
    # Turn off checkpoint saving to reduce overhead
    save_strategy="no",
    num_train_epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    logging_steps=100,
    report_to="none",
    push_to_hub=False,
    fp16=USE_FP16,                        # mixed precision
    gradient_checkpointing=USE_GRAD_CHECKPOINT,  # or keep the code here as well
)

# minimal dummy metric
def compute_metrics(eval_pred):
    return {"dummy_metric": 0.0}

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,  # won't run since evaluation_strategy="no"
    data_collator=collate_fn,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)

#####################################
# 8. Train
#####################################
trainer.train()

#####################################
# 9. Inference from a URL
#####################################
def predict_caption_from_url(url, model, feature_extractor, tokenizer, device,
                             max_length=32, num_beams=4):
    model.eval()
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert("RGB")

    pixel_values = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
    with torch.no_grad():
        output_ids = model.generate(
            pixel_values,
            max_length=max_length,
            num_beams=num_beams,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )

    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return caption

# Example inference
test_url = input("Enter an image URL for caption prediction: ").strip()
predicted_caption = predict_caption_from_url(test_url, model, feature_extractor, tokenizer, device)
print("Predicted Caption:", predicted_caption)


Number of captions: 40455
                       image  \
0  1000268201_693b08cb0e.jpg   
1  1000268201_693b08cb0e.jpg   
2  1000268201_693b08cb0e.jpg   
3  1000268201_693b08cb0e.jpg   
4  1000268201_693b08cb0e.jpg   

                                             caption  
0  A child in a pink dress is climbing up a set o...  
1              A girl going into a wooden building .  
2   A little girl climbing into a wooden playhouse .  
3  A little girl climbing the stairs to her playh...  
4  A little girl in a pink dress going into a woo...  


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Train size: 38432 Val size: 2023


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.bias', 'h.10.crossattention.q_attn.weight', 'h.10.ln_cross_attn.bias', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.11.crossat

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

  trainer = Seq2SeqTrainer(


Step,Training Loss
100,3.7676
200,3.4083
300,3.3101
400,3.2811
500,3.2652
600,3.3334
700,3.2786
800,3.235
900,3.0401
1000,3.1602


Enter an image URL for caption prediction:  https://static.vecteezy.com/system/resources/thumbnails/036/324/708/small/ai-generated-picture-of-a-tiger-walking-in-the-forest-photo.jpg


UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x7b987ccba660>

In [3]:
test_url = input("Enter an image URL for caption prediction: ").strip()
predicted_caption = predict_caption_from_url(test_url, model, feature_extractor, tokenizer, device)
print("Predicted Caption:", predicted_caption)

Enter an image URL for caption prediction:  https://images.pexels.com/photos/1054655/pexels-photo-1054655.jpeg?cs=srgb&dl=pexels-hsapir-1054655.jpg&fm=jpg


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Predicted Caption: A silhouette of a person in a body of water . The sun is casting a shadow on the ground . Another silhouette is in the distance .


In [4]:
test_url = input("Enter an image URL for caption prediction: ").strip()
predicted_caption = predict_caption_from_url(test_url, model, feature_extractor, tokenizer, device)
print("Predicted Caption:", predicted_caption)

Enter an image URL for caption prediction:  https://iso.500px.com/wp-content/uploads/2018/05/Blog-marketplace-getty500px-48429366-nologo-3000x2000.png


Predicted Caption: A man in a red shirt is standing in front of a large green field of green grass . Another man is in the background . Sooners is taking a


In [None]:
https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg

# Flickr30k

In [2]:
!pip install --quiet evaluate nltk
!pip install rouge_score

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24935 sha256=abc27734835cb179ea32dcd737cbc3dda8791a6869a691bf2bb86dda7e34c2ce
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [3]:
!pip install --quiet git+https://github.com/salaniz/pycocoevalcap.git

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for pycocoevalcap (setup.py) ... [?25l[?25hdone


In [6]:
# # # 0. Clean Up & Install (Optional)
# #####################################
# # If you installed "peft" or older versions that conflict, remove them:
# # !pip uninstall -y peft transformers
# # Reinstall Transformers 4.31.0 (or similar) + Accelerate
# # !pip install --no-cache-dir transformers==4.31.0 accelerate

# #####################################
# # 1. Imports and Setup
# #####################################
# import os
# import ssl
# import requests
# import pandas as pd
# import torch
# from torch.utils.data import Dataset, DataLoader
# from torchvision import transforms
# from PIL import Image
# from io import BytesIO

# # Optional if you run into SSL issues
# ssl._create_default_https_context = ssl._create_unverified_context

# # Hugging Face Transformers
# from transformers import (
#     VisionEncoderDecoderModel,
#     ViTImageProcessor,
#     AutoTokenizer,
#     Seq2SeqTrainingArguments,
#     Seq2SeqTrainer,
#     default_data_collator,
#     set_seed,
# )

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# set_seed(42)

# #####################################
# # 2. Config & Paths
# #####################################
# csv_file = "/kaggle/input/flickr30k/captions.txt"
# img_dir  = "/kaggle/input/flickr30k/flickr30k_images"

# encoder_model_name = "google/vit-base-patch16-224-in21k"
# decoder_model_name = "gpt2"

# # Updated hyperparameters
# EPOCHS              = 3
# BATCH_SIZE          = 2  # reduce from 8 to 2
# LEARNING_RATE       = 5e-5
# MAX_SEQ_LEN         = 32
# FREEZE_ENCODER      = True   # freeze ViT
# USE_GRAD_CHECKPOINT = False  # set True if you need more memory savings
# USE_FP16            = True   # half precision -> significantly lowers memory usage

# #####################################
# # 3. Load the Captions CSV
# #####################################
# df = pd.read_csv(csv_file)
# print("Number of captions:", len(df))
# print(df.head())

# #####################################
# # 4. Image Processor & Tokenizer
# #####################################
# feature_extractor = ViTImageProcessor.from_pretrained(encoder_model_name)
# tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "right"

# #####################################
# # 5. Dataset
# #####################################
# class ImageCaptionDataset(Dataset):
#     def __init__(self, dataframe, img_dir, feature_extractor, tokenizer, max_target_length=32, transforms=None):
#         self.df = dataframe.reset_index(drop=True)
#         self.img_dir = img_dir
#         self.feature_extractor = feature_extractor
#         self.tokenizer = tokenizer
#         self.max_target_length = max_target_length
#         self.transforms = transforms

#     def __len__(self):
#         return len(self.df)

#     def __getitem__(self, idx):
#         row = self.df.iloc[idx]
#         img_name = row["image_name"]
#         caption  = str(row["comment"])

#         # load image
#         path = os.path.join(self.img_dir, img_name)
#         with Image.open(path).convert("RGB") as image:
#             if self.transforms is not None:
#                 image = self.transforms(image)

#         pixel_values = self.feature_extractor(image, return_tensors="pt").pixel_values.squeeze()

#         # tokenize caption
#         labels = self.tokenizer(
#             caption,
#             padding="max_length",
#             truncation=True,
#             max_length=self.max_target_length
#         ).input_ids

#         # replace pad_token_id with -100
#         labels = [(lbl if lbl != self.tokenizer.pad_token_id else -100) for lbl in labels]
        
#         return {
#             "pixel_values": pixel_values,
#             "labels": torch.tensor(labels, dtype=torch.long),
#         }

# # Basic transform: resize to 224x224
# train_transforms = transforms.Compose([
#     transforms.Resize((224, 224)),
# ])

# # create dataset
# dataset = ImageCaptionDataset(
#     df,
#     img_dir,
#     feature_extractor=feature_extractor,
#     tokenizer=tokenizer,
#     max_target_length=MAX_SEQ_LEN,
#     transforms=train_transforms
# )

# # small train-val split
# train_size = int(0.95 * len(dataset))
# val_size   = len(dataset) - train_size
# train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])
# print("Train size:", len(train_ds), "Val size:", len(val_ds))

# #####################################
# # 6. ViT + GPT-2 Model
# #####################################
# model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
#     encoder_model_name, 
#     decoder_model_name
# )
# model.config.decoder_start_token_id = tokenizer.bos_token_id
# model.config.eos_token_id           = tokenizer.eos_token_id
# model.config.pad_token_id           = tokenizer.pad_token_id

# # Generation settings
# model.config.max_length           = MAX_SEQ_LEN
# model.config.early_stopping       = True
# model.config.no_repeat_ngram_size = 2

# # (Optional) Freeze the entire ViT encoder to save memory
# if FREEZE_ENCODER:
#     for param in model.encoder.parameters():
#         param.requires_grad = False

# # (Optional) Gradient checkpointing
# model.decoder.gradient_checkpointing = USE_GRAD_CHECKPOINT

# model.to(device)

# #####################################
# # 7. Seq2SeqTrainer Setup
# #####################################
# def collate_fn(batch):
#     pixel_values = torch.stack([item["pixel_values"] for item in batch])
#     labels       = torch.stack([item["labels"] for item in batch])
#     return {"pixel_values": pixel_values, "labels": labels}

# training_args = Seq2SeqTrainingArguments(
#     output_dir="vit-gpt2-checkpoints",
#     per_device_train_batch_size=BATCH_SIZE,
#     per_device_eval_batch_size=BATCH_SIZE,
#     # Turn off evaluation to prevent OOM mid-training
#     evaluation_strategy="no",
#     # Turn off checkpoint saving to reduce overhead
#     save_strategy="no",
#     num_train_epochs=EPOCHS,
#     learning_rate=LEARNING_RATE,
#     logging_steps=100,
#     report_to="none",
#     push_to_hub=False,
#     fp16=USE_FP16,                        # mixed precision
#     gradient_checkpointing=USE_GRAD_CHECKPOINT,  # or keep the code here as well
# )

# # minimal dummy metric
# def compute_metrics(eval_pred):
#     return {"dummy_metric": 0.0}

# trainer = Seq2SeqTrainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_ds,
#     eval_dataset=val_ds,  # won't run since evaluation_strategy="no"
#     data_collator=collate_fn,
#     tokenizer=feature_extractor,
#     compute_metrics=compute_metrics,
# )

# #####################################
# # 8. Train
# #####################################
# trainer.train()

# #####################################
# # 9.  Quality metrics (BLEU‑4, ROUGE‑L, CIDEr)
# #####################################
# import evaluate
# import matplotlib.pyplot as plt
# import tqdm
# import nltk
# from torch.utils.data import DataLoader

# # 1) load BLEU & ROUGE
# bleu_metric  = evaluate.load("bleu")
# rouge_metric = evaluate.load("rouge")

# # 2) load or fallback CIDER
# try:
#     cider_metric = evaluate.load("cider")
#     use_hf_cider = True
# except Exception:
#     from pycocoevalcap.cider.cider import Cider
#     cider_scorer  = Cider()
#     use_hf_cider  = False

# # 3) helper to generate preds & refs
# def generate_captions(dataloader):
#     preds, refs = [], []
#     model.eval()
#     with torch.no_grad():
#         for batch in tqdm.tqdm(dataloader, desc="Generating captions"):
#             pv        = batch["pixel_values"].to(device)
#             labels_id = batch["labels"]

#             # collect references
#             for ids in labels_id:
#                 seq = [i for i in ids.tolist() if i != -100]
#                 refs.append(tokenizer.decode(seq, skip_special_tokens=True))

#             # generate
#             gen_ids = model.generate(
#                 pv,
#                 max_length=MAX_SEQ_LEN,
#                 num_beams=4,
#                 eos_token_id=tokenizer.eos_token_id,
#                 pad_token_id=tokenizer.pad_token_id,
#             )
#             preds.extend(tokenizer.batch_decode(gen_ids, skip_special_tokens=True))
#     return preds, refs

# # 4) DataLoader for val split
# val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False)
# preds, refs = generate_captions(val_loader)

# # 5) compute BLEU‑4
# bleu_score = bleu_metric.compute(
#     predictions=[nltk.word_tokenize(p.lower()) for p in preds],
#     references=[[nltk.word_tokenize(r.lower())] for r in refs],
# )["bleu"]

# # 6) compute ROUGE‑L
# rouge_score = rouge_metric.compute(predictions=preds, references=refs)["rougeL"]

# # 7) compute CIDEr
# if use_hf_cider:
#     cider_score = cider_metric.compute(predictions=preds, references=refs)["cider"]
# else:
#     gts = {i: [refs[i]] for i in range(len(refs))}
#     res = {i: [preds[i]] for i in range(len(preds))}
#     cider_score, _ = cider_scorer.compute_score(gts, res)

# # 8) print & store
# print(f"\nBLEU‑4  : {bleu_score:.4f}")
# print(f"ROUGE‑L : {rouge_score:.4f}")
# print(f"CIDEr   : {cider_score:.4f}")

# # variables you can reuse downstream
# _bleu_score  = bleu_score
# _rouge_score = rouge_score
# _cider_score = cider_score

# # 9) simple bar chart
# scores = {"BLEU‑4": bleu_score, "ROUGE‑L": rouge_score, "CIDEr": cider_score}
# plt.figure(figsize=(6,4))
# plt.bar(scores.keys(), scores.values())
# plt.ylim(0, max(scores.values())*1.1)
# plt.title("Validation Caption Quality")
# plt.ylabel("Score")
# plt.show()

# #####################################
# # 10. Inference from a URL
# #####################################
# def predict_caption_from_url(url, model, feature_extractor, tokenizer, device,
#                              max_length=32, num_beams=4):
#     model.eval()
#     response = requests.get(url)
#     image = Image.open(BytesIO(response.content)).convert("RGB")
#     pixel_values = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
#     with torch.no_grad():
#         output_ids = model.generate(
#             pixel_values,
#             max_length=max_length,
#             num_beams=num_beams,
#             eos_token_id=tokenizer.eos_token_id,
#             pad_token_id=tokenizer.pad_token_id
#         )
#     return tokenizer.decode(output_ids[0], skip_special_tokens=True)

# # Example usage:
# test_url = input("Enter an image URL for caption prediction: ").strip()
# print("Predicted Caption:", predict_caption_from_url(
#     test_url, model, feature_extractor, tokenizer, device,
#     max_length=MAX_SEQ_LEN, num_beams=4
# ))

In [10]:
# ==============================================================
# 0. (Optional) clean-up + installs
# --------------------------------------------------------------
# !pip uninstall -y peft transformers
# !pip install --no-cache-dir transformers==4.31.0 accelerate evaluate nltk matplotlib

# ==============================================================
# 1. imports & setup
# --------------------------------------------------------------
import os, ssl, requests, tqdm, nltk, evaluate, torch, pandas as pd
from PIL import Image
from io  import BytesIO
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import (
    VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer,
    Seq2SeqTrainer, Seq2SeqTrainingArguments, TrainerCallback, set_seed
)

ssl._create_default_https_context = ssl._create_unverified_context
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(42)
nltk.download("punkt", quiet=True)

# ==============================================================
# 2. paths & h-params
# --------------------------------------------------------------
csv_file, img_dir = "/kaggle/input/flickr30k/captions.txt", "/kaggle/input/flickr30k/flickr30k_images"

encoder_model_name, decoder_model_name = "google/vit-base-patch16-224-in21k", "gpt2"

EPOCHS, BATCH_SIZE      = 3, 2
LR, MAX_SEQ_LEN         = 5e-5, 32
FREEZE_ENCODER          = True
USE_GRAD_CHECKPOINT, FP16 = False, True

# ==============================================================
# 3. data
# --------------------------------------------------------------
df = pd.read_csv(csv_file)
feature_extractor = ViTImageProcessor.from_pretrained(encoder_model_name)
tokenizer         = AutoTokenizer.from_pretrained(decoder_model_name)
tokenizer.pad_token, tokenizer.padding_side = tokenizer.eos_token, "right"

class FlickrDataset(Dataset):
    def __init__(self, frame, img_root, fe, tok, max_len=32, tfm=None):
        self.df, self.img_root, self.fe, self.tok, self.max_len, self.tfm = (
            frame.reset_index(drop=True), img_root, fe, tok, max_len, tfm
        )
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row  = self.df.iloc[idx]
        img  = Image.open(os.path.join(self.img_root, row["image_name"])).convert("RGB")
        if self.tfm: img = self.tfm(img)
        px    = self.fe(img, return_tensors="pt").pixel_values.squeeze()
        lbls  = self.tok(
            str(row["comment"]), truncation=True, padding="max_length", max_length=self.max_len
        ).input_ids
        lbls  = [(x if x != self.tok.pad_token_id else -100) for x in lbls]
        return {"pixel_values": px, "labels": torch.tensor(lbls)}

resize   = transforms.Compose([transforms.Resize((224,224))])
full_ds  = FlickrDataset(df, img_dir, feature_extractor, tokenizer, MAX_SEQ_LEN, resize)
train_ds, val_ds = random_split(full_ds, [int(0.95*len(full_ds)), len(full_ds)-int(0.95*len(full_ds))])

def collate(batch):
    return {
        "pixel_values": torch.stack([b["pixel_values"] for b in batch]),
        "labels"      : torch.stack([b["labels"]       for b in batch])
    }

# ==============================================================
# 4. model
# --------------------------------------------------------------
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_model_name, decoder_model_name
)
# --- FIX for the TypeError: set attrs directly -------------
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id           = tokenizer.eos_token_id
model.config.pad_token_id           = tokenizer.pad_token_id
# ------------------------------------------------------------
model.config.max_length, model.config.no_repeat_ngram_size, model.config.early_stopping = (
    MAX_SEQ_LEN, 2, True
)
if FREEZE_ENCODER:
    for p in model.encoder.parameters(): p.requires_grad = False
model.decoder.gradient_checkpointing = USE_GRAD_CHECKPOINT
model.to(device)

# ==============================================================
# 5. metrics helpers
# --------------------------------------------------------------
bleu_metric  = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")
try:
    cider_metric, USE_HF_CIDER = evaluate.load("cider"), True
except Exception:
    from pycocoevalcap.cider.cider import Cider
    cider_metric, USE_HF_CIDER = Cider(), False

def _compute_caption_metrics(pred_txt, ref_txt):
    bleu_dict = bleu_metric.compute(
        predictions=[nltk.word_tokenize(p.lower()) for p in pred_txt],
        references=[[nltk.word_tokenize(r.lower())] for r in ref_txt]
    )
    precision  = bleu_dict["precisions"][0]        # unigram precision
    bleu       = bleu_dict["bleu"]
    rouge      = rouge_metric.compute(predictions=pred_txt, references=ref_txt)["rougeL"]
    if USE_HF_CIDER:
        cider   = cider_metric.compute(predictions=pred_txt, references=ref_txt)["cider"]
    else:
        gts  = {i:[ref_txt[i]] for i in range(len(ref_txt))}
        res  = {i:[pred_txt[i]] for i in range(len(pred_txt))}
        cider, _ = cider_metric.compute_score(gts, res)
    return {"precision": precision, "bleu": bleu, "rougeL": rouge, "cider": cider}

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    pred_txt = tokenizer.batch_decode(preds, skip_special_tokens=True)
    ref_txt  = []
    for lbl in labels:
        ids = [i for i in lbl if i != -100 and i != tokenizer.pad_token_id]
        ref_txt.append(tokenizer.decode(ids, skip_special_tokens=True))
    return _compute_caption_metrics(pred_txt, ref_txt)

# ==============================================================
# 6. custom callback to print / update the table
# --------------------------------------------------------------
class LiveTableCallback(TrainerCallback):
    def __init__(self): self.rows, self.last_train_loss = [], None
    def on_log(self, args, state, control, logs=None, **kwargs):
        # capture running train loss so we can print it when eval ends
        if logs and "loss" in logs and "epoch" in logs: 
            self.last_train_loss = logs["loss"]
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        ep = int(state.epoch)
        self.rows.append({
            "epoch"      : ep,
            "train_loss" : self.last_train_loss,
            "val_loss"   : metrics.get("eval_loss"),
            "precision"  : metrics.get("eval_precision"),
            "BLEU-4"     : metrics.get("eval_bleu"),
            "ROUGE-L"    : metrics.get("eval_rougeL"),
            "CIDEr"      : metrics.get("eval_cider")
        })
        df = pd.DataFrame(self.rows)
        print("\n######### Epoch summary #########")
        print(df.to_string(index=False, float_format="%.4f"))
        print("#################################\n")

table_cb = LiveTableCallback()

# ==============================================================
# 7. trainer
# --------------------------------------------------------------
args = Seq2SeqTrainingArguments(
    output_dir                   = "vit-gpt2-checkpoints",
    num_train_epochs             = EPOCHS,
    learning_rate                = LR,
    per_device_train_batch_size  = BATCH_SIZE,
    per_device_eval_batch_size   = BATCH_SIZE,
    logging_steps                = 100,          # <- keep your 100-step logs
    logging_strategy             = "steps",
    evaluation_strategy          = "epoch",      # <- add val pass every epoch
    predict_with_generate        = True,
    generation_max_length        = MAX_SEQ_LEN,
    generation_num_beams         = 4,
    save_strategy                = "no",
    report_to                    = "none",
    fp16                         = FP16,
    gradient_checkpointing       = USE_GRAD_CHECKPOINT
)

trainer = Seq2SeqTrainer(
    model           = model,
    args            = args,
    train_dataset   = train_ds,
    eval_dataset    = val_ds,
    data_collator   = collate,
    tokenizer       = tokenizer,          # IMPORTANT: text tokenizer
    compute_metrics = compute_metrics,
    callbacks       = [table_cb],         # live table printer
)

# ==============================================================
# 8. train
# --------------------------------------------------------------
trainer.train()

# (The table is printed automatically after each epoch)
# ==============================================================
# 9. inference helper (unchanged)
# --------------------------------------------------------------
def caption_from_url(url, model=model, fe=feature_extractor, tok=tokenizer,
                     device=device, max_len=32, num_beams=4):
    model.eval()
    img = Image.open(BytesIO(requests.get(url).content)).convert("RGB")
    px  = fe(img, return_tensors="pt").pixel_values.to(device)
    with torch.no_grad():
        out_ids = model.generate(px, max_length=max_len, num_beams=num_beams,
                                 eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id)
    return tok.decode(out_ids[0], skip_special_tokens=True)

# Example:
# url = "https://farm3.staticflickr.com/2029/2212815924_3f6a805ec7_z.jpg"
# print(caption_from_url(url))

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.bias', 'h.10.crossattention.q_attn.weight', 'h.10.ln_cross_attn.bias', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.11.crossat

Epoch,Training Loss,Validation Loss


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instea

ValueError: Predictions and/or references don't match the expected format.
Expected format:
Feature option 0: {'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id='references')}
Feature option 1: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')},
Input predictions: ['a', 'man', 'and', ..., '.', 'and', 'the'],
Input references: [['a', 'group', 'of', 'people', 'in', 'formal', 'attire', '.']]

In [None]:
# 1. Choose a directory to store your model
output_dir = "vit_gpt2_flickr30k__model"

# 2. Save the model weights & config
model.save_pretrained(output_dir)

# 3. Save the feature extractor and tokenizer
feature_extractor.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"Model, feature_extractor and tokenizer saved to ./{output_dir}")