# Bird Classification using transfer learning

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import numpy as np

In [5]:
DATASET_DIR = '/media/grzetan/445C33B25C339E1C/datasets/birds'

class BirdDataset(Dataset):
    def __init__(self, root_dir, random_state=42, transform=None):
        np.random.seed(random_state)
        self.root = root_dir
        self.classes = os.listdir(self.root)
        self.paths = []
        self.labels = []
        for cls in self.classes:
            for p in os.listdir(os.path.join(self.root, cls)):
                self.paths.append(os.path.join(self.root, cls, p))
                self.labels.append(self.classes.index(cls))
        self.paths = np.array(self.paths)
        self.labels = np.array(self.labels)
        idx = np.random.permutation(len(self.paths))
        self.paths = self.paths[idx]
        self.labels = self.labels[idx]
        self.transform = transform
        
    def __getitem__(self, idx):
        label = self.labels[idx]
        img = np.asarray(Image.open(self.paths[idx]))
        
        sample = (img, label)
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample

## Transforms

In [76]:
class ToTensor(object):
    def __call__(self, sample):
        img, label = sample
        img = img / 255
        img = img.transpose((2,0,1))
        return (torch.from_numpy(img), torch.tensor(label))
    
class Normalize(object):
    def __init__(self, mean, std):
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)
    
    def __call__(self, sample):
        img, label = sample
        img = (img - self.mean[:,None,None]) / self.std[:,None,None]
        return img, label        

In [77]:
normalize = Normalize(mean=[0.4451, 0.4262, 0.3959], std=[0.2411, 0.2403, 0.2466])

train_set = BirdDataset(os.path.join(DATASET_DIR, 'train'), 
                      transform=transforms.Compose([ToTensor(), normalize]))
test_set = BirdDataset(os.path.join(DATASET_DIR, 'test'), 
                      transform=transforms.Compose([ToTensor(), normalize]))
val_set = BirdDataset(os.path.join(DATASET_DIR, 'valid'), 
                      transform=transforms.Compose([ToTensor(), normalize]))