In [None]:
# Imports

In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [2]:
# Paths
image_folder = "data/images/"
captions_file = "data/captions.txt"

# Read captions
captions_data = pd.read_csv(captions_file, delimiter='\t', names=['image', 'caption'])
captions_data.head()

Unnamed: 0,image,caption
0,"image,caption",
1,"1000268201_693b08cb0e.jpg,A child in a pink dr...",
2,"1000268201_693b08cb0e.jpg,A girl going into a ...",
3,"1000268201_693b08cb0e.jpg,A little girl climbi...",
4,"1000268201_693b08cb0e.jpg,A little girl climbi...",


In [3]:
import re
from collections import Counter

class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"}
        self.stoi = {v:k for k,v in self.itos.items()}
    
    def tokenizer(self, text):
        return re.findall(r'\w+', text.lower())
    
    def build_vocab(self, sentence_list):
        frequencies = Counter()
        idx = 4
        for sentence in sentence_list:
            for word in self.tokenizer(sentence):
                frequencies[word] +=1
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer(text)
        return [self.stoi.get(token, self.stoi["<UNK>"]) for token in tokenized_text]

# Build vocabulary
vocab = Vocabulary(freq_threshold=5)
vocab.build_vocab(captions_data['caption'].tolist())

AttributeError: 'float' object has no attribute 'lower'

In [None]:
class FlickrDataset(Dataset):
    def __init__(self, dataframe, img_folder, vocab, transform=None):
        self.df = dataframe
        self.img_folder = img_folder
        self.vocab = vocab
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['image']
        caption = self.df.iloc[idx]['caption']
        
        image = Image.open(os.path.join(self.img_folder, img_name)).convert("RGB")
        if self.transform:
            image = self.transform(image)
        
        numericalized_caption = [vocab.stoi["<SOS>"]] + vocab.numericalize(caption) + [vocab.stoi["<EOS>"]]
        return image, torch.tensor(numericalized_caption)

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

dataset = FlickrDataset(captions_data, image_folder, vocab, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=lambda x: x)