## How to use a checkpoint

In [19]:
from unet import UNet
import torch
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps') # for mac
else:
    device = torch.device('cpu')
print(device)
net = UNet(n_channels=3, n_classes=16, bilinear=False)
net.to(device=device) # 讲定义的模型实例加载到GPU上

checkpoint_path = './checkpoints/checkpoint_epoch1.pth'
state_dict = torch.load(checkpoint_path, map_location=device)  # 使用torch.load加载模型

# 展示模型的keys
print(state_dict.keys())
print(len(state_dict.keys()))

mps
odict_keys(['inc.double_conv.0.weight', 'inc.double_conv.1.weight', 'inc.double_conv.1.bias', 'inc.double_conv.1.running_mean', 'inc.double_conv.1.running_var', 'inc.double_conv.1.num_batches_tracked', 'inc.double_conv.3.weight', 'inc.double_conv.4.weight', 'inc.double_conv.4.bias', 'inc.double_conv.4.running_mean', 'inc.double_conv.4.running_var', 'inc.double_conv.4.num_batches_tracked', 'down1.maxpool_conv.1.double_conv.0.weight', 'down1.maxpool_conv.1.double_conv.1.weight', 'down1.maxpool_conv.1.double_conv.1.bias', 'down1.maxpool_conv.1.double_conv.1.running_mean', 'down1.maxpool_conv.1.double_conv.1.running_var', 'down1.maxpool_conv.1.double_conv.1.num_batches_tracked', 'down1.maxpool_conv.1.double_conv.3.weight', 'down1.maxpool_conv.1.double_conv.4.weight', 'down1.maxpool_conv.1.double_conv.4.bias', 'down1.maxpool_conv.1.double_conv.4.running_mean', 'down1.maxpool_conv.1.double_conv.4.running_var', 'down1.maxpool_conv.1.double_conv.4.num_batches_tracked', 'down2.maxpool_conv.

In [20]:
# status_dict 中含有一个不需要的关键字'mask_values', 需要删除
state_dict.pop('mask_values', [0, 1])
print(state_dict.keys())
print(len(state_dict.keys()))

odict_keys(['inc.double_conv.0.weight', 'inc.double_conv.1.weight', 'inc.double_conv.1.bias', 'inc.double_conv.1.running_mean', 'inc.double_conv.1.running_var', 'inc.double_conv.1.num_batches_tracked', 'inc.double_conv.3.weight', 'inc.double_conv.4.weight', 'inc.double_conv.4.bias', 'inc.double_conv.4.running_mean', 'inc.double_conv.4.running_var', 'inc.double_conv.4.num_batches_tracked', 'down1.maxpool_conv.1.double_conv.0.weight', 'down1.maxpool_conv.1.double_conv.1.weight', 'down1.maxpool_conv.1.double_conv.1.bias', 'down1.maxpool_conv.1.double_conv.1.running_mean', 'down1.maxpool_conv.1.double_conv.1.running_var', 'down1.maxpool_conv.1.double_conv.1.num_batches_tracked', 'down1.maxpool_conv.1.double_conv.3.weight', 'down1.maxpool_conv.1.double_conv.4.weight', 'down1.maxpool_conv.1.double_conv.4.bias', 'down1.maxpool_conv.1.double_conv.4.running_mean', 'down1.maxpool_conv.1.double_conv.4.running_var', 'down1.maxpool_conv.1.double_conv.4.num_batches_tracked', 'down2.maxpool_conv.1.do

In [21]:
from utils.data_loading import WSDataset
from unet import UNet
from utils.utils import plot_img_and_mask
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from torchvision import transforms

# data_loading 中的两种预处理变换
class ImageResize(torch.nn.Module):
    def __init__(self, new_size, interpolate_mode=Image.NEAREST):
        super(ImageResize, self).__init__()
        self.new_size = new_size
        self.interpolate_mode = interpolate_mode # Image.NEAREST if is_mask else Image.BICUBIC

    def forward(self, img):
        # img = transforms.Resize(self.new_size, interpolation=self.interpolate_mode)
        img = img.resize(self.new_size, resample=self.interpolate_mode)
        img = np.asarray(img)
        return img


class ImageNormalization(torch.nn.Module):
    def __init__(self):
        super(ImageNormalization, self).__init__()
        
    def forward(self, img):
        img = img / 255.0
        return img

In [22]:
net.load_state_dict(state_dict)
# 使用模型参数预测结果
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()
    new_size = (full_img.size[0] * scale_factor, full_img.size[1] * scale_factor)
    im_transform = transforms.Compose([ # 组合预处理变换
        ImageResize(new_size, interpolate_mode=Image.BICUBIC), 
        ImageNormalization(),
        transforms.ToTensor()
    ])
    img = im_transform(full_img).unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)
    print(img.shape)    

    with torch.no_grad():
        output = net(img).cpu()
        output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear')
        mask = output.argmax(dim=1)
    return mask[0].long().squeeze().numpy()