In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Enable CUDA support if available
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

Using device: cpu


In [None]:
import os  # when loading file paths
import pandas as pd  # for lookup in annotation file
import torch
from torch.nn.utils.rnn import pad_sequence  # pad batch
from torch.utils.data import DataLoader, Dataset
from PIL import Image  # Load img
import torchvision.transforms as transforms

In [None]:
# Install spaCy
!pip install spacy
# Download the English language model
!python -m spacy download en_core_web_sm
# Import spaCy and load the language model
import spacy

Collecting en-core-web-sm==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m78.5 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [None]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        spacy_eng = spacy.load("en_core_web_sm")
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]


class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # Get img, caption columns
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        # Initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption)


class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

        return imgs, targets


def get_loader(
    root_folder,
    annotation_file,
    transform,
    batch_size=256,
    num_workers=2,
    shuffle=True,
    pin_memory=True,
):
    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)

    pad_idx = dataset.vocab.stoi["<PAD>"]

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )

    return loader, dataset

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize to a square image
    transforms.CenterCrop(224),     # Crop the center to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
loader, dataset = get_loader(
        "/content/drive/MyDrive/Images", "/content/drive/MyDrive/captions.txt", transform=transform
    )

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        self.train_CNN = train_CNN
        self.vgg19 = models.vgg19(pretrained=True)

        # Remove only the last fully connected layer of VGG19
        vgg_classifier = list(self.vgg19.classifier.children())[:-1]
        self.vgg19.classifier = nn.Sequential(*vgg_classifier)

        # Add your own linear layer
        self.fc = nn.Linear(4096, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

        # Set requires_grad to False for all parameters by default
        for param in self.parameters():
            param.requires_grad = train_CNN

        # Set requires_grad to True only for the last layer (self.fc)
        for param in self.fc.parameters():
            param.requires_grad = True

    def forward(self, images):
        features = self.vgg19(images)
        features = features.view(features.size(0), -1)
        features = self.fc(features)
        return self.dropout(self.relu(features))

class Attention(nn.Module):
    def __init__(self, feature_dim, hidden_dim, attention_dim):
        super(Attention, self).__init__()
        self.feature_att = nn.Linear(feature_dim, attention_dim)  # Linear layer to transform image features
        self.hidden_att = nn.Linear(hidden_dim, attention_dim)  # Linear layer to transform hidden state
        self.full_att = nn.Linear(attention_dim, 1)  # Linear layer to compute attention weights
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, features, hidden):
        att1 = self.feature_att(features)  # (batch_size, num_pixels, attention_dim)
        att2 = self.hidden_att(hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (features * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)
        return attention_weighted_encoding, alpha

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, attention_dim=256):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(embed_size, hidden_size, attention_dim)
        self.lstm = nn.LSTM(embed_size + embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.2)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size

    def forward(self, features, captions):
        embeddings = self.dropout(self.embed(captions))
        h, c = self.init_hidden_state(features)  # (batch_size, hidden_size)
        seq_length = len(captions[0])
        predictions = torch.zeros(seq_length, features.size(0), self.vocab_size).to(features.device)
        alphas = torch.zeros(seq_length, features.size(0), features.size(1)).to(features.device)

        for t in range(seq_length):
            attention_weighted_encoding, alpha = self.attention(features, h)
            lstm_input = torch.cat((embeddings[:, t, :], attention_weighted_encoding), dim=1)
            h, c = self.lstm(lstm_input.unsqueeze(1), (h, c))
            output = self.linear(h.squeeze(1))
            predictions[t] = output
            alphas[t] = alpha

        return predictions.permute(1, 0, 2), alphas.permute(1, 0, 2)

    def init_hidden_state(self, features):
        mean_features = features.mean(dim=1)
        h = self.init_hidden(mean_features)
        c = self.init_hidden(mean_features)
        return h, c

    def init_hidden(self, input):
        return torch.zeros(input.size(0), self.hidden_size).to(input.device)

class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

    def forward(self, images, captions):
        features = self.encoderCNN(images)
        outputs, alphas = self.decoderRNN(features, captions)
        return outputs, alphas

    def caption_image(self, image, vocabulary, max_length=50):
        result_caption = []

        with torch.no_grad():
            x = self.encoderCNN(image).unsqueeze(0)
            states = None

            for _ in range(max_length):
                attention_weighted_encoding, alpha = self.decoderRNN.attention(x, states[0] if states else torch.zeros(1, x.size(1)).to(x.device))
                lstm_input = torch.cat((x, attention_weighted_encoding.unsqueeze(0)), dim=2)
                hiddens, states = self.decoderRNN.lstm(lstm_input, states)
                output = self.decoderRNN.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                result_caption.append(predicted.item())
                x = self.decoderRNN.embed(predicted).unsqueeze(0)

                if vocabulary.itos[predicted.item()] == "<EOS>":
                    break

        return [vocabulary.itos[idx] for idx in result_caption]


In [None]:
def save_checkpoint(state, filename="/content/drive/MyDrive/my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    step = checkpoint["step"]
    return step

In [None]:
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

def train():
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    load_model = False
    save_model = True
    train_CNN = False

    # Hyperparameters
    embed_size = 256
    hidden_size = 256
    vocab_size = len(dataset.vocab)
    num_layers = 1
    learning_rate = 3e-4
    num_epochs = 10

    # for tensorboard
    writer = SummaryWriter("runs/flickr")
    step = 0

    # initialize model, loss etc
    model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)


    if load_model:
        step = load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

    model.train()

    for epoch in range(num_epochs):
        # Uncomment the line below to see a couple of test cases
        # print_examples(model, device, dataset)

        if save_model:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "step": step,
            }
            save_checkpoint(checkpoint)

        for idx, (imgs, captions) in tqdm(
            enumerate(loader), total=len(loader), leave=False
        ):
            imgs = imgs.to(device)
            captions = captions.to(device)

            outputs = model(imgs, captions[:-1])
            loss = criterion(
                outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
            )

            writer.add_scalar("Training loss", loss.item(), global_step=step)
            step += 1

            optimizer.zero_grad()
            loss.backward(loss)
            optimizer.step()

In [None]:
train()

In [None]:
class Flickr_Dataset():
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # Get img, caption columns
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

    def __len__(self):
        return len(self.df)


    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id))

        if self.transform is not None:
            img = self.transform(img)

        return img, caption

In [None]:
import random
dataset = Flickr_Dataset(root_dir="/content/drive/MyDrive/Images", captions_file="/content/drive/MyDrive/captions.txt")

In [None]:
random_indices = random.sample(range(len(dataset)), 20)

images_and_captions = []
for index in random_indices:
    image, caption = dataset[index]
    images_and_captions.append((image, caption))

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),   # Resize to a square image
    transforms.CenterCrop(224),      # Crop the center to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

for idx, (img, caption) in enumerate(images_and_captions):
    # Ensure img is a tensor before applying transforms
    if not torch.is_tensor(img):
        img = transforms.ToTensor()(img)  # Convert to tensor if necessary

    # Display the image
    img_pil = transforms.ToPILImage()(img.cpu())
    plt.imshow(img_pil)

    # Apply transforms for model input
    img_tensor = transform(img_pil.convert("RGB")).unsqueeze(0)

    # Forward pass and predictions
    with torch.no_grad():
        predicted_caption = model.caption_image(img_tensor.to(device), vocab)
    predicted_caption = predicted_caption[1:-1]

    spaced_string = ""
    for item in predicted_caption:
      spaced_string += item + " "  # Append item and a space

    print(f'True Caption: {caption}\nPredicted Caption: ')
    print(spaced_string)
    plt.show()
