In [1]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt 
import pandas as pd 

## Settings

In [2]:
data_dir = "./data/cat-dog"
batch_size = 4
img_edge = 512

## Utils

In [3]:
import random

def flip_lr(x): 
# horizontal flip transformation
    if random.random()>0.5:
        x = np.fliplr(x).copy()
#     return torch.from_numpy(x)
    return x

## Dataset

In [28]:
from torch.utils.data import Dataset
from skimage import io, img_as_float32
import os

class CatDogDataset(Dataset):
    def __init__(self, data_dir, is_train=False, transform=None):
        self.data_dir = data_dir
        self.is_train = is_train
        if is_train:
            self.data_dir = os.path.join(data_dir, "train")
        else:
            self.data_dir = os.path.join(data_dir, "test")
        self.img_ids = os.listdir(self.data_dir)
        self.transform = transform
            
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_p = os.path.join(self.data_dir, img_id)
        img = img_as_float32(io.imread(img_p)).transpose((2,0,1))
        if self.transform is not None:
            img = self.transform(img)
        
        out = {}
        out['name'] = img_id
        if self.is_train:
            out['img'] = img
            out['label'] = 0 if img_id.split("_")[1]=='cat' else 1
        else:
            out['img'] = img
            out['label'] = 0 if img_id.split("_")[1]=='cat' else 1
        
        return out

In [29]:
from torch.utils.data import DataLoader
dataset = CatDogDataset(data_dir, is_train=True, transform=flip_lr)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
for x_i, x in enumerate(dataloader):
    print(x['name'])
    print(x['label'])
    print(x['img'].shape)
    break

['flickr_dog_000042.jpg', 'flickr_dog_000061.jpg', 'flickr_dog_000010.jpg', 'flickr_dog_000039.jpg']
tensor([1, 1, 1, 1])
torch.Size([4, 3, 512, 512])


## Model

In [44]:
class ClassifyCatDog(nn.Module):
    def __init__(self):
        super(ClassifyCatDog, self).__init__()
        # Encoder
        self.encoder = nn.ModuleList()
        self.encoder.append(nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1))
        self.encoder.append(nn.ReLU())
        self.encoder.append(nn.MaxPool2d(8, 8))
        self.encoder.append(nn.Conv2d(4, 4, 3, stride=1, padding=1))
        self.encoder.append(nn.ReLU())
        self.encoder.append(nn.MaxPool2d(8, 8))
        # Decoder 1
        self.decoder_1 = nn.ModuleList()
        self.decoder_1.append(nn.Conv2d(4, 3, 3, stride=1, padding=1))
        self.decoder_1.append(nn.ReLU())
        self.decoder_1.append(nn.MaxPool2d(4, 4))
        # Decoder 2
        self.decoder_2 = nn.ModuleList()
        self.decoder_2.append(nn.Conv2d(4, 3, 3, stride=1, padding=1))
        self.decoder_2.append(nn.ReLU())
        self.decoder_2.append(nn.MaxPool2d(4, 4))
        # Classifier
        self.classifier = nn.ModuleList()
        self.classifier.append(nn.Linear(3*2*2, 4))
        self.classifier.append(nn.ReLU())
        self.classifier.append(nn.Linear(4, 2))
    def forward(self, x, d_idx):
        # n*512*512*3
        # Encoder
        for layer in self.encoder:
            x = layer(x)
        # Decoder
        if d_idx == 1:
            for layer in self.decoder_1:
                x = layer(x)
        elif d_idx == 2:
            for layer in self.decoder_2:
                x = layer(x)
        # Classifier
        x = x.view(-1, 3*2*2)
        for layer in self.classifier:
            x = layer(x)
        return x

discriminator = ClassifyCatDog()
if torch.cuda.is_available():
    discriminator = discriminator.cuda()

import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer_ec = optim.Adam(discriminator.encoder.parameters(), lr=0.1, betas=(0.9, 0.999))
optimizer_dc_1 = optim.Adam(discriminator.decoder_1.parameters(), lr=0.1, betas=(0.9, 0.999))
optimizer_dc_2 = optim.Adam(discriminator.decoder_2.parameters(), lr=0.1, betas=(0.9, 0.999))
optimizer_cl = optim.Adam(discriminator.classifier.parameters(), lr=0.1, betas=(0.9, 0.999))
optimizer_all = optim.Adam(discriminator.parameters(), lr=0.1, betas=(0.9, 0.999))

def printParam(m):
    for name, param in m.named_parameters():
        print(name, param.data)

for x_i, x in enumerate(dataloader):
    imgs   = x['img']
    labels = x['label']
    imgs   = imgs.cuda()
    labels = labels.cuda()
    
    y  = discriminator(imgs, x_i%2+1)
    loss = criterion(y, labels)
    loss.backward()
    
    print(loss)
    if x_i%2+1 == 1:
        print("decoder 1 pre:")
        printParam(discriminator.decoder_2)
        optimizer_all.step()
        optimizer_all.zero_grad()
        print("decoder 1 after:")
        printParam(discriminator.decoder_2)
        print("decoder 1 end.")
    else:
        print("decoder 2 pre:")
        printParam(discriminator.decoder_1)
        optimizer_all.step()
        optimizer_all.zero_grad()
        print("decoder 2 after:")
        printParam(discriminator.decoder_1)
        print("decoder 2 end.")
    
    optimizer_ec.step()
    optimizer_ec.zero_grad()
    
    if x_i == 1:
        break

tensor(0.6996, device='cuda:0', grad_fn=<NllLossBackward>)
decoder 1 pre:
0.weight tensor([[[[ 0.0018,  0.1102, -0.1208],
          [-0.1125, -0.1494, -0.0801],
          [ 0.1488, -0.0597,  0.0584]],

         [[-0.1351, -0.0158,  0.1244],
          [ 0.0504,  0.1350, -0.0552],
          [ 0.1050, -0.0951, -0.0045]],

         [[ 0.0207, -0.0632,  0.1141],
          [ 0.0381, -0.0514, -0.1076],
          [ 0.1202,  0.0731,  0.0537]],

         [[-0.0660,  0.1004,  0.1133],
          [ 0.0226,  0.1333, -0.0049],
          [ 0.0266,  0.0897, -0.0640]]],


        [[[-0.0937,  0.0133,  0.0261],
          [-0.1330,  0.1154, -0.1417],
          [-0.0073,  0.1626, -0.0954]],

         [[ 0.1268,  0.0268, -0.0202],
          [-0.0229,  0.0610, -0.1383],
          [ 0.1023, -0.0029,  0.0923]],

         [[ 0.0034,  0.1171, -0.1101],
          [ 0.0209,  0.1633,  0.0505],
          [ 0.0991, -0.0437,  0.0201]],

         [[ 0.1666,  0.0018, -0.1574],
          [ 0.1601, -0.1250, -0.1046],
    

In [16]:
print(discriminator.parameters())

<generator object Module.parameters at 0x7f07e7932950>


## Train loop

In [137]:
for epoch in range(10):
    for x_i, x in enumerate(dataloader):
        if x_i%10==9:
            print(x_i)

9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
9
19
29
39
