In [1]:
import argparse
import logging
import pdb
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss
from evaluate import evaluate
from unet_linear import UNet_line
import matplotlib.pyplot as plt

In [3]:
dir_low = Path('./data_polar_comb/low/')
dir_high = Path('./data_polar_comb/high/')
dir_checkpoint = Path('./checkpoints_edge/')
final_model_path = Path('./')

In [4]:
def train_net(net,
              device,
              epochs: int = 5,
              batch_size: int = 1,
              learning_rate: float = 1e-5,
              val_percent: float = 0.2,
              save_checkpoint: bool = True,
              low_scale: float = 1.0,
              amp: bool = False):
    # 1. Create dataset
    try:
        dataset = CarvanaDataset(dir_low, dir_high, low_scale)
    except (AssertionError, RuntimeError):
        dataset = BasicDataset(dir_low, dir_high, low_scale)
    # 2. Split into train / validation partitions
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

    # (Initialize logging)
    experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
                                  val_percent=val_percent, save_checkpoint=save_checkpoint, low_scale=low_scale,
                                  amp=amp))

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {low_scale}
        Mixed Precision: {amp}
    ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.L1Loss()
    global_step = 0
    
    g_loss = np.zeros(epochs)
    p_loss = np.zeros(epochs)
    v_loss = np.zeros(epochs)
    # 5. Begin training
    div = epochs/10
    for epoch in range(1, epochs+1):
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                lows = batch['low']
                true_highs = batch['high']
                idx = batch['idx']
                
                assert lows.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {lows.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                lows = lows.to(device=device, dtype=torch.float32)
                true_highs = true_highs.to(device=device, dtype=torch.long)

                with torch.cuda.amp.autocast(enabled=amp):
                    highs_pred = net(lows)
                    pred = highs_pred.detach().cpu().numpy()[0,:,:,:]
                    pred = pred.transpose((1,2,0))
                    idx = idx[0].split('/')[2].split('.')[0]
                    if epoch % 10 == 0:
                        np.save('./data_polar_comb/result/%s.npy'%idx, pred)
                    loss = criterion(highs_pred, true_highs)
                optimizer.zero_grad()
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()
                   
                pbar.update(lows.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log({
                    'train loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                pbar.set_postfix(**{'loss (batch)': loss.item()})
        experiment.log({
            'epoch loss': epoch_loss
        })

                
        with tqdm(total=n_val, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in val_loader:
                lows = batch['low']
                true_highs = batch['high']
                idx = batch['idx']
                lows = lows.to(device=device, dtype=torch.float32)
                true_highs = true_highs.to(device=device, dtype=torch.long)

                with torch.cuda.amp.autocast(enabled=amp):
                    highs_pred = net(lows)
                    pred = highs_pred.detach().cpu().numpy()[0,:,:,:]
                    pred = pred.transpose((1,2,0))
                    v_l = criterion(highs_pred, true_highs)
                pbar.update(lows.shape[0])
                experiment.log({
                    'validation loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                pbar.set_postfix(**{'loss (batch)': v_l.item()})

        experiment.log({
            'epoch loss': epoch_loss
        })
            
        v_loss[epoch-1] = v_l.item()        
        g_loss[epoch-1] = epoch_loss
        p_loss[epoch-1] = loss.item()
        np.save('./loss/p_loss_polar.npy',p_loss)
        np.save('./loss/v_loss_polar.npy',v_loss)
        
        if epoch == epochs:
            torch.save(net.state_dict(), str(final_model_path / 'MODEL_polar.pth'))
            logging.info('Final model is saved!')
            plt.figure(figsize=(20,10))
            plt.plot(p_loss)
            plt.figure(figsize=(20,10))
            plt.plot(v_loss)
        elif save_checkpoint:
            if epoch % 5 == 0:
                torch.save(net.state_dict(), str(final_model_path / 'MODEL_polar.pth'))
                logging.info('P model is saved!')

In [5]:
looping_epochs = 500

def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=looping_epochs, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--scale', '-s', type=float, default=1.0, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')

    return parser.parse_args()
    

初始化:清除所有已存在的文件数据

In [6]:
import shutil
import os
# shutil.rmtree('./Model_edge')
# os.mkdir('./Model_edge')
# shutil.rmtree('./data_edge/result')
# os.mkdir('./data_edge/result')

In [7]:
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net = UNet_line(n_channels=2, n_classes=args.classes, bilinear=args.bilinear) #args.classes, bilinear=args.bilinear)
logging.info(f'Network:\n'
              f'\t{net.n_channels} input channels\n'
              f'\t{net.n_classes} output channels (classes)\n'
              f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
# args.load = './Completed Model/MODEL_cart.pth'
args.load = ''
if args.load:
    net.load_state_dict(torch.load(args.load, map_location=device))
    logging.info(f'Model loaded from {args.load}')
net.to(device=device)
try:
    train_net(net=net,
              epochs=args.epochs,
              batch_size=args.batch_size,
              learning_rate=args.lr,
              device=device,
              low_scale=args.scale,
              val_percent=args.val / 100,
              amp=args.amp)
except KeyboardInterrupt:
    torch.save(net.state_dict(), 'INTERRUPTED.pth')
    logging.info('Saved interrupt')
    raise

INFO: Using device cuda
INFO: Network:
	2 input channels
	2 output channels (classes)
	Transposed conv upscaling
INFO: Creating dataset with 2000 examples
ERROR: Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: anony-mouse-286188. Use `wandb login --relogin` to force relogin


INFO: Starting training:
        Epochs:          500
        Batch size:      1
        Learning rate:   1e-05
        Training size:   1800
        Validation size: 200
        Checkpoints:     True
        Device:          cuda
        Images scaling:  1.0
        Mixed Precision: False
    
Epoch 1/500: 100%|██████████| 1800/1800 [01:02<00:00, 28.89img/s, loss (batch)=29.9]
Epoch 1/500: 100%|██████████| 200/200 [00:03<00:00, 62.54img/s, loss (batch)=16.7]
Epoch 2/500: 100%|██████████| 1800/1800 [00:57<00:00, 31.24img/s, loss (batch)=25.8]
Epoch 2/500: 100%|██████████| 200/200 [00:03<00:00, 64.09img/s, loss (batch)=15.4]
Epoch 3/500: 100%|██████████| 1800/1800 [00:57<00:00, 31.16img/s, loss (batch)=23.2]
Epoch 3/500: 100%|██████████| 200/200 [00:03<00:00, 62.97img/s, loss (batch)=14.9]
Epoch 4/500: 100%|██████████| 1800/1800 [00:57<00:00, 31.22img/s, loss (batch)=22.7]
Epoch 4/500: 100%|██████████| 200/200 [00:03<00:00, 64.68img/s, loss (batch)=15.2]
Epoch 5/500: 100%|██████████| 18

Epoch 46/500: 100%|██████████| 1800/1800 [00:58<00:00, 30.79img/s, loss (batch)=20.6]
Epoch 46/500: 100%|██████████| 200/200 [00:03<00:00, 65.26img/s, loss (batch)=14.6]
Epoch 47/500: 100%|██████████| 1800/1800 [00:58<00:00, 30.99img/s, loss (batch)=14.3]
Epoch 47/500: 100%|██████████| 200/200 [00:03<00:00, 63.81img/s, loss (batch)=16]  
Epoch 48/500: 100%|██████████| 1800/1800 [00:58<00:00, 30.64img/s, loss (batch)=16.5]
Epoch 48/500: 100%|██████████| 200/200 [00:03<00:00, 62.72img/s, loss (batch)=14.7]
Epoch 49/500: 100%|██████████| 1800/1800 [00:58<00:00, 30.82img/s, loss (batch)=24.3]
Epoch 49/500: 100%|██████████| 200/200 [00:03<00:00, 63.25img/s, loss (batch)=15.1]
Epoch 50/500: 100%|██████████| 1800/1800 [01:08<00:00, 26.32img/s, loss (batch)=18.7]
Epoch 50/500: 100%|██████████| 200/200 [00:03<00:00, 65.53img/s, loss (batch)=15.4]
INFO: P model is saved!
Epoch 51/500: 100%|██████████| 1800/1800 [00:58<00:00, 30.83img/s, loss (batch)=13.5]
Epoch 51/500: 100%|██████████| 200/200 [

Epoch 92/500: 100%|██████████| 200/200 [00:03<00:00, 66.02img/s, loss (batch)=15.3]
Epoch 93/500: 100%|██████████| 1800/1800 [00:57<00:00, 31.22img/s, loss (batch)=20]  
Epoch 93/500: 100%|██████████| 200/200 [00:03<00:00, 63.98img/s, loss (batch)=15.4]
Epoch 94/500: 100%|██████████| 1800/1800 [00:58<00:00, 30.88img/s, loss (batch)=30]  
Epoch 94/500: 100%|██████████| 200/200 [00:03<00:00, 62.25img/s, loss (batch)=14.7]
Epoch 95/500: 100%|██████████| 1800/1800 [00:58<00:00, 30.78img/s, loss (batch)=13.4]
Epoch 95/500: 100%|██████████| 200/200 [00:03<00:00, 62.85img/s, loss (batch)=14.2]
INFO: P model is saved!
Epoch 96/500: 100%|██████████| 1800/1800 [00:58<00:00, 30.76img/s, loss (batch)=24.2]
Epoch 96/500: 100%|██████████| 200/200 [00:03<00:00, 62.86img/s, loss (batch)=14.7]
Epoch 97/500: 100%|██████████| 1800/1800 [00:58<00:00, 30.94img/s, loss (batch)=23.1]
Epoch 97/500: 100%|██████████| 200/200 [00:03<00:00, 61.84img/s, loss (batch)=14.4]
Epoch 98/500: 100%|██████████| 1800/1800 [

KeyboardInterrupt: 