In [2]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt 
import pandas as pd 

## Settings

In [3]:
data_dir = "./data/cat-dog"
batch_size = 4
img_edge = 512

## Utils

In [4]:
import random

def flip_lr(x): 
# horizontal flip transformation
    if random.random()>0.5:
        x = np.fliplr(x).copy()
#     return torch.from_numpy(x)
    return x

## Dataset

In [5]:
from torch.utils.data import Dataset
from skimage import io, img_as_float32
import os

class CatDogDataset(Dataset):
    def __init__(self, data_dir, is_train=False, transform=None):
        self.data_dir = data_dir
        self.is_train = is_train
        if is_train:
            self.data_dir = os.path.join(data_dir, "train")
        else:
            self.data_dir = os.path.join(data_dir, "test")
        self.img_ids = os.listdir(self.data_dir)
        self.transform = transform
            
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_p = os.path.join(self.data_dir, img_id)
        img = img_as_float32(io.imread(img_p)).transpose((2,0,1))
        if self.transform is not None:
            img = self.transform(img)
        
        out = {}
        out['name'] = img_id
        if self.is_train:
            out['img'] = img
            out['label'] = 0 if img_id.split("_")[1]=='cat' else 1
        else:
            out['img'] = img
            out['label'] = 0 if img_id.split("_")[1]=='cat' else 1
        
        return out

In [6]:
from torch.utils.data import DataLoader
dataset = CatDogDataset(data_dir, is_train=True, transform=flip_lr)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
for x_i, x in enumerate(dataloader):
    print(x['name'])
    print(x['label'])
    print(x['img'].shape)
    break

['flickr_cat_000083.jpg', 'flickr_cat_000002.jpg', 'flickr_dog_000073.jpg', 'flickr_dog_000085.jpg']
tensor([0, 0, 1, 1])
torch.Size([4, 3, 512, 512])


## Model

In [21]:
class ClassifyCatDog(nn.Module):
    def __init__(self):
        super(ClassifyCatDog, self).__init__()
        # Encoder
        self.encoder = nn.ModuleList()
        self.encoder.append(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1))
        self.encoder.append(nn.ReLU())
        self.encoder.append(nn.Conv2d(64, 64, 3, stride=1, padding=1))
        self.encoder.append(nn.ReLU())
        self.encoder.append(nn.MaxPool2d(2, 2))
        # Decoder 1
        self.decoder_1 = nn.ModuleList()
        self.decoder_1.append(nn.Conv2d(64, 3, 3, stride=1, padding=1))
        self.decoder_1.append(nn.ReLU())
        self.decoder_1.append(nn.Conv2d(3, 3, 3, stride=1, padding=1))
        self.decoder_1.append(nn.ReLU())
        # Decoder 2
        self.decoder_2 = nn.ModuleList()
        self.decoder_2.append(nn.Conv2d(64, 8, 3, stride=1, padding=1))
        self.decoder_2.append(nn.ReLU())
        self.decoder_2.append(nn.Conv2d(8, 1, 3, stride=1, padding=1))
        self.decoder_2.append(nn.ReLU())
    def forward(self, x, d_idx):
    # n*512*512*3
        for layer in self.encoder:
            x = layer(x)
        if d_idx == 1:
            for layer in self.decoder_1:
                x = layer(x)
        elif d_idx == 2:
            for layer in self.decoder_2:
                x = layer(x)
        return x

discriminator = ClassifyCatDog()
if torch.cuda.is_available():
    discriminator.cuda()

for x_i, x in enumerate(dataloader):
    x.cuda()
    print(x['img'].shape, discriminator(x['img'], 1).shape)
    break

AttributeError: 'dict' object has no attribute 'cuda'

## Train loop

In [137]:
for epoch in range(10):
    for x_i, x in enumerate(dataloader):
        if x_i%10==9:
            print(x_i)

9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
