In [1]:
import nltk
nltk.download('punkt_tab')


[nltk_data] Downloading package punkt_tab to /usr/share/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vit_b_16
from PIL import Image
import os
from nltk.tokenize import word_tokenize
from collections import Counter
import numpy as np
from torch.nn.utils.rnn import pad_sequence


class Config:
    def __init__(self):
        self.image_dir = '/kaggle/input/newdata1/Flicker8k_Dataset'
        self.captions_file = '/kaggle/input/newdata/Flicker8k-captions/Flickr8k.token.txt'

        self.checkpoint_dir = '/kaggle/working/checkpoints'


        self.embed_size = 256
        self.hidden_size = 512
        self.num_layers = 6
        self.num_heads = 8
        self.num_epochs = 10
        self.batch_size = 32
        self.lr = 0.001
        self.max_seq_length = 30
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


        if not os.path.exists(self.image_dir):
            raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
        if not os.path.exists(self.captions_file):
            raise FileNotFoundError(f"Captions file not found: {self.captions_file}")
        os.makedirs(self.checkpoint_dir, exist_ok=True)


class ImageCaptionDataset(Dataset):
    def __init__(self, image_dir, captions_file, vocab, transform=None):
        self.image_dir = image_dir
        self.captions_file = captions_file
        self.vocab = vocab
        self.transform = transform
        self.imgs, self.captions = self.load_data()

    def load_data(self):
        imgs = []
        captions = []
        with open(self.captions_file, 'r') as file:
            lines = file.readlines()
            for line in lines:
                img_name, caption = line.strip().split('\t')
                img_name = img_name.split('#')[0]
                imgs.append(img_name)
                caption_tokens = word_tokenize(caption.lower())
                captions.append(
                    [self.vocab['<start>']] +
                    [self.vocab[word] for word in caption_tokens if word in self.vocab] +
                    [self.vocab['<end>']]
                )
        return imgs, captions

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_name = self.imgs[idx]
        caption = self.captions[idx]
        img_path = os.path.join(self.image_dir, img_name)

        try:
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            return img, torch.tensor(caption)
        except FileNotFoundError:
            new_idx = np.random.randint(0, len(self))
            return self.__getitem__(new_idx)


def collate_fn(batch):
    imgs, captions = zip(*batch)
    imgs = torch.stack(imgs, 0)
    captions = pad_sequence(captions, batch_first=True, padding_value=2)
    return imgs, captions



    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_name = self.imgs[idx]
        caption = self.captions[idx]

        img_path = os.path.join(self.image_dir, img_name)

        try:
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            return img, torch.tensor(caption)

        except FileNotFoundError:

            new_idx = np.random.randint(0, len(self))
            return self.__getitem__(new_idx)

    from torch.nn.utils.rnn import pad_sequence



def collate_fn(batch):
    imgs, captions = zip(*batch)
    imgs = torch.stack(imgs, 0)

    captions = pad_sequence(captions, batch_first=True, padding_value=2)

    return imgs, captions




class EncoderViT(nn.Module):
    def __init__(self, embed_size):
        super(EncoderViT, self).__init__()

        self.vit = vit_b_16(weights='DEFAULT')


        self.vit.heads = nn.Linear(self.vit.hidden_dim, embed_size)


        self.embed_size = embed_size

    def forward(self, images):

        features = self.vit(images)
        return features


class DecoderTransformer(nn.Module):
    def __init__(self, embed_size, hidden_size, num_layers, num_heads, vocab_size):
        super(DecoderTransformer, self).__init__()
        self.embed_size = embed_size


        self.embedding = nn.Embedding(vocab_size, embed_size)


        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_size,
            nhead=num_heads,
            dim_feedforward=hidden_size,
            dropout=0.1
        )


        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer=decoder_layer,
            num_layers=num_layers
        )


        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, features, captions):

        embeddings = self.embedding(captions)


        tgt_mask = self.generate_square_subsequent_mask(captions.size(1)).to(features.device)


        embeddings = embeddings.permute(1, 0, 2)


        memory = features.unsqueeze(0).repeat(embeddings.size(0), 1, 1)


        output = self.transformer_decoder(
            tgt=embeddings,
            memory=memory,
            tgt_mask=tgt_mask
        )


        output = output.permute(1, 0, 2)


        output = self.fc_out(output)
        return output

    def generate_square_subsequent_mask(self, size):
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask


def build_vocab(captions_file, threshold):
    vocab = {'<start>': 0, '<end>': 1, '<pad>': 2}
    word_count = Counter()

    with open(captions_file, 'r') as file:
        lines = file.readlines()
        for line in lines:
            caption = line.strip().split('\t')[1]
            tokens = word_tokenize(caption.lower())
            word_count.update(tokens)

    idx = 3
    for word, count in word_count.items():
        if count >= threshold:
            vocab[word] = idx
            idx += 1

    return vocab


def train():
    config = Config()


    vocab = build_vocab(config.captions_file, threshold=5)
    vocab_size = len(vocab)


    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


    dataset = ImageCaptionDataset(config.image_dir, config.captions_file, vocab, transform)
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)



    encoder = EncoderViT(config.embed_size).to(config.device)
    decoder = DecoderTransformer(config.embed_size, config.hidden_size, config.num_layers, config.num_heads, vocab_size).to(config.device)


    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=config.lr)


    for epoch in range(config.num_epochs):
        for imgs, captions in dataloader:
            imgs, captions = imgs.to(config.device), captions.to(config.device)


            features = encoder(imgs)
            outputs = decoder(features, captions[:, :-1])


            loss = criterion(outputs.view(-1, vocab_size), captions[:, 1:].contiguous().view(-1))


            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch [{epoch+1}/{config.num_epochs}], Loss: {loss.item():.4f}")


        if (epoch+1) % 5 == 0:
            torch.save({'epoch': epoch+1, 'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(), 'optimizer': optimizer.state_dict()},
                       os.path.join(config.checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth"))

if __name__ == '__main__':
    train()


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:04<00:00, 82.3MB/s] 


Epoch [1/10], Loss: 3.6211
Epoch [2/10], Loss: 3.0636
Epoch [3/10], Loss: 3.9261
Epoch [4/10], Loss: 3.8701
Epoch [5/10], Loss: 3.8345
Epoch [6/10], Loss: 3.7778
Epoch [7/10], Loss: 2.9669
Epoch [8/10], Loss: 3.5022
Epoch [9/10], Loss: 2.9013
Epoch [10/10], Loss: 3.5171
