In [2]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt 
import pandas as pd 

## Settings

In [3]:
data_dir = "./data/cat-dog"
batch_size = 4

## Utils

In [4]:
import random

def flip_lr(x): 
# horizontal flip transformation
    if random.random()>0.5:
        x = np.fliplr(x).copy()
    return x

## Dataset

In [5]:
from torch.utils.data import Dataset
from skimage import io, img_as_float32
import os

class CatDogDataset(Dataset):
    def __init__(self, data_dir, is_train=False, transform=None):
        self.data_dir = data_dir
        self.is_train = is_train
        if is_train:
            self.data_dir = os.path.join(data_dir, "train")
        else:
            self.data_dir = os.path.join(data_dir, "test")
        self.img_ids = os.listdir(self.data_dir)
        self.transform = transform
            
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_p = os.path.join(self.data_dir, img_id)
        img = img_as_float32(io.imread(img_p)).transpose((2,0,1))
        if self.transform is not None:
            img = self.transform(img)
        
        out = {}
        out['name'] = img_id
        if self.is_train:
            out['img'] = img
            out['label'] = 0 if img_id.split("_")[1]=='cat' else 1
        else:
            out['img'] = img
            out['label'] = 0 if img_id.split("_")[1]=='cat' else 1
        
        return out

In [6]:
from torch.utils.data import DataLoader
dataset = CatDogDataset(data_dir, is_train=True, transform=flip_lr)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
for x_i, x in enumerate(dataloader):
    print(x['name'])
    print(x['label'])
    print(x['img'].shape)
    break

['flickr_dog_000013.jpg', 'flickr_cat_000054.jpg', 'flickr_dog_000044.jpg', 'flickr_dog_000002.jpg']
tensor([1, 0, 1, 1])
torch.Size([4, 3, 512, 512])


## Model

In [7]:
class ClassifyCatDog(nn.Module):
    def __init__(self):
        super(ClassifyCatDog, self).__init__()
        # Encoder
        self.encoder = nn.ModuleList()
        self.encoder.append(nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1))
        self.encoder.append(nn.ReLU())
        self.encoder.append(nn.MaxPool2d(8, 8))
        self.encoder.append(nn.Conv2d(4, 4, 3, stride=1, padding=1))
        self.encoder.append(nn.ReLU())
        self.encoder.append(nn.MaxPool2d(8, 8))
        # Decoder 1
        self.decoder_1 = nn.ModuleList()
        self.decoder_1.append(nn.Conv2d(4, 3, 3, stride=1, padding=1))
        self.decoder_1.append(nn.ReLU())
        self.decoder_1.append(nn.MaxPool2d(4, 4))
        # Decoder 2
        self.decoder_2 = nn.ModuleList()
        self.decoder_2.append(nn.Conv2d(4, 3, 3, stride=1, padding=1))
        self.decoder_2.append(nn.ReLU())
        self.decoder_2.append(nn.MaxPool2d(4, 4))
        # Classifier
        self.classifier = nn.ModuleList()
        self.classifier.append(nn.Linear(3*2*2, 4))
        self.classifier.append(nn.ReLU())
        self.classifier.append(nn.Linear(4, 2))
    def forward(self, x, d_idx):
        # n*512*512*3
        # Encoder
        for layer in self.encoder:
            x = layer(x)
        # Decoder
        if d_idx == 1:
            for layer in self.decoder_1:
                x = layer(x)
        elif d_idx == 2:
            for layer in self.decoder_2:
                x = layer(x)
        # Classifier
        x = x.view(-1, 3*2*2)
        for layer in self.classifier:
            x = layer(x)
        return x

discriminator = ClassifyCatDog()
if torch.cuda.is_available():
    discriminator = discriminator.cuda()

import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer_all = optim.Adam(discriminator.parameters(), lr=0.1, betas=(0.9, 0.999))

def printParam(m):
    for name, param in m.named_parameters():
        print(name, param.data)

for x_i, x in enumerate(dataloader):
    imgs   = x['img']
    labels = x['label']
    imgs   = imgs.cuda()
    labels = labels.cuda()
    
    y  = discriminator(imgs, x_i%2+1)
    loss = criterion(y, labels)
    loss.backward()
    
    print(loss)
    if x_i%2+1 == 1: # path go through decoder_1
#         print("decoder 1 pre:")
#         printParam(discriminator.decoder_2)
        optimizer_all.step()
        optimizer_all.zero_grad()
#         print("decoder 1 after:")
#         printParam(discriminator.decoder_2)
    else: # path go through decoder_2
        print("decoder 1 0:")
        printParam(discriminator.decoder_1)
        optimizer_all.step()
#         optimizer_all.zero_grad()
        print("decoder 1 1:")
        printParam(discriminator.decoder_1)
        optimizer_all.step()
        optimizer_all.zero_grad()
        print("decoder 1 2:")
        printParam(discriminator.decoder_1)
    
    optimizer_ec.step()
    optimizer_ec.zero_grad()
    
    if x_i == 1:
        break

tensor(0.7150, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(0.8057, device='cuda:0', grad_fn=<NllLossBackward>)
decoder 1 0:
0.weight tensor([[[[-0.1051,  0.1584,  0.0214],
          [-0.0344,  0.1017,  0.1629],
          [ 0.0131, -0.1128,  0.1452]],

         [[-0.0218, -0.1395, -0.0552],
          [-0.0219,  0.1536, -0.1387],
          [ 0.1242, -0.0277,  0.0546]],

         [[-0.0730, -0.1994,  0.1896],
          [-0.0542, -0.0989, -0.0320],
          [-0.0675,  0.0180, -0.1300]],

         [[-0.0088,  0.0611,  0.2087],
          [ 0.0054,  0.0527,  0.1134],
          [-0.2596, -0.2417, -0.0826]]],


        [[[-0.0318, -0.1031,  0.0638],
          [ 0.0963, -0.1222, -0.1077],
          [ 0.0555, -0.1407, -0.0565]],

         [[ 0.0491, -0.0867,  0.1140],
          [ 0.1107, -0.1234,  0.1190],
          [-0.1128,  0.1495,  0.0830]],

         [[ 0.1259, -0.0208, -0.1920],
          [ 0.2220,  0.2317,  0.0898],
          [-0.0436, -0.2050,  0.2264]],

         [[-0.1584, -0.18

In [8]:
print(discriminator.parameters())

<generator object Module.parameters at 0x7f1a26f59cd0>


## Train loop

In [137]:
for epoch in range(10):
    for x_i, x in enumerate(dataloader):
        if x_i%10==9:
            print(x_i)

9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
