In [1]:
import torch
from data import LoadData
from model import VGG16
from torch.utils import data
from utils import *
from loss import EdgeSaliencyLoss
import os

torch.__version__

'1.6.0+cu101'

In [2]:
if torch.cuda.is_available():
    device = torch.device(device='cuda')
else:
    device = torch.device(device='cpu')
device

device(type='cuda')

In [3]:
path_image = "./DUTS/DUTS-TR/DUTS-TR-Image/"
path_mask = "./DUTS/DUTS-TR/DUTS-TR-Mask/"

In [4]:
len(os.listdir(path_image))

10553

In [5]:
batch_size = 4
learning_rate = 1e-3
target_size = 256
epochs = 1

In [6]:
total_batch = len(os.listdir(path_image)) // batch_size
total_batch

2638

In [7]:
data_loader = data.DataLoader(LoadData(path_image, path_mask, target_size),
                            batch_size=batch_size,
                            shuffle=True)

In [8]:
model = VGG16()

In [9]:
model.load_state_dict(torch.load("./model/model_1.pth"), strict=False)

<All keys matched successfully>

In [10]:
model.eval()

VGG16(
  (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpooling1): MaxPool2d(kernel_size=2, stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpooling2): MaxPool2d(kernel_size=2, stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpooling3): MaxPool2d(kernel_size=2, stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv4_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1,

In [11]:
criterion = EdgeSaliencyLoss(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [12]:
model.to(device)

VGG16(
  (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpooling1): MaxPool2d(kernel_size=2, stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpooling2): MaxPool2d(kernel_size=2, stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpooling3): MaxPool2d(kernel_size=2, stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (conv4_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1,

In [13]:
for epoch in range(epochs):
    model.train()
    for batch_n, (image, mask) in enumerate(data_loader, start=1):
        
        image = image.to(device)
        mask = mask.to(device)
        
        optimizer.zero_grad()
        predict = model(image)
        loss = criterion(predict, mask)
        
        loss.backward()
        optimizer.step()
        
        if batch_n % 100 == 0:
            acc = accuracy(predict, mask)
            pre = precision(predict, mask)
            rec = recall(predict, mask)
            f_score = (1 + 0.3) * pre * rec / (0.3 * pre + rec)
            print("Epoch:{} loss:{} Batch:{}/{}".format(epoch + 1, loss.item(), batch_n, total_batch), end="")
            print(" acc:{} pre:{} recall:{} F-measure:{}".format(acc, pre, rec, f_score))

Epoch:1 loss:0.34354615211486816 Batch:100/2638 acc:0.8707199096679688 pre:0.671257734298706 recall:0.79440838098526 F-measure:0.6961623430252075
Epoch:1 loss:0.3350546658039093 Batch:200/2638 acc:0.8781852722167969 pre:0.9239149689674377 recall:0.497866690158844 F-measure:0.7715491652488708
Epoch:1 loss:0.24220749735832214 Batch:300/2638 acc:0.8904380798339844 pre:0.6997556686401367 recall:0.9028701186180115 F-measure:0.7380727529525757
Epoch:1 loss:0.3608689606189728 Batch:400/2638 acc:0.8411788940429688 pre:0.9502624869346619 recall:0.7179902195930481 F-measure:0.884249210357666
Epoch:1 loss:0.2816926836967468 Batch:500/2638 acc:0.8811531066894531 pre:0.9407837390899658 recall:0.7951963543891907 F-measure:0.9026468396186829
Epoch:1 loss:0.3704065680503845 Batch:600/2638 acc:0.8642463684082031 pre:0.9407534003257751 recall:0.706129252910614 F-measure:0.8737561702728271
Epoch:1 loss:0.2642161548137665 Batch:700/2638 acc:0.9116859436035156 pre:0.955432116985321 recall:0.835782825946807

In [14]:
torch.save(model.state_dict(),"./model/model_2.pth")

In [15]:
import cv2
import numpy as np
from torchvision import transforms

In [16]:
def pad_resize_image(inp_img, out_img=None, target_size=None):
    h, w, c = inp_img.shape
    size = max(h, w)

    padding_h = (size - h) // 2
    padding_w = (size - w) // 2

    if out_img is None:
        # For inference
        temp_x = cv2.copyMakeBorder(inp_img, top=padding_h, bottom=padding_h, left=padding_w, right=padding_w,
                                    borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
        if target_size is not None:
            temp_x = cv2.resize(temp_x, (target_size, target_size), interpolation=cv2.INTER_AREA)
        return temp_x
    else:
        # For training and testing
        temp_x = cv2.copyMakeBorder(inp_img, top=padding_h, bottom=padding_h, left=padding_w, right=padding_w,
                                    borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
        temp_y = cv2.copyMakeBorder(out_img, top=padding_h, bottom=padding_h, left=padding_w, right=padding_w,
                                    borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
        # print(inp_img.shape, temp_x.shape, out_img.shape, temp_y.shape)

        if target_size is not None:
            temp_x = cv2.resize(temp_x, (target_size, target_size), interpolation=cv2.INTER_AREA)
            temp_y = cv2.resize(temp_y, (target_size, target_size), interpolation=cv2.INTER_AREA)
        return temp_x, temp_y


In [17]:
def getInput(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype('float32')
    
    img = pad_resize_image(img, target_size=256)
    img /= 255.
    img = np.transpose(img, axes=(2, 0, 1))
    img = torch.from_numpy(img).float()
    
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]) 
    
    img = normalize(img)
    return img

In [18]:
img = getInput("./image/1.jpg")

In [19]:
img.shape

torch.Size([3, 256, 256])

In [20]:
img = img.reshape(1, 3, 256, 256)
img = img.to(device)

In [21]:
predict = model(img)

In [22]:
predict.shape

torch.Size([1, 1, 256, 256])

In [23]:
msk = predict.reshape(256, 256, 1)
msk = msk.cpu().detach()

In [24]:
msk.shape

torch.Size([256, 256, 1])

In [25]:
cv2.imshow("test", np.array(msk))
cv2.waitKey(0)
cv2.destroyAllWindows()