In [None]:
# 自定义数据集, 加载数据，定义标签，以及进行数据增强
from random import shuffle
import torch


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        # TODO
        # 1. Initialize file paths or a list of file names. 
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0

custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
                                           batch_size=64,
                                           shuffle=True)

In [23]:
from model.unet_model import UNet
from utils.dataset import ISBI_Loader
from torch import optim
import torch.nn as nn
import torch

def train_net(net, device, data_path, epochs = 40, batch_size=1, lr = 0.00001):
    # 加载训练集
    isbi_dataset = ISBI_Loader(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    # 定义损失函数RMSprop
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    # 定义Loss算法
    criterion = nn.BCEWithLogitsLoss() 
    # best_loss 统计，初始化为正无穷
    best_loss = float('inf')
    # 训练epochs次
    for epoch in range(epochs):
        net.train()
        # 按照batch_size进行训练
        for image, label in train_loader:
            optimizer.zero_grad()
            # 数据拷贝到GPU
            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)
            # 使用网络参数预测
            pred = net(image)
            # 计算loss
            loss = criterion(pred, label)
            print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, epochs, loss.item()))
            # 保存loss最小的模型
            if loss < best_loss:
                best_loss = loss
                torch.save(net.state_dict(), 'best_model.pth')
            # 更新参数
            loss.backward() # TODO 反向传播
            optimizer.step()


In [24]:
# 选择设备，有cuda用cuda，没有就用cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载网络，图片单通道1，分类为1。
net = UNet(n_channels=1, n_classes=1)
# 将网络拷贝到deivce中
net.to(device=device)
# 指定训练集地址，开始训练
data_path = "data/train/"
train_net(net, device, data_path)

Epoch [1/40], Loss: 0.7404
Epoch [1/40], Loss: 0.6831
Epoch [1/40], Loss: 0.6405
Epoch [1/40], Loss: 0.6519
Epoch [1/40], Loss: 0.6894
Epoch [1/40], Loss: 0.6693
Epoch [1/40], Loss: 0.6686
Epoch [1/40], Loss: 0.6812
Epoch [1/40], Loss: 0.6484
Epoch [1/40], Loss: 0.6454
Epoch [1/40], Loss: 0.6407
Epoch [1/40], Loss: 0.6469
Epoch [1/40], Loss: 0.6386
Epoch [1/40], Loss: 0.6783
Epoch [1/40], Loss: 0.6384
Epoch [1/40], Loss: 0.6438
Epoch [1/40], Loss: 0.6289
Epoch [1/40], Loss: 0.6367
Epoch [1/40], Loss: 0.6294
Epoch [1/40], Loss: 0.6322
Epoch [1/40], Loss: 0.6326
Epoch [1/40], Loss: 0.6388
Epoch [1/40], Loss: 0.6431
Epoch [1/40], Loss: 0.6282
Epoch [1/40], Loss: 0.6338
Epoch [1/40], Loss: 0.6398
Epoch [1/40], Loss: 0.6417
Epoch [1/40], Loss: 0.6301
Epoch [1/40], Loss: 0.6318
Epoch [1/40], Loss: 0.6136
Epoch [2/40], Loss: 0.6423
Epoch [2/40], Loss: 0.6303
Epoch [2/40], Loss: 0.6328
Epoch [2/40], Loss: 0.6343
Epoch [2/40], Loss: 0.6306
Epoch [2/40], Loss: 0.6382
Epoch [2/40], Loss: 0.6422
E

KeyboardInterrupt: 

In [17]:
import glob
import numpy as np
import torch
import os
import cv2
from model.unet_model import UNet

# 选择设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载网络,图片单通道1，分类为1。
net = UNet(n_channels=1, n_classes=1)
# 载入到device中
net.to(device=device)
# 加载参数
net.load_state_dict(torch.load('best_model.pth', map_location=device))
# 测试
net.eval()
# 读取图片路径
tests_path = glob.glob("data_bak/test/*.png")
# 遍历图片
for test_path in tests_path:
    # 保存结果图片
    save_res_path = test_path.split(".")[0] + "_res.png"
    # 读取
    img = cv2.imread(test_path)
    # 转灰度
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    # 转为batch为1,通道为1的大小为512*512的数组
    img = img.reshape(1,1,img.shape[0], img.shape[1])
    # 转tensor
    img_tensor = torch.from_numpy(img)
    # 转到device中
    img_tensor = img_tensor.to(device=device, dtype=torch.float32)
    # 预测
    pred = net(img_tensor)
    # 提取结果
    pred = np.array(pred.data.cpu()[0])[0]
    # 处理结果
    pred[pred >= 0.5] = 255
    pred[pred < 0.5] = 0
    # 保存图片
    cv2.imwrite(save_res_path, pred)
