In [None]:
from __future__ import print_function, division
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import numpy as np
import time
import json
from tensorboardX import SummaryWriter
from datasets import __datasets__
from models import __models__, model_loss_train, model_loss_test
from utils import *
from torch.utils.data import DataLoader
import gc
from dataclasses import dataclass
import matplotlib.pyplot as plt
from pathlib import Path

# plt.style.use('ggplot')

In [2]:
with open('../../config.json') as f:
    config = json.load(f)
# BASE_DIR = config["BASE_DIR_WIN"]
# CKPT_DIR = config["BASE_DIR_CGI_CKPT_WIN"]
BASE_DIR = config["BASE_DIR"]
CKPT_DIR = config["BASE_DIR_CGI_CKPT_MAC"]

In [3]:
torch.__version__


'2.4.0'

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("mps")
device

device(type='cpu')

In [5]:
cudnn.benchmark = True
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [6]:
# os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'

In [7]:
@dataclass
class Args:
    model: str           = 'CGI_Stereo'
    maxdisp: int         = 192
    dataset: str         = 'scape_pipes'
    datapath: str        = Path(BASE_DIR)
    trainlist: str       = Path("dataset_paths/half_aug_combined_rectified_scape_dataset_train.txt")
    testlist: str        = Path("dataset_paths/half_aug_combined_rectified_scape_dataset_test.txt")
    lr: float            = 0.001
    batch_size: int      = 1
    test_batch_size: int = 1
    epochs: int          = 100
    lrepochs: str        = "10,14,16,18:2"
    logdir: str          = os.path.join(CKPT_DIR, "ScapeCombinedAugmented_2")
    # loadckpt: str        = os.path.join(CKPT_DIR, "Sceneflow", "checkpoint_000098.ckpt")
    loadckpt: str        = False
    resume: bool         = False
    seed: int            = 42
    summary_freq: int    = 1
    save_freq: int       = 1
args = Args()

In [8]:
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
os.makedirs(args.logdir, exist_ok=True)

In [9]:
# create summary logger
print("creating new summary file")
logger = SummaryWriter(args.logdir)

creating new summary file


In [10]:
# dataset, dataloader
StereoDataset = __datasets__[args.dataset]
train_dataset = StereoDataset(args.datapath, args.trainlist, True)
test_dataset = StereoDataset(args.datapath, args.testlist, False)

In [11]:
TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4, drop_last=True)
TestImgLoader = DataLoader(test_dataset, args.test_batch_size, shuffle=False, num_workers=2, drop_last=False)

In [12]:
# model, optimizer
model = __models__[args.model](args.maxdisp)
model = nn.DataParallel(model)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
logger.add_text('args', str(args))
logger.add_text('model', str(model))
logger.add_text('optimiser', str(optimizer))

In [13]:
# load parameters
start_epoch = 0
if args.resume:

    # find all checkpoints file and sort according to epoch id
    all_saved_ckpts = [fn for fn in os.listdir(args.logdir) if fn.endswith(".ckpt")]
    all_saved_ckpts = sorted(all_saved_ckpts, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    # use the latest checkpoint file
    loadckpt = os.path.join(os.path.abspath(args.logdir), all_saved_ckpts[-1])
    print("loading the lastest model in logdir: {}".format(loadckpt))
    state_dict = torch.load(loadckpt)
    model.load_state_dict(state_dict['model'])
    optimizer.load_state_dict(state_dict['optimizer'])
    start_epoch = state_dict['epoch'] + 1
elif args.loadckpt:
    # load the checkpoint file specified by args.loadckpt
    print("loading model {}".format(args.loadckpt))
    state_dict = torch.load(args.loadckpt)
    model_dict = model.state_dict()
    pre_dict = {k: v for k, v in state_dict['model'].items() if k in model_dict}
    model_dict.update(pre_dict) 
    # model.load_state_dict(state_dict['model'])
    model.load_state_dict(model_dict)
print("start at epoch {}".format(start_epoch))


start at epoch 0


In [14]:
# train one sample
def train_sample(sample, compute_metrics=False):
    model.train()
    imgL, imgR, disp_gt, disp_gt_low = sample['left'], sample['right'], sample['disparity'], sample['disparity_low']

    disp_gt = torch.abs(disp_gt)
    disp_gt_low = torch.abs(disp_gt_low)

    imgL = imgL.to(device)
    imgR = imgR.to(device)
    disp_gt = disp_gt.to(device)
    disp_gt_low = disp_gt_low.to(device)

    optimizer.zero_grad()

    disp_ests = model(imgL, imgR)
    mask = (disp_gt < args.maxdisp) & (disp_gt > 0)
    mask_low = (disp_gt_low < args.maxdisp) & (disp_gt_low > 0)
    masks = [mask, mask_low]
    disp_gts = [disp_gt, disp_gt_low] 
    loss = model_loss_train(disp_ests, disp_gts, masks)
    disp_ests_final = [disp_ests[0]]

    scalar_outputs = {"loss": loss}
    # image_outputs = {"disp_est": disp_ests, "disp_gt": disp_gt, "imgL": imgL, "imgR": imgR}
    if compute_metrics:
        with torch.no_grad():
            # image_outputs["errormap"] = [disp_error_image_func()(disp_est, disp_gt) for disp_est in disp_ests_final]
            scalar_outputs["EPE"] = [EPE_metric(disp_est, disp_gt, mask) for disp_est in disp_ests_final]
            scalar_outputs["D1"] = [D1_metric(disp_est, disp_gt, mask) for disp_est in disp_ests_final]
            # scalar_outputs["Thres1"] = [Thres_metric(disp_est, disp_gt, mask, 1.0) for disp_est in disp_ests_final]
            # scalar_outputs["Thres2"] = [Thres_metric(disp_est, disp_gt, mask, 2.0) for disp_est in disp_ests_final]
            # scalar_outputs["Thres3"] = [Thres_metric(disp_est, disp_gt, mask, 3.0) for disp_est in disp_ests_final]
    loss.backward()
    optimizer.step()

    return tensor2float(loss), tensor2float(scalar_outputs)

In [15]:
# test one sample
@make_nograd_func
def test_sample(sample, compute_metrics=True):
    model.eval()
    imgL, imgR, disp_gt = sample['left'], sample['right'], sample['disparity']

    disp_gt = torch.abs(disp_gt)

    imgL = imgL.to(device)
    imgR = imgR.to(device)
    disp_gt = disp_gt.to(device)

    disp_ests = model(imgL, imgR)

    # plt.imshow(disp_gt[0].detach().cpu().numpy())
    # plt.show()
    # plt.imshow(disp_ests[0][0].detach().cpu().numpy())
    # plt.show()

    mask = (disp_gt < args.maxdisp) & (disp_gt > 0)
    masks = [mask]
    disp_gts = [disp_gt]
    loss = model_loss_test(disp_ests, disp_gts, masks)

    scalar_outputs = {"loss": loss}

    image_outputs = {"disp_est": disp_ests, "disp_gt": disp_gt, "imgL": imgL, "imgR": imgR}

    scalar_outputs["D1"] = [D1_metric(disp_est, disp_gt, mask) for disp_est in disp_ests]
    scalar_outputs["EPE"] = [EPE_metric(disp_est, disp_gt, mask) for disp_est in disp_ests]
    scalar_outputs["Thres1"] = [Thres_metric(disp_est, disp_gt, mask, 1.0) for disp_est in disp_ests]
    scalar_outputs["Thres2"] = [Thres_metric(disp_est, disp_gt, mask, 2.0) for disp_est in disp_ests]
    scalar_outputs["Thres3"] = [Thres_metric(disp_est, disp_gt, mask, 3.0) for disp_est in disp_ests]

    if compute_metrics:
        # image_outputs["errormap"] = [disp_error_image_func()(disp_est, disp_gt) for disp_est in disp_ests]
        image_outputs["errormap"] = [disp_error_image_func.apply(disp_est, disp_gt) for disp_est in disp_ests]

    return tensor2float(loss), tensor2float(scalar_outputs), image_outputs

In [16]:
def plot_error_histogram(errormap):
    errormap_flat = errormap.flatten()
    plt.figure(figsize=(12, 6))  # Adjust the size as needed
    plt.hist(errormap_flat, bins=24, range=(0, 1), align='mid')
    plt.xlabel('Normalized Error Value')
    plt.ylabel('Frequency')
    plt.title('Histogram of Normalized Error Values')
    plt.show()

def generate_error_label_counts(errormap):
    errormap_flat = errormap.flatten()
    labels, counts = np.unique(errormap_flat, return_counts=True)
    error_label_counts = dict(zip(labels, counts))
    return error_label_counts

def local_plot_error_histogram(image_outputs):
    image_outputs = tensor2numpy(image_outputs)
    fig, ax = plt.subplots(1, 3, figsize=(15, 5), dpi =300)
    ax[0].imshow(image_outputs["disp_est"][0][0])
    ax[0].set_title("disp_est")
    ax[1].imshow(image_outputs["disp_gt"][0])
    ax[1].set_title("disp_gt")
    ax[2].imshow(image_outputs["imgL"][0][0], cmap='gray')
    ax[2].set_title("imgL")
    plt.tight_layout()
    plt.show()
    # plt.savefig(f"{args.logdir}/test_{global_step:0>6}.png")

    errormap = image_outputs["errormap"][0][0]
    errormap = np.transpose(errormap, (1, 2, 0))  # Transpose to (H, W, C) format
    plt.figure(figsize=(24, 12))
    plt.imshow(errormap, cmap='RdYlBu_r')
    plt.show()

In [17]:
def train():
    bestepoch = 0
    error = 100
    for epoch_idx in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch_idx, args.lr, args.lrepochs)

        # training
        avg_train_scalars = AverageMeterDict()
        for batch_idx, sample in enumerate(TrainImgLoader):
            global_step = len(TrainImgLoader) * epoch_idx + batch_idx
            start_time = time.time()
            do_summary = True
            # do_summary = global_step % args.summary_freq == 0
            loss, scalar_outputs = train_sample(sample, compute_metrics=do_summary)
            if do_summary:
                save_scalars(logger, 'train', scalar_outputs, global_step)
                # save_images(logger, 'train', image_outputs, global_step)
            avg_train_scalars.update(scalar_outputs)
            del scalar_outputs
            print(f'Epoch {epoch_idx+1}/{args.epochs}, Iter {batch_idx+1}/{len(TrainImgLoader)}, train loss = {loss:.3f}, time = {time.time() - start_time:.3f}')
        avg_train_scalars = avg_train_scalars.mean()
        save_scalars(logger, 'train', avg_train_scalars, epoch_idx + 1)

        # saving checkpoints
        if (epoch_idx + 1) % args.save_freq == 0:
            checkpoint_data = {'epoch': epoch_idx, 'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
            torch.save(checkpoint_data, f"{args.logdir}/checkpoint_{epoch_idx:0>6}.ckpt")
        gc.collect()

        # testing
        avg_test_scalars = AverageMeterDict()
        bestepoch = 0
        error = 100
        for batch_idx, sample in enumerate(TestImgLoader):
            global_step = (len(TestImgLoader) * epoch_idx + batch_idx) + 1
            start_time = time.time()
            loss, scalar_outputs, image_outputs = test_sample(sample, compute_metrics=do_summary)

            # save image outputs
            do_summary = True
            # do_summary = global_step % args.summary_freq == 0
            if do_summary:
                save_scalars(logger, 'test', scalar_outputs, global_step)
                # save_images(logger, 'test', image_outputs, epoch_idx +1 )

            local_visualise = False
            if local_visualise:
                errormap = image_outputs["errormap"][0][0]
                local_plot_error_histogram(image_outputs)
                plot_error_histogram(errormap)
                error_label_counts = generate_error_label_counts(errormap)
                print(error_label_counts)

            avg_test_scalars.update(scalar_outputs)
            del scalar_outputs
            print(f'Epoch {epoch_idx+1}/{args.epochs}, Iter {batch_idx+1}/{len(TestImgLoader)},  test loss = {loss:.3f}, time = { time.time() - start_time:3f}')

            
        avg_test_scalars = avg_test_scalars.mean()
        nowerror = avg_test_scalars["D1"][0]
        if  nowerror < error :
            bestepoch = epoch_idx
            error = avg_test_scalars["D1"][0]
        save_scalars(logger, 'validation', avg_test_scalars, epoch_idx + 1)
        print("avg_test_scalars", avg_test_scalars)
        print('MAX epoch %d total test error = %.5f' % (bestepoch, error))
        gc.collect()
    print('MAX epoch %d total test error = %.5f' % (bestepoch, error))

In [None]:
if __name__ == '__main__':
    train()