# Predict

## Import Libraries

In [1]:
# torch
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.utils.data import Dataset, DataLoader, TensorDataset

from tqdm import trange
import numpy as np
import os
from matplotlib import pyplot as plt
import yaml

# import custom modules
from UNet import UNet, Nested_UNet
from utils import load_zipped_pickle, save_zipped_pickle

# config device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu_device = torch.device("cpu")
DATA_DIR = "Data"

## Load Data

In [2]:
test_data = load_zipped_pickle(os.path.join(DATA_DIR, "test.pkl"))
processed_test_data = load_zipped_pickle(os.path.join(DATA_DIR, "test_crop.pkl"))

# 加载目标框数据
with open(os.path.join("./Data/", "box.yaml"), 'r') as f:
    box = yaml.load(f, yaml.FullLoader)

## Data Preprocess

In [3]:
# 定义预测模型所用的采样和复原函数
def pad(img: np.ndarray, pad_size: tuple) -> np.ndarray:
    """
    将图片基于指定的pad__size进行padding
    pad_size: (pad_top, pad_bottom, pad_left, pad_right)
    """
    pad_img = np.zeros((img.shape[0] + pad_size[0] + pad_size[1], img.shape[1] + pad_size[2] + pad_size[3]), dtype=img.dtype)
    pad_img[pad_size[0]:img.shape[0] + pad_size[0], pad_size[2]:img.shape[1] + pad_size[2]] = img
    return pad_img


def pad_to_window_size(img: np.ndarray, window_size: tuple = (112, 112)) -> np.ndarray:
    """
    将图片基于指定的window_size进行padding
    """
    pad_height = (img.shape[0] // window_size[0] + 1) * window_size[0] - img.shape[0]
    pad_width = (img.shape[1] // window_size[1] + 1) * window_size[1] - img.shape[1]
    pad_size = (pad_height // 2, pad_height - (pad_height // 2), pad_width // 2, pad_width - (pad_width // 2))
    return pad(img, pad_size), pad_size


def sample_by_tailor(img: np.ndarray, window_size: tuple = (112, 112)) -> ([np.ndarray], tuple):
    """
    将图片基于指定的window_size进行裁剪
    """
    pad_img, pad_size = pad_to_window_size(img, window_size)
    samples = []
    for i in range(pad_img.shape[0] // window_size[0]):
        for j in range(pad_img.shape[1] // window_size[1]):
            samples.append(pad_img[i * window_size[0]:(i + 1) * window_size[0], j * window_size[1]:(j + 1) * window_size[1]])
    return samples, pad_size, (pad_img.shape[0] // window_size[0], pad_img.shape[1] // window_size[1])


def traverse_by_tailor(imgs: [np.ndarray], pad_size: tuple, original_shape: tuple) -> np.ndarray:
    """
    将裁切完的图片基于裁切方式还原为原始图片
    """
    rows = []
    for i in range(original_shape[0]):
        i_row = np.concatenate(imgs[i * original_shape[1]:(i + 1) * original_shape[1]], axis=1)
        rows.append(i_row)
    pad_img = np.concatenate(rows, axis=0)
    img = pad_img[pad_size[0]:pad_img.shape[0] - pad_size[1], pad_size[2]:pad_img.shape[1] - pad_size[3]]
    return img


def sample_by_window(img: np.ndarray, window_size: tuple = (112, 112), stride=16) -> ([np.ndarray], tuple):
    """
    将图片基于指定的window_size进行裁剪
    """
    pad_img, pad_size = pad_to_window_size(img, window_size)
    samples = []
    for i in range(0, pad_img.shape[0] - window_size[0] + 1, stride):
        for j in range(0, pad_img.shape[1] - window_size[1] + 1, stride):
            samples.append(pad_img[i:i + window_size[0], j:j + window_size[1]])
    return samples, pad_size, ((pad_img.shape[0] - window_size[0]) // stride + 1, (pad_img.shape[1] - window_size[1]) // stride + 1)


def traverse_by_window(imgs: [np.ndarray], pad_size: tuple, original_shape: tuple, stride=16) -> np.ndarray:
    """
    将按照窗口采样的数据还原为原始图片
    """
    pad_height = imgs[0].shape[0] + stride * (original_shape[0] - 1)
    pad_width = imgs[0].shape[1] + stride * (original_shape[1] - 1)
    pad_img = np.zeros((pad_height, pad_width))

    # 复原padding过后的图片，此时某些像素可能因为多次叠加而有过高的值
    for i in range(original_shape[0]):
        for j in range(original_shape[1]):
            pad_img[stride * i:stride * i + imgs[0].shape[0], stride * j:stride * j + imgs[0].shape[1]] += imgs[i * original_shape[1] + j]
    
    # 复原权重矩阵，用全为1的img子图模拟叠加过程，获得每个像素点的叠加次数
    weight_matrix = np.zeros_like(pad_img)
    weight_imgs = np.ones_like(imgs[0])
    for i in range(original_shape[0]):
        for j in range(original_shape[1]):
            weight_matrix[stride * i:stride * i + imgs[0].shape[0], stride * j:stride * j + imgs[0].shape[1]] += weight_imgs
    
    # 复原padding后的图像并去除多余的padding
    pad_img /= weight_matrix
    img = pad_img[pad_size[0]:pad_img.shape[0] - pad_size[1], pad_size[2]:pad_img.shape[1] - pad_size[3]]
    return img

In [4]:
def predict_by_torch(x: torch.Tensor, model: nn.Module) -> torch.Tensor:
    """
    使用torch模型进行预测
    """
    # 生成预测数据集和数据加载
    dataset = TensorDataset(x)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
    model.eval()  # 将模型的状态设置为eval
    pred = []

    # 批量进行预测
    with torch.no_grad():
        for (x_batch,) in dataloader:
            x_batch = x_batch.to(device)
            y_batch = F.sigmoid(model(x_batch)) # 预测完的结果需要计算sigmoid
            pred.append(y_batch)
    pred = torch.cat(pred, dim=0)
    return pred


def predict(x: np.ndarray, model: nn.Module):
    x = torch.from_numpy(x).float().unsqueeze(1)  # 原始数据：(sample_num, height, width)，转化后：(sample_num, 1, height, width) 增加channel的维度
    pred = predict_by_torch(x, model).squeeze(1)  # 去除channel的维度，恢复为原始的维度
    pred = pred.to(cpu_device).detach().numpy()  # 转化为np.ndarray
    return pred

In [5]:
# 通过windows进行采样并预测(bilinear unet)

# 加载模型
model = UNet(1, 1, bilinear=True).to(device)
model.load_state_dict(torch.load("./Model/UNet_4_4_bilinear.pt"))

pred_data = []
with trange(len(test_data), desc="Predicting: ") as t:
    postfix = dict()
    for video_i in t:
        video = test_data[video_i]
        video_name = video["name"]
        video_imgs = video["video"].transpose(2, 0, 1)  # 调整为(Frames, height, width)
        video_pred = []
        for img_i, img in enumerate(video_imgs):
            img_samples, pad_size, sample_shape = sample_by_window(img, window_size=(112, 112), stride=16)
            img_samples = np.array(img_samples)
            img_samples_pred = predict(img_samples, model)
            img_pred = traverse_by_window(img_samples_pred, pad_size, sample_shape, stride=16)
            video_pred.append(img_pred)
            postfix["img_i"] = img_i
            t.set_postfix(postfix)
        video_pred = np.array(video_pred).transpose(1, 2, 0)
        pred_data.append({"name": video_name, "prediction": video_pred})
save_zipped_pickle(pred_data, os.path.join("Prediction", "UNet_4_4_bilinear.pkl"))  # 此数据仍需后处理，预测结果是浮点数，需要转化为0和1

Predicting: 100%|██████████| 20/20 [32:58<00:00, 98.94s/it, img_i=61]  


In [12]:
# 通过windows进行采样并预测(convtranspose unet)

# 加载模型
model = UNet(1, 1, bilinear=False).to(device)
model.load_state_dict(torch.load("./Model/UNet_4_4_transpose_conv.pt"))

pred_data = []
with trange(len(test_data), desc="Predicting: ") as t:
    postfix = dict()
    for video_i in t:
        video = test_data[video_i]
        video_name = video["name"]
        video_imgs = video["video"].transpose(2, 0, 1)  # 调整为(Frames, height, width)
        video_pred = []
        for img_i, img in enumerate(video_imgs):
            img_samples, pad_size, sample_shape = sample_by_window(img, window_size=(112, 112), stride=16)
            img_samples = np.array(img_samples)
            img_samples_pred = predict(img_samples, model)
            img_pred = traverse_by_window(img_samples_pred, pad_size, sample_shape, stride=16)
            video_pred.append(img_pred)
            postfix["img_i"] = img_i
            t.set_postfix(postfix)
        video_pred = np.array(video_pred).transpose(1, 2, 0)
        pred_data.append({"name": video_name, "prediction": video_pred})
save_zipped_pickle(pred_data, os.path.join("Prediction", "UNet_4_4_transpose_conv.pkl"))  # 此数据仍需后处理，预测结果是浮点数，需要转化为0和1

Predicting: 100%|██████████| 20/20 [26:46<00:00, 80.33s/it, img_i=61] 


In [5]:
# 通过windows进行采样并预测(unet++)

# 加载模型
model = Nested_UNet(1, 1).to(device)
model.load_state_dict(torch.load("./Model/NestedUNet_4_4.pt"))

pred_data = []
with trange(len(test_data), desc="Predicting: ") as t:
    postfix = dict()
    for video_i in t:
        video = test_data[video_i]
        video_name = video["name"]
        video_imgs = video["video"].transpose(2, 0, 1)  # 调整为(Frames, height, width)
        video_pred = []
        for img_i, img in enumerate(video_imgs):
            img_samples, pad_size, sample_shape = sample_by_window(img, window_size=(112, 112), stride=16)
            img_samples = np.array(img_samples)
            img_samples_pred = predict(img_samples, model)
            img_pred = traverse_by_window(img_samples_pred, pad_size, sample_shape, stride=16)
            video_pred.append(img_pred)
            postfix["img_i"] = img_i
            t.set_postfix(postfix)
        video_pred = np.array(video_pred).transpose(1, 2, 0)
        pred_data.append({"name": video_name, "prediction": video_pred})
save_zipped_pickle(pred_data, os.path.join("Prediction", "Nested_UNet_4_4.pkl"))  # 此数据仍需后处理，预测结果是浮点数，需要转化为0和1

Predicting: 100%|██████████| 20/20 [1:40:39<00:00, 301.97s/it, img_i=61]  


In [5]:
# 通过windows进行采样并预测(unet processed data)

# 加载模型
model = UNet(1, 1, bilinear=True).to(device)
model.load_state_dict(torch.load("./Model/UNet_4_4_bilinear_processed.pt"))

pred_data = []
with trange(len(test_data), desc="Predicting: ") as t:
    postfix = dict()
    for video_i in t:
        original_video = test_data[video_i]
        video_name = original_video["name"]
        original_video_imgs = original_video["video"].transpose(2, 0, 1)  # 调整为(Frames, height, width)
        tailored_video = processed_test_data[video_i]
        tailored_video_imgs = tailored_video["images"].transpose(2, 0, 1)
        tailored_video_shapes = tailored_video["shapes"]
        video_pred = []
        for img_i, (original_img, tailored_img) in enumerate(zip(original_video_imgs, tailored_video_imgs)):
            img_samples, pad_size, sample_shape = sample_by_window(tailored_img, window_size=(256, 256), stride=32)
            img_samples = np.array(img_samples)
            img_samples_pred = predict(img_samples, model)
            img_pred = traverse_by_window(img_samples_pred, pad_size, sample_shape, stride=32)
            img_pred_new = np.ones_like(original_img)
            img_pred_new[tailored_video_shapes[2]:tailored_video_shapes[3], tailored_video_shapes[0]:tailored_video_shapes[1]] = img_pred  # 将裁切的结果进行还原（待预测数据是512*512的，需要填充回一个全0的矩阵中）
            img_pred = img_pred_new
            video_pred.append(img_pred)
            postfix["img_i"] = img_i
            t.set_postfix(postfix)
        video_pred = np.array(video_pred).transpose(1, 2, 0)
        pred_data.append({"name": video_name, "prediction": video_pred})
save_zipped_pickle(pred_data, os.path.join("Prediction", "UNet_Processed_Data.pkl"))  # 此数据仍需后处理，预测结果是浮点数，需要转化为0和1

Predicting: 100%|██████████| 20/20 [19:34<00:00, 58.70s/it, img_i=61] 


In [6]:
# 通过windows进行采样并预测(unet high resolution)

# 加载模型
model = UNet(1, 1, bilinear=True).to(device)
model.load_state_dict(torch.load("./Model/UNet_4_4_bilinear_high_resolution.pt"))

pred_data = []
with trange(len(test_data), desc="Predicting: ") as t:
    postfix = dict()
    for video_i in t:
        video = test_data[video_i]
        video_name = video["name"]
        video_imgs = video["video"].transpose(2, 0, 1)  # 调整为(Frames, height, width)
        video_pred = []
        for img_i, img in enumerate(video_imgs):
            img_samples, pad_size, sample_shape = sample_by_window(img, window_size=(256, 256), stride=32)
            img_samples = np.array(img_samples)
            img_samples_pred = predict(img_samples, model)
            img_pred = traverse_by_window(img_samples_pred, pad_size, sample_shape, stride=32)
            video_pred.append(img_pred)
            postfix["img_i"] = img_i
            t.set_postfix(postfix)
        video_pred = np.array(video_pred).transpose(1, 2, 0)
        pred_data.append({"name": video_name, "prediction": video_pred})
save_zipped_pickle(pred_data, os.path.join("Prediction", "UNet_High_Resolution.pkl"))  # 此数据仍需后处理，预测结果是浮点数，需要转化为0和1

Predicting: 100%|██████████| 20/20 [26:04<00:00, 78.25s/it, img_i=61]  


In [7]:
# 从保存的数据中加载预测结果
pred_data_unet_bilinear = load_zipped_pickle(os.path.join("Prediction", "UNet_4_4_bilinear.pkl"))
pred_data_unet_conv = load_zipped_pickle(os.path.join("Prediction", "UNet_4_4_transpose_conv.pkl"))
pred_data_unet_nested = load_zipped_pickle(os.path.join("Prediction", "Nested_UNet_4_4.pkl"))
pred_data_unet_processed = load_zipped_pickle(os.path.join("Prediction", "UNet_Processed_Data.pkl"))
pred_data_unet_high_resolution = load_zipped_pickle(os.path.join("Prediction", "UNet_High_Resolution.pkl"))

In [8]:
# 对预测结果进行处理
prob_threshold = 0.5

transformed_pred = []
for video_bilinear, video_conv, video_nested, video_box_arg in zip(pred_data_unet_bilinear, pred_data_unet_conv, pred_data_unet_nested, box):
    video_transformed = {"name": video_bilinear["name"]}
    video_bilinear = video_bilinear["prediction"]
    video_conv = video_conv["prediction"]
    video_nested = video_nested["prediction"]
    video_pred = (video_bilinear + video_conv + video_nested + video_high_resolution) / 4  # 对三个模型的预测结果进行平均

    # 在这里开始处理
    video_box = np.zeros_like(video_pred)
    video_box[video_box_arg[0]:video_box_arg[1], video_box_arg[2]:video_box_arg[3], :] = 1
    video_pred = (video_pred > prob_threshold) * video_box # 将预测结果按照阈值进行处理，并转换数据类型
    video_pred = (video_pred == 1) # 转换数组的数据类型为bool值

    video_transformed["prediction"] = video_pred
    transformed_pred.append(video_transformed)

In [12]:
# 对预测结果进行处理
prob_threshold = 0.5

transformed_pred = []
for video, video_box_arg in zip(pred_data_unet_processed, box):
    video_transformed = {"name": video["name"]}
    video_pred = video["prediction"]

    # 在这里开始处理
    video_box = np.zeros_like(video_pred)
    video_box[video_box_arg[0]:video_box_arg[1], video_box_arg[2]:video_box_arg[3], :] = 1
    video_pred = (video_pred > prob_threshold) * video_box # 将预测结果按照阈值进行处理，并转换数据类型
    video_pred = (video_pred == 1) # 转换数组的数据类型为bool值

    video_transformed["prediction"] = video_pred
    transformed_pred.append(video_transformed)

In [13]:
# 保存转化完的预测结果
save_zipped_pickle(transformed_pred, "./Prediction/result.pkl")