In [None]:
"""
main.py
"""
import cv2
import os
import glob
import random

import torch
from torch.utils.data import Dataset
from torch import optim
import torch.nn as nn
import numpy as np

import sys
sys.path.append('/content/drive/My Drive/Colab Notebooks/unet_lane_net')
sys.path.append('/content/drive/My Drive/Colab Notebooks/unet_lane_net/')
import model
import data
from model.unet_model import UNet



class data_loader(Dataset):
    def __init__(self, data_path):
        # 初始化函数，读取所有data_path下的图片
        self.data_path = data_path
        self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))
        
    def __getitem__(self, index):
        # 根据index读取图片
        image_path = self.imgs_path[index]
        label_path = image_path.replace('image', 'label')
        image = cv2.imread(image_path)
        label = cv2.imread(label_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
        image = image.reshape(1, image.shape[0], image.shape[1])
        label = label.reshape(1, label.shape[0], label.shape[1])
        # 处理标签，将像素值为255的改为1
        label = label / 255
        return image, label

    def __len__(self):
        # 返回训练集大小
        return len(self.imgs_path)

    
def train_net(net, device, data_path, epochs=400, batch_size=1, lr=1e-5):
    
    # 加载训练集
    dataset = data_loader(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=dataset , batch_size=batch_size, shuffle=True)
    
    # 优化器和损失函数
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    criterion = nn.BCEWithLogitsLoss()
    
    for epoch in range(epochs):
        # 训练模式
        net.train()
        # 按照batch_size开始训练
        for image, label in train_loader:
            optimizer.zero_grad()
            # 将数据拷贝到device中
            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)
            # 使用网络参数，输出预测结果
            pred = net(image)
            # 计算loss
            loss = criterion(pred, label)
            print("no.",epoch,' Loss/train = ', loss.item())
            
            #保存网络，无需保存整个网络，只需保存参数
            torch.save(net.state_dict(), 'best_model.pth')
            # 更新参数
            loss.backward()
            optimizer.step()

def train():
    # 选择设备，有cuda用cuda，没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 加载网络，图片单通道1，分类为1。
    print(device)
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 指定训练集地址，开始训练
    data_path = "/content/drive/MyDrive/Colab Notebooks/unet_lane_net/data/train/"
    epochs = 400
    batch_size = 1
    lr = 1e-5
    train_net(net, device, data_path,epochs,batch_size,lr)
def test():
    # 选择设备，有cuda用cuda，没有就用cpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    # 加载网络，图片单通道，分类为1。
    net = UNet(n_channels=1, n_classes=1)
    # 将网络拷贝到deivce中
    net.to(device=device)
    # 加载模型参数
    net.load_state_dict(torch.load('/content/best_model.pth', map_location=device))
    # 测试模式
    net.eval()
    # 读取所有图片路径
    tests_path = glob.glob('/content/drive/MyDrive/Colab Notebooks/unet_lane_net/data/test/image/*.png')
   
    test_num = 0
    for test_path  in tests_path :
        
        save_res_path = test_path.split('.')[0] + '_res.png'
        img = cv2.imread(test_path)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        img = img.reshape(1, 1, img.shape[0], img.shape[1])
        img_tensor = torch.from_numpy(img)
        img_tensor = img_tensor.to(device=device, dtype=torch.float32)
        
        # 预测
        pred = net(img_tensor)
        pred = np.array(pred.data.cpu()[0])[0]
        pred[pred >= 0.5] = 255
        pred[pred < 0.5] = 0
        result = cv2.imread(test_path)
        i = 0
        j = 0
        k = 0
        while i < result.shape[0] :
          j = 0
          while j < result.shape[1] :
            if (pred[i,j] == 255) :
              result[i,j] = [0,0,255]
            j+=1
          i+=1
        
        # 保存图片和mask
        cv2.imwrite(save_res_path, result)
        save_res_path2 = test_path.split('.')[0] + '_mask.png'
        cv2.imwrite(save_res_path2, pred)
        test_num += 1
        print(test_num)

if __name__ == "__main__":
    train()
    test()


