In [2]:
import os
import sys
import random
import logging
import argparse
import datetime
import numpy as np
from tqdm import tqdm
from omegaconf import OmegaConf
import matplotlib.pyplot as plt

sys.path.append(".")



import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler


from wetr.PAR import PAR
from datasets import coco
from utils import evaluate, imutils
from wetr.model_attn_aff import WeTr
from utils.losses import get_aff_loss
from utils.optimizer import PolyWarmupAdamW
from utils.AverageMeter import AverageMeter
from utils.camutils import (cam_to_label, cams_to_affinity_label, ignore_img_box,
                            multi_scale_cam, multi_scale_cam_with_aff_mat,
                            propagte_aff_cam_with_bkg, refine_cams_with_bkg_v2,
                            refine_cams_with_cls_label)


In [7]:
os.getcwd()

'/home/zephyr/Desktop/Newcastle_University/11_FP_D/code/AFA/afa'

In [11]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument("--config",
                    default='configs/coco_attn_reg.yaml',
                    type=str,
                    help="config")
parser.add_argument("--pooling", default="gmp", type=str, help="pooling method")
parser.add_argument("--seg_detach", action="store_true", help="detach seg")
parser.add_argument("--work_dir", default=None, type=str, help="work_dir")
parser.add_argument("--local_rank", default=-1, type=int, help="local_rank")
parser.add_argument("--radius", default=8, type=int, help="radius")
parser.add_argument("--crop_size", default=224, type=int, help="crop_size")
parser.add_argument('--backend', default='nccl')

_StoreAction(option_strings=['--backend'], dest='backend', nargs=None, const=None, default='nccl', type=None, choices=None, help=None, metavar=None)

In [4]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [5]:
def get_down_size(ori_shape=(224,224), stride=16):
    h, w = ori_shape
    _h = h // stride + 1 - ((h % stride) == 0)
    _w = w // stride + 1 - ((w % stride) == 0)
    return _h, _w

In [8]:
print(torch.cuda.set_device(-1))

None


In [None]:
train_dataset = coco.CocoClsDataset(
    root_dir=cfg.dataset.root_dir,
    name_list_dir=cfg.dataset.name_list_dir,
    split=cfg.train.split,
    stage='train',
    aug=True,
    resize_range=cfg.dataset.resize_range,
    rescale_range=cfg.dataset.rescale_range,
    crop_size=cfg.dataset.crop_size,
    img_fliplr=True,
    ignore_index=cfg.dataset.ignore_index,
    num_classes=cfg.dataset.num_classes,
    )
    
val_dataset = coco.CocoSegDataset(
    root_dir=cfg.dataset.root_dir,
    name_list_dir=cfg.dataset.name_list_dir,
    split=cfg.val.split,
    stage='val',
    aug=False,
    ignore_index=cfg.dataset.ignore_index,
    num_classes=cfg.dataset.num_classes,
    )

In [None]:
train_sampler = DistributedSampler(train_dataset,shuffle=True)
train_loader = DataLoader(train_dataset,
                          batch_size=cfg.train.samples_per_gpu,
                          #shuffle=True,
                          num_workers=num_workers,
                          pin_memory=False,
                          drop_last=True,
                          sampler=train_sampler,
                          prefetch_factor=4)

val_loader = DataLoader(val_dataset,
                        batch_size=1,
                        shuffle=False,
                        num_workers=num_workers,
                        pin_memory=False,
                        drop_last=False)

In [None]:
device = torch.device(args.local_rank)

In [None]:
dist.init_process_group(backend='nccl',)

In [19]:
wetr = WeTr(backbone='mit_b1',
            stride=[4, 2, 2, 1],
            num_classes=3,
            embedding_dim=256,
            pretrained=False,
            pooling='gmp',
           )

param_groups = wetr.get_param_groups()
par = PAR(num_iter=15, dilations=[1,2,4,8,12,24])

wetr.to(device)
par.to(device)

PAR()

In [21]:
wetr(torch.randn(1, 3, 224, 224), cam_only=True)

ValueError: SyncBatchNorm expected input tensor to be on GPU

In [None]:
def train():

    num_workers = 10

    mask_size = int(cfg.dataset.crop_size // 16)
    infer_size = int((224 * max([1, 0.5, 1.5])) // 16)
    attn_mask = get_mask_by_radius(h=mask_size, w=mask_size, radius=args.radius)
    attn_mask_infer = get_mask_by_radius(h=infer_size, w=infer_size, radius=args.radius)
    
    optimizer = PolyWarmupAdamW(
        params=[
            {
                "params": param_groups[0],
                "lr": 6e-5,
                "weight_decay": 0.01,
            },
            {
                "params": param_groups[1],
                "lr": 0.0, ## freeze norm layers
                "weight_decay": 0.0,
            },
            {
                "params": param_groups[2],
                "lr": 6e-5*10,
                "weight_decay": 0.01,
            },
            {
                "params": param_groups[3],
                "lr": 6e-5*10,
                "weight_decay": 0.01,
            },
        ],
        lr = 6e-5,
        weight_decay = 0.01,
        betas = [0.9, 0.999],
        warmup_iter = 1500,
        max_iter = 80000,
        warmup_ratio = 1e-6,
        power = 1.0
    )
    
    #wetr = DistributedDataParallel(wetr, device_ids=[args.local_rank], find_unused_parameters=True)
    # loss_layer = DenseEnergyLoss(weight=1e-7, sigma_rgb=15, sigma_xy=100, scale_factor=0.5)
    train_sampler.set_epoch(np.random.randint(80000))
    train_loader_iter = iter(train_loader)

    #for n_iter in tqdm(range(cfg.train.max_iters), total=cfg.train.max_iters, dynamic_ncols=True):
    avg_meter = AverageMeter()

    bkg_cls = torch.ones(size=(2, 1))
    
    img_box - None

    for n_iter in range(80000):
        
        try:
            #img_name, inputs, cls_labels, img_box = next(train_loader_iter)
            inputs, cls_labels = next(train_loader_iter)
        except:
            train_sampler.set_epoch(np.random.randint(80000))
            train_loader_iter = iter(train_loader)
            img_name, inputs, cls_labels, img_box = next(train_loader_iter)
        
        inputs = inputs.to(device, non_blocking=True)
        inputs_denorm = imutils.denormalize_img2(inputs.clone())
        cls_labels = cls_labels.to(device, non_blocking=True)
        
        cls, segs, attns, attn_pred = wetr(inputs, seg_detach=args.seg_detach)

        cams, aff_mat = multi_scale_cam_with_aff_mat(wetr, inputs=inputs, scales=cfg.cam.scales)
        valid_cam, pseudo_label = cam_to_label(cams.detach(), cls_label=cls_labels, img_box=img_box, ignore_mid=True, cfg=cfg)

        ######################
        valid_cam_resized = F.interpolate(valid_cam, size=(infer_size, infer_size), mode='bilinear', align_corners=False)

        aff_cam_l = propagte_aff_cam_with_bkg(valid_cam_resized, aff=aff_mat.detach().clone(), mask=attn_mask_infer, cls_labels=cls_labels, bkg_score=cfg.cam.low_thre)
        aff_cam_l = F.interpolate(aff_cam_l, size=pseudo_label.shape[1:], mode='bilinear', align_corners=False)
        aff_cam_h = propagte_aff_cam_with_bkg(valid_cam_resized, aff=aff_mat.detach().clone(), mask=attn_mask_infer, cls_labels=cls_labels, bkg_score=cfg.cam.high_thre)
        aff_cam_h = F.interpolate(aff_cam_h, size=pseudo_label.shape[1:], mode='bilinear', align_corners=False)

        
        bkg_cls = bkg_cls.to(cams.device)
        _cls_labels = torch.cat((bkg_cls, cls_labels), dim=1)

        refined_aff_cam_l = refine_cams_with_cls_label(par, inputs_denorm, cams=aff_cam_l, labels=_cls_labels, img_box=img_box)
        refined_aff_label_l = refined_aff_cam_l.argmax(dim=1)
        refined_aff_cam_h = refine_cams_with_cls_label(par, inputs_denorm, cams=aff_cam_h, labels=_cls_labels, img_box=img_box)
        refined_aff_label_h = refined_aff_cam_h.argmax(dim=1)

        aff_cam = aff_cam_l[:,1:]
        refined_aff_cam = refined_aff_cam_l[:,1:,]
        refined_aff_label = refined_aff_label_h.clone()
        refined_aff_label[refined_aff_label_h == 0] = cfg.dataset.ignore_index
        refined_aff_label[(refined_aff_label_h + refined_aff_label_l) == 0] = 0
        refined_aff_label = ignore_img_box(refined_aff_label, img_box=img_box, ignore_index=cfg.dataset.ignore_index)
        ######################

        refined_pseudo_label = refine_cams_with_bkg_v2(par, inputs_denorm, cams=cams, cls_labels=cls_labels, cfg=cfg, img_box=img_box)

        if n_iter <= 15000:
            refined_aff_label = refined_pseudo_label

        aff_label = cams_to_affinity_label(refined_aff_label, mask=attn_mask, ignore_index=cfg.dataset.ignore_index)
        aff_loss, pos_count, neg_count = get_aff_loss(attn_pred, aff_label)

        segs = F.interpolate(segs, size=refined_pseudo_label.shape[1:], mode='bilinear', align_corners=False)

        seg_loss = get_seg_loss(segs, refined_aff_label.type(torch.long), ignore_index=cfg.dataset.ignore_index)
        #reg_loss = get_energy_loss(img=inputs, logit=segs, label=refined_aff_label, img_box=img_box, loss_layer=loss_layer)
        #seg_loss = F.cross_entropy(segs, pseudo_label.type(torch.long), ignore_index=cfg.dataset.ignore_index)
        cls_loss = F.multilabel_soft_margin_loss(cls, cls_labels)
        
        if n_iter <= cfg.train.cam_iters:
            loss = 1.0 * cls_loss + 0.0 * seg_loss + 0.0 * aff_loss# + 0.0 * reg_loss
        else: 
            loss = 1.0 * cls_loss + 0.1 * seg_loss + 0.1 * aff_loss# + 0.01 * reg_loss


        avg_meter.add({'cls_loss': cls_loss.item(), 'seg_loss': seg_loss.item(), 'aff_loss': aff_loss.item()})

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    return True