In [21]:
import torch
import torchvision
import cv2
from torch.utils import data
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import random

for v in (torch, torchvision, cv2, np):
    print(v.__version__)

1.6.0+cu101
0.7.0+cu101
4.3.0
1.19.1


In [2]:
path_image = "./DUTS/DUTS-TR/DUTS-TR-Image/"
path_mask = "./DUTS/DUTS-TR/DUTS-TR-Mask/"

In [8]:
img = cv2.imread("./image/1.jpg")
img.shape

(400, 388, 3)

In [32]:
class LoadData(data.Dataset):
    def __init__(self, img_path, mask_path, target_size):
        self.img_path = img_path
        self.mask_path = mask_path
        self.target_size = target_size
        
        self.image = os.listdir(img_path)
        self.mask = []
        for name in os.listdir(img_path):
            name = name[:-3] + "png"
            self.mask.append(name)
    
    def __getitem__(self, index):
        image = cv2.imread(self.img_path + self.image[index])
        #交换RGB通道
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_path + self.mask[index])
        if len(mask.shape) == 3:
            mask = mask[:,:,0]
        mask = mask / 255.
        
        image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, self.target_size, interpolation=cv2.INTER_NEAREST)
        
        image = np.array(image, dtype=np.float32)
        mask = np.array(mask, dtype=np.float32)
        image = image.transpose((2, 0, 1))
        mask = mask.reshape((1, self.target_size[0],self.target_size[1]))
        return image, mask
    
    def __len__(self):
        return len(self.image)

In [33]:
data_loader = torch.utils.data.DataLoader(LoadData(path_image, path_mask, (256,256)),
                                         batch_size=3, shuffle=False)

In [34]:
for img, msk in data_loader:
    print(img.shape)
    print(msk.shape)
    break

torch.Size([3, 3, 256, 256])
torch.Size([3, 1, 256, 256])


In [35]:
from model import VGG16

In [36]:
vgg = VGG16()

In [26]:
state_dict = torch.load("./model/vgg16_no_top.pth")

In [29]:
state_dict.keys()

odict_keys(['conv1_1.weight', 'conv1_1.bias', 'conv1_2.weight', 'conv1_2.bias', 'conv2_1.weight', 'conv2_1.bias', 'conv2_2.weight', 'conv2_2.bias', 'conv3_1.weight', 'conv3_1.bias', 'conv3_2.weight', 'conv3_2.bias', 'conv3_3.weight', 'conv3_3.bias', 'conv4_1.weight', 'conv4_1.bias', 'conv4_2.weight', 'conv4_2.bias', 'conv4_3.weight', 'conv4_3.bias', 'conv5_1.weight', 'conv5_1.bias', 'conv5_2.weight', 'conv5_2.bias', 'conv5_3.weight', 'conv5_3.bias'])

In [28]:
vgg.load_state_dict(state_dict, strict=False)

_IncompatibleKeys(missing_keys=['mfe1.conv11.weight', 'mfe1.conv11.bias', 'mfe1.conv33.weight', 'mfe1.conv33.bias', 'mfe1.conv55.weight', 'mfe1.conv55.bias', 'mfe2.conv11.weight', 'mfe2.conv11.bias', 'mfe2.conv33.weight', 'mfe2.conv33.bias', 'mfe2.conv55.weight', 'mfe2.conv55.bias', 'mfe3.conv11.weight', 'mfe3.conv11.bias', 'mfe3.conv33.weight', 'mfe3.conv33.bias', 'mfe3.conv55.weight', 'mfe3.conv55.bias', 'mfe4.conv11.weight', 'mfe4.conv11.bias', 'mfe4.conv33.weight', 'mfe4.conv33.bias', 'mfe4.conv55.weight', 'mfe4.conv55.bias', 'mfe5.conv11.weight', 'mfe5.conv11.bias', 'mfe5.conv33.weight', 'mfe5.conv33.bias', 'mfe5.conv55.weight', 'mfe5.conv55.bias', 'mfe6.conv11.weight', 'mfe6.conv11.bias', 'mfe6.conv33.weight', 'mfe6.conv33.bias', 'mfe6.conv55.weight', 'mfe6.conv55.bias', 'mfe7.conv11.weight', 'mfe7.conv11.bias', 'mfe7.conv33.weight', 'mfe7.conv33.bias', 'mfe7.conv55.weight', 'mfe7.conv55.bias', 'mfe8.conv11.weight', 'mfe8.conv11.bias', 'mfe8.conv33.weight', 'mfe8.conv33.bias', 'm