In [1]:
import os, shutil

src = "/kaggle/input/stdc-repo/STDC-Seg"
dst = "/kaggle/working/STDC-Seg"

# ensure the parent exists
os.makedirs(os.path.dirname(dst), exist_ok=True)

# copy the whole tree
shutil.copytree(src, dst)

'/kaggle/working/STDC-Seg'

In [2]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
sys.path.append('/kaggle/working/STDC-Seg')

import argparse
import logging
import time
import datetime
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from loveda import LoveDA
from models.model_stages import BiSeNet
from loss.loss import WeightedOhemCELoss, OhemCELoss
from loss.detail_loss import DetailAggregateLoss
from optimizer_loss import Optimizer
import torch.optim as optim
# Configuration
def parse_args():
    parser = argparse.ArgumentParser(description='STDC Segmentation Training')
    parser.add_argument('--epochs', type=int, default=20, help='number of training epochs')
    parser.add_argument('--batch_size', type=int, default=6, help='images per batch')
    parser.add_argument('--n_workers_train', type=int, default=4)
    parser.add_argument('--n_workers_val', type=int, default=0)
    parser.add_argument('--cropsize', type=int, nargs=2, default=[1024, 512])
    parser.add_argument('--randomscale', type=float, nargs='+',
                        default=[0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5])
    parser.add_argument('--backbone', type=str, default='STDCNet813')
    parser.add_argument('--pretrain_path', type=str, default='/kaggle/working/STDC-Seg/pretrained_models/STDCNet813M_73.91.tar')
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--respath', type=str, default='/kaggle/working/output')
    parser.add_argument('--use_conv_last', action='store_true')
    parser.add_argument('--use_boundary_2', action='store_true')
    parser.add_argument('--use_boundary_4', action='store_true')
    parser.add_argument('--use_boundary_8', action='store_true')
    parser.add_argument('--use_boundary_16', action='store_true')
    return parser.parse_args(args=[] if '__file__' not in globals() else None)

# Metrics
def compute_iou(pred, target, num_classes, ignore_index=255):
    ious = []
    pred = pred.view(-1)
    target = target.view(-1)
    for cls in range(num_classes):
        mask = target == cls
        if mask.sum().item() == 0:
            ious.append(float('nan'))
            continue
        inter = ((pred == cls) & mask).sum().item()
        union = ((pred == cls) | mask).sum().item()
        ious.append(inter / union if union > 0 else float('nan'))
    miou = np.nanmean(ious)
    return ious, miou

# Training loop
def train():
    args = parse_args()
    os.makedirs(args.respath, exist_ok=True)
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    n_classes = 7
    ignore_idx = 255

    # Dataset and DataLoader
    ds_train = LoveDA('/kaggle/input/loveda-splits', cropsize=args.cropsize,
                      mode='train', randomscale=tuple(args.randomscale), resolution=(512, 512))
    dl_train = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True,
                          num_workers=args.n_workers_train, pin_memory=True, drop_last=False)
    ds_val = LoveDA('/kaggle/input/loveda-splits', mode='val', randomscale=tuple(args.randomscale), resolution=(512, 512))
    dl_val = DataLoader(ds_val, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.n_workers_val, pin_memory=True, drop_last=False)

    logger.info(f"Train samples: {len(ds_train)}, Val samples: {len(ds_val)}")

    # Model
    net = BiSeNet(backbone=args.backbone, n_classes=n_classes,
                  pretrain_model=args.pretrain_path,
                  use_boundary_2=args.use_boundary_2,
                  use_boundary_4=args.use_boundary_4,
                  use_boundary_8=args.use_boundary_8,
                  use_boundary_16=args.use_boundary_16,
                  use_conv_last=args.use_conv_last)
    if args.ckpt:
        net.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
    net = net.to(device)
    net = nn.DataParallel(net)

    # Losses and Optimizer
    n_min = args.batch_size * args.cropsize[0] * args.cropsize[1] // 16
    criteria = {
        'p': WeightedOhemCELoss(0.7, n_min, 7, ignore_lb=ignore_idx),
        '16': WeightedOhemCELoss(0.7, n_min, 7, ignore_lb=ignore_idx),
        '32': WeightedOhemCELoss(0.7, n_min, 7, ignore_lb=ignore_idx)
    }
    boundary_loss = DetailAggregateLoss()
    optimizer = Optimizer(model=net.module, loss=boundary_loss,
                          lr0=1e-3, momentum=0.9, wd=5e-4,
                          warmup_steps=1000, warmup_start_lr=1e-5,
                          max_iter=len(dl_train)*args.epochs,
                          power=0.9)

    
    
    

    # Epoch loop
    best_miou = 0.0
    for epoch in range(1, args.epochs + 1):
        # -- Training --
        net.train()
        running_loss = 0.0
        train_bar = tqdm(dl_train, desc=f"Epoch {epoch}/{args.epochs} Training", unit="batch")
        for imgs, lbs in train_bar:
            imgs = imgs.to(device)
            lbs = lbs.squeeze(1).to(device)

            optimizer.zero_grad()
            outputs = net(imgs)
            out, out16, out32 = outputs[:3]

            lp = criteria['p'](out, lbs)
            l16 = criteria['16'](out16, lbs)
            l32 = criteria['32'](out32, lbs)
            lbce, ldice = 0.0, 0.0
            if args.use_boundary_2 or args.use_boundary_4 or args.use_boundary_8 or args.use_boundary_16:
                for det in outputs[3:]:
                    bce, dice = boundary_loss(det, lbs)
                    lbce += bce; ldice += dice

            loss = lp + l16 + l32 + lbce + ldice
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.set_postfix(loss=f"{loss.item():.4f}")

        avg_train_loss = running_loss / len(dl_train)
        print(f"Epoch {epoch} Training complete. Avg Loss: {avg_train_loss:.4f}")
        logger.info(f"Epoch [{epoch}/{args.epochs}] Train Loss: {avg_train_loss:.4f}")

        # -- Validation --
        net.eval()
        running_val_loss = 0.0
        intersection = torch.zeros(n_classes, dtype=torch.float64)
        union        = torch.zeros(n_classes, dtype=torch.float64)
        val_bar = tqdm(dl_val, desc=f"Epoch {epoch}/{args.epochs} Validation", unit="batch")
        with torch.no_grad():
            for imgs, lbs in val_bar:
                imgs = imgs.to(device)
                lbs = lbs.squeeze(1).to(device)
                outputs = net(imgs)
                out = outputs[0]

                # validation loss
                lval = criteria['p'](out, lbs)
                running_val_loss += lval.item()
                val_bar.set_postfix(valloss=f"{lval.item():.4f}")

                pred = torch.argmax(out, dim=1)
                for cls in range(n_classes):
                    mask = (lbs == cls)
                    if mask.sum().item() == 0:
                        continue
                    inter = ((pred == cls) & mask).sum().item()
                    uni   = ((pred == cls) | mask).sum().item()
                    intersection[cls] += inter
                    union[cls]        += uni

        avg_val_loss = running_val_loss / len(dl_val)
        print(f"Epoch {epoch} Validation complete. Avg Loss: {avg_val_loss:.4f}")
        logger.info(f"Epoch [{epoch}/{args.epochs}] Val Loss: {avg_val_loss:.4f}")

        ious = (intersection / union).tolist()
        miou = float(torch.nanmean(intersection / union))
        print("Per-class IoU:")
        for cls_idx, cls_iou in enumerate(ious):
            print(f"  Class {cls_idx}: {cls_iou:.4f}")
        print(f"Mean IoU: {miou:.4f}")

        for cls_idx, cls_iou in enumerate(ious):
            logger.info(f"Class {cls_idx:>2} IoU: {cls_iou:.4f}")
        logger.info(f"Mean IoU: {miou:.4f}")

        # Save
        ckpt_name = f"epoch{epoch:02d}_miou{miou:.4f}.pth"
        torch.save(net.module.state_dict(), os.path.join(args.respath, ckpt_name))
        if miou > best_miou:
            best_miou = miou
            torch.save(net.module.state_dict(), os.path.join(args.respath, 'best_mIoU.pth'))

    logger.info("Training complete")

if __name__ == '__main__':
    train()

  check_for_updates()


use pretrain model /kaggle/working/STDC-Seg/pretrained_models/STDCNet813M_73.91.tar


Epoch 1/20 Training: 100%|██████████| 193/193 [00:46<00:00,  4.15batch/s, loss=21.8889]


Epoch 1 Training complete. Avg Loss: 27.6975


Epoch 1/20 Validation: 100%|██████████| 166/166 [01:29<00:00,  1.87batch/s, valloss=8.0694] 


Epoch 1 Validation complete. Avg Loss: 9.7881
Per-class IoU:
  Class 0: 0.1040
  Class 1: 0.3158
  Class 2: 0.1210
  Class 3: 0.2569
  Class 4: 0.1200
  Class 5: 0.0768
  Class 6: 0.0257
Mean IoU: 0.1457


Epoch 2/20 Training: 100%|██████████| 193/193 [00:40<00:00,  4.72batch/s, loss=27.5616]


Epoch 2 Training complete. Avg Loss: 21.6751


Epoch 2/20 Validation: 100%|██████████| 166/166 [01:04<00:00,  2.59batch/s, valloss=7.9411] 


Epoch 2 Validation complete. Avg Loss: 9.9730
Per-class IoU:
  Class 0: 0.1322
  Class 1: 0.3897
  Class 2: 0.2821
  Class 3: 0.2955
  Class 4: 0.1077
  Class 5: 0.0811
  Class 6: 0.3741
Mean IoU: 0.2375


Epoch 3/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.83batch/s, loss=22.6214]


Epoch 3 Training complete. Avg Loss: 17.6992


Epoch 3/20 Validation: 100%|██████████| 166/166 [01:02<00:00,  2.64batch/s, valloss=9.8664] 


Epoch 3 Validation complete. Avg Loss: 11.5656
Per-class IoU:
  Class 0: 0.1583
  Class 1: 0.3766
  Class 2: 0.2699
  Class 3: 0.3365
  Class 4: 0.1327
  Class 5: 0.0784
  Class 6: 0.4073
Mean IoU: 0.2514


Epoch 4/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.89batch/s, loss=14.4742]


Epoch 4 Training complete. Avg Loss: 15.5537


Epoch 4/20 Validation: 100%|██████████| 166/166 [01:03<00:00,  2.63batch/s, valloss=13.4657]


Epoch 4 Validation complete. Avg Loss: 12.0635
Per-class IoU:
  Class 0: 0.2130
  Class 1: 0.3127
  Class 2: 0.3139
  Class 3: 0.3709
  Class 4: 0.1479
  Class 5: 0.0865
  Class 6: 0.4457
Mean IoU: 0.2701


Epoch 5/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.91batch/s, loss=19.5182]


Epoch 5 Training complete. Avg Loss: 15.5480


Epoch 5/20 Validation: 100%|██████████| 166/166 [01:03<00:00,  2.63batch/s, valloss=7.9275] 


Epoch 5 Validation complete. Avg Loss: 11.4511
Per-class IoU:
  Class 0: 0.3448
  Class 1: 0.3507
  Class 2: 0.2110
  Class 3: 0.3563
  Class 4: 0.1181
  Class 5: 0.0903
  Class 6: 0.3633
Mean IoU: 0.2621


Epoch 6/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.88batch/s, loss=25.2022]


Epoch 6 Training complete. Avg Loss: 15.5879


Epoch 6/20 Validation: 100%|██████████| 166/166 [01:02<00:00,  2.64batch/s, valloss=13.3801]


Epoch 6 Validation complete. Avg Loss: 14.0007
Per-class IoU:
  Class 0: 0.1126
  Class 1: 0.3362
  Class 2: 0.2480
  Class 3: 0.5015
  Class 4: 0.1922
  Class 5: 0.0843
  Class 6: 0.3984
Mean IoU: 0.2676


Epoch 7/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.92batch/s, loss=13.3293]


Epoch 7 Training complete. Avg Loss: 13.9890


Epoch 7/20 Validation: 100%|██████████| 166/166 [01:02<00:00,  2.66batch/s, valloss=4.4157] 


Epoch 7 Validation complete. Avg Loss: 9.0180
Per-class IoU:
  Class 0: 0.2142
  Class 1: 0.2613
  Class 2: 0.3168
  Class 3: 0.3220
  Class 4: 0.1420
  Class 5: 0.1026
  Class 6: 0.4422
Mean IoU: 0.2573


Epoch 8/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.85batch/s, loss=10.4569]


Epoch 8 Training complete. Avg Loss: 12.7447


Epoch 8/20 Validation: 100%|██████████| 166/166 [01:02<00:00,  2.65batch/s, valloss=6.5751] 


Epoch 8 Validation complete. Avg Loss: 10.5087
Per-class IoU:
  Class 0: 0.5195
  Class 1: 0.2242
  Class 2: 0.3627
  Class 3: 0.4932
  Class 4: 0.1531
  Class 5: 0.1629
  Class 6: 0.3933
Mean IoU: 0.3299


Epoch 9/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.86batch/s, loss=11.5342]


Epoch 9 Training complete. Avg Loss: 11.7095


Epoch 9/20 Validation: 100%|██████████| 166/166 [01:02<00:00,  2.64batch/s, valloss=2.9561] 


Epoch 9 Validation complete. Avg Loss: 12.7916
Per-class IoU:
  Class 0: 0.5343
  Class 1: 0.4062
  Class 2: 0.3950
  Class 3: 0.3080
  Class 4: 0.1664
  Class 5: 0.1203
  Class 6: 0.4652
Mean IoU: 0.3422


Epoch 10/20 Training: 100%|██████████| 193/193 [00:40<00:00,  4.81batch/s, loss=11.7035]


Epoch 10 Training complete. Avg Loss: 11.0777


Epoch 10/20 Validation: 100%|██████████| 166/166 [01:03<00:00,  2.61batch/s, valloss=3.6728] 


Epoch 10 Validation complete. Avg Loss: 13.2460
Per-class IoU:
  Class 0: 0.5260
  Class 1: 0.3695
  Class 2: 0.3706
  Class 3: 0.3556
  Class 4: 0.2139
  Class 5: 0.1465
  Class 6: 0.4560
Mean IoU: 0.3483


Epoch 11/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.90batch/s, loss=9.3982] 


Epoch 11 Training complete. Avg Loss: 10.5084


Epoch 11/20 Validation: 100%|██████████| 166/166 [01:02<00:00,  2.64batch/s, valloss=4.6132] 


Epoch 11 Validation complete. Avg Loss: 15.9478
Per-class IoU:
  Class 0: 0.5378
  Class 1: 0.4205
  Class 2: 0.3007
  Class 3: 0.3236
  Class 4: 0.1898
  Class 5: 0.1639
  Class 6: 0.4576
Mean IoU: 0.3420


Epoch 12/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.90batch/s, loss=9.3677] 


Epoch 12 Training complete. Avg Loss: 9.8825


Epoch 12/20 Validation: 100%|██████████| 166/166 [01:02<00:00,  2.64batch/s, valloss=4.1675] 


Epoch 12 Validation complete. Avg Loss: 18.8689
Per-class IoU:
  Class 0: 0.4837
  Class 1: 0.4497
  Class 2: 0.3620
  Class 3: 0.2808
  Class 4: 0.1775
  Class 5: 0.1417
  Class 6: 0.4414
Mean IoU: 0.3338


Epoch 13/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.92batch/s, loss=9.0567] 


Epoch 13 Training complete. Avg Loss: 9.5637


Epoch 13/20 Validation: 100%|██████████| 166/166 [01:02<00:00,  2.66batch/s, valloss=5.3387] 


Epoch 13 Validation complete. Avg Loss: 17.7695
Per-class IoU:
  Class 0: 0.4945
  Class 1: 0.3834
  Class 2: 0.3791
  Class 3: 0.3005
  Class 4: 0.1618
  Class 5: 0.1903
  Class 6: 0.4547
Mean IoU: 0.3377


Epoch 14/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.88batch/s, loss=11.5207]


Epoch 14 Training complete. Avg Loss: 9.3384


Epoch 14/20 Validation: 100%|██████████| 166/166 [01:02<00:00,  2.65batch/s, valloss=5.2183] 


Epoch 14 Validation complete. Avg Loss: 21.5191
Per-class IoU:
  Class 0: 0.5352
  Class 1: 0.4655
  Class 2: 0.3430
  Class 3: 0.2811
  Class 4: 0.2211
  Class 5: 0.0937
  Class 6: 0.4475
Mean IoU: 0.3410


Epoch 15/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.92batch/s, loss=9.1401] 


Epoch 15 Training complete. Avg Loss: 8.9203


Epoch 15/20 Validation: 100%|██████████| 166/166 [01:03<00:00,  2.63batch/s, valloss=6.6484] 


Epoch 15 Validation complete. Avg Loss: 21.4881
Per-class IoU:
  Class 0: 0.5324
  Class 1: 0.4693
  Class 2: 0.3309
  Class 3: 0.3390
  Class 4: 0.1617
  Class 5: 0.1871
  Class 6: 0.4310
Mean IoU: 0.3502


Epoch 16/20 Training: 100%|██████████| 193/193 [00:40<00:00,  4.82batch/s, loss=9.4671] 


Epoch 16 Training complete. Avg Loss: 8.8081


Epoch 16/20 Validation: 100%|██████████| 166/166 [01:03<00:00,  2.62batch/s, valloss=7.1461] 


Epoch 16 Validation complete. Avg Loss: 28.3798
Per-class IoU:
  Class 0: 0.5357
  Class 1: 0.3443
  Class 2: 0.3455
  Class 3: 0.2871
  Class 4: 0.1316
  Class 5: 0.1247
  Class 6: 0.4485
Mean IoU: 0.3168


Epoch 17/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.84batch/s, loss=8.2077] 


Epoch 17 Training complete. Avg Loss: 8.6263


Epoch 17/20 Validation: 100%|██████████| 166/166 [01:02<00:00,  2.65batch/s, valloss=6.0330] 


Epoch 17 Validation complete. Avg Loss: 24.0907
Per-class IoU:
  Class 0: 0.5440
  Class 1: 0.4230
  Class 2: 0.3346
  Class 3: 0.3089
  Class 4: 0.1685
  Class 5: 0.1492
  Class 6: 0.4508
Mean IoU: 0.3399


Epoch 18/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.90batch/s, loss=8.6792] 


Epoch 18 Training complete. Avg Loss: 8.4582


Epoch 18/20 Validation: 100%|██████████| 166/166 [01:02<00:00,  2.64batch/s, valloss=10.4303]


Epoch 18 Validation complete. Avg Loss: 23.7006
Per-class IoU:
  Class 0: 0.5379
  Class 1: 0.4203
  Class 2: 0.3134
  Class 3: 0.3579
  Class 4: 0.1440
  Class 5: 0.1431
  Class 6: 0.4270
Mean IoU: 0.3348


Epoch 19/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.90batch/s, loss=9.1082] 


Epoch 19 Training complete. Avg Loss: 8.2817


Epoch 19/20 Validation: 100%|██████████| 166/166 [01:03<00:00,  2.61batch/s, valloss=10.9586]


Epoch 19 Validation complete. Avg Loss: 28.1996
Per-class IoU:
  Class 0: 0.5354
  Class 1: 0.3817
  Class 2: 0.2937
  Class 3: 0.3050
  Class 4: 0.1432
  Class 5: 0.1367
  Class 6: 0.4347
Mean IoU: 0.3186


Epoch 20/20 Training: 100%|██████████| 193/193 [00:39<00:00,  4.92batch/s, loss=10.1442]


Epoch 20 Training complete. Avg Loss: 8.1664


Epoch 20/20 Validation: 100%|██████████| 166/166 [01:03<00:00,  2.63batch/s, valloss=10.6387]


Epoch 20 Validation complete. Avg Loss: 27.4577
Per-class IoU:
  Class 0: 0.5401
  Class 1: 0.4190
  Class 2: 0.3360
  Class 3: 0.3257
  Class 4: 0.1456
  Class 5: 0.1115
  Class 6: 0.4359
Mean IoU: 0.3305
