In [1]:
import os
import sys
import csv
import argparse
import logging
import random
import numpy as np
from pathlib import Path
from tqdm import tqdm
from types import SimpleNamespace

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

ROOT = Path('..').resolve()
DATA_ROOT = ROOT / 'data' / 'instereo2k_sample'
DATASETS_PY = ROOT / 'datasets'
MODELS_DIR = ROOT / 'models'
LIGHT_WEIGHT = MODELS_DIR / 'light_weight'
WEIGHTS_PATH = MODELS_DIR / 'sceneflow.pth'
CORE_DIR = LIGHT_WEIGHT / 'core'

OUTPUT_DIR = ROOT / 'results' / 'train'

print("ROOT:", ROOT)
print("DATA_ROOT:", DATA_ROOT)
print("LIGHT_WEIGHT:", LIGHT_WEIGHT)
print("WEIGHTS_PATH:", WEIGHTS_PATH)
print("OUTPUT_DIR:", OUTPUT_DIR)
print("DATASETS_PY", DATASETS_PY)

os.chdir('..')
sys.path.insert(0, str(LIGHT_WEIGHT))

from models.light_weight.core_rt.rt_igev_stereo import IGEVStereo
from models.light_weight.evaluate_stereo_rt import *
import models.light_weight.core_rt.stereo_datasets as datasets
from models.light_weight.core_rt.utils.utils import InputPadder as Padder
from datasets.instereo2k import InStereo2KDataset, pil_to_tensor, load_disp_png, resize_pair_and_disp

ROOT: D:\DepthAnalitycs
DATA_ROOT: D:\DepthAnalitycs\data\instereo2k_sample
LIGHT_WEIGHT: D:\DepthAnalitycs\models\light_weight
WEIGHTS_PATH: D:\DepthAnalitycs\models\sceneflow.pth
OUTPUT_DIR: D:\DepthAnalitycs\results\train
DATASETS_PY D:\DepthAnalitycs\datasets


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cpu


### Конфиг модели

In [3]:
cfg = SimpleNamespace()
cfg.name = 'light-igev-train'
cfg.logdir = str(OUTPUT_DIR)
cfg.device = device
cfg.mixed_precision = True
cfg.precision_dtype = 'float16'
cfg.lr = 2e-4
cfg.wdecay = 1e-5
cfg.batch_size = 2
cfg.num_epochs = 2
cfg.num_steps = 20000
cfg.train_iters = 12
cfg.valid_iters = 12
cfg.img_size = [320, 768]
cfg.max_disp = 192
cfg.n_gru_layers = 1
cfg.corr_levels = 2
cfg.corr_radius = 4
cfg.n_downsample = 2
cfg.hidden_dim = 96

In [4]:
torch.manual_seed(666)
np.random.seed(666)
random.seed(666)

### Немного упрощенный sequence loss из репозитория модели

In [5]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def sequence_loss(agg_pred, iter_preds, disp_gt, valid, loss_gamma=0.9):

    n_predictions = len(iter_preds)
    assert n_predictions >= 1

    mag = torch.sum(disp_gt**2, dim=1).sqrt()
    valid_loc = ((valid >= 0.5) & (mag < 192)).unsqueeze(1)
    disp_loss = 0.0

    disp_loss += 1.0 * F.smooth_l1_loss(agg_pred[valid_loc.bool()], disp_gt[valid_loc.bool()], reduction='mean')
    for i in range(n_predictions):
        adjusted_loss_gamma = loss_gamma**(15/(n_predictions - 1)) if n_predictions > 1 else 1.0
        i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
        i_loss = (iter_preds[i] - disp_gt).abs()
        disp_loss += i_weight * i_loss[valid_loc.bool()].mean()
    epe = torch.sum((iter_preds[-1] - disp_gt)**2, dim=1).sqrt()
    epe = epe.view(-1)[valid_loc.view(-1)]
    metrics = {
        'epe': epe.mean().item() if epe.numel() else float('nan'),
        '1px': (epe < 1).float().mean().item() if epe.numel() else float('nan'),
        '3px': (epe < 3).float().mean().item() if epe.numel() else float('nan'),
    }
    return disp_loss, metrics

### Даталоадеры

In [6]:
train_ds = InStereo2KDataset(
        root_dir=str(DATA_ROOT),
        split='train',
        val_ratio=0.1,
        load_disp=True,
        disp_side='left',
        disp_divisor=100.0,
        resize_hw=(cfg.img_size[0], cfg.img_size[1]))

val_ds = InStereo2KDataset(
        root_dir=str(DATA_ROOT),
        split='val',
        val_ratio=0.1,
        load_disp=True,
        disp_side='left',
        disp_divisor=100.0,
        resize_hw=(cfg.img_size[0], cfg.img_size[1]))

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=True, pin_memory=True)

### Импрот модели

In [7]:
def strip_prefix(state_dict, prefix='module.'):
    if any(k.startswith(prefix) for k in state_dict.keys()):
        return {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
    return state_dict

ckpt_path = WEIGHTS_PATH
ckpt = torch.load(str(ckpt_path), map_location='cpu')

state = ckpt
state = strip_prefix(state, 'module.')

  ckpt = torch.load(str(ckpt_path), map_location='cpu')


In [8]:
args = cfg

model = IGEVStereo(args)
print("Model instantiated.")
print("Total params (M): %.2f" % (sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6))

model.load_state_dict(state, strict=False)

model = model.to(device)
model.train()

Model instantiated.
Total params (M): 4.17


IGEVStereo(
  (update_block): BasicUpdateBlock(
    (encoder): BasicMotionEncoder(
      (convc1): Conv2d(144, 64, kernel_size=(1, 1), stride=(1, 1))
      (convc2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (convd1): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
      (convd2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv): Conv2d(128, 95, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (gru): ConvGRU(
      (convz): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (convr): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (convq): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (disp_head): DispHead(
      (conv1): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(128, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
    )
    (mask_feat_4)

### Оптимизатор и логгер

In [9]:
optimizer = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wdecay, eps=1e-8)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                                          max_lr=cfg.lr, 
                                          total_steps=max(1, cfg.num_epochs * (len(train_loader) if train_loader is not None else 1)), 
                                          pct_start=0.01, 
                                          cycle_momentum=False)

class DummyScaler:
        def scale(self, x): return x
        def unscale_(self, opt): pass
        def step(self, opt): opt.step()
        def update(self): pass
scaler = DummyScaler()

In [10]:
class TrainLogger:
    def __init__(self, out_dir: Path):
        self.out_dir = Path(out_dir)
        self.out_dir.mkdir(parents=True, exist_ok=True)
        self.csv_path = self.out_dir / 'train_log.csv'
        self.writer = SummaryWriter(log_dir=str(self.out_dir))
        self.step = 0
        with open(self.csv_path, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=['step','epoch','loss','epe','1px','3px','lr'])
            writer.writeheader()
    def log_step(self, epoch, loss, metrics, lr):
        row = {'step': self.step, 'epoch': epoch, 'loss': float(loss), 'epe': metrics.get('epe', float('nan')), '1px': metrics.get('1px', float('nan')), '3px': metrics.get('3px', float('nan')), 'lr': lr}
        with open(self.csv_path, 'a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=list(row.keys()))
            writer.writerow(row)

        self.writer.add_scalar('train/loss', float(loss), self.step)
        self.writer.add_scalar('train/epe', float(metrics.get('epe', 0.0)), self.step)
        self.writer.add_scalar('train/1px', float(metrics.get('1px', 0.0)), self.step)
        self.writer.add_scalar('train/3px', float(metrics.get('3px', 0.0)), self.step)
        self.writer.add_scalar('train/lr', lr, self.step)
        self.step += 1
    def close(self):
        self.writer.close()

logger = TrainLogger(OUTPUT_DIR)

### Обучение

In [11]:
def validate(model, dataloader, cfg, device):
    model.eval()
    tot_epe = []
    tot_1px = []
    tot_3px = []
    with torch.no_grad():
        for _, *data_blob in dataloader:
            try:
                image1, image2, disp_gt, valid = [x.to(device) for x in data_blob]
            except Exception:
                batch = _ if isinstance(_, dict) else None
                continue
            inp1 = (image1 * 255.0)
            inp2 = (image2 * 255.0)

            padder = Padder(inp1.shape, divis_by=32)
            inp1_p, inp2_p = padder.pad(inp1, inp2)
            out = model(inp1_p, inp2_p, iters=cfg.valid_iters, test_mode=True)
            out = padder.unpad(out)
            if out.ndim == 4 and out.shape[1] == 1:
                out = out.squeeze(1)
            pred = out[:,0,:,:] if out.ndim==4 else out
            # pred (B,H,W), disp_gt (B,1,H,W)
            if disp_gt is not None:
                pred_interp = pred
                gt = disp_gt
                valid_mask = (valid >= 0.5).float()
                epe = torch.abs(pred_interp - gt).view(-1)
                valid_flat = valid_mask.view(-1).bool()
                if valid_flat.any():
                    epe = epe[valid_flat]
                    tot_epe.append(epe.mean().item())
                    tot_1px.append((epe < 1.0).float().mean().item())
                    tot_3px.append((epe < 3.0).float().mean().item())
    model.train()
    return {'epe': np.nanmean(tot_epe) if tot_epe else float('nan'),
            '1px': np.nanmean(tot_1px) if tot_1px else float('nan'),
            '3px': np.nanmean(tot_3px) if tot_3px else float('nan')}


In [None]:
best_val = float('inf')
global_step = 0
save_dir = Path(cfg.logdir)
save_dir.mkdir(parents=True, exist_ok=True)

for epoch in range(1, cfg.num_epochs + 1):
    epoch_start = time.time()
    running_loss = 0.0
    running_metrics = {'epe':0.0, '1px':0.0, '3px':0.0}
    n_batches = 0

    for i_batch, batch in enumerate(train_loader):
        if isinstance(batch, dict):
            image1 = batch['left'].to(device)
            image2 = batch['right'].to(device)
            disp_gt = batch.get('disp', None)
            valid = batch.get('mask', None)
            if disp_gt is not None:
                disp_gt = disp_gt.to(device).unsqueeze(1)
            if valid is not None:
                valid = valid.to(device)
        else:
            try:
                if len(batch) == 4:
                    image1, image2, disp_gt, valid = batch
                    image1 = image1.to(device); image2 = image2.to(device)
                    disp_gt = disp_gt.to(device).unsqueeze(1)
                    valid = valid.to(device)
                else:
                    _, *data_blob = batch
                    image1, image2, disp_gt, valid = [x.to(device) for x in data_blob]
            except Exception:
                continue

        optimizer.zero_grad()

        inp1 = (image1 * 255.0)
        inp2 = (image2 * 255.0)

        padder = Padder(inp1.shape, divis_by=32)
        inp1_p, inp2_p = padder.pad(inp1, inp2)

        with torch.cuda.amp.autocast(enabled=cfg.mixed_precision):
            agg_pred, iter_preds = model(inp1_p, inp2_p, iters=cfg.train_iters)
            loss, metrics = sequence_loss(agg_pred, iter_preds, disp_gt, valid)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        if isinstance(scheduler, optim.lr_scheduler.OneCycleLR):
            try:
                scheduler.step()
            except Exception:
                pass

        running_loss += loss.item()
        running_metrics['epe'] += metrics.get('epe', 0.0)
        running_metrics['1px'] += metrics.get('1px', 0.0)
        running_metrics['3px'] += metrics.get('3px', 0.0)
        n_batches += 1
        global_step += 1

        if global_step % 20 == 0:
            avg_loss = running_loss / n_batches
            avg_metrics = {k: running_metrics[k]/n_batches for k in running_metrics}
            lr = optimizer.param_groups[0]['lr']
            logger.log_step(epoch, avg_loss, avg_metrics, lr)

    val_metrics = {}
    if val_loader is not None:
        val_metrics = validate(model, val_loader, cfg, device)
        logger.writer.add_scalar('val/epe', val_metrics['epe'], epoch)
        logger.writer.add_scalar('val/1px', val_metrics['1px'], epoch)
        logger.writer.add_scalar('val/3px', val_metrics['3px'], epoch)

    # save
    ckpt_path = save_dir / f"{cfg.name}_epoch{epoch}.pth"
    to_save = model.state_dict()
    torch.save({'epoch': epoch, 'state_dict': to_save, 'optimizer': optimizer.state_dict()}, str(ckpt_path))

    epoch_time = time.time() - epoch_start
    print(f"Epoch {epoch} done in {epoch_time:.1f}s | loss={running_loss/n_batches:.4f} | train_epe={running_metrics['epe']/n_batches:.4f} | val_epe={val_metrics.get('epe','nan'):.4f}")

# finish
logger.close()
print("Training finished. Checkpoints & logs at", save_dir)