In [8]:
import argparse
import logging
import pdb
import sys
from pathlib import Path
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import wandb

from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss
from evaluate import evaluate
from unet_line import UNet_line
import matplotlib.pyplot as plt
import shutil
import os

In [9]:
dir_low = Path('./data/data_double_cart_64/real/low/')
dir_high = Path('./data/data_double_cart_64/real/high/')
dir_checkpoint = Path('./checkpoint/real/')
final_model_path = Path('./')

In [10]:
def train_net(net,
              device,
              epochs: int = 5,
              batch_size: int = 1,
              learning_rate: float = 1e-5,
              val_percent: float = 0.2,
              save_checkpoint: bool = True,
              low_scale: float = 1.0,
              amp: bool = False):
    # 1. Create dataset
    try:
        dataset = CarvanaDataset(dir_low, dir_high, low_scale)
    except (AssertionError, RuntimeError):
        dataset = BasicDataset(dir_low, dir_high, low_scale)
    # 2. Split into train / validation partitions
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

    # (Initialize logging)
    experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
                                  val_percent=val_percent, save_checkpoint=save_checkpoint, low_scale=low_scale,
                                  amp=amp))

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {low_scale}
        Mixed Precision: {amp}
    ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.L1Loss()
    global_step = 0
    
    g_loss = np.zeros(epochs)
    p_loss = np.zeros(epochs)
    v_loss = np.zeros(epochs)
    # 5. Begin training
    div = epochs/10
    for epoch in range(1, epochs+1):
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                lows = batch['low']
                true_highs = batch['high']
                idx = batch['idx']
                
                assert lows.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                lows = lows.to(device=device, dtype=torch.float32)
                true_highs = true_highs.to(device=device, dtype=torch.long)

                with torch.cuda.amp.autocast(enabled=amp):
                    highs_pred = net(lows)
                    pred = highs_pred.detach().cpu().numpy()[0,0,:,:]
                    idx = idx[0].split('/')[4].split('.')[0]
                    if epoch % 10 == 0: 
                        np.save('./data/data_double_cart_64/real/result/%s.npy'%idx, pred)
                    loss = criterion(highs_pred, true_highs)

                optimizer.zero_grad()
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()
                pbar.update(lows.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log({
                    'train loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                
# validation part          
        with tqdm(total=n_val, desc=f'Test {epoch}/{epochs}', unit='img') as pbar:
            for batch in val_loader:
                lows = batch['low']
                true_highs = batch['high']
                idx = batch['idx']
            
                assert lows.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                lows = lows.to(device=device, dtype=torch.float32)
                true_highs = true_highs.to(device=device, dtype=torch.long)

                with torch.cuda.amp.autocast(enabled=amp):
                    highs_pred = net(lows)
                    pred = highs_pred.detach().cpu().numpy()[0,0,:,:]
                    val_loss = criterion(highs_pred, true_highs)

                pbar.update(lows.shape[0])
                
                experiment.log({
                    'validation loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                pbar.set_postfix(**{'loss (batch)': val_loss.item()})

        experiment.log({
            'epoch loss': epoch_loss
        })
        g_loss[epoch-1] = epoch_loss
        p_loss[epoch-1] = loss.item()        
        v_loss[epoch-1] = val_loss.item()
        np.save('./loss/p_loss_real.npy',p_loss)
        np.save('./loss/v_loss_real.npy',v_loss)
        if epoch == epochs:
            logging.info('Final model is saved!')
            plt.figure(figsize=(20,10))
            plt.plot(p_loss)
            plt.figure(figsize=(20,10))
            plt.plot(v_loss)
        elif save_checkpoint:
            if epoch % 5 == 0:
                torch.save(net.state_dict(), str(final_model_path / 'MODEL_edge_real.pth'))
                logging.info('P model is saved!')
            if epoch % 50 == 0:
                Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
                torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch//50)))
                logging.info(f'Checkpoint {epoch} saved!')

In [11]:
looping_epochs = 300

def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=looping_epochs, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--scale', '-s', type=float, default=1.0, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=1.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
    parser.add_argument('--classes', '-c', type=int, default=1, help='Number of classes')

    return parser.parse_args()

In [None]:
if __name__ == '__main__':
    args = get_args()

    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')
    net = UNet_line(n_channels=1, n_classes=args.classes, bilinear=args.bilinear) #args.classes, bilinear=args.bilinear)

    logging.info(f'Network:\n'
                  f'\t{net.n_channels} input channels\n'
                  f'\t{net.n_classes} output channels (classes)\n'
                  f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
    
    args.load = 'MODEL_edge_real.pth'
    # args.load = ''
    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')
    net.to(device=device)
    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batch_size,
                  learning_rate=args.lr,
                  device=device,
                  low_scale=args.scale,
                  val_percent=args.val / 100,
                  amp=args.amp)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        raise

INFO: Using device cuda
INFO: Network:
	1 input channels
	1 output channels (classes)
	Transposed conv upscaling
INFO: Model loaded from MODEL_edge_real.pth
INFO: Creating dataset with 3000 examples


INFO: Starting training:
        Epochs:          300
        Batch size:      1
        Learning rate:   1e-05
        Training size:   2970
        Validation size: 30
        Checkpoints:     True
        Device:          cuda
        Images scaling:  1.0
        Mixed Precision: False
    
Epoch 1/300: 100%|██████████| 2970/2970 [01:49<00:00, 27.14img/s, loss (batch)=3.86]
Test 1/300: 100%|██████████| 30/30 [00:00<00:00, 42.34img/s, loss (batch)=2.92]
Epoch 2/300: 100%|██████████| 2970/2970 [01:44<00:00, 28.48img/s, loss (batch)=3.07]
Test 2/300: 100%|██████████| 30/30 [00:00<00:00, 46.77img/s, loss (batch)=2.85]
Epoch 3/300: 100%|██████████| 2970/2970 [01:43<00:00, 28.78img/s, loss (batch)=2.38]
Test 3/300: 100%|██████████| 30/30 [00:00<00:00, 49.76img/s, loss (batch)=2.86]
Epoch 4/300: 100%|██████████| 2970/2970 [01:44<00:00, 28.41img/s, loss (batch)=3.1] 
Test 4/300: 100%|██████████| 30/30 [00:00<00:00, 46.99img/s, loss (batch)=2.82]
Epoch 5/300: 100%|██████████| 2970/2970 [01:4