# ICS 504 Project

In [1]:
import logging
import os
import pprint

import torch
from torch import nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import yaml

from dataset.semi import SemiDataset
from model.semseg.deeplabv3plus import DeepLabV3Plus
from no_distrib import evaluate
from util.classes import CLASSES
from util.ohem import ProbOhemCrossEntropy2d
from util.utils import count_params, init_log, AverageMeter
from util.dist_helper import setup_distributed

from earlystopper import EarlyStopper

In [2]:
dataset='pascal'
method='unimatch'
exp='r101'
split='732'

config = f"configs/{dataset}.yaml"
labeled_id_path = f"splits/{dataset}/{split}/labeled.txt"
unlabeled_id_path = f"splits/{dataset}/{split}/unlabeled.txt"
save_path = f"exp/{dataset}/{method}/{exp}/{split}"
local_rank = 0
port = 1202

In [3]:
cfg = yaml.load(open(config, "r"), Loader=yaml.Loader)
print(cfg)

{'dataset': 'pascal', 'nclass': 21, 'crop_size': 321, 'data_root': 'E:\\ICS_504\\project_datasets\\Pascal\\', 'epochs': 80, 'batch_size': 2, 'lr': 0.001, 'lr_multi': 10.0, 'criterion': {'name': 'CELoss', 'kwargs': {'ignore_index': 255}}, 'conf_thresh': 0.95, 'weak_threshold': 0.7, 'model': 'deeplabv3plus', 'backbone': 'resnet101', 'replace_stride_with_dilation': [False, False, True], 'dilations': [6, 12, 18]}


In [4]:
logger = init_log('global', logging.INFO)
logger.propagate = 0

In [5]:
# rank, world_size = setup_distributed(port=port)
os.makedirs(save_path, exist_ok=True)
writer = SummaryWriter(save_path)
cudnn.enabled = True
cudnn.benchmark = True

In [6]:
model = DeepLabV3Plus(cfg)
optimizer = SGD([{'params': model.backbone.parameters(), 'lr': cfg['lr']},
                {'params': [param for name, param in model.named_parameters() if 'backbone' not in name],
                'lr': cfg['lr'] * cfg['lr_multi']}], lr=cfg['lr'], momentum=0.9, weight_decay=1e-4)


# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda()

# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], broadcast_buffers=False,
#                                                       output_device=local_rank, find_unused_parameters=False)
logger.info('Total params: {:.1f}M\n'.format(count_params(model)))

[2023-05-22 23:17:12,711][    INFO] Total params: 60.2M



In [7]:
if cfg['criterion']['name'] == 'CELoss':
    criterion_l = nn.CrossEntropyLoss(**cfg['criterion']['kwargs']).cuda(local_rank)
elif cfg['criterion']['name'] == 'OHEM':
    criterion_l = ProbOhemCrossEntropy2d(**cfg['criterion']['kwargs']).cuda(local_rank)
else:
    raise NotImplementedError('%s criterion is not implemented' % cfg['criterion']['name'])

criterion_u = nn.CrossEntropyLoss(reduction='none').cuda(local_rank)

In [8]:
trainset_u = SemiDataset(cfg['dataset'], cfg['data_root'], 'train_u',
                            cfg['crop_size'], unlabeled_id_path)
trainset_l = SemiDataset(cfg['dataset'], cfg['data_root'], 'train_l',
                            cfg['crop_size'], labeled_id_path, nsample=len(trainset_u.ids))
valset = SemiDataset(cfg['dataset'], cfg['data_root'], 'val')

trainsampler_l = torch.utils.data.SequentialSampler(trainset_l)
trainloader_l = DataLoader(trainset_l, batch_size=cfg['batch_size'],
                            pin_memory=True, num_workers=1, drop_last=True, sampler=trainsampler_l)
trainsampler_u = torch.utils.data.SequentialSampler(trainset_u)
trainloader_u = DataLoader(trainset_u, batch_size=cfg['batch_size'],
                            pin_memory=True, num_workers=1, drop_last=True, sampler=trainsampler_u)
valsampler = torch.utils.data.SequentialSampler(valset)
valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=1,
                        drop_last=False, sampler=valsampler)


In [9]:
total_iters = len(trainloader_u) * cfg['epochs']
previous_best = 0.0
epoch = -1

early_stopper = EarlyStopper(patience=5, min_delta=3)

logger.info(f"Total iters: {total_iters}")

[2023-05-22 23:17:12,740][    INFO] Total iters: 394000


In [10]:
if os.path.exists(os.path.join(save_path, 'latest.pth')):
    checkpoint = torch.load(os.path.join(save_path, 'latest.pth'))
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch = checkpoint['epoch']
    previous_best = checkpoint['previous_best']
    
    logger.info('************ Load from checkpoint at epoch %i\n' % epoch)

In [11]:
for epoch in range(epoch + 1, cfg['epochs']):
    logger.info('===========> Epoch: {:}, LR: {:.5f}, Previous best: {:.2f}'.format(
        epoch, optimizer.param_groups[0]['lr'], previous_best))

    total_loss = AverageMeter()
    total_loss_x = AverageMeter()
    total_loss_s = AverageMeter()
    total_loss_w_fp = AverageMeter()
    total_mask_ratio = AverageMeter()

    # trainloader_l.sampler.set_epoch(epoch)
    # trainloader_u.sampler.set_epoch(epoch)

    loader = zip(trainloader_l, trainloader_u, trainloader_u)

    for i, ((img_x, mask_x),
            (img_u_w, img_u_s1, img_u_s2, ignore_mask, cutmix_box1, cutmix_box2),
            (img_u_w_mix, img_u_s1_mix, img_u_s2_mix, ignore_mask_mix, _, _)) in enumerate(loader):
        
        img_x, mask_x = img_x.cuda(), mask_x.cuda()
        img_u_w = img_u_w.cuda()
        img_u_s1, img_u_s2, ignore_mask = img_u_s1.cuda(), img_u_s2.cuda(), ignore_mask.cuda()
        cutmix_box1, cutmix_box2 = cutmix_box1.cuda(), cutmix_box2.cuda()
        img_u_w_mix = img_u_w_mix.cuda()
        img_u_s1_mix, img_u_s2_mix = img_u_s1_mix.cuda(), img_u_s2_mix.cuda()
        ignore_mask_mix = ignore_mask_mix.cuda()

        with torch.no_grad():
            model.eval()

            pred_u_w_mix = model(img_u_w_mix).detach()
            conf_u_w_mix = pred_u_w_mix.softmax(dim=1).max(dim=1)[0]
            mask_u_w_mix = pred_u_w_mix.argmax(dim=1)

        img_u_s1[cutmix_box1.unsqueeze(1).expand(img_u_s1.shape) == 1] = \
            img_u_s1_mix[cutmix_box1.unsqueeze(1).expand(img_u_s1.shape) == 1]
        img_u_s2[cutmix_box2.unsqueeze(1).expand(img_u_s2.shape) == 1] = \
            img_u_s2_mix[cutmix_box2.unsqueeze(1).expand(img_u_s2.shape) == 1]

        model.train()

        num_lb, num_ulb = img_x.shape[0], img_u_w.shape[0]

        preds, preds_fp = model(torch.cat((img_x, img_u_w)), need_fp=True)
        pred_x, pred_u_w = preds.split([num_lb, num_ulb])
        pred_u_w_fp = preds_fp[num_lb:]

        pred_u_s1, pred_u_s2 = model(torch.cat((img_u_s1, img_u_s2))).chunk(2)

        pred_u_w = pred_u_w.detach()
        conf_u_w = pred_u_w.softmax(dim=1).max(dim=1)[0]
        mask_u_w = pred_u_w.argmax(dim=1)

        mask_u_w_cutmixed1, conf_u_w_cutmixed1, ignore_mask_cutmixed1 = \
            mask_u_w.clone(), conf_u_w.clone(), ignore_mask.clone()
        mask_u_w_cutmixed2, conf_u_w_cutmixed2, ignore_mask_cutmixed2 = \
            mask_u_w.clone(), conf_u_w.clone(), ignore_mask.clone()

        mask_u_w_cutmixed1[cutmix_box1 == 1] = mask_u_w_mix[cutmix_box1 == 1]
        conf_u_w_cutmixed1[cutmix_box1 == 1] = conf_u_w_mix[cutmix_box1 == 1]
        ignore_mask_cutmixed1[cutmix_box1 == 1] = ignore_mask_mix[cutmix_box1 == 1]

        mask_u_w_cutmixed2[cutmix_box2 == 1] = mask_u_w_mix[cutmix_box2 == 1]
        conf_u_w_cutmixed2[cutmix_box2 == 1] = conf_u_w_mix[cutmix_box2 == 1]
        ignore_mask_cutmixed2[cutmix_box2 == 1] = ignore_mask_mix[cutmix_box2 == 1]

        loss_x = criterion_l(pred_x, mask_x)

        loss_u_s1 = criterion_u(pred_u_s1, mask_u_w_cutmixed1)
        loss_u_s1 = loss_u_s1 * ((conf_u_w_cutmixed1 >= cfg['conf_thresh']) & (ignore_mask_cutmixed1 != 255))
        loss_u_s1 = loss_u_s1.sum() / (ignore_mask_cutmixed1 != 255).sum().item()

        loss_u_s2 = criterion_u(pred_u_s2, mask_u_w_cutmixed2)
        loss_u_s2 = loss_u_s2 * ((conf_u_w_cutmixed2 >= cfg['conf_thresh']) & (ignore_mask_cutmixed2 != 255))
        loss_u_s2 = loss_u_s2.sum() / (ignore_mask_cutmixed2 != 255).sum().item()

        loss_u_w_fp = criterion_u(pred_u_w_fp, mask_u_w)
        loss_u_w_fp = loss_u_w_fp * ((conf_u_w >= cfg['conf_thresh']) & (ignore_mask != 255))
        loss_u_w_fp = loss_u_w_fp.sum() / (ignore_mask != 255).sum().item()

        loss = (loss_x + loss_u_s1 * 0.25 + loss_u_s2 * 0.25 + loss_u_w_fp * 0.5) / 2.0

        # torch.distributed.barrier()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss.update(loss.item())
        total_loss_x.update(loss_x.item())
        total_loss_s.update((loss_u_s1.item() + loss_u_s2.item()) / 2.0)
        total_loss_w_fp.update(loss_u_w_fp.item())
        
        mask_ratio = ((conf_u_w >= cfg['conf_thresh']) & (ignore_mask != 255)).sum().item() / \
            (ignore_mask != 255).sum()
        total_mask_ratio.update(mask_ratio.item())

        iters = epoch * len(trainloader_u) + i
        lr = cfg['lr'] * (1 - iters / total_iters) ** 0.9
        optimizer.param_groups[0]["lr"] = lr
        optimizer.param_groups[1]["lr"] = lr * cfg['lr_multi']
        
        writer.add_scalar('train/loss_all', loss.item(), iters)
        writer.add_scalar('train/loss_x', loss_x.item(), iters)
        writer.add_scalar('train/loss_s', (loss_u_s1.item() + loss_u_s2.item()) / 2.0, iters)
        writer.add_scalar('train/loss_w_fp', loss_u_w_fp.item(), iters)
        writer.add_scalar('train/mask_ratio', mask_ratio, iters)
    
        if (i % (len(trainloader_u) // 8) == 0):
            logger.info('Iters: {:}, Total loss: {:.3f}, Loss x: {:.3f}, Loss s: {:.3f}, Loss w_fp: {:.3f}, Mask ratio: '
                        '{:.3f}'.format(i, total_loss.avg, total_loss_x.avg, total_loss_s.avg,
                                        total_loss_w_fp.avg, total_mask_ratio.avg))

    eval_mode = 'sliding_window' if cfg['dataset'] == 'cityscapes' else 'original'
    mIoU, iou_class = evaluate(model, valloader, eval_mode, cfg)

    for (cls_idx, iou) in enumerate(iou_class):
        logger.info('***** Evaluation ***** >>>> Class [{:} {:}] '
                    'IoU: {:.2f}'.format(cls_idx, CLASSES[cfg['dataset']][cls_idx], iou))
    logger.info('***** Evaluation {} ***** >>>> MeanIoU: {:.2f}\n'.format(eval_mode, mIoU))
    
    writer.add_scalar('eval/mIoU', mIoU, epoch)
    for i, iou in enumerate(iou_class):
        writer.add_scalar('eval/%s_IoU' % (CLASSES[cfg['dataset']][i]), iou, epoch)

    is_best = mIoU > previous_best
    previous_best = max(mIoU, previous_best)
    
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'previous_best': previous_best,
    }
    torch.save(checkpoint, os.path.join(save_path, 'latest.pth'))
    if is_best:
        torch.save(checkpoint, os.path.join(save_path, 'best.pth'))

    if early_stopper.early_stop(mIoU):
        logger.info("***** Early Stoping *****")
        break

[2023-05-22 23:17:30,072][    INFO] Iters: 0, Total loss: 1.509, Loss x: 3.019, Loss s: 0.000, Loss w_fp: 0.000, Mask ratio: 0.000
[2023-05-22 23:25:12,032][    INFO] Iters: 615, Total loss: 0.626, Loss x: 1.218, Loss s: 0.058, Loss w_fp: 0.008, Mask ratio: 0.242
[2023-05-22 23:32:24,877][    INFO] Iters: 1230, Total loss: 0.567, Loss x: 1.092, Loss s: 0.078, Loss w_fp: 0.008, Mask ratio: 0.268
[2023-05-22 23:39:42,386][    INFO] Iters: 1845, Total loss: 0.523, Loss x: 0.995, Loss s: 0.094, Loss w_fp: 0.008, Mask ratio: 0.292
[2023-05-22 23:47:03,924][    INFO] Iters: 2460, Total loss: 0.489, Loss x: 0.921, Loss s: 0.106, Loss w_fp: 0.008, Mask ratio: 0.308
[2023-05-22 23:54:25,007][    INFO] Iters: 3075, Total loss: 0.458, Loss x: 0.855, Loss s: 0.112, Loss w_fp: 0.008, Mask ratio: 0.324
[2023-05-23 00:01:45,549][    INFO] Iters: 3690, Total loss: 0.439, Loss x: 0.815, Loss s: 0.120, Loss w_fp: 0.008, Mask ratio: 0.335
[2023-05-23 00:09:56,202][    INFO] Iters: 4305, Total loss: 0.418

[2023-05-23 02:17:27,066][    INFO] ***** Evaluation ***** >>>> Class [3 bird] IoU: 72.90
[2023-05-23 02:17:27,067][    INFO] ***** Evaluation ***** >>>> Class [4 boat] IoU: 55.20
[2023-05-23 02:17:27,067][    INFO] ***** Evaluation ***** >>>> Class [5 bottle] IoU: 49.78
[2023-05-23 02:17:27,068][    INFO] ***** Evaluation ***** >>>> Class [6 bus] IoU: 85.48
[2023-05-23 02:17:27,068][    INFO] ***** Evaluation ***** >>>> Class [7 car] IoU: 70.37
[2023-05-23 02:17:27,069][    INFO] ***** Evaluation ***** >>>> Class [8 cat] IoU: 76.38
[2023-05-23 02:17:27,069][    INFO] ***** Evaluation ***** >>>> Class [9 chair] IoU: 22.96
[2023-05-23 02:17:27,070][    INFO] ***** Evaluation ***** >>>> Class [10 cow] IoU: 54.30
[2023-05-23 02:17:27,071][    INFO] ***** Evaluation ***** >>>> Class [11 dining table] IoU: 42.07
[2023-05-23 02:17:27,071][    INFO] ***** Evaluation ***** >>>> Class [12 dog] IoU: 68.23
[2023-05-23 02:17:27,072][    INFO] ***** Evaluation ***** >>>> Class [13 horse] IoU: 58.73

[2023-05-23 04:13:19,088][    INFO] ***** Evaluation ***** >>>> Class [20 tv/monitor] IoU: 42.78
[2023-05-23 04:13:19,089][    INFO] ***** Evaluation original ***** >>>> MeanIoU: 59.09

[2023-05-23 04:13:23,953][    INFO] Iters: 0, Total loss: 0.323, Loss x: 0.205, Loss s: 0.847, Loss w_fp: 0.036, Mask ratio: 0.753
[2023-05-23 04:20:25,813][    INFO] Iters: 615, Total loss: 0.121, Loss x: 0.138, Loss s: 0.194, Loss w_fp: 0.014, Mask ratio: 0.685
[2023-05-23 04:27:27,813][    INFO] Iters: 1230, Total loss: 0.115, Loss x: 0.129, Loss s: 0.187, Loss w_fp: 0.013, Mask ratio: 0.687
[2023-05-23 04:34:29,694][    INFO] Iters: 1845, Total loss: 0.110, Loss x: 0.122, Loss s: 0.184, Loss w_fp: 0.013, Mask ratio: 0.697
[2023-05-23 04:41:31,247][    INFO] Iters: 2460, Total loss: 0.109, Loss x: 0.120, Loss s: 0.183, Loss w_fp: 0.013, Mask ratio: 0.700
[2023-05-23 04:48:32,934][    INFO] Iters: 3075, Total loss: 0.109, Loss x: 0.121, Loss s: 0.181, Loss w_fp: 0.013, Mask ratio: 0.699
[2023-05-23 04

[2023-05-23 07:06:41,023][    INFO] ***** Evaluation ***** >>>> Class [1 aeroplane] IoU: 77.79
[2023-05-23 07:06:41,024][    INFO] ***** Evaluation ***** >>>> Class [2 bicycle] IoU: 53.35
[2023-05-23 07:06:41,024][    INFO] ***** Evaluation ***** >>>> Class [3 bird] IoU: 72.16
[2023-05-23 07:06:41,025][    INFO] ***** Evaluation ***** >>>> Class [4 boat] IoU: 58.49
[2023-05-23 07:06:41,025][    INFO] ***** Evaluation ***** >>>> Class [5 bottle] IoU: 56.58
[2023-05-23 07:06:41,026][    INFO] ***** Evaluation ***** >>>> Class [6 bus] IoU: 89.55
[2023-05-23 07:06:41,026][    INFO] ***** Evaluation ***** >>>> Class [7 car] IoU: 75.37
[2023-05-23 07:06:41,026][    INFO] ***** Evaluation ***** >>>> Class [8 cat] IoU: 32.15
[2023-05-23 07:06:41,027][    INFO] ***** Evaluation ***** >>>> Class [9 chair] IoU: 26.91
[2023-05-23 07:06:41,027][    INFO] ***** Evaluation ***** >>>> Class [10 cow] IoU: 60.11
[2023-05-23 07:06:41,028][    INFO] ***** Evaluation ***** >>>> Class [11 dining table] IoU:

[2023-05-23 09:02:07,983][    INFO] ***** Evaluation ***** >>>> Class [18 sofa] IoU: 44.42
[2023-05-23 09:02:07,984][    INFO] ***** Evaluation ***** >>>> Class [19 train] IoU: 75.76
[2023-05-23 09:02:07,984][    INFO] ***** Evaluation ***** >>>> Class [20 tv/monitor] IoU: 62.94
[2023-05-23 09:02:07,984][    INFO] ***** Evaluation original ***** >>>> MeanIoU: 63.41

[2023-05-23 09:02:12,826][    INFO] Iters: 0, Total loss: 0.113, Loss x: 0.111, Loss s: 0.227, Loss w_fp: 0.005, Mask ratio: 0.766
[2023-05-23 09:09:14,391][    INFO] Iters: 615, Total loss: 0.080, Loss x: 0.078, Loss s: 0.156, Loss w_fp: 0.010, Mask ratio: 0.760
[2023-05-23 09:16:16,264][    INFO] Iters: 1230, Total loss: 0.080, Loss x: 0.081, Loss s: 0.151, Loss w_fp: 0.010, Mask ratio: 0.756
[2023-05-23 09:23:17,892][    INFO] Iters: 1845, Total loss: 0.077, Loss x: 0.075, Loss s: 0.150, Loss w_fp: 0.010, Mask ratio: 0.761
[2023-05-23 09:30:19,743][    INFO] Iters: 2460, Total loss: 0.076, Loss x: 0.074, Loss s: 0.147, L

[2023-05-23 11:55:57,332][    INFO] ***** Evaluation ***** >>>> Class [0 background] IoU: 88.21
[2023-05-23 11:55:57,333][    INFO] ***** Evaluation ***** >>>> Class [1 aeroplane] IoU: 82.34
[2023-05-23 11:55:57,333][    INFO] ***** Evaluation ***** >>>> Class [2 bicycle] IoU: 58.15
[2023-05-23 11:55:57,334][    INFO] ***** Evaluation ***** >>>> Class [3 bird] IoU: 70.99
[2023-05-23 11:55:57,334][    INFO] ***** Evaluation ***** >>>> Class [4 boat] IoU: 67.62
[2023-05-23 11:55:57,335][    INFO] ***** Evaluation ***** >>>> Class [5 bottle] IoU: 65.09
[2023-05-23 11:55:57,335][    INFO] ***** Evaluation ***** >>>> Class [6 bus] IoU: 65.53
[2023-05-23 11:55:57,335][    INFO] ***** Evaluation ***** >>>> Class [7 car] IoU: 60.60
[2023-05-23 11:55:57,336][    INFO] ***** Evaluation ***** >>>> Class [8 cat] IoU: 46.28
[2023-05-23 11:55:57,336][    INFO] ***** Evaluation ***** >>>> Class [9 chair] IoU: 28.15
[2023-05-23 11:55:57,337][    INFO] ***** Evaluation ***** >>>> Class [10 cow] IoU: 64

[2023-05-23 13:52:22,254][    INFO] ***** Evaluation ***** >>>> Class [17 sheep] IoU: 55.27
[2023-05-23 13:52:22,254][    INFO] ***** Evaluation ***** >>>> Class [18 sofa] IoU: 44.29
[2023-05-23 13:52:22,255][    INFO] ***** Evaluation ***** >>>> Class [19 train] IoU: 27.01
[2023-05-23 13:52:22,255][    INFO] ***** Evaluation ***** >>>> Class [20 tv/monitor] IoU: 52.25
[2023-05-23 13:52:22,255][    INFO] ***** Evaluation original ***** >>>> MeanIoU: 58.79

[2023-05-23 13:52:27,062][    INFO] Iters: 0, Total loss: 0.089, Loss x: 0.136, Loss s: 0.078, Loss w_fp: 0.007, Mask ratio: 0.565
[2023-05-23 13:59:31,001][    INFO] Iters: 615, Total loss: 0.060, Loss x: 0.050, Loss s: 0.132, Loss w_fp: 0.008, Mask ratio: 0.794
[2023-05-23 14:06:35,099][    INFO] Iters: 1230, Total loss: 0.060, Loss x: 0.050, Loss s: 0.131, Loss w_fp: 0.008, Mask ratio: 0.799
[2023-05-23 14:13:39,311][    INFO] Iters: 1845, Total loss: 0.059, Loss x: 0.051, Loss s: 0.128, Loss w_fp: 0.008, Mask ratio: 0.801
[2023-0

In [12]:
%load_ext tensorboard

In [16]:
%tensorboard --logdir $save_path