In [21]:
# 数据处理库
import numpy as np
import torch
# 模型代码
from model import spsnet
# 数据集代码
from datasets import WHUBuilding
from torch.utils.data import DataLoader
# 信息输出代码
from tqdm import tqdm
# 图片处理代码
from PIL import Image
import imgviz
# 其他
from pathlib import Path
from torch.nn import functional as F


In [2]:
# 加载模型
model = spsnet.__dict__["SPSNet"](in_channel=3, out_channel=1, spc=3).cuda()
# 加载预训练参数
checkpoint = torch.load(r'checkpoints/sps-unet-whu/best_model.pt')
model.load_state_dict(checkpoint["state_dict"], strict=False)
del checkpoint

# 加载数据集
dataset_val = WHUBuilding(root=Path('/mnt/d/dataset/WHU/'), mode="test")
val_dataloader = DataLoader(dataset_val, batch_size=2, num_workers=2, shuffle=False)

In [3]:
def save_colored_mask(mask, save_path):
    mask = Image.fromarray(mask.astype(np.uint8), mode="P")
    colormap = imgviz.label_colormap()
    mask.putpalette(colormap.flatten())
    mask.save(save_path)

In [4]:
ori_dir = Path('/mnt/d/dataset/WHU/test/ori_label')
label_list = list(ori_dir.glob('*.tif'))

target_dir = Path('/mnt/d/dataset/WHU/test/label')
for lab in tqdm(label_list):
    data = np.array(Image.open(lab))
    new_data = data.copy()
    new_data[new_data == 255] = 1
    save_colored_mask(new_data, target_dir.joinpath(lab.stem + '.png'))

100%|██████████| 3848/3848 [01:50<00:00, 34.93it/s]


In [10]:
from torch import nn
net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=(3, 3), padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU6(inplace=True),
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU6(inplace=True),
            nn.Conv2d(64, 9, kernel_size=(3, 3), padding=1, bias=True),
            nn.Softmax(dim=1),
        )

In [11]:
net

Sequential(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU6(inplace=True)
  (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU6(inplace=True)
  (6): Conv2d(64, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): Softmax(dim=1)
)

In [12]:
from thop import profile


input = torch.randn(1, 3, 256, 256) 
Flops, params = profile(net, inputs=(input,)) # macs
print('Flops: % .4fG'%(Flops / 1000000000)) # 计算量
print('params参数量: % .4fM'% (params / 1000000))   #参数量：等价与上面的summary输出的Total params值

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU6'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Flops:  1.6312G
params参数量:  0.0247M


In [48]:
t = torch.randn(4, 9, 256, 384)

In [None]:
t = t.softmax(dim=1).argmax(dim=1)
t = t.permute(1, 2, 0)
t = F.one_hot(t, 9)
t = t.permute(2, 3, 0, 1)

In [15]:
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm import tqdm

In [9]:
t = torch.randn(1, 3, 4, 4)
mask = torch.randint(0, 3, (1, 4, 4))

func = torch.nn.CrossEntropyLoss()



In [10]:
loss = func(t, mask)

In [14]:
mask_list = list(Path('/home/eveleaf/MyDataset/yqdataset/deepglobe/train/mask').glob('*.png'))

In [27]:
lab = 0
for mask in tqdm(mask_list):
    lab = max(np.array(Image.open(mask)).max(), lab)

100%|██████████| 22325/22325 [01:21<00:00, 274.90it/s]


In [28]:
lab


9

In [25]:
t = data.max()

In [26]:
t

9