In [25]:
import torch
import torch.nn as nn
from torchsummary import summary
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import torchvision.models as models
import os
import scipy.io as spio

CUDA = torch.cuda.is_available()

device = torch.device("cuda" if CUDA else 'cpu')

In [32]:
model_vgg = models.vgg16(pretrained=True)

print(model_vgg)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [14]:
class FCN16(nn.Module):
    def __init__(self,model,hidden =64,kernel_size=3,padding = 1, num_classes = 21):
        super(FCN16,self).__init__()
        
        self.block1 = nn.Sequential(*list(model.features)[:5])
        self.block2 = nn.Sequential(*list(model.features)[5:10])
        self.block3 = nn.Sequential(*list(model.features)[10:17])
        self.block4 = nn.Sequential(*list(model.features)[17:24])
        self.block5 = nn.Sequential(*list(model.features)[24:31])
        
        model.classifier[0] = nn.Conv2d(512,4096,7)
        model.classifier[3] = nn.Conv2d(4096,4096,1)
        model.classifier[6] = nn.Conv2d(4096,num_classes,1)
        
        self.fc6 = nn.Sequential(*list(model.classifier)[0:3])
        self.fc7 = nn.Sequential(*list(model.classifier)[3:6])
        self.block_score = model.classifier[6]
        
        self.score_pool4 = nn.Conv2d(8*hidden,num_classes, kernel_size=1)
        
        self.upscore2 = nn.ConvTranspose2d(num_classes,num_classes,4,stride=2,bias=False)
        self.upscore16 = nn.ConvTranspose2d(num_classes,num_classes,32,stride = 16, bias=False)
        
    def forward(self,x):
        pred1 = self.block1(x)
        pred2 = self.block2(pred1)
        pred3 = self.block3(pred2)
        pred4 = self.block4(pred3)
        pred5 = self.block5(pred4)
        
        pred6 = self.fc6(pred5)
        pred7 = self.fc7(pred6)
        score = self.block_score(pred7)
        upscore2 = self.upscore2(score)
        
        pred4_1 = self.score_pool4(pred4)
        
        upscore16 = self.upscore16(upscore2+pred4_1)
        return upscore16

In [None]:
model = models.vgg16(pretrained=True)

fcn_model = FCN16(model)

print(fcn_model)

In [26]:
class CenterCrop():
    """
    Crops the center of the image and its dense labels in a sample.
    Note that PIL Image instances are casted to numpy ndarrays in this step.
    
    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size
    
    def __call__(self, sample):
        img, target = sample
        # convert PIL Image to numpy.ndarray
        img = np.array(img)
        target = np.array(target)
        new_h, new_w = self.output_size
        
        # zero-pad if the width or height is less than the output_size
        if img.shape[0] < new_h:
            # zero-pad vertically
            pad_width = new_h - img.shape[0]
            up_pad = pad_width // 2
            bottom_pad = pad_width - up_pad
            img = np.pad(img, ((up_pad, bottom_pad), (0, 0), (0, 0)), 'constant', constant_values=(0, 0))
            target = np.pad(target, ((up_pad, bottom_pad), (0, 0)), 'constant', constant_values=(0, 0))
            
        if img.shape[1] < new_w:
            # zero-pad horizontally
            pad_width = new_w - img.shape[1]
            left_pad = pad_width // 2
            right_pad = pad_width - left_pad
            img = np.pad(img, ((0, 0), (left_pad, right_pad), (0, 0)), 'constant', constant_values=(0, 0))
            target = np.pad(target, ((0, 0), (left_pad, right_pad)), 'constant', constant_values=(0, 0))

            
        h, w = img.shape[:2]
        
        top = (h - new_h) // 2
        left = (w - new_w) // 2
        
        img = img[top: top + new_h, left: left + new_w]
        target = target[top: top + new_h, left: left + new_w]
        assert img.shape[:2] == target.shape[:2]
        
        return img, target

In [27]:
def inference_dense_label(model, base_path, img_name, center_crop=(224, 224)):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 20))
    voc_colors = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

    voc_classes = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
                   'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
                   'diningtable', 'dog', 'horse', 'motorbike', 'person',
                   'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']
    
    img_path = os.path.join(base_path + 'img', img_name + '.jpg')
    gt_path = os.path.join(base_path + 'cls', img_name + '.mat')
    ind2color = {}
    legend_elements = []
    for ind, color in enumerate(voc_colors):
        ind2color[ind] = (color, voc_classes[ind])
        legend_elements.append(Line2D([0], [0], color=np.array(color)/255, lw=6))
    
    img = Image.open(img_path).convert('RGB')
    img = np.array(img)
    
    gt_mat = spio.loadmat(gt_path, mat_dtype=True, squeeze_me=True, struct_as_record=False)
    gt_mat = gt_mat['GTcls'].Segmentation
    
    
    if center_crop:
        crop = CenterCrop(center_crop)
        img, gt_mat = crop((img, gt_mat))
        
    ax1.set_title('original image')
    ax1.imshow(img)
    
    img = np.transpose(img, [2, 0, 1])
    img = torch.from_numpy(img).to(torch.float32)
    img = transforms.functional.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    img = torch.unsqueeze(img, dim=0)
    model.eval()
    with torch.no_grad():
        output = model(img)
        pred = torch.argmax(output, dim=1)
    
    pred = torch.squeeze(pred, dim=0)
    
    
    pred = pred.numpy()
    height, width = pred.shape
    colored_pred = np.zeros((height, width, 3), dtype=np.int32)
    for h in range(height):
        for w in range(width):
            colored_pred[h, w] = ind2color[pred[h, w]][0]
    
    ax2.set_title('predicted feature map')
    ax2.imshow(colored_pred)
    
    height, width = gt_mat.shape
    colored_gt_mat = np.zeros((height, width, 3), dtype=np.uint8)
    for h in range(height):
        for w in range(width):
            colored_gt_mat[h, w] = ind2color[gt_mat[h, w]][0]
            
    ax3.set_title('ground true dense label')
    ax3.legend(handles=legend_elements, labels=voc_classes, loc='upper center', bbox_to_anchor=(1.5, 1.2))
    ax3.imshow(colored_gt_mat)

In [28]:
def visualize_inference(model, base_path, img_name, center_crop=(512, 512)):
    """
    Plot the original image, predicted dense labels and the ground-true labels.
    Parameters:
    - model: PyTorch model
    - base_path: path to the augmented Pascal VOC dataset
    - img_name: image file name without format extension
                e.g. 2008_000073 is the img_path for 2008_000073.jpg and 2008_000073.mat
    - center_crop: whether crop to certain height, width for better visualization
    """
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 20))
    voc_colors = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

    voc_classes = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
                   'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
                   'diningtable', 'dog', 'horse', 'motorbike', 'person',
                   'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']
    
    img_path = os.path.join(base_path + 'img', img_name + '.jpg')
    gt_path = os.path.join(base_path + 'cls', img_name + '.mat')
    ind2color = {}
    legend_elements = []
    for ind, color in enumerate(voc_colors):
        ind2color[ind] = (color, voc_classes[ind])
        legend_elements.append(Line2D([0], [0], color=np.array(color)/255, lw=6))
        
    img = Image.open(img_path).convert('RGB')
    img = np.array(img)
    gt_mat = spio.loadmat(gt_path, mat_dtype=True, squeeze_me=True, struct_as_record=False)
    gt_mat = gt_mat['GTcls'].Segmentation
    
    
    if center_crop:
        crop = CenterCrop(center_crop)
        img, gt_mat = crop((img, gt_mat))
    
    ax1.set_title('original image')
    ax1.imshow(img)
    
    img = np.transpose(img, [2, 0, 1])
    img = torch.from_numpy(img).to(torch.float32)
    img = torch.div(img, 255)
    img = transforms.functional.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # make img a 4-dimension tensor NCHW, N == 1
    img = torch.unsqueeze(img, dim=0)

    model.eval()
    with torch.no_grad():
        output = model(img)
        pred = torch.argmax(output, dim=1)
    
    # squeeze the dense label back to a 3-dimension tensor, CHW
    pred = torch.squeeze(pred, dim=0)
    
    pred = pred.numpy()
    height, width = pred.shape
    colored_pred = np.zeros((height, width, 3), dtype=np.uint8)
    for h in range(height):
        for w in range(width):
            colored_pred[h, w] = ind2color[pred[h, w]][0]
    
    ax2.set_title('predicted feature map')
    ax2.imshow(colored_pred)
    
    height, width = gt_mat.shape
    colored_gt_mat = np.zeros((height, width, 3), dtype=np.uint8)
    for h in range(height):
        for w in range(width):
            colored_gt_mat[h, w] = ind2color[gt_mat[h, w]][0]
            
    ax3.set_title('ground true dense label')
    ax3.legend(handles=legend_elements, labels=voc_classes, loc='upper center', bbox_to_anchor=(1.5, 1.2))
    ax3.imshow(colored_gt_mat)

In [31]:
model_state_dict_path = 'C:/Users/준영/.cache/torch/hub/checkpoints/vgg16-397923af.pth'
model = models.vgg16(pretrained=False)
device_gpu = torch.device("cuda")
fcn_model = FCN16(model)
fcn_model.load_state_dict(torch.load(model_state_dict_path, map_location = device))
base_path = 'D:\VOCtrainval_25-May-2011\TrainVal\VOCdevkit\VOC2011\JPEGImages/'
inference_dense_label(fcn_model, base_path,'2007_000042')

RuntimeError: Error(s) in loading state_dict for FCN16:
	Missing key(s) in state_dict: "block1.0.weight", "block1.0.bias", "block1.2.weight", "block1.2.bias", "block2.0.weight", "block2.0.bias", "block2.2.weight", "block2.2.bias", "block3.0.weight", "block3.0.bias", "block3.2.weight", "block3.2.bias", "block3.4.weight", "block3.4.bias", "block4.0.weight", "block4.0.bias", "block4.2.weight", "block4.2.bias", "block4.4.weight", "block4.4.bias", "block5.0.weight", "block5.0.bias", "block5.2.weight", "block5.2.bias", "block5.4.weight", "block5.4.bias", "fc6.0.weight", "fc6.0.bias", "fc7.0.weight", "fc7.0.bias", "block_score.weight", "block_score.bias", "score_pool4.weight", "score_pool4.bias", "upscore2.weight", "upscore16.weight". 
	Unexpected key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias", "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias", "classifier.6.weight", "classifier.6.bias". 