In [1]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image
import os
import numpy as np
import pandas as pd
import cv2
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, models
import torch.optim as optim

In [2]:
data = pd.read_csv('data/results.csv', sep='|')
data['captions'] = data[' comment']
data.drop(labels=[' comment_number', ' comment'], inplace=True, axis=1)
data.dropna(inplace=True)

dictio_map = {}
for n,i in enumerate(set(data.image_name)):
    dictio_map[str(i)] = n

data['mapper_img_cap'] = data['image_name'].map(dictio_map)

data.to_csv('data/captions.csv', index=False)

In [8]:
import glob
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.image as mpimg
import pandas as pd
import cv2
from PIL import Image



class ImageCaptionDataset(Dataset):
    """Image Caption dataset."""

    def __init__(self, csv_file, root_dir, transform_img=None, transform_cap=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.captions = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform_img = transform_img
        self.transform_cap = transform_cap

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        
        # take the image name contained in the csv file
        
        list_image_names = list(self.captions[self.captions.iloc[:,2] == idx].iloc[:, 0])
        
        if len(list_image_names) > 0:
            image_name = os.path.join(self.root_dir, list_image_names[0])
            
            # read the true image based on that name
            #image = mpimg.imread(image_name) #Numpy
            image = Image.open(image_name) #PIL                
        
        else:
            
            image = Image.fromarray(np.zeros((224, 224)))
        
        
        if self.transform_img:
                image = self.transform_img(image)
        
        # read captions
        captions = self.captions[self.captions.iloc[:, 2] == idx].iloc[:, 1].values
        
        if len(captions) > 0:
            captions = captions
            
            if self.transform_cap:
                captions = self.transform_cap(captions)
        else:
            captions = 'NULL'
        
   
        
        
        sample = {'image': image, 'captions': captions}

        return sample

In [9]:
batch_size = 10
num_workers = 4
csv_file = 'data/captions.csv'
root_dir = 'data/flickr30k_images'

transform_img = transforms.Compose([
    transforms.Resize(255),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [10]:
valid_size = 0.3

def train_valid_split(training_set, validation_size):
    """ Function that split our dataset into train and validation
        given in parameter the training set and the % of sample for validation"""
    
    # obtain training indices that will be used for validation
    num_train = len(training_set)
    indices = list(range(num_train))
    np.random.shuffle(indices)
    split = int(np.floor(validation_size * num_train))
    train_idx, valid_idx = indices[split:], indices[:split]

    # define samplers for obtaining training and validation batches
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    
    return train_sampler, valid_sampler




train_set = ImageCaptionDataset(csv_file=csv_file,
                                root_dir=root_dir,
                                transform_img=transform_img)

train_sampler, valid_sampler = train_valid_split(train_set, valid_size)


train_loader = DataLoader(train_set,
                          batch_size=batch_size,
                          sampler=train_sampler,
                          num_workers=num_workers)

valid_loader = torch.utils.data.DataLoader(train_set,
                                           batch_size=batch_size,
                                           sampler=valid_sampler,
                                           num_workers=num_workers)

In [11]:
batch = next(iter(train_loader))
len(batch)

TypeError: Traceback (most recent call last):
  File "/home/mdhvince/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/mdhvince/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 229, in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
  File "/home/mdhvince/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 229, in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
  File "/home/mdhvince/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 234, in default_collate
    raise TypeError((error_msg.format(type(batch[0]))))
TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'PIL.Image.Image'>


In [None]:
plt.figure(figsize=(15,15))

# obtain one batch of training images
batch = next(iter(train_loader))

# display 10 images
for i in np.arange(10):
    
    images, labels = batch['image'], batch['captions']
    
    #unormalize images
    image = images[i].numpy()
    image = np.transpose(image, (1, 2, 0))
    
    plt.subplot(2,5,i+1)
    plt.imshow(np.squeeze(image), cmap='gray')
    plt.title(labels)