In [10]:
import argparse
import logging
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import wandb
import torchvision
import torchvision.transforms as transforms
import albumentations as A
from evaluation_GT_SR import *
from morphological_transforms import *

import torch
torch.cuda.empty_cache()

In [11]:
from utils.dice_score import dice_loss
from evaluate import evaluate
from unet import UNet

In [12]:
from wholeslidedata.iterators import create_batch_iterator
import numpy as np
from matplotlib import pyplot as plt
from plot_utils import init_plot, plot_batch, show_plot
from shapely.prepared import prep

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
dir_checkpoint = Path('./checkpoints/')

cuda


In [14]:
net = UNet(n_channels=3, n_classes=2, bilinear=True)
net.to(device=device)

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [15]:
aug_transforms = A.Compose(
    [A.HorizontalFlip(p=0.5),
     A.VerticalFlip(p=0.5),
     A.RandomBrightnessContrast(p=0.5),
     A.OneOf([A.Blur(blur_limit=5, p=0.5),
             A.ColorJitter(p=0.5),
             ],p=1.0),
])

In [16]:
def train_net(net,
              device,
              epochs: int = 5,
              batch_size: int = 1,
              learning_rate: float = 1e-5,
              val_percent: float = 0.1,
              save_checkpoint: bool = True,
              img_scale: float = 0.5,
              amp: bool = False):
    '''
    # 1. Create dataset
    try:
        dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
    except (AssertionError, RuntimeError):
        dataset = BasicDataset(dir_img, dir_mask, img_scale)

    # 2. Split into train / validation partitions
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

    # (Initialize logging)
    experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
                                  val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale,
                                  amp=amp))

    logging.info(Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {img_scale}
        Mixed Precision: {amp})
'''
    experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
                                  val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale,
                                  amp=amp))
    user_config = 'user_config.yml'
    if torch.cuda.is_available():
        try:
            train_loader = create_batch_iterator(user_config=user_config, number_of_batches=batch_size, mode='training', cpus=1)
            val_loader = create_batch_iterator(user_config=user_config, number_of_batches=batch_size, mode='validation', cpus=1)
        except:
            print('Exception!')
            sys.exit()
        # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
        optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
        grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
        criterion = nn.CrossEntropyLoss()
        global_step = 0
        #image_transforms = torchvision.transforms.Compose([ transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5],
                                 #[0.5, 0.5, 0.5])])
        n_train = 1
        # 5. Begin training
        for epoch in range(epochs):
            net.train()
            epoch_loss = 0
            acc = 0.
            SE = 0.
            SP = 0.
            PC = 0.
            F1 = 0.
            JS = 0.
            DC = 0.
            length = 0
            with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
                for batch in train_loader:
                    images = batch[0]
                    true_masks = batch[1]
                    #adding augmentations
                    for aug_batch in range(4):
                        image_aug = images[aug_batch,:,:,:]
                        mask_aug = true_masks[aug_batch,:,:]
                        transformed = aug_transforms(image=np.zeros((10, 10, 3), dtype=np.uint8),
                                                                    target_image=image_aug, target_mask=mask_aug)

                        images[aug_batch,:,:,:] =  transformed[target_image]
                        true_masks[aug_batch,:,:] = transformed[target_mask]

                    images = images.astype(np.float64)
                    true_masks = true_masks.astype(np.float64)
                    images = torch.Tensor(images)
                    true_masks = torch.Tensor(true_masks)

                    images = torch.permute(images,(0,3,1,2))

                    assert images.shape[1] == net.n_channels, \
                        f'Network has been defined with {net.n_channels} input channels, ' \
                        f'but loaded images have {images.shape[1]} channels. Please check that ' \
                        'the images are loaded correctly.'

                    images = images.to(device=device, dtype=torch.float32)/255
                    true_masks = true_masks.to(device=device, dtype=torch.long)


                    with torch.cuda.amp.autocast(enabled=amp):
                        masks_pred = net(images)
                        loss = criterion(masks_pred, true_masks) \
                               + dice_loss(F.softmax(masks_pred, dim=1).float(),
                                           F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
                                           multiclass=False)
                    optimizer.zero_grad(set_to_none=True)
                    grad_scaler.scale(loss).backward()
                    grad_scaler.step(optimizer)
                    grad_scaler.update()
                    loss = loss.detach().cpu()
                    pbar.update(images.shape[0])
                    global_step += 1
                    epoch_loss += loss.item()
                    experiment.log({
                        'train loss': loss.item(),
                        'step': global_step,
                        'epoch': epoch
                    })
                    pbar.set_postfix(**{'loss (batch)': loss.item()})

                    # Evaluation round
                    #division_step = (n_train // (10 * batch_size))
                    #if division_step > 0:
                    #    if global_step % division_step == 0:
                    histograms = {}
                    for tag, value in net.named_parameters():
                        tag = tag.replace('/', '.')
                        histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
                        histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

                    val_score = evaluate(net, val_loader, device)
                    scheduler.step(val_score)

                    logging.info('Validation Dice score: {}'.format(val_score))
                    experiment.log({
                        'learning rate': optimizer.param_groups[0]['lr'],
                        'validation Dice': val_score,
                        'images': wandb.Image(images[0].cpu()),
                        'masks': {
                            'true': wandb.Image(true_masks[0].float().cpu()),
                            'pred': wandb.Image(torch.softmax(masks_pred, dim=1).argmax(dim=1)[0].float().cpu()),
                        },
                        'step': global_step,
                        'epoch': epoch,
                        **histograms
                    })
                    images = images.detach().cpu()
                    true_masks = true_masks.detach().cpu()
                    del images
                    del true_masks
                    del loss
                    torch.cuda.empty_cache()
            print("Epoch: "+str(epoch))
            print("Loss:" +str(epoch_loss))
            if save_checkpoint:
                Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
                torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
                logging.info(f'Checkpoint {epoch + 1} saved!')

In [17]:
train_net(net=net,device=device,epochs=30,batch_size=1,learning_rate=0.01,
              val_percent=10 / 100,amp=False)

Problem at: <ipython-input-16-55cb41ab04ba> 44 train_net


Traceback (most recent call last):
  File "/home/mainuser/.local/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 995, in init
    run = wi.init()
  File "/home/mainuser/.local/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 531, in init
    backend.ensure_launched()
  File "/home/mainuser/.local/lib/python3.8/site-packages/wandb/sdk/backend/backend.py", line 186, in ensure_launched
    self.record_q = self._multiprocessing.Queue()
  File "/usr/lib/python3.8/multiprocessing/context.py", line 103, in Queue
    return Queue(maxsize, ctx=self.get_context())
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 42, in __init__
    self._rlock = ctx.Lock()
  File "/usr/lib/python3.8/multiprocessing/context.py", line 68, in Lock
    return Lock(ctx=self.get_context())
  File "/usr/lib/python3.8/multiprocessing/synchronize.py", line 162, in __init__
    SemLock.__init__(self, SEMAPHORE, 1, 1, ctx=ctx)
  File "/usr/lib/python3.8/multiprocessing/synchronize.py", line 

Exception: problem

In [6]:
x, y, info = next(inf_loader)
print(info)
image = x.astype(np.uint8)[0]
plt.imshow(image)
plt.show()
mask = y.astype(np.uint8)[0]
plt.imshow(mask)
plt.show()

NameError: name 'inf_loader' is not defined