In [1]:
from torch.utils.data import Dataset
from torch.autograd import Variable
from torchvision import transforms
from zutil.config import Config
from PIL import Image
import torchvision.models as models
import numpy as np
import torch.nn as nn
import torch
import os.path
import random
import json

In [6]:
class ArchDataset(Dataset):
    def __init__(self, config):
        self.config = config
        fn_labels = 'static/dataset/handcraft_labels.txt'
        fn_datapair = 'static/dataset/data_pair.json'
        all_labels = [label.strip().split(' ') for label in open(fn_labels, 'r')]
        self.all_labels = [label.decode('utf8') for label, weight in all_labels] # to unicode
        self.n_classes = len(self.all_labels) + 1 # extra class
        self.data_pair = [json.loads(data) for data in open(fn_datapair, 'r')]
        self.img_transform = transforms.Compose([
                        transforms.Scale(100),
                        transforms.CenterCrop(100),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.data_pair)

    def __getitem__(self, index):
        item = self.data_pair[index]
        # get image
        img_name = random.choice(item['images'])['image_name']
        img_path = 'static/dataset/' + os.path.join(item['poster'], img_name)
        image = Image.open(img_path).convert('RGB')
        image = self.img_transform(image)
        # get target
        target = []
        for i, label in enumerate(self.all_labels):
            if label in item['title']:
                target.append(i)
        if len(target) == 0:
            target.append(len(self.all_labels))     # extra class
        for i in range(self.n_classes - len(target)):
            target.append(-1)
        return image, target


    def __iter__(self):
        for i in range(len(self.data_pair)):
            yield self.__getitem__(i)
            
    def next_batch(self, mode='train'):
        cut_point = int(self.__len__() * self.config.split_rate)
        while True:
            batch_data = []
            for i in range(self.config.batch_size):
                if mode == 'train':
                    index = random.randint(0, cut_point-1)
                elif mode == 'valid':
                    index = random.randint(cut_point, self.__len__() - 1)
                else:
                    raise Exception('no such mode!')
                batch_data.append(self.__getitem__(index))
            yield self._split_batch(batch_data)

    def _split_batch(self, batch_data, requires_grad=True):
        images = torch.cat([image.unsqueeze(0) for image, _ in batch_data], 0)
        targets = torch.LongTensor([target for _, target in batch_data])
        images = Variable(images, requires_grad=requires_grad)
        targets = Variable(targets, requires_grad=requires_grad)
        if self.config.gpu:
            images = images.cuda(0)
            targets = targets.cuda(0)
        return images, targets
    
    def raw_image(self):
        for data in self.data_pair[0:10]:
            for image in data['images']:
                img_path = 'static/dataset/' + os.path.join(data['poster'], image['image_name'])
                img_pil = Image.open(img_path).convert('RGB')
                yield self.img_transform(img_pil)
        
    def raw_batch(self):
        batch_data = []
        for data in self.raw_image():
            batch_data.append(data)
            if len(batch_data) == self.config.batch_size:
                yield self._split_batch(batch_data)
                batch_data = []
        if len(batch_data) > 0:
            yield self._split_batch(batch_data, requires_grad=False)

config = Config(batch_size=50, gpu=True, split_rate=0.8, learning_rate=1e-4, save_freq=10, max_epoch=20)
dataset = ArchDataset(config)
print config
print dataset

{'save_freq': 10, 'max_epoch': 20, 'learning_rate': 0.0001, 'batch_size': 50, 'split_rate': 0.8, 'gpu': True}
<__main__.ArchDataset object at 0x7f703cd6f750>


In [3]:
from zutil.convblock import ConvBlockModule
class CNN(nn.Module):
    def __init__(self, config):
        super(CNN, self).__init__()
        self.config = config
        self.conv = ConvBlockModule(dims=[3, 16, 32, 64, 64])
        self.fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
        )
        self.classifier = nn.Linear(512, 150)

    def forward(self, images):
        output = self.conv(images)
        output = output.view(self.config.batch_size, -1)
        output = self.fc(output)
        output = self.classifier(output)
        return output
  

In [9]:
model = CNN(config)
if config.gpu:
    model = model.cuda(0)
print model
criterion = nn.MultiLabelMarginLoss() # or MultiLabelSoftMarginLoss, or KLDivLoss
print criterion

optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

CNN (
  (conv): ConvBlockModule (
    (basic_1_conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
    (basic_1_batchnorm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
    (basic_1_relu): ReLU ()
    (basic_1_maxpool): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (basic_2_conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (basic_2_batchnorm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
    (basic_2_relu): ReLU ()
    (basic_2_maxpool): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (basic_3_conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (basic_3_batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (basic_3_relu): ReLU ()
    (basic_3_maxpool): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (basic_4_conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (basic_4_batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (basic_4_relu): ReLU ()
    (basic_4_maxpool):

In [10]:
def train(epoch):
    epoch_train_loss = []
    for batchid, (images, targets) in enumerate(dataset.next_batch(mode='train')):
        outputs = model(images)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss = loss.data.mean()
        epoch_train_loss.append(loss)
        #print 'epoch = %d, batch_id = %d, train loss = %.3f' %(epoch, batchid, loss)

        if batchid > 10:
            return np.array(epoch_train_loss).mean()

def validate(epoch):
    epoch_valid_loss = []
    for batchid, (images, targets) in enumerate(dataset.next_batch(mode='valid')):
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss = loss.data.mean()
        epoch_valid_loss.append(loss)
        #print 'epoch = %d, batch_id = %d, valid loss = %.3f' %(epoch, batchid, loss)

        if batchid > 3:
            return np.array(epoch_valid_loss).mean()

def do_loop():
    for epoch in range(config.max_epoch):
        model.train()
        train_loss = train(epoch)
        model.eval()
        valid_loss = validate(epoch)
        print 'epoch = %d, train_loss = %.3f, valid_loss = %.3f' % (epoch, train_loss, valid_loss)

        if (epoch+1) % config.save_freq == 0:
            torch.save(model, 'cnn.pt')

do_loop()

epoch = 0, train_loss = 2.020, valid_loss = 2.023
epoch = 1, train_loss = 1.807, valid_loss = 1.932
epoch = 2, train_loss = 1.636, valid_loss = 1.675
epoch = 3, train_loss = 1.400, valid_loss = 1.509
epoch = 4, train_loss = 1.258, valid_loss = 1.363
epoch = 5, train_loss = 1.202, valid_loss = 1.230
epoch = 6, train_loss = 1.110, valid_loss = 1.158
epoch = 7, train_loss = 1.071, valid_loss = 1.053
epoch = 8, train_loss = 0.959, valid_loss = 1.038
epoch = 9, train_loss = 0.957, valid_loss = 0.990
epoch = 10, train_loss = 0.895, valid_loss = 0.982
epoch = 11, train_loss = 0.897, valid_loss = 0.999
epoch = 12, train_loss = 0.849, valid_loss = 0.866
epoch = 13, train_loss = 0.881, valid_loss = 0.892
epoch = 14, train_loss = 0.862, valid_loss = 0.965
epoch = 15, train_loss = 0.844, valid_loss = 0.909
epoch = 16, train_loss = 0.847, valid_loss = 0.926
epoch = 17, train_loss = 0.827, valid_loss = 0.744
epoch = 18, train_loss = 0.836, valid_loss = 0.851
epoch = 19, train_loss = 0.881, valid_los

KeyboardInterrupt: 

In [4]:
model = torch.load('resnet.pt')
model = nn.Sequential(*list(model.children())[:-1])
#model = model.cuda()
print model

Sequential (
  (0): ConvBlockModule (
    (basic_1_conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
    (basic_1_batchnorm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
    (basic_1_relu): ReLU ()
    (basic_1_maxpool): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (basic_2_conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (basic_2_batchnorm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
    (basic_2_relu): ReLU ()
    (basic_2_maxpool): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (basic_3_conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (basic_3_batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (basic_3_relu): ReLU ()
    (basic_3_maxpool): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (basic_4_conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (basic_4_batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (basic_4_relu): ReLU ()
    (basic_4_maxpo

In [7]:
outputs = []
for batchid, (images, targets) in enumerate(dataset.next_batch(mode='train')):
    output = model[1](model[0](images).view(50, -1))
    output = output.data.cpu()
    outputs.append(output)
    print batchid, output.size(), output.mean()

0 torch.Size([50, 512]) 0.207776700285
1 torch.Size([50, 512]) 0.204154449434
2 torch.Size([50, 512]) 0.20391350588
3 torch.Size([50, 512]) 0.210842076445
4 torch.Size([50, 512]) 0.209852579185
5 torch.Size([50, 512]) 0.206684084371
6 torch.Size([50, 512]) 0.203037867238
7 torch.Size([50, 512]) 0.212064444954
8 torch.Size([50, 512]) 0.204439554277
9 torch.Size([50, 512]) 0.204937258174
10 torch.Size([50, 512]) 0.206845387965
11 torch.Size([50, 512]) 0.21078286467
12 torch.Size([50, 512]) 0.201878118938
13 torch.Size([50, 512]) 0.204403856673
14 torch.Size([50, 512]) 0.212976426154
15 torch.Size([50, 512]) 0.210995226955
16 torch.Size([50, 512]) 0.202853717839
17 torch.Size([50, 512]) 0.204087724868
18 torch.Size([50, 512]) 0.205268760841
19 torch.Size([50, 512]) 0.213421289502
20 torch.Size([50, 512]) 0.209668488791
21 torch.Size([50, 512]) 0.211878580473
22 torch.Size([50, 512]) 0.206474140396
23 torch.Size([50, 512]) 0.209425104698
24 torch.Size([50, 512]) 0.210581827986
25 torch.Siz

KeyboardInterrupt: 