In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import models
from torchvision.models.vgg import VGG
from sklearn.metrics import confusion_matrix
import pandas as pd
import scipy.misc
import random
import sys

if '/opt/ros/kinetic/lib/python2.7/dist-packages' in sys.path:
    sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
import cv2

from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

from matplotlib import pyplot as plt
import numpy as np
import time
import os

## Define FCN16s model for deconvolution layers

In [2]:
class FCN16s(nn.Module):

    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        self.relu    = nn.ReLU(inplace = True)
        self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn1     = nn.BatchNorm2d(512)
        self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn2     = nn.BatchNorm2d(256)
        self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn3     = nn.BatchNorm2d(128)
        self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn4     = nn.BatchNorm2d(64)
        self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn5     = nn.BatchNorm2d(32)
        self.classifier = nn.Conv2d(32, n_class, kernel_size=1)

    def forward(self, x):
        output = self.pretrained_net(x)
        x5 = output['x5']  # size=(N, 512, x.H/32, x.W/32)
        x4 = output['x4']  # size=(N, 512, x.H/16, x.W/16)

        score = self.relu(self.deconv1(x5))               # size=(N, 512, x.H/16, x.W/16)
        score = self.bn1(score + x4)                      # element-wise add, size=(N, 512, x.H/16, x.W/16)
        score = self.bn2(self.relu(self.deconv2(score)))  # size=(N, 256, x.H/8, x.W/8)
        score = self.bn3(self.relu(self.deconv3(score)))  # size=(N, 128, x.H/4, x.W/4)
        score = self.bn4(self.relu(self.deconv4(score)))  # size=(N, 64, x.H/2, x.W/2)
        score = self.bn5(self.relu(self.deconv5(score)))  # size=(N, 32, x.H, x.W)
        score = self.classifier(score)                    # size=(N, n_class, x.H/1, x.W/1)
        
        return score

## Define VGG16 for convolution layers

In [3]:
class VGGNet(VGG):
    def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False):
        super().__init__(make_layers(cfg[model]))
        self.ranges = ranges[model]

        if pretrained:
            exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model)

        if not requires_grad:
            for param in super().parameters():
                param.requires_grad = False

        if remove_fc:  # delete redundant fully-connected layer params, can save memory
            del self.classifier

        if show_params:
            for name, param in self.named_parameters():
                print(name, param.size())

    def forward(self, x):
        output = {}

        # get the output of each maxpooling layer (5 maxpool in VGG net)
        for idx in range(len(self.ranges)):
            for layer in range(self.ranges[idx][0], self.ranges[idx][1]):      
                x = self.features[layer](x)
            output["x%d"%(idx+1)] = x
        return output

In [4]:
ranges = {
    'vgg11': ((0, 3), (3, 6),  (6, 11),  (11, 16), (16, 21)),
    'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),
    'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),
    'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))
}

# cropped version from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
cfg = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

## Define training parameter

In [5]:
batch_size = 6
epochs     = 50  #500
lr         = 1e-4
momentum   = 0
w_decay    = 1e-5
step_size  = 50
gamma      = 0.5
model_use  = "text_spotting" # "products_20" "mini_competition"
n_class = 16

## Define path, directory trainning environment

In [6]:
# get data
#data_dir  = os.path.join("data", model_use)
data_dir = "data"
if not os.path.exists(data_dir):
    print("Data not found!")
# create dir for model
model_dir = os.path.join("models", model_use)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
# create dir for score
score_dir = os.path.join("scores", model_use)
if not os.path.exists(score_dir):
    os.makedirs(score_dir)

use_gpu = torch.cuda.is_available()
num_gpu = list(range(torch.cuda.device_count()))

vgg_model = VGGNet(requires_grad=True, remove_fc=True)
fcn_model = FCN16s(pretrained_net=vgg_model, n_class=n_class)
#use_gpu = False
if use_gpu:
    ts = time.time()
    vgg_model = vgg_model.cuda()
    fcn_model = fcn_model.cuda()
    fcn_model = nn.DataParallel(fcn_model, device_ids=num_gpu)
    print("Finish cuda loading, time elapsed {}".format(time.time() - ts))

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.torch/models/vgg16-397923af.pth
100%|██████████| 553433881/553433881 [00:48<00:00, 11355035.96it/s]


Finish cuda loading, time elapsed 8.682100772857666


## Visualize model

In [7]:
print(fcn_model)
params = list(fcn_model.parameters())

DataParallel(
  (module): FCN16s(
    (pretrained_net): VGGNet(
      (features): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  

Dataset class
-------------

``torch.utils.data.Dataset`` is an abstract class representing a
dataset.
Your custom dataset should inherit ``Dataset`` and override the following
methods:

-  ``__len__`` so that ``len(dataset)`` returns the size of the dataset.
-  ``__getitem__`` to support the indexing such that ``dataset[i]`` can
   be used to get $i$\ th sample

Let's create a dataset class for our face landmarks dataset. We will
read the csv in ``__init__`` but leave the reading of images to
``__getitem__``. This is memory efficient because all the images are not
stored in the memory at once but read as required.



In [13]:
means     = np.array([103.939, 116.779, 123.68]) / 255. # mean of three channels in the order of BGR
h, w      = 480, 640
val_h     = h
val_w     = w
class product_dataset(Dataset):

    def __init__(self, csv_file, phase, n_class=n_class, flip_rate=0.):
        self.data      = pd.read_csv(csv_file)
        self.means     = means
        self.n_class   = n_class
        self.flip_rate = flip_rate
        if phase == 'train':
            self.flip_rate = 0.5

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name   = self.data.iloc[idx, 0]
        img        = cv2.imread(img_name,cv2.IMREAD_UNCHANGED)
        # img        = cv2.imread(os.path.join(data_dir, img_name),cv2.IMREAD_UNCHANGED)
        label_name = self.data.iloc[idx, 1]
        # label      = cv2.imread(os.path.join(data_dir, label_name), cv2.IMREAD_GRAYSCALE)
        label      = cv2.imread(label_name, cv2.IMREAD_GRAYSCALE)
        origin_img = img

        if random.random() < self.flip_rate:
            img   = np.fliplr(img)
            label = np.fliplr(label)

        # reduce mean
        img = img[:, :, ::-1]  # switch to BGR
        
        img = np.transpose(img, (2, 0, 1)) / 255.
        img[0] -= self.means[0]
        img[1] -= self.means[1]
        img[2] -= self.means[2]

        # convert to tensor
        img = torch.from_numpy(img.copy()).float()
        label = torch.from_numpy(label.copy()).long()

        # create one-hot encoding
        h, w = label.size()
        target = torch.zeros(self.n_class, h, w)
        #print(np.unique(label))
        for i in range(self.n_class):
            target[i][label == i] = 1

        sample = {'X': img, 'Y': target, 'l': label, 'origin': origin_img}

        return sample

## Define dataloader and optimizer

In [14]:
# initial dataloader for trainning and validation
# train_file = os.path.join(data_dir, "train.csv")
# val_file   = os.path.join(data_dir, "val.csv")
train_file = "train.csv"
val_file = "val.csv"
train_data = product_dataset(csv_file = train_file, phase = 'train')
val_data   = product_dataset(csv_file = val_file, phase = 'val', flip_rate = 0)
dataloader = DataLoader(train_data, batch_size = batch_size, shuffle=True, num_workers = 0)
val_loader = DataLoader(val_data, batch_size = 4, num_workers = 0)

dataiter = iter(val_loader)

# define loss function
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.RMSprop(fcn_model.parameters(), lr = lr, momentum = momentum, weight_decay = w_decay)
# decay LR by a factor of 0.5 every step_size = 50 epochs
scheduler = lr_scheduler.StepLR(optimizer, step_size = step_size, gamma = gamma)  

## Train

In [11]:
def train():
    for epoch in range(epochs):
        fcn_model.train()
        scheduler.step()
        configs    = "FCNs_{}_batch{}_epoch{}_RMSprop_lr{}"\
            .format(model_use, batch_size, epoch, lr)
        model_path = os.path.join(model_dir, configs)
        
        ts = time.time()
        for iter, batch in enumerate(dataloader):
            optimizer.zero_grad()

            if use_gpu:
                inputs = Variable(batch['X'].cuda())
                labels = Variable(batch['Y'].cuda())
            else:
                inputs, labels = Variable(batch['X']), Variable(batch['Y'])

            outputs = fcn_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if iter % 10 == 0:
                print("epoch{}, iter{}, loss: {}".format(epoch+1, iter, loss.data[0]))
        
        print("Finish epoch {}, time elapsed {}".format(epoch, time.time() - ts))
        torch.save(fcn_model.state_dict(),model_path + '.pkl')

        val(epoch)

## Evaluation

In [10]:
def val(epoch):
    fcn_model.eval()
    TP = np.zeros(n_class-1)
    FN = np.zeros(n_class-1)
    FP = np.zeros(n_class-1)
    total_ious = []
    pixel_accs = []
    for iter, batch in enumerate(val_loader):
        if use_gpu:
            inputs = Variable(batch['X'].cuda())
        else:
            inputs = Variable(batch['X'])

        output = fcn_model(inputs)
        output = output.data.cpu().numpy()

        N, _, h, w = output.shape
        pred = output.transpose(0, 2, 3, 1).reshape(-1, n_class).argmax(axis=1).reshape(N, h, w)

        target = batch['l'].cpu().numpy().reshape(N, h, w)
        for p, t in zip(pred, target):
            pixel_accs.append(pixel_acc(p, t))
            _TP, _FN, _FP =  analysis(p, t, h, w)
            TP += _TP[1:n_class]
            FN += _FN[1:n_class]
            FP += _FP[1:n_class]
            
    recall = TP / (TP + FN)
    precision = TP / (TP + FP)
    ious = TP / (TP + FN + FP)
    fscore = 2*TP / (2*TP + FN + FP)
    total_ious = np.array(total_ious).T  # n_class * val_len
    pixel_accs = np.array(pixel_accs).mean()
    
    print("epoch{}, pix_acc: {}, meanIoU: {}, IoUs: {}, recall: {}, precision: {}, fscore: {}"\
          .format(epoch, pixel_accs, np.nanmean(ious), ious, recall, precision, fscore))
    
    f1 = open(score_dir + "/cls_acc_log.txt","a+")
    f1.write('epoch:'+ str(epoch) + ', pix_acc: ' + str(pixel_accs) + '\n' )
    f2 = open(score_dir + "/cls_iou_log.txt","a+")
    f2.write('epoch:'+ str(epoch) + ', class ious: ' + str(ious) + '\n' )
    f3 = open(score_dir + "/mean_iou_log.txt","a+")
    f3.write('epoch:'+ str(epoch) + ', mean IoU: ' + str(np.nanmean(ious)) + '\n' ) 
    f4 = open(score_dir + "/recall_log.txt","a+")
    f4.write('epoch:'+ str(epoch) + ', class recall: ' + str(recall) + '\n' )
    f5 = open(score_dir + "/precision_log.txt","a+")
    f5.write('epoch:'+ str(epoch) + ', class precision: ' + str(precision) + '\n' )    
    f6 = open(score_dir + "/fscore_log.txt","a+")
    f6.write('epoch:'+ str(epoch) + ', class fscore: ' + str(fscore) + '\n' )  
    

def analysis(pred, target, h, w):
    # TP, FN, FP, TN
    TP = np.zeros(n_class)
    FN = np.zeros(n_class)
    FP = np.zeros(n_class)

    target = target.reshape(h * w)
    pred = pred.reshape(h * w)

    con_matrix = confusion_matrix(target, pred,labels = np.arange(0,n_class,1))
    con_matrix[0][0] = 0
    for i in range(0, n_class):
        for j in range(0, n_class):
            if i == j:
                TP[i] += con_matrix[i][j]
            if i != j:
                FP[j] += con_matrix[i][j]
                FN[i] += con_matrix[i][j]
    return TP, FN, FP
                
def pixel_acc(pred, target):
    correct = (pred == target).sum()
    total   = (target == target).sum()
    return correct / total

In [12]:
train()



epoch1, iter0, loss: 0.7167602181434631
epoch1, iter10, loss: 0.7020950317382812
epoch1, iter20, loss: 0.6988784670829773
epoch1, iter30, loss: 0.6961067914962769
epoch1, iter40, loss: 0.6937254071235657
epoch1, iter50, loss: 0.6912814974784851
epoch1, iter60, loss: 0.6889186501502991
epoch1, iter70, loss: 0.6865101456642151
epoch1, iter80, loss: 0.684277355670929
epoch1, iter90, loss: 0.6820200681686401
epoch1, iter100, loss: 0.6793283820152283
epoch1, iter110, loss: 0.6771360039710999
epoch1, iter120, loss: 0.6742465496063232
epoch1, iter130, loss: 0.6721674203872681
epoch1, iter140, loss: 0.669731616973877
epoch1, iter150, loss: 0.6674121022224426
epoch1, iter160, loss: 0.6649880409240723
epoch1, iter170, loss: 0.6626982688903809
epoch1, iter180, loss: 0.6600777506828308
epoch1, iter190, loss: 0.6581557989120483
epoch1, iter200, loss: 0.6555794477462769
epoch1, iter210, loss: 0.6530529856681824
epoch1, iter220, loss: 0.6508887410163879
epoch1, iter230, loss: 0.6481037735939026
epoch



epoch0, pix_acc: 0.9865132446823053, meanIoU: 0.012144899491278487, IoUs: [2.78746566e-02 0.00000000e+00 9.57727330e-03 0.00000000e+00
 0.00000000e+00 0.00000000e+00 5.18086348e-03 6.27967145e-06
 4.90763375e-02 2.00387695e-02 6.94243215e-02 0.00000000e+00
 9.64232641e-04 3.07581105e-05 0.00000000e+00], recall: [5.76575872e-02 0.00000000e+00 2.51613089e-02 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.58620812e-03 6.28824036e-06
 1.54347720e-01 1.09232227e-01 6.41930457e-01 0.00000000e+00
 1.01745230e-03 3.08401631e-05 0.00000000e+00], precision: [0.05120037 0.         0.01522759 0.         0.         0.
 0.02370478 0.00458716 0.06712517 0.02395298 0.07222108        nan
 0.01810051 0.01142857        nan], fscore: [5.42374626e-02 0.00000000e+00 1.89728385e-02 0.00000000e+00
 0.00000000e+00 0.00000000e+00 1.03083210e-02 1.25592640e-05
 9.35610418e-02 3.92902115e-02 1.29834940e-01 0.00000000e+00
 1.92660758e-03 6.15143290e-05 0.00000000e+00]
epoch2, iter0, loss: 0.5503596067428589
epoch

epoch4, iter130, loss: 0.16940419375896454
epoch4, iter140, loss: 0.16711577773094177
epoch4, iter150, loss: 0.16524916887283325
epoch4, iter160, loss: 0.16273225843906403
epoch4, iter170, loss: 0.16129182279109955
epoch4, iter180, loss: 0.1588410884141922
epoch4, iter190, loss: 0.15713854134082794
epoch4, iter200, loss: 0.15557821094989777
epoch4, iter210, loss: 0.152631476521492
epoch4, iter220, loss: 0.15085956454277039
epoch4, iter230, loss: 0.14890319108963013
epoch4, iter240, loss: 0.14653557538986206
epoch4, iter250, loss: 0.1443842053413391
epoch4, iter260, loss: 0.14332422614097595
epoch4, iter270, loss: 0.14126484096050262
epoch4, iter280, loss: 0.13960671424865723
epoch4, iter290, loss: 0.1370023787021637
epoch4, iter300, loss: 0.13535211980342865
epoch4, iter310, loss: 0.1334863156080246
epoch4, iter320, loss: 0.13186125457286835
epoch4, iter330, loss: 0.1300349235534668
epoch4, iter340, loss: 0.12891532480716705
epoch4, iter350, loss: 0.12683629989624023
epoch4, iter360, l

epoch6, iter410, loss: 0.019461149349808693
epoch6, iter420, loss: 0.018967486917972565
epoch6, iter430, loss: 0.018522363156080246
epoch6, iter440, loss: 0.018827110528945923
epoch6, iter450, loss: 0.018631720915436745
epoch6, iter460, loss: 0.01879206672310829
epoch6, iter470, loss: 0.01801261305809021
epoch6, iter480, loss: 0.017434246838092804
epoch6, iter490, loss: 0.018009226769208908
epoch6, iter500, loss: 0.01711990311741829
epoch6, iter510, loss: 0.01684899814426899
epoch6, iter520, loss: 0.01579112373292446
epoch6, iter530, loss: 0.01591603085398674
epoch6, iter540, loss: 0.0163346566259861
epoch6, iter550, loss: 0.015019946731626987
epoch6, iter560, loss: 0.014721651561558247
epoch6, iter570, loss: 0.014234433881938457
Finish epoch 5, time elapsed 473.2727279663086
epoch5, pix_acc: 0.9892696115622265, meanIoU: 0.055916887285863595, IoUs: [2.50509350e-02 0.00000000e+00 0.00000000e+00 7.56372099e-02
 1.72391794e-01 7.54603079e-05 5.67428294e-02 2.08099139e-01
 1.21573157e-05 5

epoch9, iter0, loss: 0.0035058376379311085
epoch9, iter10, loss: 0.003734368132427335
epoch9, iter20, loss: 0.0026997539680451155
epoch9, iter30, loss: 0.0032223269809037447
epoch9, iter40, loss: 0.00355695397593081
epoch9, iter50, loss: 0.0036898490507155657
epoch9, iter60, loss: 0.0042193252593278885
epoch9, iter70, loss: 0.002787773497402668
epoch9, iter80, loss: 0.00315161794424057
epoch9, iter90, loss: 0.003585493192076683
epoch9, iter100, loss: 0.0028173199389129877
epoch9, iter110, loss: 0.0039655957370996475
epoch9, iter120, loss: 0.00400557229295373
epoch9, iter130, loss: 0.0027472914662212133
epoch9, iter140, loss: 0.004112934228032827
epoch9, iter150, loss: 0.005771559197455645
epoch9, iter160, loss: 0.0030579669401049614
epoch9, iter170, loss: 0.0031248468440026045
epoch9, iter180, loss: 0.004022412002086639
epoch9, iter190, loss: 0.0044196536764502525
epoch9, iter200, loss: 0.00411915173754096
epoch9, iter210, loss: 0.0034037926234304905
epoch9, iter220, loss: 0.0042296191

epoch11, iter260, loss: 0.002550797536969185
epoch11, iter270, loss: 0.0033462585415691137
epoch11, iter280, loss: 0.003126739524304867
epoch11, iter290, loss: 0.0019060777267441154
epoch11, iter300, loss: 0.0026237780693918467
epoch11, iter310, loss: 0.0016630932223051786
epoch11, iter320, loss: 0.0026407691184431314
epoch11, iter330, loss: 0.002752267522737384
epoch11, iter340, loss: 0.0026343578938394785
epoch11, iter350, loss: 0.0026494916528463364
epoch11, iter360, loss: 0.002161892596632242
epoch11, iter370, loss: 0.002735109068453312
epoch11, iter380, loss: 0.0026098210364580154
epoch11, iter390, loss: 0.0027811876498162746
epoch11, iter400, loss: 0.0018773498013615608
epoch11, iter410, loss: 0.0023577904794365168
epoch11, iter420, loss: 0.0021903947927057743
epoch11, iter430, loss: 0.0031032925471663475
epoch11, iter440, loss: 0.002285656286403537
epoch11, iter450, loss: 0.0018244268139824271
epoch11, iter460, loss: 0.003323663491755724
epoch11, iter470, loss: 0.003730676369741

epoch13, iter490, loss: 0.002371035283431411
epoch13, iter500, loss: 0.0031927453819662333
epoch13, iter510, loss: 0.0015784146962687373
epoch13, iter520, loss: 0.0025474289432168007
epoch13, iter530, loss: 0.0015748406294733286
epoch13, iter540, loss: 0.003237919183447957
epoch13, iter550, loss: 0.0019600463565438986
epoch13, iter560, loss: 0.002612767741084099
epoch13, iter570, loss: 0.0028590497095137835
Finish epoch 12, time elapsed 467.3672788143158
epoch12, pix_acc: 0.9924025317147855, meanIoU: 0.21707654673581306, IoUs: [0.35155308 0.         0.00946801 0.06467832 0.23242388 0.35914284
 0.24441296 0.59389282 0.         0.09898895 0.36967943 0.18670735
 0.10842015 0.25095254 0.38582788], recall: [0.41929735 0.         0.00993778 0.11403244 0.48164177 0.50876757
 0.47903687 0.85389902 0.         0.24338777 0.53667599 0.20442196
 0.11622896 0.3585169  0.6454799 ], precision: [0.68512937        nan 0.1668682  0.13001032 0.30995718 0.54979088
 0.332899   0.66106662        nan 0.14299

epoch16, iter10, loss: 0.001744970679283142
epoch16, iter20, loss: 0.002566617913544178
epoch16, iter30, loss: 0.0017364107770845294
epoch16, iter40, loss: 0.0032035557087510824
epoch16, iter50, loss: 0.001551165129058063
epoch16, iter60, loss: 0.003179781837388873
epoch16, iter70, loss: 0.0014138215919956565
epoch16, iter80, loss: 0.003070734441280365
epoch16, iter90, loss: 0.0019069996196776628
epoch16, iter100, loss: 0.0015966162318363786
epoch16, iter110, loss: 0.0016297537367790937
epoch16, iter120, loss: 0.001505227410234511
epoch16, iter130, loss: 0.0016374537954106927
epoch16, iter140, loss: 0.002284959191456437
epoch16, iter150, loss: 0.0026843941304832697
epoch16, iter160, loss: 0.0017559916013851762
epoch16, iter170, loss: 0.0016233667265623808
epoch16, iter180, loss: 0.0021013475488871336
epoch16, iter190, loss: 0.0021452854853123426
epoch16, iter200, loss: 0.0019087835680693388
epoch16, iter210, loss: 0.0018393059726804495
epoch16, iter220, loss: 0.0013805236667394638
epoc

epoch18, iter300, loss: 0.002024620771408081
epoch18, iter310, loss: 0.0012897331034764647
epoch18, iter320, loss: 0.0019501970382407308
epoch18, iter330, loss: 0.001206813263706863
epoch18, iter340, loss: 0.0008011674508452415
epoch18, iter350, loss: 0.0014674653066322207
epoch18, iter360, loss: 0.001604455872438848
epoch18, iter370, loss: 0.0016924869269132614
epoch18, iter380, loss: 0.0015500375302508473
epoch18, iter390, loss: 0.002281557535752654
epoch18, iter400, loss: 0.001656858716160059
epoch18, iter410, loss: 0.002015091013163328
epoch18, iter420, loss: 0.001807622262276709
epoch18, iter430, loss: 0.001823708415031433
epoch18, iter440, loss: 0.002333893906325102
epoch18, iter450, loss: 0.002415452850982547
epoch18, iter460, loss: 0.001679245033301413
epoch18, iter470, loss: 0.0021191248670220375
epoch18, iter480, loss: 0.002481533447280526
epoch18, iter490, loss: 0.0021062714513391256
epoch18, iter500, loss: 0.0022660670801997185
epoch18, iter510, loss: 0.0013539048377424479


Finish epoch 19, time elapsed 468.63059306144714
epoch19, pix_acc: 0.9955853308863737, meanIoU: 0.47514988545774534, IoUs: [0.30602626 0.         0.00341043 0.30916485 0.44863205 0.86939285
 0.49035799 0.88844079 0.         0.17415595 0.4880181  0.81805771
 0.68739804 0.79355174 0.85064154], recall: [0.31109461 0.         0.00342781 0.50359751 0.77561441 0.92695135
 0.92036028 0.92675458 0.         0.30285366 0.9063258  0.82804513
 0.85742501 0.85983146 0.87256512], precision: [0.94945375        nan 0.40202703 0.4446801  0.51554425 0.93333841
 0.51208642 0.95553596        nan 0.29069301 0.51394108 0.98547022
 0.77610922 0.91146182 0.97131036], fscore: [0.46863723 0.         0.00679767 0.47230851 0.61938716 0.93013392
 0.65804054 0.94092523 0.         0.29664876 0.65593032 0.89992491
 0.8147432  0.88489417 0.91929369]
epoch21, iter0, loss: 0.0011382984230294824
epoch21, iter10, loss: 0.0011797632323578
epoch21, iter20, loss: 0.0011459790403023362
epoch21, iter30, loss: 0.002203752286732

epoch23, iter100, loss: 0.001836395706050098
epoch23, iter110, loss: 0.0012695633340626955
epoch23, iter120, loss: 0.0026613676454871893
epoch23, iter130, loss: 0.0012498284922912717
epoch23, iter140, loss: 0.0006894238176755607
epoch23, iter150, loss: 0.0009167907410301268
epoch23, iter160, loss: 0.0009551002294756472
epoch23, iter170, loss: 0.0013220927212387323
epoch23, iter180, loss: 0.0012842569267377257
epoch23, iter190, loss: 0.000758138601668179
epoch23, iter200, loss: 0.0026310463435947895
epoch23, iter210, loss: 0.0016096682520583272
epoch23, iter220, loss: 0.0004962706589139998
epoch23, iter230, loss: 0.0025482408236712217
epoch23, iter240, loss: 0.0012959461892023683
epoch23, iter250, loss: 0.0011486379662528634
epoch23, iter260, loss: 0.002025189809501171
epoch23, iter270, loss: 0.002440507523715496
epoch23, iter280, loss: 0.0013122690143063664
epoch23, iter290, loss: 0.0006925793131813407
epoch23, iter300, loss: 0.0019583026878535748
epoch23, iter310, loss: 0.000784920121

epoch25, iter350, loss: 0.0005066283047199249
epoch25, iter360, loss: 0.0016960370121523738
epoch25, iter370, loss: 0.0026198062114417553
epoch25, iter380, loss: 0.0007596356444992125
epoch25, iter390, loss: 0.0011828355491161346
epoch25, iter400, loss: 0.000757817062549293
epoch25, iter410, loss: 0.0016217428492382169
epoch25, iter420, loss: 0.0012596173910424113
epoch25, iter430, loss: 0.001028581173159182
epoch25, iter440, loss: 0.0005943814176134765
epoch25, iter450, loss: 0.0009483222966082394
epoch25, iter460, loss: 0.002448661020025611
epoch25, iter470, loss: 0.0015826633898541331
epoch25, iter480, loss: 0.0013442636700347066
epoch25, iter490, loss: 0.0013323132880032063
epoch25, iter500, loss: 0.00044615697697736323
epoch25, iter510, loss: 0.0008266856311820447
epoch25, iter520, loss: 0.003095319028943777
epoch25, iter530, loss: 0.0013653092319145799
epoch25, iter540, loss: 0.002072584116831422
epoch25, iter550, loss: 0.0012891015503555536
epoch25, iter560, loss: 0.000548796029

epoch27, iter570, loss: 0.0011250822572037578
Finish epoch 26, time elapsed 467.90291929244995
epoch26, pix_acc: 0.9921778710766624, meanIoU: 0.25162544382772917, IoUs: [2.87336197e-01 8.22165584e-05 1.85638394e-01 1.99037993e-01
 4.26705182e-01 1.11197669e-02 7.42389729e-03 3.99023776e-01
 0.00000000e+00 3.25661350e-01 3.83888257e-01 7.41538298e-01
 4.74323566e-01 3.10233628e-01 2.23691339e-02], recall: [2.88600571e-01 8.24198467e-05 7.06619426e-01 4.57734646e-01
 4.27104571e-01 1.12648649e-02 1.70173377e-02 6.57988895e-01
 0.00000000e+00 3.83324501e-01 9.19629679e-01 7.49815528e-01
 5.15991956e-01 3.54546226e-01 2.23863592e-02], precision: [0.98498186 0.03225806 0.20114137 0.26045095 0.99781333 0.46331703
 0.01299772 0.50344005 0.         0.68403255 0.39721466 0.98533171
 0.8545177  0.71282417 0.96674584], fscore: [4.46404285e-01 1.64419599e-04 3.13145045e-01 3.31996141e-01
 5.98168686e-01 2.19949551e-02 1.47383784e-02 5.70431730e-01
 0.00000000e+00 4.91319069e-01 5.54796610e-01 8.51

epoch30, iter50, loss: 0.001648113364353776
epoch30, iter60, loss: 0.0006913807592354715
epoch30, iter70, loss: 0.0005476996884681284
epoch30, iter80, loss: 0.0007048517582006752
epoch30, iter90, loss: 0.0011741186026483774
epoch30, iter100, loss: 0.0012828052276745439
epoch30, iter110, loss: 0.0005353896412998438
epoch30, iter120, loss: 0.00031458979356102645
epoch30, iter130, loss: 0.00032261034357361495
epoch30, iter140, loss: 0.0011334101436659694
epoch30, iter150, loss: 0.0006418131524696946
epoch30, iter160, loss: 0.0013984997058287263
epoch30, iter170, loss: 0.00048081131535582244
epoch30, iter180, loss: 0.0009929310763254762
epoch30, iter190, loss: 0.0011356064351275563
epoch30, iter200, loss: 0.0004694392264354974
epoch30, iter210, loss: 0.00029159229598008096
epoch30, iter220, loss: 0.0012242996599525213
epoch30, iter230, loss: 0.000488383520860225
epoch30, iter240, loss: 0.0005313659785315394
epoch30, iter250, loss: 0.0004211149935144931
epoch30, iter260, loss: 0.00041010847

epoch32, iter310, loss: 0.00032525917049497366
epoch32, iter320, loss: 0.0012158729368820786
epoch32, iter330, loss: 0.000822222966235131
epoch32, iter340, loss: 0.00038647576002404094
epoch32, iter350, loss: 0.00039965249015949667
epoch32, iter360, loss: 0.0023628470953553915
epoch32, iter370, loss: 0.000810066529083997
epoch32, iter380, loss: 0.00032449825084768236
epoch32, iter390, loss: 0.0020259530283510685
epoch32, iter400, loss: 0.0010616066865622997
epoch32, iter410, loss: 0.0007784474291838706
epoch32, iter420, loss: 0.0004626554436981678
epoch32, iter430, loss: 0.0005309082334861159
epoch32, iter440, loss: 0.0003464099718257785
epoch32, iter450, loss: 0.0005512433126568794
epoch32, iter460, loss: 0.0017215281259268522
epoch32, iter470, loss: 0.0004036313621327281
epoch32, iter480, loss: 0.00029921173700131476
epoch32, iter490, loss: 0.0004270763020031154
epoch32, iter500, loss: 0.000400099263060838
epoch32, iter510, loss: 0.0003930891689378768
epoch32, iter520, loss: 0.000651

Finish epoch 33, time elapsed 467.3270568847656
epoch33, pix_acc: 0.9951001766868984, meanIoU: 0.48036080621170896, IoUs: [0.54539611 0.01891706 0.50031001 0.49170821 0.61941151 0.61128167
 0.37500881 0.78283309 0.32432432 0.4505301  0.43920747 0.86613538
 0.50364208 0.56773442 0.10897186], recall: [0.55040056 0.0203577  0.86000691 0.69589195 0.95023539 0.66297297
 0.75769874 0.80691958 0.37895966 0.5670858  0.61659218 0.92152891
 0.78625123 0.74023331 0.10971595], precision: [0.98360221 0.21093083 0.54466843 0.62628319 0.64017831 0.88687868
 0.42610883 0.9632699  0.69226667 0.68671606 0.60422567 0.93510312
 0.58353988 0.70898779 0.94141046], fscore: [0.70583342 0.03713169 0.66694218 0.65925522 0.76498346 0.7587521
 0.54546387 0.87818999 0.48979592 0.62119373 0.61034629 0.92826639
 0.66989623 0.72427372 0.19652772]
epoch35, iter0, loss: 0.001981401117518544
epoch35, iter10, loss: 0.0008446598658338189
epoch35, iter20, loss: 0.0004956851480528712
epoch35, iter30, loss: 0.000517227803356

epoch37, iter90, loss: 0.00047994175110943615
epoch37, iter100, loss: 0.0003028297214768827
epoch37, iter110, loss: 0.00031313273939304054
epoch37, iter120, loss: 0.00041159812826663256
epoch37, iter130, loss: 0.0005569664645008743
epoch37, iter140, loss: 0.0003861689765471965
epoch37, iter150, loss: 0.0003674222680274397
epoch37, iter160, loss: 0.00022135741892270744
epoch37, iter170, loss: 0.0007317912532016635
epoch37, iter180, loss: 0.0005170621443539858
epoch37, iter190, loss: 0.001287892460823059
epoch37, iter200, loss: 0.0007128985016606748
epoch37, iter210, loss: 0.0007926722755655646
epoch37, iter220, loss: 0.0009178640902973711
epoch37, iter230, loss: 0.0006499467417597771
epoch37, iter240, loss: 0.0006370653281919658
epoch37, iter250, loss: 0.0005116953980177641
epoch37, iter260, loss: 0.000256402330705896
epoch37, iter270, loss: 0.0009284631814807653
epoch37, iter280, loss: 0.0007871852722018957
epoch37, iter290, loss: 0.0008354434976354241
epoch37, iter300, loss: 0.0003540

epoch39, iter360, loss: 0.0005316830356605351
epoch39, iter370, loss: 0.0002564397582318634
epoch39, iter380, loss: 0.000912246061488986
epoch39, iter390, loss: 0.0002502426505088806
epoch39, iter400, loss: 0.0007195628131739795
epoch39, iter410, loss: 0.000394779781345278
epoch39, iter420, loss: 0.000419734074966982
epoch39, iter430, loss: 0.0006750518805347383
epoch39, iter440, loss: 0.00031218797084875405
epoch39, iter450, loss: 0.0006233045132830739
epoch39, iter460, loss: 0.00037252993206493556
epoch39, iter470, loss: 0.0007517659105360508
epoch39, iter480, loss: 0.0004142926773056388
epoch39, iter490, loss: 0.00043158591142855585
epoch39, iter500, loss: 0.0005779266939498484
epoch39, iter510, loss: 0.00028915810980834067
epoch39, iter520, loss: 0.0006904349429532886
epoch39, iter530, loss: 0.0003956687287427485
epoch39, iter540, loss: 0.0005911141051910818
epoch39, iter550, loss: 0.0006304103299044073
epoch39, iter560, loss: 0.0006646651891060174
epoch39, iter570, loss: 0.0014676

epoch42, iter0, loss: 0.0005192614626139402
epoch42, iter10, loss: 0.000699241238180548
epoch42, iter20, loss: 0.0006182183278724551
epoch42, iter30, loss: 0.0011980768758803606
epoch42, iter40, loss: 0.0003379835980013013
epoch42, iter50, loss: 0.0004704413586296141
epoch42, iter60, loss: 0.0003290085296612233
epoch42, iter70, loss: 0.00022683134011458606
epoch42, iter80, loss: 0.0021529512014240026
epoch42, iter90, loss: 0.0005021213437430561
epoch42, iter100, loss: 0.0003972142585553229
epoch42, iter110, loss: 0.0015654682647436857
epoch42, iter120, loss: 0.0003379923873580992
epoch42, iter130, loss: 0.0006637996411882341
epoch42, iter140, loss: 0.0005168399075046182
epoch42, iter150, loss: 0.0005229018861427903
epoch42, iter160, loss: 0.0006854877574369311
epoch42, iter170, loss: 0.000547312549315393
epoch42, iter180, loss: 0.00036058161640539765
epoch42, iter190, loss: 0.00028478424064815044
epoch42, iter200, loss: 0.0009154640138149261
epoch42, iter210, loss: 0.000751136976759880

epoch44, iter270, loss: 0.00024494034005329013
epoch44, iter280, loss: 0.0004903999506495893
epoch44, iter290, loss: 0.0005583124002441764
epoch44, iter300, loss: 0.0010135836200788617
epoch44, iter310, loss: 0.0006872545345686376
epoch44, iter320, loss: 0.0007236694218590856
epoch44, iter330, loss: 0.00024968895013444126
epoch44, iter340, loss: 0.00024263160594273359
epoch44, iter350, loss: 0.0002225116768386215
epoch44, iter360, loss: 0.001956937601789832
epoch44, iter370, loss: 0.00047400270705111325
epoch44, iter380, loss: 0.0004949181457050145
epoch44, iter390, loss: 0.00022879692551214248
epoch44, iter400, loss: 0.000630201306194067
epoch44, iter410, loss: 0.0003274483315180987
epoch44, iter420, loss: 0.0002623645414132625
epoch44, iter430, loss: 0.00035568472230806947
epoch44, iter440, loss: 0.000584438384976238
epoch44, iter450, loss: 0.0005844512488692999
epoch44, iter460, loss: 0.0007631722837686539
epoch44, iter470, loss: 0.0004918791237287223
epoch44, iter480, loss: 0.00036

epoch46, iter540, loss: 0.0007879445911385119
epoch46, iter550, loss: 0.0003557199379429221
epoch46, iter560, loss: 0.0010820544557645917
epoch46, iter570, loss: 0.002000128384679556
Finish epoch 45, time elapsed 468.1316463947296
epoch45, pix_acc: 0.998472803204287, meanIoU: 0.8458348139941937, IoUs: [0.87147591 0.51187242 0.86377836 0.87024193 0.91669891 0.89668715
 0.85865721 0.8984228  0.7509796  0.86535234 0.8703397  0.93604947
 0.71481413 0.92803119 0.93412108], recall: [0.91880105 0.66005934 0.96769501 0.91358437 0.94404493 0.95275676
 0.98240118 0.95495733 0.87895966 0.98541857 0.94906543 0.96262259
 0.96923703 0.97630704 0.96530861], precision: [0.94419456 0.69512195 0.88942575 0.94830227 0.96936882 0.93841174
 0.87207146 0.93817917 0.83760129 0.87657672 0.91298477 0.97135392
 0.73140813 0.94941331 0.96656937], fscore: [0.93132474 0.67713706 0.92691103 0.93061963 0.95653929 0.94552984
 0.92395435 0.9464939  0.85778224 0.9278165  0.93067553 0.96696855
 0.83369284 0.96267238 0.9

epoch49, iter50, loss: 0.0006602155626751482
epoch49, iter60, loss: 0.0009200692875310779
epoch49, iter70, loss: 0.00019030450494028628
epoch49, iter80, loss: 0.0012625225353986025
epoch49, iter90, loss: 0.0006762693519704044
epoch49, iter100, loss: 0.0001964512630365789
epoch49, iter110, loss: 0.00019570357108023018
epoch49, iter120, loss: 0.0003392819780856371
epoch49, iter130, loss: 0.0002581305743660778
epoch49, iter140, loss: 0.0002632957184687257
epoch49, iter150, loss: 0.0002587687922641635
epoch49, iter160, loss: 0.0002910625480581075
epoch49, iter170, loss: 0.00034566267277114093
epoch49, iter180, loss: 0.00026008946588262916
epoch49, iter190, loss: 0.00027421925915405154
epoch49, iter200, loss: 0.0006897112471051514
epoch49, iter210, loss: 0.001301199896261096
epoch49, iter220, loss: 0.000480084796436131
epoch49, iter230, loss: 0.00036562979221343994
epoch49, iter240, loss: 0.0002798319037538022
epoch49, iter250, loss: 0.00036981605808250606
epoch49, iter260, loss: 0.00036364

## Prediction Result 

In [8]:
def prediction(model_name):
    
    # load pretrain models
    state_dict = torch.load(os.path.join(model_dir, model_name))
    fcn_model.load_state_dict(state_dict)
    
    batch = dataiter.next()
    if use_gpu:
        inputs = Variable(batch['X'].cuda())
    else:
        inputs = Variable(batch['X'])
    img    = batch['origin'] 
    label  = batch['l']
    output = fcn_model(inputs)
    output = output.data.cpu().numpy()

    N, _, h, w = output.shape
    pred = output.transpose(0, 2, 3, 1).reshape(-1, n_class).argmax(axis = 1).reshape(N, h, w)

    # show images
    plt.figure(figsize = (10, 12))
    img = img.numpy()
    for i in range(N):
        img[i] = cv2.cvtColor(img[i], cv2.COLOR_BGR2RGB)
        plt.subplot(N, 3, i*3 + 1)
        plt.title("origin_img")
        plt.imshow(img[i])
        print(np.unique(label[i]))

        plt.subplot(N, 3, i*3 + 2)
        plt.title("label_img")
        plt.imshow(label[i],cmap = "gray",vmin = 0, vmax = n_class - 1)
        print(np.unique(pred[i]))
        plt.subplot(N, 3, i*3 + 3)
        plt.title("prediction")
        plt.imshow(pred[i],cmap = "gray",vmin = 0, vmax = n_class - 1)

    plt.show()

In [None]:
prediction("FCNs_text_spotting_batch6_epoch49_RMSprop_lr0.0001.pkl")