In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset,random_split
import torchvision.io as tio
from torchvision.io import ImageReadMode
import torchvision.transforms as Transforms

import cv2
import pathlib
import train.dice_score as dice_score
import os

def imagelist(data_path : str, fmt : str):
    # 判断目录是否存在
    if not os.path.exists(data_path):
        raise FileNotFoundError('文件夹不存在')
    # 读取目录下的所有jpg图像与png图像
    image_paths = [os.path.join(data_path, f) for f in os.listdir(data_path) if f.endswith(fmt)] # '.jpg'
    # 用torchvision库读取image_paths里所有的图像
    # images = [tio.read_image(f) for f in image_paths]
    return image_paths
    # 将Tensor作为图像展示
    '''
        print(images[0].shape)
        cv2.imshow('image', images[0].permute(1,2,0).numpy())
        cv2.waitKey(0)
    '''


# 自定义数据集类DataSet
class SegDataSet(Dataset):
    def __init__(self, root_dir, transform=None, target_transform=None):
        self.img_raw = imagelist(root_dir, '.jpg')
        self.img_mask = imagelist(root_dir, '.png')
        self.transform = transform
        self.target_transform = target_transform
    
    def __len__(self):
        return len(self.img_raw)

    def __getitem__(self, idx):
        img_raw = tio.read_image(self.img_raw[idx], mode = ImageReadMode.GRAY)
        if((img_raw > 1).any()):
            # bias是float的，这里也要转成float
            img_raw = img_raw / 255.0
        if self.transform:
            #print('transform')
            img_raw = self.transform(img_raw)
        
        img_mask = tio.read_image(self.img_mask[idx], mode = ImageReadMode.GRAY)
        if self.target_transform:
            #print('target_transform')
            img_mask = self.target_transform(img_mask).squeeze(0)

        return {
            'image' : img_raw, 
            'mask' : img_mask
            }

# 全局变量
paths = {
    'train' : 'F:/model_tuning/data/viod/train',
    'test'  : 'F:/model_tuning/data/viod/test',
    'checkpoint' : 'F:/model_tuning/data/checkpoint/'
}
image_size=(256,256)
batch_size = 1

dataset = SegDataSet(paths['train'], 
        transform=Transforms.Compose([Transforms.Resize(image_size)]),
        target_transform=Transforms.Compose([Transforms.Resize(image_size)]))

        #dataset = BasicDataset(dir_img, dir_mask, img_scale)

    # 2. Split into train / validation partitions
val_percent = 0.1
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

# 3. Create data loaders
# 遇到了多线程问题，将num_workers设为0后不报错了
loader_args = dict(batch_size=batch_size, num_workers=0, pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True,**loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True,**loader_args)



In [3]:
img = next(iter(train_loader))['image']
img

tensor([[[[0.3507, 0.3519, 0.3684,  ..., 0.3960, 0.3791, 0.3663],
          [0.3809, 0.3735, 0.3942,  ..., 0.3958, 0.3838, 0.3753],
          [0.4030, 0.4267, 0.4446,  ..., 0.3962, 0.3905, 0.3835],
          ...,
          [0.4745, 0.5062, 0.5327,  ..., 0.4365, 0.4268, 0.4240],
          [0.4133, 0.4391, 0.4786,  ..., 0.4324, 0.4270, 0.4255],
          [0.4126, 0.4206, 0.4393,  ..., 0.4275, 0.4210, 0.4219]]]])

torch  遵循的是NCHW
opencv 遵循的是 HWC
所以要交换顺序permute(1，2，0)

In [4]:
import torch.nn.functional as F
import torch

# 测试 conv2D batchNorm Maxpool
input_ch = 1
output_ch = 1
input_img = next(iter(train_loader))['image']
print('input shape:',input_img.shape)
conv = nn.Conv2d(input_ch, output_ch, kernel_size=3, stride=1, padding=0)
conv_re = conv(input_img)
print('conv shape:',conv_re.shape)
batchNorm = nn.BatchNorm2d(output_ch, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
norm_re = batchNorm(conv_re)
print('norm shape:',norm_re.shape)

# 测试 transposeConv2D maxpool upSample
tran2D = nn.ConvTranspose2d(1,1,kernel_size=3, stride=1, padding=0)
tran_re = tran2D(norm_re)
print('tran shape:',tran_re.shape)

maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
max_re = maxpool(norm_re)
print('maxp shape:',max_re.shape)

upSample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
up_re = upSample(norm_re)
print('upsa shape:',up_re.shape)

# 测试 pad cat 
pad_re = F.pad(norm_re, (1,1,1,1), mode='constant', value=0)
print('padd shape:',pad_re.shape)
cat_re = torch.cat((pad_re, input_img), dim=1)
print('cat shape:',cat_re.shape)

# 测试ReLu
relu_re = F.relu(norm_re)
print('relu shape:',relu_re.shape)

# 测试flatten
flatten_re = relu_re.flatten(0,1)
print('flatten shape:',flatten_re.shape)

# 显示图像
image_show = (relu_re[0].permute(1,2,0)).detach().numpy()
cv2.imshow('img',image_show)
cv2.waitKey(0)



input shape: torch.Size([1, 1, 256, 256])
conv shape: torch.Size([1, 1, 254, 254])
norm shape: torch.Size([1, 1, 254, 254])
tran shape: torch.Size([1, 1, 256, 256])
maxp shape: torch.Size([1, 1, 127, 127])
upsa shape: torch.Size([1, 1, 508, 508])
padd shape: torch.Size([1, 1, 256, 256])
cat shape: torch.Size([1, 2, 256, 256])
relu shape: torch.Size([1, 1, 254, 254])
flatten shape: torch.Size([1, 254, 254])


-1

In [26]:
# 测试 mask
import numpy as np
mask = next(iter(train_loader))['mask']
print('mask shape: ',mask.shape)

unique_re = np.unique(mask)
print('unique:',unique_re)
if (mask != 0).any():
    mask[mask != 0] = 1

# squeeze 移除大小为1的维度，如果指定的维度大小不为1，张量尺寸不变 sqeeze(1):移除第二个维度
mask_sq = mask.squeeze(0)# 等同torch.squeeze(mask,1)
print('mask_sq shape:',mask_sq.shape)
# 判断移除张量维度是否成功
if ((mask_sq.shape == mask.shape)):
    print('Sqeeze failed')

# flatten 把start至end(包含end)维度合并到end上
flat_re = torch.flatten(mask,0,1) # (tensor,start = 0,end = -1)
print('flat_re shape:',flat_re.shape)

# float 转为float32
mask_float = mask.float()
print('mask_float dtype:',mask.dtype)

image_show = (mask.permute(1,2,0)).detach().numpy()
cv2.imshow('img',image_show)
cv2.waitKey(0)

mask

mask shape:  torch.Size([1, 256, 256])
unique: [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  27  30  31  32  34  35  36  37  38  39
  40  41  43  46  47  49  50  51  52  54  55  60  61  64  65  66  68  69
  70  72  73  74  76  79  80  84  85  87  88  89  92  96  97  99 100 106
 107 108 110 111 114 133 145 147 148 150 151 157 163 164 165 170 171 174
 179 182 183 184 191 192 200 204 207 211 215 216 218 221 222 228 230 233
 234 254]
mask shape:  torch.Size([1, 256, 256])
mask_sq shape: torch.Size([256, 256])
flat_re shape: torch.Size([256, 256])
mask_float dtype: torch.uint8


tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8)