# Packages import, dir creation, log functions

In [None]:
# import packages
import os
import csv
import random
import tarfile
import multiprocessing as mp

import tqdm
import requests

import numpy as np
import sklearn.model_selection as skms

import torch
import torch.utils.data as td
import torch.nn.functional as F

import torchvision as tv
import torchvision.transforms.functional as TF

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.image as mpimg


# define constants
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)
OUT_DIR = 'results'
RANDOM_SEED = 42

# create an output folder
os.makedirs(OUT_DIR, exist_ok=True)


def get_model_desc(pretrained=False, num_classes=200, use_attention=False):
    """
    Generates description string.  
    """
    desc = list()

    if pretrained:
        desc.append('Transfer')
    else:
        desc.append('Baseline')

    if num_classes == 204:
        desc.append('Multitask')

    if use_attention:
        desc.append('Attention')

    return '-'.join(desc)


def log_accuracy(path_to_csv, desc, acc, sep='\t', newline='\n'):
    """
    Logs accuracy into a CSV-file.
    """
    file_exists = os.path.exists(path_to_csv)

    mode = 'a'
    if not file_exists:
        mode += '+'

    with open(path_to_csv, mode) as csv:
        if not file_exists:
            csv.write(f'setup{sep}accuracy{newline}')

        csv.write(f'{desc}{sep}{acc}{newline}')

# Dataset import

In [4]:
class GoogleDriveDownloader(object):
    """
    Downloading a file stored on Google Drive by its URL.
    If the link is pointing to another resource, the redirect chain is being expanded.
    Returns the output path.
    """
    
    base_url = 'https://docs.google.com/uc?export=download'
    chunk_size = 32768
    
    def __init__(self, url, out_dir):
        super().__init__()
        
        self.out_name = url.rsplit('/', 1)[-1]
        self.url = self._get_redirect_url(url)
        self.out_dir = out_dir
    
    @staticmethod
    def _get_redirect_url(url):
        response = requests.get(url)
        if response.url != url and response.url is not None:
            redirect_url = response.url
            return redirect_url
        else:
            return url
    
    @staticmethod
    def _get_confirm_token(response):
        for key, value in response.cookies.items():
            if key.startswith('download_warning'):
                return value
        return None
    
    def _save_response_content(self, response):
        with open(self.fpath, 'wb') as f:
            bar = tqdm.tqdm(total=None)
            progress = 0
            for chunk in response.iter_content(self.chunk_size):
                if chunk:
                    f.write(chunk)
                    progress += len(chunk)
                    bar.update(progress - bar.n)
            bar.close()
    
    @property
    def file_id(self):
        return self.url.split('?')[0].split('/')[-2]
    
    @property
    def fpath(self):
        return os.path.join(self.out_dir, self.out_name)
    
    def download(self):
        os.makedirs(self.out_dir, exist_ok=True)
        
        if os.path.isfile(self.fpath):
            print('File is downloaded yet:', self.fpath)
        else:
            session = requests.Session()
            response = session.get(self.base_url, params={'id': self.file_id}, stream=True)
            token = self._get_confirm_token(response)

            if token:
                response = session.get(self.base_url, params={'id': self.file_id, 'confirm': token}, stream=True)
            else:
                raise RuntimeError()

            self._save_response_content(response)
        
        return self.fpath


# download an archive containing the dataset and store it into the output directory
url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
dl = GoogleDriveDownloader(url, 'data')
dl.download()

1150585339it [00:06, 168350795.91it/s]


'data/CUB_200_2011.tgz'

In [5]:
def extract_tgz(from_path, to_path=None, img_extention='.jpg'):
    """
    Extracts data from '.tgz' file and displays data statistics.
    Returns the output directory name.  
    """
    with tarfile.open(from_path, 'r:gz') as tar:   
        
        if to_path is None:
            out_dir = os.path.splitext(from_path)[0]
        if os.path.isdir(out_dir):
            print('Files are extracted yet.')
        else:
            print('Extracting files...')
        to_path = os.path.dirname(out_dir)

        subdir_and_files = [tarinfo for tarinfo in tar.getmembers()]    
        imgs = [t for t in subdir_and_files if t.name.endswith(img_extention)]
        print('\tClasses: {}\n\tImages: {}'.format(len(set([os.path.dirname(t.name) for t in imgs])), len(imgs)))

        tar.extractall(to_path, members=subdir_and_files)
        
        return out_dir


# extract the downloaded archive & assess data statistics
in_dir_data = extract_tgz(from_path=dl.fpath)

Extracting files...
	Classes: 200
	Images: 11788


In [8]:
in_dir_img = os.path.join(in_dir_data, 'images')

# Dataset exploration

## Image corruption check

In [9]:
def get_filepaths(path_to_data, fileformat='.jpg'):
    """
    Ruturns paths to files of the specified format.  
    """             
    filepaths = list()
    for root, _, finenames in os.walk(path_to_data):
        for fn in finenames:
            if fn.endswith(fileformat):
                filepaths.append(os.path.join(root, fn))
                
    return filepaths


def cleaning_worker(path_to_img):
    """
    Verifies whether the image is corrupted.
    """
    std = np.std(mpimg.imread(path_to_img))
    img_ok = not np.isclose(std, 0.0)
    
    return img_ok, path_to_img


# calculate standard deviation of images
imgs_corrupted = list()
with mp.Pool(processes=mp.cpu_count()) as pool:    
    for img_ok, fn in pool.imap_unordered(cleaning_worker, get_filepaths(in_dir_img)):
        if not img_ok:
            imgs_corrupted.append(fn)

# verify do corrupted images (missing data) exist
print('Corrupted images #:', len(imgs_corrupted))

# clean up the images that aren't OK            
# for fn in imgs_corrupted:
#    os.remove(fn)

Corrupted images #: 0


## Similar species

In [None]:
def plot_simmilar_species(s_name):
  img_sparrows = dict()
  # get all wróble
  sparrows_total = [k for k in os.listdir(in_dir_img) if s_name in k.lower()]
  print("{} species of {} in dataset".format(len(sparrows_total), s_name))
  ####
  some_sparrows = sparrows_total[:5]
  for dirname in some_sparrows:
      imgs = list()
      for dp, _, fn in os.walk(os.path.join(in_dir_img, dirname)):
          imgs.extend(fn)
      img_sparrows[dirname] = imgs
  print(some_sparrows)
  ###
  row_count = 5
  column_count = len(some_sparrows)
  f, ax = plt.subplots(row_count, column_count, figsize=(20, 12))
  f.patch.set_facecolor('white')

  for j in range(row_count):
    for i in range(column_count):
        cls_name = some_sparrows[i]
        img_count = len(img_sparrows[cls_name])
        img_name = img_sparrows[cls_name][j % img_count]
        path_img = os.path.join(os.path.join(in_dir_img, cls_name), img_name)
        ax[j,i].imshow(mpimg.imread(path_img))
        if j == 0:
          ax[j,i].set_title(cls_name.split('.')[-1].replace('_', ' '),  fontsize=15)
        plt.tight_layout()
    plt.tight_layout()
      
  plt.show()

# plot_simmilar_species('sparrow')
# plot_simmilar_species('auklet')
# plot_simmilar_species('blackbird')
plot_simmilar_species('hummingbird')

## Size of images

In [None]:
# calculate image statistics (takes some time to complete)
ds = tv.datasets.ImageFolder(in_dir_img)
shapes = [(img.height, img.width) for img, _ in ds]
heights, widths = [[h for h,_ in shapes], [w for _,w in shapes]]
print('Average sizes:', *map(np.median, zip(*shapes)))

# visualize the distribution of the size of images
fig = plt.figure()
ax = fig.add_subplot(111)

bp = ax.boxplot([heights, widths], patch_artist=True)

ax.set_xticklabels(['height', 'width'])
ax.set_xlabel('image sizes')
ax.set_ylabel('pixels')

plt.show()

## Average image

In [None]:
def pad(img, fill=0, size_max=500):
    """
    Pad image to size_max pixels width and height
    """
    pad_height = max(0, size_max - img.shape[1])
    pad_width = max(0, size_max - img.shape[2])
    
    pad_top = pad_height // 2
    pad_bottom = pad_height - pad_top
    pad_left = pad_width // 2
    pad_right = pad_width - pad_left
    
    return TF.pad(img, (pad_left, pad_top, pad_right, pad_bottom), fill=fill)


ds = tv.datasets.ImageFolder(in_dir_img, transform=tv.transforms.ToTensor())

# average image
img_mean = np.zeros((3, 500, 500))
for img, _ in tqdm.tqdm(ds):
    img = pad(img)
    img_mean += img.numpy()

img_mean = img_mean / len(ds)

# plot average image  
plt.imshow(np.moveaxis(img_mean, 0, 2))
plt.show()

# Testing

In [None]:
sorted(os.listdir(in_dir_img))