In [None]:
import os,sys
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import time
import random
import copy
from collections import defaultdict
import helper
import simulation

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torchvision.utils
from torchsummary import summary
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional
from loss import dice_loss

In [None]:
def load_data(training_data, validation_data, batch_size = 10):
    image_datasets = {
    'train': training_data, 'val': validation_data
    }

    c = {
        'train': DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=0),
        'val': DataLoader(validation_data, batch_size=batch_size, shuffle=True, num_workers=0)
    }

    dataset_sizes = {
        x: len(image_datasets[x]) for x in image_datasets.keys()
    }
    return image_datasets, image_datasets, dataset_sizes

In [None]:
#图像转换
def reverse_transform(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)
    
    return inp

In [None]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )   

In [None]:
class UNet(nn.Module):

    def __init__(self, n_class):
        super().__init__()
                
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)        

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)
        
        self.conv_last = nn.Conv2d(64, n_class, 1)
        
        
    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   
        
        x = self.dconv_down4(x)
        
        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)
        
        x = self.dconv_up3(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv2], dim=1)       

        x = self.dconv_up2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1)   
        
        x = self.dconv_up1(x)
        
        out = self.conv_last(x)
        
        return out

In [None]:
#计算各部分loss
def compute_loss(pred, t, results, bw = 0.5):
    bce = torch.nn.functional.binary_cross_entropy_with_logits(pred, t)    
    pred = torch.sigmoid(pred)
    preds = pred.contiguous()
    targets = t.contiguous()
    sec = (preds * targets).sum(dim=2).sum(dim=2)
    dice = ((1 - ((2. * sec + 1.) / (preds.sum(dim=2).sum(dim=2) + targets.sum(dim=2).sum(dim=2) + 1.)))).mean()
    loss = bce * bw + dice * (1 - bw)
    
    results['bce'] += bce.data.cpu().numpy() * t.size(0)
    results['dice'] += dice.data.cpu().numpy() * t.size(0)
    results['loss'] += loss.data.cpu().numpy() * t.size(0)
    
    return loss

In [None]:
#打印loss信息
def print_loss(metrics, epoch_samples, phase):    
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
        
    print("{}: {}".format(phase, ", ".join(outputs)))    

In [None]:
#训练模型
def train_model(model, optimizer, scheduler, num_epochs=25, best_loss = 1e10):
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('----------------------')
        
        since = time.time()
        
        #training part
        scheduler.step()
        for param_group in optimizer.param_groups:
            print("LR", param_group['lr'])
                    
        model.train()
        metrics = defaultdict(float)
        epoch_samples = 0
        for inputs, labels in dataloaders['train']:
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                loss = compute_loss(outputs, labels, metrics)
                loss.backward()
                optimizer.step()
            epoch_samples += inputs.size(0)
        print_loss(metrics, epoch_samples, 'train')
        epoch_loss = metrics['loss'] / epoch_samples
        
        #validation part
        model.eval()
        metrics = defaultdict(float)
        epoch_samples = 0
        metrics = defaultdict(float)
        epoch_samples = 0
        for inputs, labels in dataloaders['val']:
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            with torch.set_grad_enabled(False):
                outputs = model(inputs)
                loss = compute_loss(outputs, labels, metrics)
            epoch_samples += inputs.size(0)
        print_loss(metrics, epoch_samples, 'val')
        epoch_loss = metrics['loss'] / epoch_samples
        
        if epoch_loss < best_loss:
                print("saving best model")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                
        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))
    model.load_state_dict(best_model_wts)
    return model

In [None]:
# 构建Unet网络结构
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #最好别用cpu
model = UNet(6)
model = model.to(device)
summary(model, input_size=(3, 224, 224))

In [None]:
#用我们自己的数据
image_datasets, image_datasets, dataset_sizes = load_data(training_data, validation_data)

In [None]:
#清空内存，防止溢出
torch.cuda.empty_cache()

In [None]:
#开始训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_class = 6

with torch.no_grad():
    model = UNet(num_class).to(device)

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)

model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=1)