## 说明
https://www.kaggle.com/balraj98/stanford-background-dataset

## 相关依赖

In [8]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision import datasets, transforms
from PIL import Image
import matplotlib.image as IMG
import shutil


## 数据集划分


In [9]:
#   分割训练集和测试集
files = os.listdir('../cropped_dataset/images/')
img_and_lables = []
for i in files:
    img_and_lables.append(('../cropped_dataset/images/'+i,'../cropped_dataset/labels/'+i.split('.')[0]+'.txt'))

train_set, test_set = random_split(
    dataset=img_and_lables,
    lengths=[650,65],
    generator=torch.Generator().manual_seed(0)
)

## Dataset & Dataloader

In [16]:
#   dataset
class DataSet(Dataset):
    def __init__(self,dataSet):
        self.dataset = dataSet
       
    def __getitem__(self, index):
        pic = IMG.imread(self.dataset[index][0])
        with open(self.dataset[index][1],"r") as f:    #设置文件对象
            label_str = f.read()    #可以是随便对文件的操作
        
        pic = torch.from_numpy(pic).transpose(0,2).transpose(1,2).float()/255
        
        return pic,label_str

    def __len__(self):
        return len(self.dataset)



## 模型搭建


In [12]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.overlapTile = nn.ReflectionPad2d(30)

        self.conv1_1 = nn.Conv2d(3,64,3,1)
        self.conv1_2 = nn.Conv2d(64,64,3,1)
        self.pool1 = nn.MaxPool2d(2,2)

        self.conv2_1 = nn.Conv2d(64,128,3,1)
        self.conv2_2 = nn.Conv2d(128,128,3,1)
        self.pool2 = nn.MaxPool2d(2,2)

        self.conv3_1 = nn.Conv2d(128,256,3,1)
        self.conv3_2 = nn.Conv2d(256,256,3,1)
        self.pool3 = nn.MaxPool2d(2,2)

        self.conv4_1 = nn.Conv2d(256,512,3,1)
        self.conv4_2 = nn.Conv2d(512,512,3,1)
        self.pool4 = nn.MaxPool2d(2,2)

        self.conv5_1 = nn.Conv2d(512,1024,3,1)
        self.conv5_2 = nn.Conv2d(1024,512,3,1)

        self.upconv1 = nn.ConvTranspose2d(512,512,2,2)

        self.conv6_1 = nn.Conv2d(1024,512,3,1)
        self.conv6_2 = nn.Conv2d(512,256,3,1)

        self.upconv2 = nn.ConvTranspose2d(256,256,2,2)

        self.conv7_1 = nn.Conv2d(512,256,3,1)
        self.conv7_2 = nn.Conv2d(256,128,3,1)

        self.upconv3 = nn.ConvTranspose2d(128,128,2,2)

        self.conv8_1 = nn.Conv2d(256,128,3,1)
        self.conv8_2 = nn.Conv2d(128,64,3,1)

        self.upconv4 = nn.ConvTranspose2d(64,64,2,2)

        self.conv9_1 = nn.Conv2d(128,64,3,1)
        self.conv9_2 = nn.Conv2d(64,64,3,1)
        self.conv9_3 = nn.Conv2d(64,9,1,1)

    def forward(self,x):
        x = self.overlapTile(x) #略

        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        Intermediate1 = x[:,:,88:480,88:480] #  cropping to copy
        x = self.pool1(x)

        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        Intermediate2 = x[:,:,40:240,40:240]#  cropping to copy
        x = self.pool2(x)

        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))
        Intermediate3 = x[:,:,16:120,16:120]#  cropping to copy
        x = self.pool3(x)

        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x))
        Intermediate4 = x[:,:,4:60,4:60]#  cropping to copy
        x = self.pool4(x)

        x = F.relu(self.conv5_1(x))
        x = F.relu(self.conv5_2(x))

        x = self.upconv1(x)
        
        x = F.relu(self.conv6_1(torch.cat((Intermediate4,x),1)))
        x = F.relu(self.conv6_2(x))

        x = self.upconv2(x)

        x = F.relu(self.conv7_1(torch.cat((Intermediate3,x),1)))
        x = F.relu(self.conv7_2(x))

        x = self.upconv3(x)

        x = F.relu(self.conv8_1(torch.cat((Intermediate2,x),1)))
        x = F.relu(self.conv8_2(x))

        x = self.upconv4(x)

        x = F.relu(self.conv9_1(torch.cat((Intermediate1,x),1)))
        x = F.relu(self.conv9_2(x))
        x = self.conv9_3(x)

        return x

        

## 训练与评估

In [None]:
#   参数设置
EPOCH_NUM = 5
BATCH_SIZE = 1
NUM_WORKERS = 0
LEARNING_RATE = 0.001

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#   initialization for dataset,DataLoader and optimazer
net = UNet().to(device)

trainingSet = DataSet(train_set)
testSet = DataSet(test_set)

trainingSet_Loader = DataLoader(trainingSet,
                                batch_size = BATCH_SIZE,
                                num_workers = NUM_WORKERS,
                                shuffle=True)

testSet_Loader = DataLoader(testSet,
                            batch_size = 1,
                            num_workers = NUM_WORKERS,
                            shuffle=True)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)


if __name__ == '__main__':

    #   training phase
    for epoch in range(EPOCH_NUM):
        print("trainin on Epoch:[{}/{}]".format(epoch+1,EPOCH_NUM))

        for uselessnum,data in enumerate(trainingSet_Loader,0):
            pic,label = data
            pic = pic.to(device)
            loss = 0    #   The sum of losses
            output = net(pic).transpose(0,2).transpose(1,3).view(388 ** 2,BATCH_SIZE,9)
            optimizer.zero_grad()

            for index in range(388 ** 2):
                target = torch.Tensor([int(i[index]) for i in label]).long().to(device)
                loss += criterion(output[index],target)

            print("Optimizing!!")
            loss.backward()
            optimizer.step()
            print("{} pic done!! totalloss:{}".format(uselessnum+1,loss))

        torch.save(net, "C:/Users/29147/source/repos/U-Netunet_epoch"+str(epoch)+".pth")
    
        #   testing phase
        for uselessnum,data in enumerate(testSet_Loader,0):
            pic,label = data
            pic = pic.to(device)
            output = net(pic.detach()).transpose(0,2).transpose(1,3).view(388 ** 2,1,9)

            correct_count = 0.
            for index in range(388 ** 2):
                target = int(label[0][index])
                if torch.argmax(output[index][0]) == target:
                    correct_count += 1

            percision = correct_count/388 ** 2
            print("testset acc:{}".format(percision))
        