In [0]:
!pip3 install http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl 
!pip3 install torchvision
# !pip3 install --no-cache-dir -I pillow
!git clone https://github.com/SeanNaren/warp-ctc.git
!apt install cmake
!cd warp-ctc && mkdir build; cd build && cmake .. && make
!cd warp-ctc/pytorch_binding && python setup.py install

In [0]:
from google.colab import files
import os.path

if not os.path.exists('train'):
    files.upload()
    !tar xzvf train.tar.gz
    
if not os.path.exists('test'):
    files.upload()
    !tar xzvf test.tar.gz

In [0]:
import math
import torch


def meshgrid(x, y):
    a = torch.arange(0, x)
    b = torch.arange(0, y)
    xx = a.repeat(y).view(-1, 1)
    yy = b.view(-1, 1).repeat(1, x).view(-1, 1)
    return torch.cat([xx, yy], 1)


def xywh2xyxy(boxes):
    xy = boxes[..., :2]
    wh = boxes[..., 2:]
    return torch.cat([xy - wh / 2, xy + wh / 2], -1)


def xyxy2xywh(boxes):
    xymin = boxes[:, :2]
    xymax = boxes[:, 2:]
    return torch.cat([(xymin + xymax) / 2, xymax - xymin + 1], 1)


def box_iou(box1, box2):
    lt = torch.max(box1[..., None, :2], box2[:, :2])  # N, M, 2
    rb = torch.min(box1[..., None, 2:], box2[:, 2:])  # N, M, 2

    wh = (rb - lt + 1).clamp(min=0)
    inter = wh[:, :, 0] * wh[:, :, 1]  # N, M
    area1 = (box1[..., 2] - box1[..., 0] + 1) * (box1[..., 3] - box1[..., 1] + 1)
    area2 = (box2[..., 2] - box2[..., 0] + 1) * (box2[..., 3] - box2[..., 1] + 1)
    iou = inter / (area1[:, None] + area2 - inter)
    return iou


def box_nms(bboxes, scores, thres=0.5):
    x1 = bboxes[:, 0]
    y1 = bboxes[:, 1]
    x2 = bboxes[:, 2]
    y2 = bboxes[:, 3]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    _, order = scores.sort(0, descending=True)

    keep = []
    while order.numel() > 0:
        i = order[0]
        keep.append(i)
        if order.numel() == 1:
            break
        xx1 = x1[order[1:]].clamp(min=x1[i])
        yy1 = y1[order[1:]].clamp(min=y1[i])
        xx2 = x2[order[1:]].clamp(max=x2[i])
        yy2 = y2[order[1:]].clamp(max=y2[i])

        w = (xx2 - xx1 + 1).clamp(min=0)
        h = (yy2 - yy1 + 1).clamp(min=0)
        inter = w * h

        ovr = inter / (areas[i] + areas[order[1:]] - inter)
        ids = (ovr <= thres).nonzero().squeeze()
        if ids.numel() == 0:
            break
        order = order[ids+1]
    return torch.LongTensor(keep)


class DataEncoder():
    def __init__(self):
        self.anchor_areas = [16*16.0, 32*32.0, 64*64.]  # p3 -> p7
        # self.anchor_areas = [32*32.0, 64*64., 128*128.0]
        self.aspect_ratios = [2/1., 8/1., 16/1.]
        self.anchor_wh = self._get_anchor_wh()

    def _get_anchor_wh(self):
        anchor_wh = []
        for s in self.anchor_areas:
            for ar in self.aspect_ratios:
                h = math.sqrt(s / ar)
                w = ar * h
                anchor_wh.append([w, h])
        return torch.Tensor(anchor_wh).view(9, 2)

    def _get_anchor_boxes(self, input_size):
        fm_size = (input_size / pow(2, 2)).ceil()

        grid_size = input_size / fm_size
        fm_w, fm_h = int(fm_size[0]), int(fm_size[1])
        xy = meshgrid(fm_w, fm_h) + 0.5
        xy = (xy * grid_size).view(fm_h, fm_w, 1, 2).expand(fm_h,
                                                            fm_w, 9, 2)  # (fm_h, fm_w, #anchor, (x, y)
        wh = self.anchor_wh.view(1, 1, 9, 2).expand(fm_h, fm_w, 9, 2)
        box = torch.cat([xy, wh], 3)  # [x, y, w, h]
        return box.view(-1, 4)

    def encode(self, boxes, input_size):
        '''
        Args:
            boxes: tensor [#box, [xmin, ymin, xmax, ymax]]
            input_size: (W, H)
        Returns:
            loc_targets: tensor [#anchor(9) * [confidence, xcenter, ycenter, width, height], FH, FW]
        '''
        fm_size = [math.ceil(i / pow(2, 2)) for i in input_size]
        input_size = torch.Tensor(input_size)
        anchor_boxes = self._get_anchor_boxes(input_size)


        ious = box_iou(xywh2xyxy(anchor_boxes), boxes)
        boxes = xyxy2xywh(boxes)

        max_ious, max_ids = ious.max(1)
        boxes = boxes[max_ids]

        loc_xy = (boxes[:, :2] - anchor_boxes[:, :2]) / anchor_boxes[:, 2:]
        loc_wh = torch.log(boxes[:, 2:] / anchor_boxes[:, 2:])
        loc_targets = torch.cat([loc_xy, loc_wh], 1)

        masks = torch.ones(max_ids.size())
        masks[max_ious < 0.5] = 0
        # masks[(max_ious > 0.3) & (max_ious < 0.7)] = -1

        loc_targets = loc_targets.contiguous().view(fm_size[1], fm_size[0], 9, 4)
        masks = masks.contiguous().view(fm_size[1], fm_size[0], 9, 1)
        return torch.cat((masks, loc_targets), 3).view(fm_size[1], fm_size[0], 9 * 5).permute(2, 0, 1)

    def decode(self, loc_preds, input_size, conf_thres=0.5, nms_thres=0.5):
        input_size = torch.Tensor(input_size)
        anchor_boxes = self._get_anchor_boxes(input_size)
        
        loc_preds = loc_preds.permute(1, 2, 0).contiguous().view(-1, 5)

        conf_preds = loc_preds[:, 0]
        loc_xy = F.sigmoid(loc_preds[:, 1:3])
        loc_wh = loc_preds[:, 3:]

        xy = loc_xy * anchor_boxes[:, 2:] + anchor_boxes[:, :2]
        wh = loc_wh.exp() * anchor_boxes[:, 2:]
        boxes = torch.cat([xy - wh / 2, xy + wh / 2], 1)

        score = conf_preds.sigmoid()
        ids = score > conf_thres
        ids = ids.nonzero().squeeze()
        if len(ids) == 0:
            return None
        keep = box_nms(boxes[ids], score[ids], thres=nms_thres)
        return boxes[ids][keep]

import os
import json

import torch
import torch.utils.data as data
from PIL import Image


class ListDataset(data.Dataset):
    def __init__(self, root, transform):
        self.root = root
        self.transform = transform
        self.encoder = DataEncoder()
        self.input_size = [300, 200] # W, H

        self.fnames = []
        self.boxes = []

        i = 0
        while True:
            f = os.path.join(self.root, f'{i}.json')
            i += 1
            if not os.path.isfile(f):
                break
            with open(f, 'r') as fp:
                info = json.load(fp)
            self.fnames.append(info['file'])
            box = []
            for b in info['boxes']:
                xmin = float(b['left'])
                ymin = float(b['top'])
                xmax = xmin + float(b['width'])
                ymax = ymin + float(b['height'])
                box.append([xmin, ymin, xmax, ymax])
            self.boxes.append(torch.Tensor(box))

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = Image.open(os.path.join(self.root, fname))
        if img.mode != 'RGB':
            img = img.convert('RGB')

        boxes = self.boxes[idx].clone()
        img = self.transform(img)
        return img, boxes

    def collate_fn(self, batch):
        imgs = [x[0] for x in batch]
        boxes = [x[1] for x in batch]

        w, h = self.input_size
        n_imgs = len(imgs)
        inputs = torch.zeros(n_imgs, 3, h, w)
        loc_targets = []
        for i in range(n_imgs):
            inputs[i] = imgs[i]
            loc_target = self.encoder.encode(boxes[i], self.input_size)
            loc_targets.append(loc_target)

        return inputs, torch.stack(loc_targets)

    def __len__(self):
        return len(self.fnames)


In [0]:
import os

import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt

%matplotlib inline

print("Encode and Decode test")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])
testset = ListDataset(root='test', transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=0, collate_fn=testset.collate_fn)

img, loc_targets = next(iter(testloader))
boxes = testset.encoder.decode(loc_targets[0], [300, 200])

img = Image.open("test/0.png")
draw = ImageDraw.Draw(img)
for box in boxes:
    draw.rectangle(list(box), outline='red')
img.show()
plt.imshow(img)
print("sigmoid が xy にかかるからおかしくなるけど、大体あっていることは確認できるはず")

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable


class Bottleneck(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=2, dilation=2, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, 2 * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(2 * planes)

        self.downsample = nn.Sequential()
        if stride != 1 or in_planes != 2 * planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, 2 * planes, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(2 * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.downsample(x)
        out = F.relu(out)
        return out


class FeatureExtractNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        self.layer1 = self._make_layer(64, 1, stride=1)
        self.layer2 = self._make_layer(128, 1, stride=2)
        self.layer3 = self._make_layer(256, 1, stride=2)
        self.layer4 = self._make_layer(512, 1, stride=2)

        self.latlayer1 = nn.Conv2d(2 * 512, 256, kernel_size=1, stride=1, padding=0)
        self.latlayer2 = nn.Conv2d(1 * 512, 256, kernel_size=1, stride=1, padding=0)
        self.latlayer3 = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0)

        self.toplayer1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
        self.toplayer2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)

    def _make_layer(self, planes, block, stride):
        strides = [stride] + [1] * (block - 1)
        layers = []
        for stride in strides:
            layers.append(Bottleneck(self.in_planes, planes, stride))
            self.in_planes = planes * 2
        return nn.Sequential(*layers)

    def _upsample_add(self, x, y):
        _, _, H, W = y.size()
        return F.upsample(x, size=(H, W), mode='bilinear') + y

    def forward(self, x):
        c1 = F.relu(self.bn1(self.conv1(x)))
        c1 = F.max_pool2d(c1, kernel_size=3, stride=2, padding=1)
        c2 = self.layer1(c1)
        c3 = self.layer2(c2)
        c4 = self.layer3(c3)
        c5 = self.layer4(c4)
        p5 = self.latlayer1(c5)
        p4 = self._upsample_add(p5, self.latlayer2(c4))
        p4 = self.toplayer1(p4)
        p3 = self._upsample_add(p4, self.latlayer3(c3))
        p3 = self.toplayer2(p3)
        return p3

In [0]:
class PositionPredictionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(128, 9 * 5, kernel_size=1, stride=1)
        
    def forward(self, x):
        return self.conv(x)

In [0]:
def weighted_binary_cross_entropy(output, target, weight):
    loss = weight[1] * (target * torch.log(output)) + \
           weight[0] * ((1 - target) * torch.log(1 - output))

    return torch.neg(torch.mean(loss))

def loss_positions(loc_preds, loc_targets):
    '''
    Args:
        loc_preds: tensor [#batch, (#anchor * [p, x, y, w, h]), h, w]
        loc_targets: tensor [#batch, (#anchor * [p, x, y, w, h]), h, w]
    '''

    loc_preds = loc_preds.permute(0, 2, 3, 1).contiguous().view(-1, 5)
    loc_targets = loc_targets.permute(0, 2, 3, 1).contiguous().view(-1, 5)
    
    conf_preds = loc_preds[..., 0]
    conf_targets = loc_targets[..., 0]
    mask = conf_targets > 0.9
    
    xy_preds = F.sigmoid(loc_preds[..., 1:3])
    wh_preds = loc_preds[..., 3:5]
    loc_preds = torch.cat([xy_preds, wh_preds], 1)
    loc_targets = loc_targets[..., 1:]

    loss_conf = weighted_binary_cross_entropy(F.sigmoid(conf_preds), conf_targets, weight=torch.Tensor([1, 9]))
    loss_loc = (F.mse_loss(loc_preds, loc_targets, size_average=False, reduce=False) * mask.unsqueeze(1).float()).sum() / mask.data.sum()
    return loss_conf + 2 * loss_loc

In [0]:
!pip install -U -q PyDrive

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# PyDrive reference:
# https://googledrive.github.io/PyDrive/docs/build/html/index.html



In [0]:
import os

import torch
import torch.optim as optim
from torch.autograd import Variable

import torchvision.transforms as transforms
from itertools import chain
import gc

gc.collect()

CUDA = True


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])
trainset = ListDataset(root='train', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=0, collate_fn=trainset.collate_fn)
testset = ListDataset(root='test', transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False, num_workers=0, collate_fn=testset.collate_fn)

fnet = FeatureExtractNet()
pnet = PositionPredictionHead()

if CUDA:
    fnet.cuda()
    pnet.cuda()

optimizer = optim.Adam(chain(fnet.parameters(), pnet.parameters()))


def train(epoch):
    print(f'\nEpoch: {epoch}')
    fnet.train()
    pnet.train()
    train_loss = 0
    for batch_idx, (inputs, loc_targets) in enumerate(trainloader):
        if CUDA:
            inputs = inputs.cuda()
            loc_targets = loc_targets.cuda()
            
        inputs = Variable(inputs)
        loc_targets = Variable(loc_targets)

        optimizer.zero_grad()
        x = fnet(inputs)
        loc_preds = pnet(x)
        loss = loss_positions(loc_preds, loc_targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.data[0]
    print(f'train_loss: average: {train_loss / (batch_idx + 1)}')

def test(epoch):
    print('\nTest')
    fnet.eval()
    pnet.eval()
    test_loss = 0
    for batch_idx, (inputs, loc_targets) in enumerate(testloader):
        if CUDA:
            inputs = inputs.cuda()
            loc_targets = loc_targets.cuda()
            
        inputs = Variable(inputs)
        loc_targets = Variable(loc_targets)

        x = fnet(inputs)
        loc_preds = pnet(x)
        loss = loss_positions(loc_preds, loc_targets)
        test_loss += loss.data[0]
    print('test_loss: average: %.3f' % (test_loss/(batch_idx+1)))

start_epoch = 0
for epoch in range(start_epoch, start_epoch+50):
    train(epoch)
    test(epoch)
    torch.save(fnet.state_dict(), 'fnet.pth')
    torch.save(pnet.state_dict(), 'pnet.pth')
    
    if (epoch + 1) % 10 == 0:
        for filename in ['fnet', 'pnet']:
            uploaded = drive.CreateFile({'title': '%s%d.pth' % (filename, epoch + 1)})
            uploaded.SetContentFile('%s.pth' % filename)
            uploaded.Upload()

In [0]:
import os

import numpy as np
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt

%matplotlib inline

def test(epoch):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
    ])
    testset = ListDataset(root='test', transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=True, num_workers=8, collate_fn=testset.collate_fn)


    n = np.random.randint(0, 100)
    img = Image.open('test/{}.png'.format(n))
    x = transform(img).unsqueeze(0)
    x = Variable(x, volatile=True)
    fnet = FeatureExtractNet()
    pnet = PositionPredictionHead()
    fnet.load_state_dict(torch.load(f'fnet{epoch}.pth'))
    pnet.load_state_dict(torch.load(f'pnet{epoch}.pth'))

    loc_preds = pnet(fnet(x))
    boxes = testset.encoder.decode(loc_preds.data.squeeze(0), [300, 200], conf_thres=0.5, nms_thres=0.3)
    print(boxes is not None and boxes.size())
    print(f'test/{n}.png')
    if boxes is not None:
        draw = ImageDraw.Draw(img)
        for box in boxes:
            draw.rectangle(list(box), outline='red')
        plt.imshow(img)


test(50)