# ICS 504 Project

In [9]:
import logging
import os
import pprint

import torch
from torch import nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import yaml

from dataset.semi import SemiDataset
from model.semseg.deeplabv3plus import DeepLabV3Plus
from no_distrib import evaluate
from util.classes import CLASSES
from util.ohem import ProbOhemCrossEntropy2d
from util.utils import count_params, init_log, AverageMeter
from util.dist_helper import setup_distributed

In [10]:
dataset='pascal'
method='unimatch'
exp='r101'
split='732'

config = f"configs/{dataset}.yaml"
labeled_id_path = f"splits/{dataset}/{split}/labeled.txt"
unlabeled_id_path = f"splits/{dataset}/{split}/unlabeled.txt"
save_path = f"exp/{dataset}/{method}/{exp}/{split}"
local_rank = 0
port = 1202

In [11]:
cfg = yaml.load(open(config, "r"), Loader=yaml.Loader)
print(cfg)

{'dataset': 'pascal', 'nclass': 21, 'crop_size': 321, 'data_root': 'E:\\ICS_504\\project_datasets\\Pascal\\', 'epochs': 2, 'batch_size': 2, 'lr': 0.001, 'lr_multi': 10.0, 'criterion': {'name': 'CELoss', 'kwargs': {'ignore_index': 255}}, 'conf_thresh': 0.95, 'model': 'deeplabv3plus', 'backbone': 'resnet101', 'replace_stride_with_dilation': [False, False, True], 'dilations': [6, 12, 18]}


In [12]:
logger = init_log('global', logging.INFO)
logger.propagate = 0

In [13]:
# rank, world_size = setup_distributed(port=port)
os.makedirs(save_path, exist_ok=True)
writer = SummaryWriter(save_path)
cudnn.enabled = True
cudnn.benchmark = True

In [14]:
model = DeepLabV3Plus(cfg)
optimizer = SGD([{'params': model.backbone.parameters(), 'lr': cfg['lr']},
                {'params': [param for name, param in model.named_parameters() if 'backbone' not in name],
                'lr': cfg['lr'] * cfg['lr_multi']}], lr=cfg['lr'], momentum=0.9, weight_decay=1e-4)


# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda()

# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], broadcast_buffers=False,
#                                                       output_device=local_rank, find_unused_parameters=False)
logger.info('Total params: {:.1f}M\n'.format(count_params(model)))

[2023-05-21 18:34:15,538][    INFO] Total params: 59.5M



In [15]:
if cfg['criterion']['name'] == 'CELoss':
    criterion_l = nn.CrossEntropyLoss(**cfg['criterion']['kwargs']).cuda(local_rank)
elif cfg['criterion']['name'] == 'OHEM':
    criterion_l = ProbOhemCrossEntropy2d(**cfg['criterion']['kwargs']).cuda(local_rank)
else:
    raise NotImplementedError('%s criterion is not implemented' % cfg['criterion']['name'])

criterion_u = nn.CrossEntropyLoss(reduction='none').cuda(local_rank)

In [16]:
trainset_u = SemiDataset(cfg['dataset'], cfg['data_root'], 'train_u',
                            cfg['crop_size'], unlabeled_id_path)
trainset_l = SemiDataset(cfg['dataset'], cfg['data_root'], 'train_l',
                            cfg['crop_size'], labeled_id_path, nsample=len(trainset_u.ids))
valset = SemiDataset(cfg['dataset'], cfg['data_root'], 'val')

trainsampler_l = torch.utils.data.SequentialSampler(trainset_l)
trainloader_l = DataLoader(trainset_l, batch_size=cfg['batch_size'],
                            pin_memory=True, num_workers=1, drop_last=True, sampler=trainsampler_l)
trainsampler_u = torch.utils.data.SequentialSampler(trainset_u)
trainloader_u = DataLoader(trainset_u, batch_size=cfg['batch_size'],
                            pin_memory=True, num_workers=1, drop_last=True, sampler=trainsampler_u)
valsampler = torch.utils.data.SequentialSampler(valset)
valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=1,
                        drop_last=False, sampler=valsampler)


In [17]:
total_iters = len(trainloader_u) * cfg['epochs']
previous_best = 0.0
epoch = -1

logger.info(f"Total iters: {total_iters}")

[2023-05-21 18:34:15,568][    INFO] Total iters: 9850


In [18]:
if os.path.exists(os.path.join(save_path, 'latest.pth')):
    checkpoint = torch.load(os.path.join(save_path, 'latest.pth'))
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch = checkpoint['epoch']
    previous_best = checkpoint['previous_best']
    
    logger.info('************ Load from checkpoint at epoch %i\n' % epoch)

In [19]:
for epoch in range(epoch + 1, cfg['epochs']):
    logger.info('===========> Epoch: {:}, LR: {:.5f}, Previous best: {:.2f}'.format(
        epoch, optimizer.param_groups[0]['lr'], previous_best))

    total_loss = AverageMeter()
    total_loss_x = AverageMeter()
    total_loss_s = AverageMeter()
    total_loss_w_fp = AverageMeter()
    total_mask_ratio = AverageMeter()

    # trainloader_l.sampler.set_epoch(epoch)
    # trainloader_u.sampler.set_epoch(epoch)

    loader = zip(trainloader_l, trainloader_u, trainloader_u)

    for i, ((img_x, mask_x),
            (img_u_w, img_u_s1, img_u_s2, ignore_mask, cutmix_box1, cutmix_box2),
            (img_u_w_mix, img_u_s1_mix, img_u_s2_mix, ignore_mask_mix, _, _)) in enumerate(loader):
        
        img_x, mask_x = img_x.cuda(), mask_x.cuda()
        img_u_w = img_u_w.cuda()
        img_u_s1, img_u_s2, ignore_mask = img_u_s1.cuda(), img_u_s2.cuda(), ignore_mask.cuda()
        cutmix_box1, cutmix_box2 = cutmix_box1.cuda(), cutmix_box2.cuda()
        img_u_w_mix = img_u_w_mix.cuda()
        img_u_s1_mix, img_u_s2_mix = img_u_s1_mix.cuda(), img_u_s2_mix.cuda()
        ignore_mask_mix = ignore_mask_mix.cuda()

        with torch.no_grad():
            model.eval()

            pred_u_w_mix = model(img_u_w_mix).detach()
            conf_u_w_mix = pred_u_w_mix.softmax(dim=1).max(dim=1)[0]
            mask_u_w_mix = pred_u_w_mix.argmax(dim=1)

        img_u_s1[cutmix_box1.unsqueeze(1).expand(img_u_s1.shape) == 1] = \
            img_u_s1_mix[cutmix_box1.unsqueeze(1).expand(img_u_s1.shape) == 1]
        img_u_s2[cutmix_box2.unsqueeze(1).expand(img_u_s2.shape) == 1] = \
            img_u_s2_mix[cutmix_box2.unsqueeze(1).expand(img_u_s2.shape) == 1]

        model.train()

        num_lb, num_ulb = img_x.shape[0], img_u_w.shape[0]

        preds, preds_fp = model(torch.cat((img_x, img_u_w)), True)
        pred_x, pred_u_w = preds.split([num_lb, num_ulb])
        pred_u_w_fp = preds_fp[num_lb:]

        pred_u_s1, pred_u_s2 = model(torch.cat((img_u_s1, img_u_s2))).chunk(2)

        pred_u_w = pred_u_w.detach()
        conf_u_w = pred_u_w.softmax(dim=1).max(dim=1)[0]
        mask_u_w = pred_u_w.argmax(dim=1)

        mask_u_w_cutmixed1, conf_u_w_cutmixed1, ignore_mask_cutmixed1 = \
            mask_u_w.clone(), conf_u_w.clone(), ignore_mask.clone()
        mask_u_w_cutmixed2, conf_u_w_cutmixed2, ignore_mask_cutmixed2 = \
            mask_u_w.clone(), conf_u_w.clone(), ignore_mask.clone()

        mask_u_w_cutmixed1[cutmix_box1 == 1] = mask_u_w_mix[cutmix_box1 == 1]
        conf_u_w_cutmixed1[cutmix_box1 == 1] = conf_u_w_mix[cutmix_box1 == 1]
        ignore_mask_cutmixed1[cutmix_box1 == 1] = ignore_mask_mix[cutmix_box1 == 1]

        mask_u_w_cutmixed2[cutmix_box2 == 1] = mask_u_w_mix[cutmix_box2 == 1]
        conf_u_w_cutmixed2[cutmix_box2 == 1] = conf_u_w_mix[cutmix_box2 == 1]
        ignore_mask_cutmixed2[cutmix_box2 == 1] = ignore_mask_mix[cutmix_box2 == 1]

        loss_x = criterion_l(pred_x, mask_x)

        loss_u_s1 = criterion_u(pred_u_s1, mask_u_w_cutmixed1)
        loss_u_s1 = loss_u_s1 * ((conf_u_w_cutmixed1 >= cfg['conf_thresh']) & (ignore_mask_cutmixed1 != 255))
        loss_u_s1 = loss_u_s1.sum() / (ignore_mask_cutmixed1 != 255).sum().item()

        loss_u_s2 = criterion_u(pred_u_s2, mask_u_w_cutmixed2)
        loss_u_s2 = loss_u_s2 * ((conf_u_w_cutmixed2 >= cfg['conf_thresh']) & (ignore_mask_cutmixed2 != 255))
        loss_u_s2 = loss_u_s2.sum() / (ignore_mask_cutmixed2 != 255).sum().item()

        loss_u_w_fp = criterion_u(pred_u_w_fp, mask_u_w)
        loss_u_w_fp = loss_u_w_fp * ((conf_u_w >= cfg['conf_thresh']) & (ignore_mask != 255))
        loss_u_w_fp = loss_u_w_fp.sum() / (ignore_mask != 255).sum().item()

        loss = (loss_x + loss_u_s1 * 0.25 + loss_u_s2 * 0.25 + loss_u_w_fp * 0.5) / 2.0

        # torch.distributed.barrier()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss.update(loss.item())
        total_loss_x.update(loss_x.item())
        total_loss_s.update((loss_u_s1.item() + loss_u_s2.item()) / 2.0)
        total_loss_w_fp.update(loss_u_w_fp.item())
        
        mask_ratio = ((conf_u_w >= cfg['conf_thresh']) & (ignore_mask != 255)).sum().item() / \
            (ignore_mask != 255).sum()
        total_mask_ratio.update(mask_ratio.item())

        iters = epoch * len(trainloader_u) + i
        lr = cfg['lr'] * (1 - iters / total_iters) ** 0.9
        optimizer.param_groups[0]["lr"] = lr
        optimizer.param_groups[1]["lr"] = lr * cfg['lr_multi']
        
        writer.add_scalar('train/loss_all', loss.item(), iters)
        writer.add_scalar('train/loss_x', loss_x.item(), iters)
        writer.add_scalar('train/loss_s', (loss_u_s1.item() + loss_u_s2.item()) / 2.0, iters)
        writer.add_scalar('train/loss_w_fp', loss_u_w_fp.item(), iters)
        writer.add_scalar('train/mask_ratio', mask_ratio, iters)
    
        if (i % (len(trainloader_u) // 8) == 0):
            logger.info('Iters: {:}, Total loss: {:.3f}, Loss x: {:.3f}, Loss s: {:.3f}, Loss w_fp: {:.3f}, Mask ratio: '
                        '{:.3f}'.format(i, total_loss.avg, total_loss_x.avg, total_loss_s.avg,
                                        total_loss_w_fp.avg, total_mask_ratio.avg))

    eval_mode = 'sliding_window' if cfg['dataset'] == 'cityscapes' else 'original'
    mIoU, iou_class = evaluate(model, valloader, eval_mode, cfg)

    for (cls_idx, iou) in enumerate(iou_class):
        logger.info('***** Evaluation ***** >>>> Class [{:} {:}] '
                    'IoU: {:.2f}'.format(cls_idx, CLASSES[cfg['dataset']][cls_idx], iou))
    logger.info('***** Evaluation {} ***** >>>> MeanIoU: {:.2f}\n'.format(eval_mode, mIoU))
    
    writer.add_scalar('eval/mIoU', mIoU, epoch)
    for i, iou in enumerate(iou_class):
        writer.add_scalar('eval/%s_IoU' % (CLASSES[cfg['dataset']][i]), iou, epoch)

    is_best = mIoU > previous_best
    previous_best = max(mIoU, previous_best)
    
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'previous_best': previous_best,
    }
    torch.save(checkpoint, os.path.join(save_path, 'latest.pth'))
    if is_best:
        torch.save(checkpoint, os.path.join(save_path, 'best.pth'))

[2023-05-21 18:34:33,722][    INFO] Iters: 0, Total loss: 1.422, Loss x: 2.843, Loss s: 0.000, Loss w_fp: 0.000, Mask ratio: 0.000
[2023-05-21 18:40:25,944][    INFO] Iters: 615, Total loss: 0.637, Loss x: 1.237, Loss s: 0.065, Loss w_fp: 0.009, Mask ratio: 0.246
[2023-05-21 18:46:28,559][    INFO] Iters: 1230, Total loss: 0.569, Loss x: 1.092, Loss s: 0.084, Loss w_fp: 0.008, Mask ratio: 0.269
[2023-05-21 18:52:30,206][    INFO] Iters: 1845, Total loss: 0.517, Loss x: 0.984, Loss s: 0.093, Loss w_fp: 0.008, Mask ratio: 0.289
[2023-05-21 18:58:33,647][    INFO] Iters: 2460, Total loss: 0.487, Loss x: 0.920, Loss s: 0.100, Loss w_fp: 0.008, Mask ratio: 0.305
[2023-05-21 19:04:40,866][    INFO] Iters: 3075, Total loss: 0.457, Loss x: 0.856, Loss s: 0.109, Loss w_fp: 0.008, Mask ratio: 0.324
[2023-05-21 19:10:48,802][    INFO] Iters: 3690, Total loss: 0.431, Loss x: 0.798, Loss s: 0.118, Loss w_fp: 0.009, Mask ratio: 0.340
[2023-05-21 19:16:51,025][    INFO] Iters: 4305, Total loss: 0.408