In [None]:
#!git clone https://github.com/milesial/Pytorch-UNet.git
# Fix for cuda init error
#!pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
import argparse
import logging
import sys
from pathlib import Path
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from utils.clothing_dataset import BasicDataset, ClothingDataset
from utils.dice_score import dice_loss
from utils.evaluate import evaluate
from unet import UNet
import gc
import matplotlib.pyplot as plt
from datetime import datetime

In [None]:
import torch.cuda
print(torch.cuda.is_available())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Setup wandb

In [None]:
wandb.init(project='unet', entity='endev')

# Dataset Path Variables

In [None]:
root = "./data/"
dir_checkpoint = Path('./training/')

In [None]:
train_dir_img = Path('/data/datasets/clothing-size/train/imgs/')
train_dir_mask = Path('/data/datasets/clothing-size/train/mask/')

In [None]:
val_dir_img = Path('/data/datasets/clothing-size/val/imgs/')
val_dir_mask = Path('/data/datasets/clothing-size/val/mask/')

# Check data

In [None]:
ls /data/datasets/clothing-size/val/imgs/

In [None]:
ls /data/datasets/clothing-size/val/mask/

In [None]:
def train_net(net,
              device,
              epochs: int = 5,
              batch_size: int = 1,
              learning_rate: float = 0.001,
              save_checkpoint: bool = True,
              img_scale: float = 0.5,
              amp: bool = False):
    # 1. Create dataset
    train_dataset = ClothingDataset(train_dir_img, train_dir_mask, img_scale)
    val_dataset = ClothingDataset(val_dir_img, val_dir_mask, img_scale)

    # 2. Totals
    n_train = len(train_dataset)
    n_val = len(val_dataset)

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=1, pin_memory=True)
    train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)
    val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, **loader_args)

    # (Initialize logging)
    experiment = wandb.init(project='U-Net-mask', resume='allow', anonymous='must')
    experiment.config.update(dict(epochs=epochs, 
                                  batch_size=batch_size, 
                                  learning_rate=learning_rate,
                                  save_checkpoint=save_checkpoint, 
                                  img_scale=img_scale,
                                  amp=amp))
    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {img_scale}
        Mixed Precision: {amp}
    ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss()
    global_step = 0

    # 5. Begin training
    for epoch in range(epochs):
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images = batch['image']
                true_masks = batch['mask']

                assert images.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.long)

                with torch.cuda.amp.autocast(enabled=amp):
                    masks_pred = net(images)
                    loss = criterion(masks_pred, true_masks) \
                           + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                       F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
                                       multiclass=True)

                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                experiment.log({
                    'train loss': loss.item(),
                    'step': global_step,
                    'epoch': epoch
                })
                
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                # Evaluation round
                if global_step % (n_train // (2 * batch_size)) == 0:
                    histograms = {}
                    for tag, value in net.named_parameters():
                        tag = tag.replace('/', '.')
                        histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                        histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

                    val_score = evaluate(net, val_loader, device)
                    scheduler.step(val_score)

                    logging.info('Validation Dice score: {}'.format(val_score))
                    experiment.log({
                        'learning rate': optimizer.param_groups[0]['lr'],
                        'validation Dice': val_score,
                        'images': wandb.Image(images[0].cpu()),
                        'masks': {
                            'true': wandb.Image(true_masks[0].float().cpu()),
                            'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
                        },
                        'step': global_step,
                        'epoch': epoch,
                        **histograms
                    })

        if save_checkpoint:
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}_{}.pth'.format(epoch, datetime.now().strftime("%m_%d_%Y_%H_%M_%S"))))
            logging.info(f'Checkpoint {epoch + 1} saved at {datetime.now().strftime("%m_%d_%Y_%H_%M_%S")}!')

In [None]:
# Number of epochs
epochs = 50

# Batch size
batch_size=1

# Learning rate
lr = 0.0001

# Load model from a .pth file
load = False

# Downscaling factor of the image
scale = 1

# Use mixed precision
amp = False

weight_decay=1e-8

momentum=0.9

config = wandb.config
config.epochs = epochs
config.batch_size = batch_size
config.lr = lr
config.load = load
config.scale = scale
config.amp = amp
config.weight_decay = weight_decay
config.momentum = momentum

config.model = "Unet-mask"

In [None]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')

In [None]:
net = UNet(n_channels=3, n_classes=2, bilinear=True)
net.to(device="cuda")

In [None]:
logging.info(f'Network:\n'
                 f'\t{net.n_channels} input channels\n'
                 f'\t{net.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
try:
    train_net(net=net,
              epochs=epochs,
              batch_size=batch_size,
              learning_rate=lr,
              device=device,
              img_scale=scale,
              amp=amp)
except KeyboardInterrupt:
    torch.save(net.state_dict(), 'INTERRUPTED.pth')
    logging.info('Saved interrupt')
    sys.exit(0)