In [1]:
import sys
import os
import logging
from IPython.display import clear_output
sys.path.append('../src')

# local imports.
from utils import ROOT_DIR
from data_loader import FarsightDataset, ToTensor
import visualize as viz
import model
from utils import DATA_DIR, get_depth_dir, get_img_dir, get_dev
from other_models.tiny_unet import UNet



import torch
import torch.nn as nn
# import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, Subset, random_split
from torch.utils.tensorboard import SummaryWriter
# task = Task.init(project_name='mde', task_name='test loop')
# logger = task.get_logger()

import matplotlib.pyplot as plt
from trains import Task
from tqdm import tqdm


%load_ext autoreload
%autoreload 2

In [2]:
def weight_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu')
#         torch.nn.init.normal_(m.weight.data)

In [5]:

def train(epochs=2,
          verbose=False,
          batch_size=2,
          val_percent=0.25):
    """
    main training loop.
    """
    print('started')
    writer = SummaryWriter()
    # create dataset
    ds = FarsightDataset(img_dir=get_img_dir(),
                         depth_dir=get_depth_dir(),
                         transform=ToTensor())
    minids = Subset(ds, range(2))
    n_val = int(len(minids) * val_percent)
    n_train = len(minids) - n_val
    train, val = random_split(minids,
                              [n_train, n_val],
                              generator=torch.Generator().manual_seed(42))
    train_loader = DataLoader(minids,
                              shuffle=False,
                              batch_size=batch_size,
                              num_workers=0)
#     val_loader = DataLoader(val,
#                             shuffle=False,
#                             batch_size=batch_size,
#                             num_workers=0)
    # TODO: fix weird float32 requirement in conv2d to work with uint8. Quantization?
    net = UNet()
    net.to(device=get_dev())
    print('using ', get_dev())
    net.apply(weight_init)
    num_batches = len(train_loader)
    print('num_batches: ', num_batches)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters())
#     crap = torch.zeros((2,512,512),device=get_dev())
    # main training loop.
    for epoch in range(epochs):  # loop over the dataset multiple times
        net.train()
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            
            running_loss = 0.0
            for i, data in enumerate(train_loader):
                # get the inputs; data is a list of [input images, depth maps]
                imgs, gt_depths = data['image'], data['depth']
                optimizer.zero_grad()
                # print('input shape {}, type: {}'.format(inputs.size(), inputs.dtype))
                pred_depth = net(imgs)
                #             print('out shape: {}, gt shape: {}'.format(outputs.size(), gt_depths.size()))
                loss = criterion(pred_depth, gt_depths)
#                 loss = criterion(pred_depth, crap)
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

                pbar.update(imgs.shape[0])
                # val statistics. check stats
#                 if i == num_batches - 1:  # last batch
                if epoch % 10 == 0:
                    print(running_loss / num_batches)
                    writer.add_scalar('Loss/train', running_loss / num_batches, epoch + 1)
#                     print(running_loss /)
                    # val scores
#                     val_score = model.eval_net(net, val_loader, criterion, writer)
#                     writer.add_scalar('Metric/test', val_score, epoch)
#                     logging.info('Validation score (metric?): {}'.format(val_score))

                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), epoch)
                        writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), epoch)
                    writer.add_histogram('values', pred_depth.detach().cpu().numpy(),epoch)
                    fig = viz.show_batch({**data, 'pred': pred_depth.detach()})
                    fig.suptitle(f'epoch {epoch}', fontsize='xx-large')
#                     plt.show()
#                     plt.title(f'epoch {epoch}')
                    print('adding plot')
                    writer.add_figure(tag='epoch/end', figure=fig, global_step=epoch)
                    writer.add_images('images', imgs, epoch)
                    writer.add_images('masks/true', gt_depths.unsqueeze(1), epoch)
                    writer.add_images('masks/pred', pred_depth.unsqueeze(1), epoch)
#             if verbose or epoch == epochs - 1:
                #             viz.show_batch({**data, 'pred': outputs.detach()})
#                 plt.show()
    print('Finished Training')
    writer.close()
    return gt_depths, pred_depth

In [7]:
gt, pred = train(50)

Epoch 1/50: 100%|██████████| 2/2 [00:00<00:00, 10.76img/s, loss (batch)=0.243]

started
using  cuda
num_batches:  1
0.24255569279193878
batch size:  2
adding plot


Epoch 1/50: 100%|██████████| 2/2 [00:01<00:00,  1.25img/s, loss (batch)=0.243]
Epoch 2/50: 100%|██████████| 2/2 [00:00<00:00,  5.88img/s, loss (batch)=0.206]
Epoch 3/50: 100%|██████████| 2/2 [00:00<00:00, 10.69img/s, loss (batch)=0.186]
Epoch 4/50: 100%|██████████| 2/2 [00:00<00:00, 10.60img/s, loss (batch)=0.17]
Epoch 5/50: 100%|██████████| 2/2 [00:00<00:00, 10.87img/s, loss (batch)=0.159]
Epoch 6/50: 100%|██████████| 2/2 [00:00<00:00, 10.69img/s, loss (batch)=0.151]
Epoch 7/50: 100%|██████████| 2/2 [00:00<00:00, 10.64img/s, loss (batch)=0.145]
Epoch 8/50: 100%|██████████| 2/2 [00:00<00:00, 10.66img/s, loss (batch)=0.14]
Epoch 9/50: 100%|██████████| 2/2 [00:00<00:00, 10.76img/s, loss (batch)=0.134]
Epoch 10/50: 100%|██████████| 2/2 [00:00<00:00, 10.33img/s, loss (batch)=0.128]
Epoch 11/50: 100%|██████████| 2/2 [00:00<00:00, 10.83img/s, loss (batch)=0.119]

0.11948443949222565
batch size:  2
adding plot


Epoch 11/50: 100%|██████████| 2/2 [00:01<00:00,  1.33img/s, loss (batch)=0.119]
Epoch 12/50: 100%|██████████| 2/2 [00:00<00:00,  7.55img/s, loss (batch)=0.11]
Epoch 13/50: 100%|██████████| 2/2 [00:00<00:00, 10.67img/s, loss (batch)=0.101]
Epoch 14/50: 100%|██████████| 2/2 [00:00<00:00, 10.86img/s, loss (batch)=0.0946]
Epoch 15/50: 100%|██████████| 2/2 [00:00<00:00, 10.23img/s, loss (batch)=0.0897]
Epoch 16/50: 100%|██████████| 2/2 [00:00<00:00, 10.04img/s, loss (batch)=0.0848]
Epoch 17/50: 100%|██████████| 2/2 [00:00<00:00, 10.09img/s, loss (batch)=0.0799]
Epoch 18/50: 100%|██████████| 2/2 [00:00<00:00, 10.13img/s, loss (batch)=0.0753]
Epoch 19/50: 100%|██████████| 2/2 [00:00<00:00,  9.86img/s, loss (batch)=0.0714]
Epoch 20/50: 100%|██████████| 2/2 [00:00<00:00, 10.11img/s, loss (batch)=0.0678]
Epoch 21/50: 100%|██████████| 2/2 [00:00<00:00, 10.26img/s, loss (batch)=0.0645]

0.06452243030071259
batch size:  2
adding plot


Epoch 21/50: 100%|██████████| 2/2 [00:01<00:00,  1.12img/s, loss (batch)=0.0645]
Epoch 22/50: 100%|██████████| 2/2 [00:00<00:00,  8.07img/s, loss (batch)=0.0617]
Epoch 23/50: 100%|██████████| 2/2 [00:00<00:00, 10.12img/s, loss (batch)=0.0589]
Epoch 24/50: 100%|██████████| 2/2 [00:00<00:00, 10.42img/s, loss (batch)=0.0565]
Epoch 25/50: 100%|██████████| 2/2 [00:00<00:00, 10.46img/s, loss (batch)=0.0543]
Epoch 26/50: 100%|██████████| 2/2 [00:00<00:00, 10.46img/s, loss (batch)=0.0522]
Epoch 27/50: 100%|██████████| 2/2 [00:00<00:00, 10.60img/s, loss (batch)=0.0502]
Epoch 28/50: 100%|██████████| 2/2 [00:00<00:00, 10.02img/s, loss (batch)=0.0484]
Epoch 29/50: 100%|██████████| 2/2 [00:00<00:00, 10.18img/s, loss (batch)=0.0468]
Epoch 30/50: 100%|██████████| 2/2 [00:00<00:00, 10.47img/s, loss (batch)=0.0453]
Epoch 31/50: 100%|██████████| 2/2 [00:00<00:00, 10.41img/s, loss (batch)=0.0439]

0.043906282633543015
batch size:  2
adding plot


Epoch 31/50: 100%|██████████| 2/2 [00:01<00:00,  1.32img/s, loss (batch)=0.0439]
Epoch 32/50: 100%|██████████| 2/2 [00:00<00:00,  9.21img/s, loss (batch)=0.0427]
Epoch 33/50: 100%|██████████| 2/2 [00:00<00:00, 10.45img/s, loss (batch)=0.0415]
Epoch 34/50: 100%|██████████| 2/2 [00:00<00:00, 10.00img/s, loss (batch)=0.0403]
Epoch 35/50: 100%|██████████| 2/2 [00:00<00:00, 10.50img/s, loss (batch)=0.0393]
Epoch 36/50: 100%|██████████| 2/2 [00:00<00:00, 10.36img/s, loss (batch)=0.0383]
Epoch 37/50: 100%|██████████| 2/2 [00:00<00:00, 10.46img/s, loss (batch)=0.0374]
Epoch 38/50: 100%|██████████| 2/2 [00:00<00:00, 10.37img/s, loss (batch)=0.0366]
Epoch 39/50: 100%|██████████| 2/2 [00:00<00:00, 10.28img/s, loss (batch)=0.0359]
Epoch 40/50: 100%|██████████| 2/2 [00:00<00:00, 10.62img/s, loss (batch)=0.0352]
Epoch 41/50: 100%|██████████| 2/2 [00:00<00:00, 10.84img/s, loss (batch)=0.0346]

0.0345703586935997
batch size:  2
adding plot


Epoch 41/50: 100%|██████████| 2/2 [00:01<00:00,  1.41img/s, loss (batch)=0.0346]
Epoch 42/50: 100%|██████████| 2/2 [00:00<00:00, 10.18img/s, loss (batch)=0.034]
Epoch 43/50: 100%|██████████| 2/2 [00:00<00:00,  9.91img/s, loss (batch)=0.0336]
Epoch 44/50: 100%|██████████| 2/2 [00:00<00:00, 10.46img/s, loss (batch)=0.0331]
Epoch 45/50: 100%|██████████| 2/2 [00:00<00:00, 10.54img/s, loss (batch)=0.0326]
Epoch 46/50: 100%|██████████| 2/2 [00:00<00:00, 10.58img/s, loss (batch)=0.0322]
Epoch 47/50: 100%|██████████| 2/2 [00:00<00:00, 10.18img/s, loss (batch)=0.0319]
Epoch 48/50: 100%|██████████| 2/2 [00:00<00:00, 10.52img/s, loss (batch)=0.0316]
Epoch 49/50: 100%|██████████| 2/2 [00:00<00:00, 10.42img/s, loss (batch)=0.0312]
Epoch 50/50: 100%|██████████| 2/2 [00:00<00:00, 10.43img/s, loss (batch)=0.0309]

Finished Training





tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0080, 0.0080, 0.0080,  ..., 0.0120, 0.0120, 0.0120],
         [0.0080, 0.0080, 0.0080,  ..., 0.0120, 0.0120, 0.0120],
         [0.0080, 0.0080, 0.0080,  ..., 0.0120, 0.0120, 0.0080]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0080, 0.0080, 0.0080,  ..., 0.0120, 0.0120, 0.0120],
         [0.0080, 0.0080, 0.0080,  ..., 0.0120, 0.0120, 0.0120],
         [0.0080, 0.0080, 0.0080,  ..., 0.0120, 0.0120, 0.0120]]],
       device='cuda:0')