<a href="https://colab.research.google.com/github/DavoodSZ1993/Dive_into_Deep_Learning/blob/main/14_6_the_object_detection_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install d2l==1.0.0-alpha1.post0 --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.0/93.0 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.9/121.9 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.9/84.9 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m40.4 MB/s[0m eta [36m0:00:00[0m
[?25h

## 14.6 The Object Detection Dataset

### 14.6.1 Downloading the Dataset

In [2]:
%matplotlib inline
import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l

In [3]:
d2l.DATA_HUB['banana-detection'] = (
    d2l.DATA_URL + 'banana-detection.zip',
    '5de26c8fce5ccdea9f91267273464dc968d20d72')

### 14.6.2 Reading the Dataset

In [4]:
def read_data_bananas(is_train=True):
  """Read the banana detection dataset images and labels."""
  data_dir = d2l.download_extract('banana-detection')
  csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
                           else 'bananas_val', 'label.csv')
  csv_data = pd.read_csv(csv_fname)
  csv_data = csv_data.set_index('img_name')
  images, targets = [], []
  for img_name, target in csv_data.iterrows():
    images.append(torchvision.io.read_image(
        os.path.join(data_dir, 'bananas_train' if is_train else
                     'bananas_val', 'images', f'{img_name}')))
    targets.append(list(target))
  return images, torch.tensor(targets).unsqueeze(1) / 256

In [5]:
class BananasDataset(torch.utils.data.Dataset):
  """A Customized dataset to load the banana detection dataset."""
  def __init__(self, is_train):
    self.features, self.labels = read_data_bananas(is_train)
    print('read' + str(len(self.features)) + (f' training examples' if 
          is_train else f' validation examples'))
    
  def __getitem__(self, idx):
    return (self.features[idx].float(), self.labels[idx])

  def __len__(self):
    return len(self.features)

In [6]:
def load_data_bananas(batch_size):
  """Load the banana detection dataset"""
  train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                                           batch_size, shuffle=True)
  val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                                         batch_size)
  return train_iter, val_iter

In [7]:
batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
batch[0].shape, batch[1].shape

Downloading ../data/banana-detection.zip from http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip...
read1000 training examples
read100 validation examples


(torch.Size([32, 3, 256, 256]), torch.Size([32, 1, 5]))