Adapted from https://github.com/milesial/Pytorch-UNet and the PyTorch documentation.

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import gc, collections, resource

In [2]:
from unet_model import UNet
from dataset import PandoraImage
from eval import eval_net

In [3]:
#%load_ext tensorboard
!rm -rf ./runs/
#%tensorboard --logdir /runs

rm: cannot remove './runs/May24_15-45-32_x360LR_0.0001_BS_1': Directory not empty
rm: cannot remove './runs/May24_15-51-50_x360LR_0.0001_BS_1': Directory not empty
rm: cannot remove './runs/May24_15-52-24_x360LR_0.0001_BS_1': Directory not empty
rm: cannot remove './runs/May24_15-59-03_x360LR_0.0001_BS_1': Directory not empty
rm: cannot remove './runs/May24_16-14-04_x360LR_0.0001_BS_1': Directory not empty
rm: cannot remove './runs/May24_16-23-59_x360LR_0.0001_BS_1': Directory not empty


In [4]:
# https://forum.pyro.ai/t/a-clever-trick-to-debug-tensor-memory/556
def debug_memory():
    print('maxrss = {}'.format(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss))
    tensors = collections.Counter((str(o.device), o.dtype, tuple(o.shape))
                                  for o in gc.get_objects()
                                  if torch.is_tensor(o))
    for line in tensors.items():
        print('{}\t{}'.format(*line))

In [5]:
def NoBackgroundCrossEntropyLoss(masks_pred, true_masks):
    print("torch.shape(masks_pred)", masks_pred.size())
    print("torch.shape(true_masks)", true_masks.size())
    x = torch.sum(true_masks, dim=1)
    print("torch.shape(x)", x.size())
    x = x[:,None,:,:]
    x = x.repeat_interleave(3,1)
    background_mask = (x==0.0)
    print("torch.shape(background_mask)", background_mask.size())
    masks_pred[background_mask] = 0.0
    print("torch.shape(masks_pred)2", masks_pred.size())
    #return nn.CrossEntropyLoss()(masks_pred, true_masks).cuda()
    return nn.BCEWithLogitsLoss()(masks_pred, true_masks).cuda()

In [6]:
def train_net(net, device, dataPath, epochs=5, batch_size=1, lr=0.001, val_percent=0.1, save_cp=True):
    global_step = 0
    
    print("DataInput")
    dataset = PandoraImage(dataPath, imsize=384)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
    
    writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}')
    
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)
    criterion = nn.CrossEntropyLoss(ignore_index=3)
    
    for epoch in range(epochs):
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                imgs = imgs.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.long) # Why long???

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)#NoBackgroundCrossEntropyLoss(masks_pred, true_masks)
                epoch_loss += loss.item()
                
                writer.add_scalar('Batch Loss (train)', loss.item(), global_step)
                
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                pbar.update(imgs.shape[0])
                optimizer.zero_grad()

                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1) # Necessary ???
                optimizer.step()

                global_step += 1
                if global_step % (len(dataset) // (10 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
                        writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
                    val_score = eval_net(net, val_loader, device)
                    scheduler.step(val_score)
                    writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
                    writer.add_scalar('Loss/test', val_score, global_step)
                    writer.add_images('images', imgs, global_step)
                    try:
                        writer.add_images('masks/true', true_masks[0,:,:,:], global_step)
                        writer.add_images('masks/pred', masks_pred[0,:,:,:], global_step)
                    except:
                        print("Problem adding masks")
                        
            
        if save_cp:
            torch.save(net.state_dict(), 'checkpoints/' + f'CP_epoch{epoch + 1}.pth')
    writer.close()

# Training

In [7]:
#debug_memory()

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(n_channels=1, n_classes=3, bilinear=False)
net.to(device=device)

path = "/home/philip/RemoteFNAL/PanLee_v8_00_00_13/OutTest/viewV_May13.bin"

train_net(net=net,
          dataPath = path,
          epochs=5,
          batch_size=1,
          lr=0.0001,
          device=device,
          val_percent=0.2)

DataInput


Epoch 1/5:  12%|█▏        | 351/2815 [04:29<30:24,  1.35img/s, loss (batch)=0.379]
Validation round:   0%|          | 0/703 [00:00<?, ?batch/s][A
Validation round:   0%|          | 1/703 [00:00<04:07,  2.84batch/s][A
Validation round:   0%|          | 2/703 [00:00<03:30,  3.33batch/s][A
Validation round:   0%|          | 3/703 [00:00<03:04,  3.80batch/s][A
Validation round:   1%|          | 4/703 [00:00<02:45,  4.22batch/s][A
Validation round:   1%|          | 5/703 [00:01<02:34,  4.52batch/s][A
Validation round:   1%|          | 6/703 [00:01<02:50,  4.09batch/s][A
Validation round:   1%|          | 7/703 [00:01<03:00,  3.86batch/s][A
Validation round:   1%|          | 8/703 [00:01<03:02,  3.80batch/s][A
Validation round:   1%|▏         | 9/703 [00:02<02:44,  4.23batch/s][A
Validation round:   1%|▏         | 10/703 [00:02<02:33,  4.51batch/s][A
Validation round:   2%|▏         | 11/703 [00:02<02:29,  4.63batch/s][A
Validation round:   2%|▏         | 12/703 [00:02<02:21,  4.

Validation round:  16%|█▌        | 111/703 [00:22<01:51,  5.31batch/s][A
Validation round:  16%|█▌        | 112/703 [00:22<01:51,  5.28batch/s][A
Validation round:  16%|█▌        | 113/703 [00:22<01:52,  5.23batch/s][A
Validation round:  16%|█▌        | 114/703 [00:22<01:52,  5.24batch/s][A
Validation round:  16%|█▋        | 115/703 [00:22<01:52,  5.24batch/s][A
Validation round:  17%|█▋        | 116/703 [00:23<01:52,  5.21batch/s][A
Validation round:  17%|█▋        | 117/703 [00:23<01:51,  5.24batch/s][A
Validation round:  17%|█▋        | 118/703 [00:23<01:50,  5.30batch/s][A
Validation round:  17%|█▋        | 119/703 [00:23<01:48,  5.37batch/s][A
Validation round:  17%|█▋        | 120/703 [00:23<01:47,  5.41batch/s][A
Validation round:  17%|█▋        | 121/703 [00:23<01:47,  5.39batch/s][A
Validation round:  17%|█▋        | 122/703 [00:24<01:47,  5.41batch/s][A
Validation round:  17%|█▋        | 123/703 [00:24<01:46,  5.44batch/s][A
Validation round:  18%|█▊        | 124

Validation round:  31%|███▏      | 221/703 [00:43<01:30,  5.34batch/s][A
Validation round:  32%|███▏      | 222/703 [00:43<01:30,  5.33batch/s][A
Validation round:  32%|███▏      | 223/703 [00:43<01:29,  5.36batch/s][A
Validation round:  32%|███▏      | 224/703 [00:43<01:29,  5.36batch/s][A
Validation round:  32%|███▏      | 225/703 [00:43<01:28,  5.41batch/s][A
Validation round:  32%|███▏      | 226/703 [00:43<01:28,  5.41batch/s][A
Validation round:  32%|███▏      | 227/703 [00:44<01:27,  5.45batch/s][A
Validation round:  32%|███▏      | 228/703 [00:44<01:26,  5.46batch/s][A
Validation round:  33%|███▎      | 229/703 [00:44<01:26,  5.46batch/s][A
Validation round:  33%|███▎      | 230/703 [00:44<01:26,  5.47batch/s][A
Validation round:  33%|███▎      | 231/703 [00:44<01:27,  5.42batch/s][A
Validation round:  33%|███▎      | 232/703 [00:45<01:27,  5.39batch/s][A
Validation round:  33%|███▎      | 233/703 [00:45<01:26,  5.42batch/s][A
Validation round:  33%|███▎      | 234

Validation round:  47%|████▋     | 331/703 [01:03<01:12,  5.14batch/s][A
Validation round:  47%|████▋     | 332/703 [01:04<01:11,  5.22batch/s][A
Validation round:  47%|████▋     | 333/703 [01:04<01:10,  5.26batch/s][A
Validation round:  48%|████▊     | 334/703 [01:04<01:09,  5.32batch/s][A
Validation round:  48%|████▊     | 335/703 [01:04<01:08,  5.38batch/s][A
Validation round:  48%|████▊     | 336/703 [01:04<01:07,  5.42batch/s][A
Validation round:  48%|████▊     | 337/703 [01:05<01:07,  5.44batch/s][A
Validation round:  48%|████▊     | 338/703 [01:05<01:06,  5.45batch/s][A
Validation round:  48%|████▊     | 339/703 [01:05<01:06,  5.46batch/s][A
Validation round:  48%|████▊     | 340/703 [01:05<01:06,  5.46batch/s][A
Validation round:  49%|████▊     | 341/703 [01:05<01:07,  5.34batch/s][A
Validation round:  49%|████▊     | 342/703 [01:05<01:08,  5.28batch/s][A
Validation round:  49%|████▉     | 343/703 [01:06<01:07,  5.30batch/s][A
Validation round:  49%|████▉     | 344

Validation round:  63%|██████▎   | 441/703 [01:24<00:49,  5.29batch/s][A
Validation round:  63%|██████▎   | 442/703 [01:24<00:49,  5.30batch/s][A
Validation round:  63%|██████▎   | 443/703 [01:24<00:49,  5.26batch/s][A
Validation round:  63%|██████▎   | 444/703 [01:25<00:48,  5.31batch/s][A
Validation round:  63%|██████▎   | 445/703 [01:25<00:48,  5.36batch/s][A
Validation round:  63%|██████▎   | 446/703 [01:25<00:47,  5.40batch/s][A
Validation round:  64%|██████▎   | 447/703 [01:25<00:47,  5.43batch/s][A
Validation round:  64%|██████▎   | 448/703 [01:25<00:46,  5.45batch/s][A
Validation round:  64%|██████▍   | 449/703 [01:25<00:46,  5.44batch/s][A
Validation round:  64%|██████▍   | 450/703 [01:26<00:46,  5.46batch/s][A
Validation round:  64%|██████▍   | 451/703 [01:26<00:46,  5.42batch/s][A
Validation round:  64%|██████▍   | 452/703 [01:26<00:46,  5.42batch/s][A
Validation round:  64%|██████▍   | 453/703 [01:26<00:45,  5.45batch/s][A
Validation round:  65%|██████▍   | 454

Validation round:  78%|███████▊  | 551/703 [01:45<00:28,  5.38batch/s][A
Validation round:  79%|███████▊  | 552/703 [01:45<00:28,  5.36batch/s][A
Validation round:  79%|███████▊  | 553/703 [01:45<00:27,  5.37batch/s][A
Validation round:  79%|███████▉  | 554/703 [01:46<00:27,  5.38batch/s][A
Validation round:  79%|███████▉  | 555/703 [01:46<00:28,  5.18batch/s][A
Validation round:  79%|███████▉  | 556/703 [01:46<00:27,  5.30batch/s][A
Validation round:  79%|███████▉  | 557/703 [01:46<00:27,  5.40batch/s][A
Validation round:  79%|███████▉  | 558/703 [01:46<00:26,  5.43batch/s][A
Validation round:  80%|███████▉  | 559/703 [01:47<00:26,  5.43batch/s][A
Validation round:  80%|███████▉  | 560/703 [01:47<00:26,  5.44batch/s][A
Validation round:  80%|███████▉  | 561/703 [01:47<00:26,  5.41batch/s][A
Validation round:  80%|███████▉  | 562/703 [01:47<00:26,  5.38batch/s][A
Validation round:  80%|████████  | 563/703 [01:47<00:25,  5.40batch/s][A
Validation round:  80%|████████  | 564

Validation round:  94%|█████████▍| 661/703 [02:07<00:07,  5.52batch/s][A
Validation round:  94%|█████████▍| 662/703 [02:07<00:07,  5.53batch/s][A
Validation round:  94%|█████████▍| 663/703 [02:07<00:07,  5.55batch/s][A
Validation round:  94%|█████████▍| 664/703 [02:08<00:07,  5.45batch/s][A
Validation round:  95%|█████████▍| 665/703 [02:08<00:06,  5.52batch/s][A
Validation round:  95%|█████████▍| 666/703 [02:08<00:06,  5.54batch/s][A
Validation round:  95%|█████████▍| 667/703 [02:08<00:06,  5.51batch/s][A
Validation round:  95%|█████████▌| 668/703 [02:08<00:06,  5.56batch/s][A
Validation round:  95%|█████████▌| 669/703 [02:08<00:06,  5.59batch/s][A
Validation round:  95%|█████████▌| 670/703 [02:09<00:05,  5.52batch/s][A
Validation round:  95%|█████████▌| 671/703 [02:09<00:05,  5.49batch/s][A
Validation round:  96%|█████████▌| 672/703 [02:09<00:05,  5.42batch/s][A
Validation round:  96%|█████████▌| 673/703 [02:09<00:05,  5.51batch/s][A
Validation round:  96%|█████████▌| 674

AssertionError: size of input tensor and input format are different.         tensor shape: (1, 1, 3, 384, 384), input_format: NCHW

In [None]:
#del net, device
#gc.collect()