#### 1. Import Libraries

In [5]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


import argparse
import os
import random

import torch
import torch.nn as nn

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as utils

from torch.utils.data import Dataset, DataLoader
from PIL import Image

import requests, sys, cv2


Download the dataset from miniImageNet

In [2]:
import gdown, tarfile

url = 'https://drive.google.com/uc?id=107FTosYIeBn5QbynR46YG91nHcJ70whs'
path = './Data/miniImageNet/'
output = './Data/miniImageNet/train.tar'

if not os.path.exists(path):
    os.mkdir(path)

if not os.path.exists(output):
    gdown.download(url, output, quiet=False)

    with tarfile.open(output, 'r') as tar:
        tar.extractall(path)

Download the dataset from EuroSAT(RGB)

In [3]:
import zipfile

link = 'https://zenodo.org/records/7711810/files/EuroSAT_RGB.zip?download=1'
destination = './Data/EuroSAT_RGB.zip'

if not os.path.exists(destination): 
  # Send a GET request to the URL
  response = requests.get(link)

  # Check if the request was successful (status code 200)
  if response.status_code == 200:
      with open(destination, 'wb') as file:
          file.write(response.content)

      with zipfile.ZipFile(destination, 'r') as zip_ref:
              zip_ref.extractall('./Data')

      print(f"EuroSAT_RGB downloaded successfully to '{destination}'.")
  else:
      print(f"Failed to download EuroSAT_RGB. Status code: {response.status_code}")

#### 2. Create Datasets

In this section I will create a dataLoader for EuroSAT_RGB and miniImageNet dataset

In [6]:
import pathlib 
import math

train_data_dir_path = './Data/miniImageNet/train/'
miniImageNet_data_dir = pathlib.Path(train_data_dir_path)

# Dataset uses the folder name as its labels
subfolder_names = [subfolder.name for subfolder in miniImageNet_data_dir.glob('*') if subfolder.is_dir()]

train_indices = []
valid_indices = []
test_indices = []

all_data = []
all_data_labels = []

for subfolder in miniImageNet_data_dir.iterdir():
  label_name = subfolder.name
  count = 0 #This count keeps track of the images count for a particular class

  for item in subfolder.iterdir():
    img = np.array(cv2.imread(str(item)))/255
    all_data.append(img)
    all_data_labels.append(label_name)
    count +=1
  all_data_size = len(all_data)

  # get train indices
  start_index = all_data_size - count
  end_index = start_index + math.ceil(count * 0.6)
  train_indices.extend(range(start_index, end_index, 1))

  # get validation indices
  start_index = end_index + 1
  end_index = start_index + math.ceil(count * 0.2)
  valid_indices.extend(range(start_index, end_index, 1))

  # get test indices
  start_index = end_index + 1
  end_index = start_index + math.ceil(count * 0.2)
  test_indices.extend(range(start_index, end_index, 1))


class MiniImageNet(Dataset):
  
  def __init__(self, phase="train", transform=None):

    all_data_np = np.array(all_data)
    all_data_labels_np = np.array(all_data_labels)
    if 'train' == phase:
        self.data = all_data_np[train_indices]
        self.labels = all_data_labels_np[train_indices]

    elif 'valid' == phase:
        self.data = all_data_np[valid_indices]
        self.labels = all_data_labels_np[valid_indices]

    elif 'test' == phase:
        self.data = all_data_np[test_indices]
        self.labels = all_data_labels_np[test_indices]

    else:
        assert True, 'wrong phase'
  
    self.transform = transform 
    
    self.label_names = subfolder_names

  def __len__(self):
      return len(self.data)

  def __getitem__(self, index):
      
      img, label = self.data[index], self.labels[index]

      img = Image.fromarray(img, mode='L')

      if self.transform is not None:
          img = self.transform(img)

      return img, label


In [7]:
# Set batch_size to 64, shuffling the training set. Number of workers here is set to 2
data_transform = transforms.Compose([transforms.ToTensor()])

train_set = MiniImageNet(phase='train', transform=data_transform) 
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)

valid_set = MiniImageNet(phase='valid', transform=data_transform) 
valid_loader = DataLoader(valid_set, batch_size=64, shuffle=True, num_workers=2)

test_set = MiniImageNet(phase='test', transform=data_transform) 
test_loader = DataLoader(test_set, batch_size=64, shuffle=True, num_workers=2)

KeyboardInterrupt: 