In [18]:
from util import create_dir, replace_dir, Clock, compare_dir, split_parameters
from transform_img import flatten_onehot, Diff_size_collect, get_transform, norm_black_color
from loss import Soft_dice_loss, Focal_loss, SSIM, activate
from plot import plot_grad_flow, Progress_writer, onehot_gird, Loss_record, Acc_record, Loss_writer,LayerActivations
from dataset.dataset import Image_Dataset, Zip_dataset, get_data_files
from dataset.tarpath import Tar_path
import accuracy as acc
from torch.utils.data import DataLoader, Subset
from torchsummary import summary

from net.unet import U_Net
from net.nested_unet import NestedUNet
from net.regseg import RegSeg
from net.regseg_p import RegSeg_dp

import tarfile
from pathlib import Path
from itertools import islice
import copy
import time
import cv2
from matplotlib import pyplot as plt
import torch.nn.functional as F
import torchvision
from torch.utils.tensorboard import SummaryWriter
import torch
import numpy as np
import matplotlib
from itertools import product

from dataset.lmdb_format import Lmdb_dataset
from dataset.supervisely_format import get_super_dataset
from transform_img import get_transform, onehot_seq_torch, eq_proportion_resize, color_means, color_stds,transform_label

# 检查cuda
train_on_gpu = False # torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available. Training on CPU')
    torch.set_num_threads(12)
else:
    print('CUDA is available. Training on GPU')

device = torch.device("cuda:0" if train_on_gpu else "cpu")


CUDA is not available. Training on CPU


In [19]:
IMG_SIZE = 512

# 数据转换
tr_list = [
    torchvision.transforms.Lambda(
        lambda img:eq_proportion_resize(img, float(IMG_SIZE), cv2.INTER_CUBIC)),
    torchvision.transforms.Lambda(
        lambda img:cv2.cvtColor(img, cv2.COLOR_BGR2RGB)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=color_means,
        std=color_stds, inplace=True)
]
transform = torchvision.transforms.Compose(tr_list)


def get_collect_fn(channels):
    gd_black = torch.zeros((channels, 1, 1))
    if channels > 1:
        gd_black[0, ...] = torch.tensor([1])

    def collect_fn(imgs):
        return Diff_size_collect.collect_fn(
            imgs, 4,
            black={
                0: torch.tensor(norm_black_color).reshape((3, 1, 1)),
                1: gd_black
            })
    return collect_fn


def get_data_loader_para(channels):
    return {
        "batch_size": 1,
        "shuffle": False,
        "pin_memory": train_on_gpu,
        "num_workers": 1,
        "collate_fn": get_collect_fn(channels),
    }


def test_dataset(dataset):
    for x, y in islice(dataset, 4):
        plt.figure()
        plt.subplot(211)
        plt.imshow(x.permute((1, 2, 0)))
        plt.subplot(212)
        plt.imshow(onehot_gird(y).permute((1, 2, 0)))
        plt.show()


# 加载数据集
# Persons Labeled
persons_labeled = get_super_dataset(
    "../data/graduate/Supervisely Person Dataset", transform,
    torchvision.transforms.Compose([
        torchvision.transforms.Lambda(
            lambda img:eq_proportion_resize(img, float(IMG_SIZE), cv2.INTER_NEAREST)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda img:img*255),
    ]))
persons_labeled_loader = DataLoader(
    persons_labeled, **get_data_loader_para(1)
)

# ATR
unuse_label = []


def delete_label(img: np.ndarray, labels: list[int]):
    for l in labels:
        img[img == l] = 0
    img[img > 0] = 1
    return img


atr_img = Image_Dataset(get_data_files(
    "../data/graduate/LIP/ATR/humanparsing/JPEGImages"), transform)
atr_tr = Image_Dataset(get_data_files(
    "../data/graduate/LIP/ATR/humanparsing/SegmentationClassAug"),
    torchvision.transforms.Compose([
        torchvision.transforms.Lambda(
            lambda img:eq_proportion_resize(img, float(IMG_SIZE), cv2.INTER_NEAREST)),
        torchvision.transforms.Lambda(
            lambda img:delete_label(img, unuse_label)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda img:img*255),
    ]),
    cv2.IMREAD_GRAYSCALE)
atr_dataset = Zip_dataset(atr_img, atr_tr)
atr_loader = DataLoader(atr_dataset, **get_data_loader_para(1))

#
val_img_dataset = Lmdb_dataset(
    "../data/graduate/lip_c5_db", "validation",
    torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=color_means,
            std=color_stds, inplace=True)
    ])
)
val_target_img_dataset = Lmdb_dataset(
    "../data/graduate/lip_c5_db", "validation_seg",
    torchvision.transforms.Compose([
        torchvision.transforms.Lambda(
            lambda img:torch.tensor(np.expand_dims(img, 0), dtype=torch.float32)),
        torchvision.transforms.Lambda(
            lambda img: onehot_seq_torch(img, 6, torch.float32)),
    ])
)

val_dataset = Zip_dataset(val_img_dataset, val_target_img_dataset)
val_loader = DataLoader(val_dataset, **get_data_loader_para(6))


In [20]:

# 加载模型
from net.gate_t0 import RegSeg_gate0


class RegSeg_gate0_6(RegSeg_gate0):
    C12_TO_C5 = {
        1: [1, 2],
        2: [4, 12],
        3: [5, 6, 9, 10],
        4: [7, 8, 11],
        5: [3,],
    }

    def forward(self, x: torch.Tensor):
        o = super().forward(x)
        return transform_label(o, self.C12_TO_C5)


regseg_gate0 = RegSeg_gate0_6(3, 13)
regseg_gate0.load_state_dict(torch.load(
    "../model/complete/RegSeg_gatet0_3to13_e50_b8_s512/model/model_e45.pth", map_location=device))
regseg_gate0.to(device)

regseg_align = RegSeg_dp(3, 6, 0)
regseg_align.load_state_dict(torch.load(
    "../model/complete/RegSeg_align_3to6_e50_b8_s512/model/model_e29.pth", map_location=device))
regseg_align.to(device)

regseg = RegSeg(2, 6)
regseg.load_state_dict(torch.load(
    "../model/complete/RegSeg_3to6_e200_b8_s512/model/model_e196.pth", map_location=device))
regseg.to(device)

regseg_gate = RegSeg_dp(3, 6, 1)
regseg_gate.load_state_dict(torch.load(
    "../model/complete/RegSeg_gate_3to6_e50_b8_s512/model/model_e40.pth", map_location=device))
regseg_gate.to(device)
pass


In [21]:
def test(model: torch.nn.Module, dataloader: DataLoader, calculate_acc):

    use_time = 0.

    model.eval()
    with torch.no_grad():
        for batch_index, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)

            t = time.time()
            prediction: torch.Tensor = model(x)
            use_time += time.time() - t

            prediction = activate(prediction)

            if prediction.size(1) != y.size(1):
                if prediction.size(1) > 1 and y.size(1) == 1:
                    prediction = torch.argmax(
                        prediction, dim=1, keepdim=True) > 0
                    prediction = prediction.to(torch.int32)
                    y = y.to(torch.int32)
                elif prediction.size(1) == 1:
                    prediction = prediction > 0.5
                    prediction = prediction.to(torch.int32)

                    if y.size(1) != 1:
                        y = torch.argmax(y, dim=1, keepdim=True) > 0
                        y = y.to(torch.int32)
            else:
                prediction = torch.argmax(prediction, dim=1, keepdim=True)
                y = torch.argmax(y, dim=1, keepdim=True)

            calculate_acc(prediction, y, batch_index)

    return use_time


def test_model(model: torch.nn.Module, data: DataLoader, class_):
    acc_r = Acc_record(len(data.dataset), max(class_, 2))
    time = test(model, data,
                lambda p, y, b: acc_r.calculate(p, y))

    print("times: {:.4f}".format(time / len(data.dataset)))

    print("class iou:"+
          "".join(("\n\t{:.2%}".format(c.mean()) for c in acc_r.class_iou))
        )

    print("mean iou: {:.2%}".format(acc_r.mean_iou.mean()))
    print("front iou: {:.2%}".format(acc_r.front_iou.mean()))

    print("mean pa: {:.2%}".format(acc_r.cpa.mean(0).mean()))
    print("pa: {:.2%}".format(acc_r.pa.mean()))
    return acc_r


for (model_name, model), (name, d, n_class) in product(
    (
        # ("regseg", regseg),
        # ("align", regseg_align),
        # ("gate", regseg_gate),
        ("gate0", regseg_gate0),
    ),
    (
        ("Person labeled", persons_labeled_loader, 1),
        # ("ATR", atr_loader, 1),
        # ("VAL", val_loader, 6),
    )
):
    print("test {}\ndata {}".format(model_name, name))
    test_model(model, d, n_class)


test gate0
data Person labeled




times: 0.1218
class iou:
	92.61%
	77.16%
mean iou: 84.88%
front iou: 77.16%
mean pa: 91.02%
pa: 94.98%


In [17]:
STDS_ = torch.tensor(color_stds).view(1, 3, 1, 1).to(device)
MEANS_ = torch.tensor(color_means).view(1, 3, 1, 1).to(device)


def unnormalize(img: torch.Tensor):
    return img * STDS_.to(img.device) + MEANS_.to(img.device)


def zoom(img: torch.Tensor, t_min=0., t_max=1.):
    max_ = img.max()
    min_ = img.min()
    return img / (max_-min_) * (t_max-t_min) + t_min


def predict(model: torch.nn.Module, x: torch.Tensor):
    model.eval()
    with torch.no_grad():
        x = x.to(device)

        prediction: torch.Tensor = model(x)
        prediction = activate(prediction)

        return prediction


def show(prediction: torch.Tensor, x: torch.Tensor):

    if prediction.size(1) > 1:
        prediction = torch.argmax(prediction, dim=1, keepdim=True)
    else:
        prediction = prediction > 0.5
        prediction = prediction.to(torch.int32)

    xp = unnormalize(x)
    idx = (prediction > 0).expand_as(xp)

    xp[~idx] *= 0.1
    idx = idx.clone()
    idx &= torch.tensor([1, 0, 0], dtype=torch.bool,
                        device=device).view(1, 3, 1, 1)
    xp[~idx] += 0.2

    prediction = prediction.to(torch.float32)

    w = torch.tensor([[-1, 0], [0, 1]],
                     dtype=torch.float32).view(1, 1, 2, 2).to(device)
    edge = F.conv2d(prediction, w, padding="same")
    idx = (edge > 0).expand_as(xp)
    xp[idx] = 1

    w = torch.tensor([[0, -1], [1, 0]],
                     dtype=torch.float32).view(1, 1, 2, 2).to(device)
    edge = F.conv2d(prediction, w, padding="same")
    idx = (edge > 0).expand_as(xp)
    xp[idx] = 1

    return xp


align16_layer = LayerActivations(regseg_align.decoder.up16.flow_make)
align8_layer = LayerActivations(regseg_align.decoder.up8.flow_make)


def show_flow():
    flow16 = torchvision.utils.flow_to_image(align16_layer.features)
    flow8 = torchvision.utils.flow_to_image(align8_layer.features)
    return flow16, flow8


gate16_layer = LayerActivations(regseg_gate0.decoder.fusion_8_16.h_conv)
gate8_layer = LayerActivations(regseg_gate0.decoder.fusion_4_8.h_conv)


def show_gate():
    gate16 = gate16_layer.features
    gate16 = torch.sigmoid_(gate16)
    gate8 = gate8_layer.features
    gate8 = torch.sigmoid_(gate8)

    gate16 = gate16.mean(1, keepdim=True)
    gate8 = gate8.mean(1, keepdim=True)
    return gate16, gate8


SAVE_PATH = Path("../tmp")

p = SAVE_PATH/"regseg_test_pl"
compare_img_save_path = p/"compare"
gate_img_save_path = p/"gate"
flow_img_save_path = p/"flow"

for path in p, compare_img_save_path, gate_img_save_path, flow_img_save_path:
    create_dir(path)

for i, (x, y) in enumerate(persons_labeled_loader):
    x = x.to(device)
    seg_p = predict(regseg, x)
    seg_effect = show(seg_p, x)

    align_p = predict(regseg_align, x)
    align_effect = show(align_p, x)

    gate_p = predict(regseg_gate0, x)
    gate_effect = show(gate_p, x)

    gate16, gate8 = show_gate()
    gate16 = F.interpolate(gate16, x.size()[-2:]).expand_as(x)
    gate8 = F.interpolate(gate8, x.size()[-2:]).expand_as(x)

    flow16, flow8 = show_flow()
    flow16 = F.interpolate(flow16, x.size()[-2:])
    flow8 = F.interpolate(flow8, x.size()[-2:])

    x = unnormalize(x)

    for imgs, path in (
        ([x, seg_effect, align_effect, gate_effect], compare_img_save_path),
        ([x, gate8, gate16], gate_img_save_path),
        ([x, flow8, flow16], flow_img_save_path),
    ):
        torchvision.utils.save_image(
            [img.squeeze(0) for img in imgs],
            path/"{}.jpg".format(i),
            pad_value=0.5, normalize=True, scale_each=True)


Successfully created '../tmp/regseg_test_pl' 
Successfully created '../tmp/regseg_test_pl/compare' 
Successfully created '../tmp/regseg_test_pl/gate' 
Successfully created '../tmp/regseg_test_pl/flow' 


