In [1]:
import numpy as np
import cv2
from matplotlib import pyplot as plt
import h5py
import os.path

import image_proc
import gt_io

import torch
import torchvision
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from craft.craft import CRAFT

import time

In [2]:
class SynthCharMapDataset(Dataset):
    """SynthText Dataset + Heatmap + Direction Ground Truths"""

    def __init__(self, gt_path, img_dir, color_flag=1, hard_examples=False, #transform=None,
                 character_map=True, affinity_map=True, direction_map=True, word_map=True,
                 begin=0, cuda=True):
        """
        Args:
            gt_path (string): Path to gt.mat file (GT file)
            img_dir (string): Path to directory of {i}/....jpg (folders of images)

            color_flag {1,0,-1}: Colored (1), Grayscale (0), or Unchanged (-1)
        """
        super(SynthCharMapDataset).__init__()

        # paths
        self.gt_path = gt_path
        self.img_dir = img_dir
        
        # flags
        self.color_flag = color_flag
        self.hard_examples = hard_examples
        self.character_map = character_map
        self.affinity_map = affinity_map
        self.word_map = word_map
        self.direction_map = direction_map
        
        # templates
        if self.character_map:
            self.character_template = image_proc.genCharMapTemplate()
        if self.direction_map:
            self.direction_template = image_proc.genDirectionMapTemplate()

        self.f = h5py.File(gt_path, 'r')
        self.length = len(self.f['imnames'])
        
        self.begin = 0
        if begin > 0:
            self.begin = begin
            self.length -= begin
        
        if cuda:
            self.dtype = torch.cuda.FloatTensor
        else:
            self.dtype = torch.FloatTensor

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # if changing starting index
        idx += self.begin

        f = self.f
        imgname = image_proc.u2ToStr(f[f['imnames'][idx][0]])
        charBBs = f[f['charBB'][idx][0]].value
        wordBBs = f[f['wordBB'][idx][0]].value
        txts    = f[f['txt'][idx][0]]

        imgpath = os.path.join(self.img_dir, imgname)
        synthetic_image = cv2.imread(imgpath, self.color_flag).transpose(2,0,1)# HWC to CHW
        image_shape = synthetic_image.shape[-2:]

        char_map, aff_map = image_proc.genPseudoGT(charBBs, txts, image_shape,
                                                   generate_affinity=self.affinity_map)
        if self.word_map:
            word_map = image_proc.genWordGT(wordBBs, image_shape)
        if self.direction_map:
            cos_map, sin_map  = image_proc.genDirectionGT(charBBs, image_shape, normalize=True,
                                                            template=self.direction_template)

        # combine gts to a single tensor
        gt = None
        if self.character_map:
            gt = char_map[None,...]
        if self.affinity_map:
            affinity_map = aff_map[None,...]
            if gt is None:
                gt = affinity_map
            else:
                gt = np.concatenate((gt, affinity_map))
        if self.word_map:
            word_map = word_map[None,...]
            if gt is None:
                gt = word_map
            else:
                gt = np.concatenate((gt, word_map))
        if self.direction_map:
            dir_map = np.concatenate((cos_map[None,...], sin_map[None,...]))
            if gt is None:
                gt = dir_map
            else:
                gt = np.concatenate((gt, dir_map))

        # get hard examples + corresponding gts
        if self.hard_examples:
            hard_img, hard_gt = image_proc.hard_example_mining(synthetic_image, gt, wordBBs)
            # hard_img: NCHW
            # hard_gt: NCHW -> NHWC
            
            hard_gt = torch.from_numpy(hard_gt).type(self.dtype)
            hard_gt_resized = F.interpolate(hard_gt, scale_factor=0.5).permute(0,2,3,1)
            
            hard_img = torch.from_numpy(hard_img).type(self.dtype) / 255.0

        # resize to match feature map size
        # to match expectations of F.interpolate, we reshape to NCHW
        gt = torch.from_numpy(gt[None,...]).type(self.dtype)
        gt_resized = F.interpolate(gt, scale_factor=0.5)[0].permute(1,2,0)# HWC
        
        #synthetic_image = self.transform(synthetic_image)
        synthetic_image = torch.from_numpy(synthetic_image).type(self.dtype)# CHW
        synthetic_image = synthetic_image / 255.0
        
        if self.hard_examples:
#             print(f"hard_gt_resized.shape = {hard_gt_resized.shape}")
            return synthetic_image, gt_resized, hard_img, hard_gt_resized
        else:
            return synthetic_image, gt_resized

In [3]:
gt_path = "/home/eee198/Downloads/SynthText/gt_v7.3.mat"#"/media/aerjay/Acer/Users/Aerjay/Downloads/SynthText/gt_v7.3.mat"
img_dir = "/home/eee198/Downloads/SynthText/images"#"/media/aerjay/Acer/Users/Aerjay/Downloads/SynthText/SynthText"

begin = 0

# remember requires_grad=True
dataset = SynthCharMapDataset(gt_path, img_dir, affinity_map=False, direction_map=True, word_map=True,
                              begin=begin, cuda=True, hard_examples=True)
train, test = torch.utils.data.random_split(dataset, [800000,58750])
dataloader = DataLoader(train, batch_size=4, shuffle=True, collate_fn=image_proc.collate)

In [4]:
for img,gt,hard_img,hard_gt in dataloader:
    print(f"img.shape = {img.shape}")
    print(f"gt.shape = {gt.shape}")
    print(f"hard_img.shape = {hard_img.shape}")
    print(f"hard_gt.shape = {hard_gt.shape}")
    break

img.shape = torch.Size([4, 3, 600, 600])
gt.shape = torch.Size([4, 300, 300, 4])
hard_img.shape = torch.Size([16, 3, 300, 202])
hard_gt.shape = torch.Size([16, 150, 101, 4])


In [18]:
# input: NCHW
# output: NHWC
model = CRAFT(pretrained=True, num_class=4).cuda()

# weight_path = "/home/eee198/Downloads/SynthText/weights/w_7683_interrupt.pth"
# model.load_state_dict(torch.load(weight_path))
# model.eval()

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)

In [None]:
T_save = 10000
T = 100
epochs = 1
start = time.time()

for epoch in range(epochs):
    running_loss = 0.0

    while(True):
        try:
            for i, (img, target, hard_img, hard_target) in enumerate(dataloader):
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                output, _ = model(img.cuda())
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()

                # print statistics
                running_loss += loss.item()
                if i % T == T-1:    # print every 2000 mini-batches
                    print('[%d, %5d] loss: %f' % (epoch + 1, i + 1, running_loss/T))
                    running_loss = 0.0

                ## Hard Example Training
                #for hard_img, hard_target in zip(hard_imgs, hard_targets):
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                output, _ = model(hard_img)
                loss = criterion(output, hard_target)
                loss.backward()
                optimizer.step()

                # print statistics
                running_loss += loss.item()
                if i % T == T-1:    # print every 2000 mini-batches
                    print('\t[%d, %5d] loss: %f' % (epoch + 1, i + 1, running_loss/T))
                    running_loss = 0.0



                if i % T_save == T_save-1:
                    print(f"\nsaving at {i}-th batch'\n")
                    torch.save(model.state_dict(), f"/home/eee198/Downloads/SynthText/weights/w_{i}_p2.pth")
                    end = time.time()
                    print(f"\nElapsed time: {end-start}")
            break
        except MemoryError:
            print("Memory Error")
            continue
        except AttributeError:
            print("Attribute Error")
        except KeyboardInterrupt:
            print(f"\nsaving at {i}-th batch'\n")
            torch.save(model.state_dict(), f"/home/eee198/Downloads/SynthText/weights/w_{i}_interrupt.pth")
            end = time.time()
            print(f"\nElapsed time: {end-start}")
            
            break
        except:
            continue


print("Finished training.")

end = time.time()
print(f"\nTotal elapsed time: {end-start}")
#24700

[1,   100] loss: 0.023371
	[1,   100] loss: 0.000093
[1,   200] loss: 0.024285
	[1,   200] loss: 0.000126
[1,   300] loss: 0.024499
	[1,   300] loss: 0.000099
[1,   400] loss: 0.023533
	[1,   400] loss: 0.000098
[1,   500] loss: 0.023204
	[1,   500] loss: 0.000099
[1,   600] loss: 0.022326
	[1,   600] loss: 0.000122
[1,   700] loss: 0.023648
	[1,   700] loss: 0.000145
[1,   800] loss: 0.024264
	[1,   800] loss: 0.000128
[1,   900] loss: 0.023030
	[1,   900] loss: 0.000123
[1,  1000] loss: 0.024602
	[1,  1000] loss: 0.000229
[1,  1100] loss: 0.023116
	[1,  1100] loss: 0.000133
[1,  1200] loss: 0.022853
	[1,  1200] loss: 0.000100
[1,  1300] loss: 0.022840
	[1,  1300] loss: 0.000119
[1,  1400] loss: 0.024464
	[1,  1400] loss: 0.000116
[1,  1500] loss: 0.023738
	[1,  1500] loss: 0.000199
[1,  1600] loss: 0.022689
	[1,  1600] loss: 0.000179
[1,  1700] loss: 0.023016
	[1,  1700] loss: 0.000323
[1,  1800] loss: 0.025043
	[1,  1800] loss: 0.000073
[1,  1900] loss: 0.023208
	[1,  1900] loss: 0.

In [9]:
torch.save(model.state_dict(), f"/home/eee198/Downloads/SynthText/weights/w_{i}_AttributeError_NoneType.shape.pth")

In [7]:
model.load_state_dict(torch.load(f"/home/eee198/Downloads/SynthText/weights/w_74229_AttributeError_NoneType.shape.pth"))
model.eval()

CRAFT(
  (basenet): vgg16_bn(
    (slice1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace)
      (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (slice2): Sequential(
      (12): ReLU(inplace)
      (13): MaxPool2d(kernel_size=2, stride=2, 