In [1]:
import argparse
import logging
import os
import sys

import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm

from eval import eval_net
from unet import UNet

from torch.utils.tensorboard import SummaryWriter
from torch.nn import functional as F
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split

from IPython.display import clear_output

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [3]:
dir_train_img = '/home/natasha/unet4/axial_data/train/Fifth_3/pictures/'
dir_train_mask = '/home/natasha/unet4/axial_data/train/Fifth_3/masks/'

dir_val_img = '/home/natasha/unet4/axial_data/val/Fifth_2/pictures/'
dir_val_mask = '/home/natasha/unet4/axial_data/val/Fifth_2/masks/'

dir_checkpoint = 'ckpts_dir/axial_ckpts/'

try:
# Create target Directory
    os.mkdir(dir_checkpoint)
    print("Directory " , dir_checkpoint , " Created ") 
except FileExistsError:
    print("Directory " , dir_checkpoint , " already exists")

Directory  ckpts_dir/axial_ckpts/  already exists


In [4]:
img_scale = 1

In [5]:
def train_net(net,
              device,
              epochs=5,
              batch_size=1,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              epoch_bias=0):

    train_dataset = BasicDataset(dir_train_img, dir_train_mask)
    val_dataset = BasicDataset(dir_val_img, dir_val_mask)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)

    writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0
    val_score_max = 0
    n_train = len(train_dataset)
    n_val = len(val_dataset)

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 2 else 'max', patience=2)
    if net.n_classes > 2:
        print("Using CrossEntropyLoss")
        criterion = nn.CrossEntropyLoss()
#         print("Using NLL loss")
#         criterion = nn.NLLLoss()
    else:
        print("Using BCEWithLogitsLoss")
        criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes <= 10 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                
#                 print(f'masks_pred.size() = {masks_pred.size()}')
#                 print(f'F.log_softmax(masks_pred, 1).size() = {F.log_softmax(masks_pred, 1).size()}')
#                 print(f'true_masks.cpu().numpy() == {true_masks.cpu().numpy()}')
#                 print(f'true_masks.size() = {true_masks.size()}')
#                 print(f'mask_from_onehot(true_masks) = {mask_from_onehot(true_masks)}')
#                 assert 1 == 2
                loss = criterion(masks_pred, true_masks.type_as(masks_pred))
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % ((n_train + n_val) // (2 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
                        writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)

                    val_score = eval_net(net, val_loader, device)
                    scheduler.step(val_score)
                    writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

                    if net.n_classes > 2:
                        logging.info('Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)
                    else:
                        logging.info('Validation Dice Coeff: {}'.format(val_score))
                        writer.add_scalar('Dice/test', val_score, global_step)

                    writer.add_images('images', imgs, global_step)
                    if net.n_classes == 1:
                        writer.add_images('masks/true', true_masks, global_step)
                        writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            
            with torch.no_grad():
                val_score = eval_net(net, val_loader, device)
                
            if val_score_max < val_score:
                val_score_max = val_score
                torch.save(net.state_dict(),
                           dir_checkpoint + f'best_epoch_{epoch + epoch_bias + 1}.pth')
                logging.info(f'Checkpoint {epoch + epoch_bias + 1} saved !')
                logging.info(f'Current max val_score = {val_score_max}')
            
#             if (epoch + epoch_bias + 1) % 50 == 0:
#                 torch.save(net.state_dict(),
#                            dir_checkpoint + f'CP_epoch{epoch + epoch_bias + 1}.pth')
#                 logging.info(f'Checkpoint {epoch + epoch_bias + 1} saved !')

    writer.close()

In [6]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')

INFO: Using device cuda


In [7]:
net = UNet(n_channels=1, n_classes=7, bilinear=True)
logging.info(f'Network:\n'
             f'\t{net.n_channels} input channels\n'
             f'\t{net.n_classes} output channels (classes)\n'
             f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

INFO: Network:
	1 input channels
	7 output channels (classes)
	Bilinear upscaling


In [9]:
epochs = 3
batch_size = 10

# load = dir_checkpoint+'CP_epoch400.pth'
load = False
# epoch_bias = 400

if load:
    net.load_state_dict(
        torch.load(load, map_location=device)
    )
    logging.info(f'Model loaded from {load}')

net.to(device=device)

# clear_output()

try:
    train_net(net=net,
              epochs=epochs,
              batch_size=batch_size,
              device=device)
except KeyboardInterrupt:
    torch.save(net.state_dict(), 'INTERRUPTED.pth')
    logging.info('Saved interrupt')
    try:
        sys.exit(0)
    except SystemExit:
        os._exit(0)

INFO: Creating dataset with 100 examples
INFO: Creating dataset with 100 examples
INFO: Starting training:
        Epochs:          3
        Batch size:      10
        Learning rate:   0.001
        Training size:   100
        Validation size: 100
        Checkpoints:     True
        Device:          cuda
        Images scaling:  1
    
Epoch 1/3:   0%|          | 0/100 [00:00<?, ?img/s]

Using CrossEntropyLoss
image shape = (308, 320)image shape = (308, 320)image shape = (308, 320)image shape = (308, 320)image shape = (308, 320)


image shape = (308, 320)image shape = (308, 320)image shape = (308, 320)

image shape = (308, 320)
image shape = (308, 320)


image shape = (308, 320)
image shape = (308, 320)image shape = (308, 320)

image shape = (300, 320)image shape = (308, 320)image shape = (308, 320)

image shape = (308, 320)
image shape = (308, 320)image shape = (308, 320)


image shape = (308, 320)image shape = (308, 320)
image shape = (308, 320)
