In [1]:
import argparse
import os
import datetime
import logging
import time

import torch
import torch.nn as nn
import torch.utils
import torch.distributed
from torch.utils.data import DataLoader
import multiprocessing

import numpy as np

from core.configs import cfg
from core.datasets import build_dataset
from core.models import build_feature_extractor, build_classifier
from core.solver import adjust_learning_rate
from core.utils.misc import mkdir
from core.utils.logger import setup_logger
from core.utils.metric_logger import MetricLogger
from core.active.build import PixelSelection, RegionSelection
from core.datasets.dataset_path_catalog import DatasetCatalog
from core.loss.negative_learning_loss import NegativeLearningLoss
from core.loss.local_consistent_loss import LocalConsistentLoss
from core.utils.utils import set_random_seed
from torch.utils.tensorboard import SummaryWriter

import warnings
warnings.filterwarnings('ignore')

In [2]:
os.system("nvidia-smi")

Tue Oct 18 02:09:31 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.07    Driver Version: 515.48.07    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-PCI...  On   | 00000000:C3:00.0 Off |                    0 |
| N/A   32C    P0    33W / 250W |      2MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

0

In [3]:
torch.cuda.is_available()

True

In [4]:
parser = argparse.ArgumentParser(description="Active Domain Adaptive Semantic Segmentation Training")
parser.add_argument("-cfg",
                    "--config-file",
                    default="",
                    metavar="FILE",
                    help="path to config file",
                    type=str)
parser.add_argument("--proctitle",
                    type=str,
                    default="RCL-AAA",
                    help="allow a process to change its title",)
parser.add_argument(
    "opts",
    help="Modify config options using the command-line",
    default=None,
    nargs=argparse.REMAINDER
)

# args = parser.parse_args()

args = parser.parse_args(args=['-cfg', 'configs/gtav/deeplabv3plus_r101_RA.yaml','--proctitle', 'RCL-AAA'])
args.opts = ['OUTPUT_DIR', 'results/v3plus_gtav_ra_5.0_precent', 'DEBUG', '0']

if args.opts is not None:
    args.opts[-1] = args.opts[-1].strip('\r\n')

torch.backends.cudnn.benchmark = True

cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()

output_dir = cfg.OUTPUT_DIR
if output_dir:
    mkdir(output_dir)

logger = setup_logger("RCL-AAA", output_dir, 0)
logger.info(args)

logger.info("Loaded configuration file {}".format(args.config_file))
logger.info("Running with config:\n{}".format(cfg))

logger.info('Initializing Cityscapes label mask...')

set_random_seed(cfg.SEED)

2022-10-18 02:09:32,376 RCL-AAA INFO: Namespace(config_file='configs/gtav/deeplabv3plus_r101_RA.yaml', opts=['OUTPUT_DIR', 'results/v3plus_gtav_ra_5.0_precent', 'DEBUG', '0'], proctitle='RCL-AAA')
2022-10-18 02:09:32,393 RCL-AAA INFO: Loaded configuration file configs/gtav/deeplabv3plus_r101_RA.yaml
2022-10-18 02:09:32,394 RCL-AAA INFO: Running with config:
ACTIVE:
  NAME: RCL-AAA
  PIXELS: 40
  RADIUS_K: 1
  RATIO: 0.05
  SELECT_ITER: [10000, 12000, 14000, 16000, 18000]
  SETTING: RA
DATASETS:
  SOURCE_TRAIN: gtav_train
  TARGET_TRAIN: cityscapes_train
  TEST: cityscapes_val
DEBUG: 0
INPUT:
  IGNORE_LABEL: 255
  INPUT_SCALES_TRAIN: (1.0, 1.0)
  INPUT_SIZE_TEST: (1280, 640)
  PIXEL_MEAN: [0.485, 0.456, 0.406]
  PIXEL_STD: [0.229, 0.224, 0.225]
  SOURCE_INPUT_SIZE_TRAIN: (1280, 720)
  TARGET_INPUT_SIZE_TRAIN: (1280, 640)
  TO_BGR255: False
LOSS:
  NEG_WEIGHT_INCRE_STEP: 0.02
  NUMPARTS_H: 2
  NUMPARTS_W: 4
  POS_WEIGHT_INCRE_STEP: 0.02
  TEMPERATURE: 0.05
MODEL:
  DEVICE: cuda
  FREEZE_

In [5]:
print("Here is {} CPU, {} GPU.".format(multiprocessing.cpu_count(), torch.cuda.device_count()))
logger = logging.getLogger("RCL-AAA.trainer")
# tb_writer = SummaryWriter('./{}_{}_tensorboard_log_{}_confidence'.format(cfg.DATASETS.SOURCE_TRAIN.split('_')[0], cfg.DATASETS.TARGET_TRAIN.split('_')[0], cfg.CONFIDENCE.WEIGHT))
# print('Tensorboard writer log has been created at {}'.format('./{}_{}_tensorboard_log_{}_confidence'.format(cfg.DATASETS.SOURCE_TRAIN.split('_')[0], cfg.DATASETS.TARGET_TRAIN.split('_')[0], cfg.CONFIDENCE.WEIGHT)))

# create network
device = torch.device(cfg.MODEL.DEVICE)
feature_extractor = build_feature_extractor(cfg)
#feature_extractor = nn.DataParallel(feature_extractor)
feature_extractor.to(device)

classifier = build_classifier(cfg)
#classifier = nn.DataParallel(classifier)
classifier.to(device)
print()

Here is 256 CPU, 1 GPU.
load checkpoint from http path: https://download.pytorch.org/models/resnet101-5d3b4d8f.pth



In [6]:
# init optimizer
optimizer_fea = torch.optim.SGD(feature_extractor.parameters(), lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)
optimizer_fea.zero_grad()

optimizer_cls = torch.optim.SGD(classifier.parameters(), lr=cfg.SOLVER.BASE_LR * 10, momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)
optimizer_cls.zero_grad()

iteration = 0

In [7]:
# # load checkpoint
# if cfg.resume:
#     logger.info("Loading checkpoint from {}".format(cfg.resume))
#     checkpoint = torch.load(cfg.OUTPUT_DIR + '/' + cfg.resume, map_location=torch.device('cpu'))
#     iteration = checkpoint['iteration']
#     feature_extractor.load_state_dict(checkpoint['feature_extractor'])
#     optimizer_fea.load_state_dict(checkpoint['optimizer_fea'])
#     classifier.load_state_dict(checkpoint['classifier'])
#     optimizer_cls.load_state_dict(checkpoint['optimizer_cls'])
# # feature_extractor = nn.DataParallel(feature_extractor)      # modified by CZC
# # classifier = nn.DataParallel(classifier)            # modified by CZC

In [8]:
# init mask for cityscape
# DatasetCatalog.initMask(cfg)

In [9]:
src_train_data = build_dataset(cfg, mode='train', is_source=True)
tgt_train_data = build_dataset(cfg, mode='train', is_source=False)
tgt_epoch_data = build_dataset(cfg, mode='active', is_source=False, epochwise=True)

src_train_loader = DataLoader(src_train_data, batch_size=2, shuffle=True, num_workers=4,
                              pin_memory=False, drop_last=True)
tgt_train_loader = DataLoader(tgt_train_data, batch_size=2, shuffle=True, num_workers=4,
                              pin_memory=False, drop_last=True)
tgt_epoch_loader = DataLoader(tgt_epoch_data, batch_size=1, shuffle=False, num_workers=4,
                              pin_memory=False, drop_last=True)

Compose(
    <core.datasets.transform.Resize object at 0x152b34922e80>
    <core.datasets.transform.ToTensor object at 0x152b34922d60>
    <core.datasets.transform.Normalize object at 0x152b34922d90>
)
Compose(
    <core.datasets.transform.Resize object at 0x152b54234730>
    <core.datasets.transform.ToTensor object at 0x152b34922fd0>
    <core.datasets.transform.Normalize object at 0x152b34922e20>
)
Compose(
    <core.datasets.transform.Resize object at 0x152b34a216d0>
    <core.datasets.transform.ToTensor object at 0x152b54db9880>
    <core.datasets.transform.Normalize object at 0x152b54234910>
)


In [10]:
torch.set_printoptions(threshold=np.inf)    # added by czc
# src_train_data[0]

In [11]:
print(src_train_data[0]['img'].size())
print(src_train_data[0]['label'].size())

torch.Size([3, 720, 1280])
torch.Size([720, 1280])


In [12]:
src_input, src_label = src_train_data[0]['img'], src_train_data[0]['label']
src_input = src_input.cuda(non_blocking=True)
src_label = src_label.cuda(non_blocking=True)

In [13]:
# init loss
sup_criterion = nn.CrossEntropyLoss(ignore_index=255)
# negative_criterion = NegativeLearningLoss(threshold=cfg.SOLVER.NEGATIVE_THRESHOLD)
# local_consistent_loss = LocalConsistentLoss(cfg.MODEL.NUM_CLASSES, cfg.SOLVER.LCR_TYPE).cuda()


start_warmup_time = time.time()
end = time.time()
max_iters = cfg.SOLVER.MAX_ITER
warmup_iters = 10000
meters = MetricLogger(delimiter="  ")

logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>")
feature_extractor.train()
classifier.train()
active_round = 1

2022-10-18 02:09:46,348 RCL-AAA.trainer INFO: >>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>


## Warm Up

In [14]:
iteration=0
for batch_index, (src_data, tgt_data) in enumerate(zip(src_train_loader, tgt_train_loader)):       
    data_time = time.time() - end

    current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD, cfg.SOLVER.BASE_LR, iteration, max_iters,
                                      power=cfg.SOLVER.LR_POWER)
    # tb_writer.add_scalar(tag="lr", scalar_value=current_lr, global_step=iteration)      # added by czc
    for index in range(len(optimizer_fea.param_groups)):
        optimizer_fea.param_groups[index]['lr'] = current_lr
    for index in range(len(optimizer_cls.param_groups)):
        optimizer_cls.param_groups[index]['lr'] = current_lr * 10

    optimizer_fea.zero_grad()
    optimizer_cls.zero_grad()

    src_input, src_label = src_data['img'], src_data['label']
    src_input = src_input.cuda(non_blocking=True)
    src_label = src_label.cuda(non_blocking=True)
    # print(src_input.size())
    # print(src_label.size())
    # print(src_label.size()[0])

    tgt_input, tgt_mask = tgt_data['img'], tgt_data['mask']
    tgt_input = tgt_input.cuda(non_blocking=True)
    tgt_mask = tgt_mask.cuda(non_blocking=True)


    src_size = src_input.shape[-2:]
    # print(src_size)
    src_out = classifier(feature_extractor(src_input), size=src_size)
    
    tgt_size = tgt_input.shape[-2:]
    tgt_out = classifier(feature_extractor(tgt_input), size=tgt_size)
    predict = torch.softmax(tgt_out, dim=1)

    # target active supervision loss
    if torch.sum((tgt_mask != 255)) != 0:  # target has labeled pixels
        loss_sup_tgt = sup_criterion(tgt_out, tgt_mask)
        # meters.update(loss_sup_tgt=loss_sup_tgt.item())



    # predict = torch.softmax(src_out, dim=1)

    iteration += 1
    if iteration == 1:
        break

In [15]:
sup_criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
# predict[0,:,0:10,0:10]

In [16]:
tgt_label = torch.argmax(predict[:,:,:,:],dim=1)
print(tgt_label.size())

torch.Size([2, 640, 1280])


In [17]:
tgt_label = torch.argmax(predict[0,:,0:5,0:5],dim=0)
print(tgt_label)
print(tgt_label.size())
# print(tgt_out[0,0:3,0:2,0:2])
loss_sup_tgt = sup_criterion(tgt_out[0,0:19,0:5,0:5].permute(1,0,2), tgt_label)
print(loss_sup_tgt)
print(torch.topk(loss_sup_tgt, k=3, dim=1, largest=False))
print(torch.min(loss_sup_tgt))
print(torch.max(loss_sup_tgt))

tensor([[ 6,  6,  6,  6,  6],
        [ 6,  6,  6,  6,  6],
        [ 6,  6,  6,  6, 14],
        [ 6,  6,  6,  6, 14],
        [ 6,  6,  6,  6, 14]], device='cuda:0')
torch.Size([5, 5])
tensor([[2.3433, 2.2887, 2.2379, 2.1907, 2.1471],
        [2.3106, 2.2600, 2.2134, 2.1707, 2.1317],
        [2.2795, 2.2331, 2.1911, 2.1532, 2.1166],
        [2.2500, 2.2079, 2.1707, 2.1382, 2.0971],
        [2.2221, 2.1845, 2.1524, 2.1256, 2.0806]], device='cuda:0',
       grad_fn=<ViewBackward>)
torch.return_types.topk(
values=tensor([[2.1471, 2.1907, 2.2379],
        [2.1317, 2.1707, 2.2134],
        [2.1166, 2.1532, 2.1911],
        [2.0971, 2.1382, 2.1707],
        [2.0806, 2.1256, 2.1524]], device='cuda:0', grad_fn=<TopkBackward>),
indices=tensor([[4, 3, 2],
        [4, 3, 2],
        [4, 3, 2],
        [4, 3, 2],
        [4, 3, 2]], device='cuda:0'))
tensor(2.0806, device='cuda:0', grad_fn=<MinBackward1>)
tensor(2.3433, device='cuda:0', grad_fn=<MaxBackward1>)


### RegionSplit_CentroidCal函数编写

In [18]:
from collections import Counter

from cv2 import threshold
import math
h = 720
w = 1280
numparts_h = 1
numparts_w = 2
parts_h = int(h / numparts_h)
parts_w = int(w / numparts_w)
batch_size = 2
batch_centroids = {}        
is_source = False
tgt_centroids_base_ratio = 0.9999

if is_source:
    predict = torch.softmax(src_out, dim=1)
else:
    predict = torch.softmax(tgt_out, dim=1)
# batch_centroids:{
#                  'img_idx': {
#                               {}_{}': {
#                                         classID: {centroid}, classID: {centroid}, classID: {centroid}, ...
cross_entropy_computation = nn.CrossEntropyLoss(ignore_index=255, reduction='none')

import time
start = time.perf_counter()

tgt_label = torch.argmax(predict, dim=1)


for k in range(batch_size):
    batch_centroids[k] = {}     # k是batch里面img的序号
    for i in range(numparts_h):
        for j in range(numparts_w):
            batch_centroids[k]['{}_{}'.format(i,j)] = {} 
            
            # Get region coordinates 
            if [i,j] == [range(numparts_h)[-1], range(numparts_w)[-1]]:
                rg_id = [i*parts_h, h-1, j*parts_w, w-1]                 # rg_id: region_index
            else:
                rg_id = [i*parts_h, (i+1)*parts_h-1, j*parts_w, (j+1)*parts_w-1]
            # batch_centroids['img_idx_{}'.format(k)]['region_{}_{}'.format(i,j)]['rg_id'] = rg_id
            
            # Get all class ID in a single region
            if is_source == True:
                classID = dict(Counter(src_label[k, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]].cpu().numpy().flatten()))     # 放到代码文件中得改一下
            else:
                classID = dict(Counter(tgt_label[k, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]].cpu().numpy().flatten()))     # 放到代码文件中得改一下
            
            if classID.__contains__(255): del classID[255]
            print(classID)
            # batch_centroids['img_idx_{}'.format(k)]['region_{}_{}'.format(i,j)]['region_class']['classID'] = classID
            
            # Get all predict mean as centroids
            centroids = {}
            for key in classID:
                predict_sum = torch.zeros([1,19], requires_grad=True)
                predict_sum = predict_sum.cuda(non_blocking=True)
                if is_source == True:
                    mask = src_label[k, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]].eq(key)
                else:
                    origin_tgt_mask = tgt_label[k, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]].eq(key)       # 代表是这个类的

                    tgt_ce = cross_entropy_computation(tgt_out[k, :, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]].permute(1,0,2), \
                                                        tgt_label[k, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]])
                    tgt_cls_uncertainty = tgt_ce * origin_tgt_mask      # 取最小的几个值注意可能会取到0
                    
                    unselected_sample_num = math.ceil(origin_tgt_mask.sum().item() * (1 - tgt_centroids_base_ratio))
                    unselected_samples, _ = torch.topk(torch.flatten(tgt_cls_uncertainty), k=unselected_sample_num, dim=-1, largest=True)
                    uncertainty_thres = unselected_samples.min().item()
                    
                    uncertainty_mask = tgt_cls_uncertainty.le(uncertainty_thres)        # 代表uncertainty不会过高的样本

                    mask = origin_tgt_mask * uncertainty_mask
                    
                predict_mask = predict[k, :, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]] * mask
                centroids[key] = predict_mask.sum(axis=[1,2]) / classID[key]


            batch_centroids[k]['{}_{}'.format(i,j)] = centroids
            
end = time.perf_counter()
print(str(end-start))

{6: 117940, 14: 16264, 0: 430, 9: 245361, 4: 3991, 8: 70, 13: 11882, 3: 9107, 18: 78, 1: 3355, 17: 482}
{6: 147136, 14: 24933, 4: 108, 9: 224833, 3: 3566, 13: 6118, 1: 1308, 18: 269, 17: 101, 0: 588}
{3: 57364, 6: 171679, 4: 22884, 13: 4220, 9: 88047, 14: 53254, 1: 7784, 16: 90, 0: 2738, 18: 216, 8: 684}
{6: 169796, 14: 79845, 4: 6258, 3: 28963, 13: 11436, 0: 1358, 9: 98483, 1: 10065, 16: 762, 18: 1840, 17: 154}
0.3419200790813193


### contrastive_loss函数编写

In [19]:
def contrastive_loss(pos_set, neg_set, temperature):
    assert pos_set.size() != 0, "Positive pairs should not be EMPTY!"
    assert neg_set.size() != 0, "Negative pairs should not be EMPTY!"

    pos_head = torch.index_select(pos_set, 0, torch.tensor([0]).cuda(non_blocking=True))
    pos_pairs = torch.mm(pos_head, pos_set.permute(1,0))
    neg_pairs = torch.mm(pos_head, neg_set.permute(1,0))

    all_pairs = torch.cat([neg_pairs.repeat(pos_pairs.size()[1],1), pos_pairs.permute(1,0)], dim=1)
    all_pairs = torch.exp(all_pairs / temperature)

    exp_aggregation_row = all_pairs.sum(dim=1, keepdim=True)
    frac_row = torch.index_select(all_pairs, 1, torch.tensor([all_pairs.size()[1] - 1]).cuda(non_blocking=True)) / exp_aggregation_row
    log_row = torch.log(frac_row)
    
    if pos_set.size()[0] == 1:
        cl_loss = torch.mean(log_row) * (-1)
    else:
        cl_loss = torch.mean(log_row[1:,:]) * (-1)
    
    return cl_loss


### intra-image level contrastive loss 编写

In [20]:
# intra contrastive loss v3

from core.loss.contrastive_loss import ContrastiveLoss

num_classes = 19
positive_weight_increment_step = 0.01
negative_weight_increment_step = 0.01
temperature = 0.07

loss = []

contrastive_loss_criterion = ContrastiveLoss()

for k in batch_centroids:
    # calculate per clas
    for cls in range(num_classes):
        pos, neg = {}, {}
        # pos: {'region_{}_{}': {tensor([...])} }
        # neg: {'region_{}_{}': {cls: tensor([...]), cls: tensor([...]), cls: tensor([...]), ...}}
        for region in batch_centroids[k]:  # 此时的region是 0_0
            neg[region] = {}
            for intra_cls in batch_centroids[k][region]:
                neg[region][intra_cls] = batch_centroids[k][region][intra_cls]    # 只有这样copy tensor，才不至于改变batch_centroids本身
            if batch_centroids[k][region].__contains__(cls):
                pos[region] = batch_centroids[k][region][cls]
                del neg[region][cls]
        pos_region, neg_region = [], []
        pos_region = list(pos.keys()) # positive pairs所在的region
        neg_region = list(neg.keys()) # negative pairs所在的region
        all_region = list(set(pos_region + neg_region)) # 合并两个list，并且删除重复元素
        for region_1 in all_region:
            pos_per_region, neg_per_region = [], []         # 该类以该region为中心的positive和negative pairs
            cl_per_region = []
            region_1_index = np.array([int(region_1.split('_')[0]), int(region_1.split('_')[1])])
            if pos.__contains__(region_1):
                pos_per_region.append(pos[region_1])        # positive pairs的头，即cls在region_1的centroids
            for neg_cls_1 in neg[region_1]:
                neg_per_region.append(neg[region_1][neg_cls_1])   # negatiave pairs的头，即在region_1中除了cls之外的所有centroids
            for region_2 in all_region:                         # 收集其它region的positive pairs和negative pairs
                if region_2 != region_1:
                    region_2_index = np.array([int(region_2.split('_')[0]), int(region_2.split('_')[1])])
                    positive_weight = 1 + positive_weight_increment_step * np.linalg.norm(region_1_index - region_2_index, ord=2) # L2 norm
                    negative_weight = 1 - negative_weight_increment_step * np.linalg.norm(region_1_index - region_2_index, ord=2) # L2 norm
                    if pos.__contains__(region_2):     # 其它区域的positive pairs，将会乘上权重
                        pos_per_region.append(pos[region_2] * positive_weight)
                    # print(pos_per_region)
                    if neg.__contains__(region_2):     # 其它区域的negative pairs，将会乘上权重
                        for neg_cls_2 in neg[region_2]:
                            neg_per_region.append(neg[region_2][neg_cls_2] * negative_weight)
            if pos_per_region != [] and neg_per_region != []:    # 否则会报错，stack不能对empty list操作
                pos_set_cl = torch.stack(pos_per_region, dim=0).cuda(non_blocking=True)       # 第一行tensor就是用于query的positive pair头，剩下都是所有乘上权重后的positive pairs
                neg_set_cl = torch.stack(neg_per_region, dim=0).cuda(non_blocking=True)       # 所有乘上权重后的negative pairs
                
                # cl_per_region = contrastive_loss(pos_set_cl, neg_set_cl, temperature)
                cl_per_region = contrastive_loss_criterion(pos_set_cl, neg_set_cl, temperature)
                # print(cl_per_region)
                loss.append(cl_per_region)

            
if loss != []:
        loss = torch.stack(loss, dim=0).cuda(non_blocking=True)  
        loss = torch.mean(loss)    
else:
    loss = torch.tensor([0]).cuda(non_blocking=True)
print(loss)  

tensor(2.8694, device='cuda:0', grad_fn=<MeanBackward0>)


### inter-images level contrastive loss 编写

In [21]:
# inter images contrastive loss v1
from core.loss.contrastive_loss import ContrastiveLoss

num_classes = 19
positive_weight_increment_step = 0.01
negative_weight_increment_step = 0.01
temperature = 0.07

loss = []

contrastive_loss_criterion = ContrastiveLoss()

for k in batch_centroids:
    # calculate per image
    for j in batch_centroids:
        if k != j:
            # calculate per class
            for cls in range(num_classes):      # pos: 另一张图片的positive samples，pos_origin：该张图片的positive samples
                pos_origin, pos, neg = {}, {}, {}
                # pos: {'region_{}_{}': {tensor([...])} }
                # neg: {'region_{}_{}': {cls: tensor([...]), cls: tensor([...]), cls: tensor([...]), ...}}
                for region in batch_centroids[j]:  # 找到每一个region里该cls对应的其它所有pos所在的region和neg所在的region
                    neg[region] = {}
                    for inter_cls in batch_centroids[j][region]:
                        neg[region][inter_cls] = batch_centroids[j][region][inter_cls]    # 只有这样copy tensor，才不至于改变batch_centroids本身
                    if batch_centroids[j][region].__contains__(cls):
                        pos[region] = batch_centroids[j][region][cls]     # 另一张图片的该区域的pos_region
                        del neg[region][cls]
                    if batch_centroids[k][region].__contains__(cls):        # 本张图片的该区域的pos_region
                        pos_origin[region] = batch_centroids[k][region][cls]

                pos_origin_region, pos_region, neg_region = [], [], []
                pos_region = list(pos.keys()) # positive pairs所在的batch内另一张图片的region
                neg_region = list(neg.keys()) # negative pairs所在的batch内另一张图片region
                pos_origin_region = list(pos_origin.keys()) # positive pairs所在的batch内此张图片的region
                inter_region = list(set(pos_region + neg_region))   # 另一张图片上pos和neg sample的所在region
                all_region = list(set(pos_origin_region + pos_region + neg_region)) # 合并多个list，并且删除重复元素
                
                for region_1 in all_region:
                    pos_per_region, neg_per_region = [], []         # 该类以该region为中心的positive和negative pairs
                    cl_per_region = []
                    region_1_index = np.array([int(region_1.split('_')[0]), int(region_1.split('_')[1])])
                    
                    if pos_origin.__contains__(region_1):
                        pos_per_region.append(pos_origin[region_1])    # positive pairs的头，即cls在该图中region_1的centroids     
                    
                    for neg_cls_1 in neg[region_1]:
                        neg_per_region.append(neg[region_1][neg_cls_1])   # negatiave pairs的头，即在region_1中除了cls之外的所有centroids

                    for region_2 in inter_region:                         # 收集另一张图片每个region的positive pairs和negative pairs
                        region_2_index = np.array([int(region_2.split('_')[0]), int(region_2.split('_')[1])])
                        positive_weight = 1 + positive_weight_increment_step * np.linalg.norm(region_1_index - region_2_index, ord=2) # L2 norm
                        negative_weight = 1 - negative_weight_increment_step * np.linalg.norm(region_1_index - region_2_index, ord=2) # L2 norm

                        if pos.__contains__(region_2):     # 其它区域的positive pairs，将会乘上权重
                            pos_per_region.append(pos[region_2] * positive_weight)

                        if neg.__contains__(region_2):     # 其它区域的negative pairs，将会乘上权重
                            for neg_cls_2 in neg[region_2]:
                                neg_per_region.append(neg[region_2][neg_cls_2] * negative_weight)

                    if pos_per_region != [] and neg_per_region != []:    # 否则会报错，stack不能对empty list操作
                        pos_set_cl = torch.stack(pos_per_region, dim=0).cuda(non_blocking=True)       # 第一行tensor就是用于query的positive pair头，剩下都是所有乘上权重后的positive pairs
                        neg_set_cl = torch.stack(neg_per_region, dim=0).cuda(non_blocking=True)       # 所有乘上权重后的negative pairs
                        
                        cl_per_region = contrastive_loss_criterion(pos_set_cl, neg_set_cl, temperature)

                        loss.append(cl_per_region)
            
if loss != []:
        loss = torch.stack(loss, dim=0).cuda(non_blocking=True)  
        loss = torch.mean(loss)    
else:
    loss = torch.tensor([0]).cuda(non_blocking=True)
print(loss)  

tensor(3.2716, device='cuda:0', grad_fn=<MeanBackward0>)


### cross-domain level contrastive loss 编写

In [22]:
from core.models.region_spliter import RegionSplit_CentroidCal
batch_centroids_src = RegionSplit_CentroidCal(
    predict=torch.softmax(src_out, dim=1),
    label=src_label,
    is_source=True,
    numparts_h=2, 
    numparts_w=4
)
batch_centroids_tgt = RegionSplit_CentroidCal(
    predict=torch.softmax(tgt_out, dim=1),
    label=tgt_label,
    is_source=False,
    numparts_h=2, 
    numparts_w=4,
    tgt_centroids_base_ratio=0.9,
    tgt_out=tgt_out
)

In [23]:
# cross-domain images contrastive loss v1
from core.loss.contrastive_loss import ContrastiveLoss

num_classes = 19
positive_weight_increment_step = 0.01
negative_weight_increment_step = 0.01
temperature = 0.07

loss = []

contrastive_loss_criterion = ContrastiveLoss()

for cl_round in [0,1]:
    if cl_round == 0:
        batch_centroids_0 = batch_centroids_src
        batch_centroids_1 = batch_centroids_tgt
    else:
        batch_centroids_1 = batch_centroids_src
        batch_centroids_0 = batch_centroids_tgt
    for k in batch_centroids_0:
        # calculate per image from another domain
        for j in batch_centroids_1:
            # calculate per class
            print(k,j)
            for cls in range(num_classes):      # pos: 另一张图片的positive samples，pos_origin：该张图片的positive samples
                pos_origin, pos, neg = {}, {}, {}
                # pos: {'region_{}_{}': {tensor([...])} }
                # neg: {'region_{}_{}': {cls: tensor([...]), cls: tensor([...]), cls: tensor([...]), ...}}
                for region in batch_centroids_1[j]:  # 找到每一个region里该cls对应的其它所有pos所在的region和neg所在的region
                    neg[region] = {}
                    for inter_cls in batch_centroids_1[j][region]:
                        neg[region][inter_cls] = batch_centroids_1[j][region][inter_cls]    # 只有这样copy tensor，才不至于改变batch_centroids本身
                    if batch_centroids_1[j][region].__contains__(cls):
                        pos[region] = batch_centroids_1[j][region][cls]     # 另一张图片的该区域的pos_region
                        del neg[region][cls]
                    if batch_centroids_0[k][region].__contains__(cls):        # 本张图片的该区域的pos_region
                        pos_origin[region] = batch_centroids_0[k][region][cls]

                pos_origin_region, pos_region, neg_region = [], [], []
                pos_region = list(pos.keys()) # positive pairs所在的batch内另一张图片的region
                neg_region = list(neg.keys()) # negative pairs所在的batch内另一张图片region
                pos_origin_region = list(pos_origin.keys()) # positive pairs所在的batch内此张图片的region

                inter_region = list(set(pos_region + neg_region))   # 另一张图片上pos和neg sample的所在region
                all_region = list(set(pos_origin_region + pos_region + neg_region)) # 合并多个list，并且删除重复元素
                
                for region_1 in all_region:
                    pos_per_region, neg_per_region = [], []         # 该类以该region为中心的positive和negative pairs
                    cl_per_region = []
                    region_1_index = np.array([int(region_1.split('_')[0]), int(region_1.split('_')[1])])
                    
                    if pos_origin.__contains__(region_1):
                        pos_per_region.append(pos_origin[region_1])    # positive pairs的头，即cls在该图中region_1的centroids     
                    
                    for neg_cls_1 in neg[region_1]:
                        neg_per_region.append(neg[region_1][neg_cls_1])   # negatiave pairs的头，即在region_1中除了cls之外的所有centroids

                    for region_2 in inter_region:                         # 收集另一张图片每个region的positive pairs和negative pairs
                        region_2_index = np.array([int(region_2.split('_')[0]), int(region_2.split('_')[1])])
                        positive_weight = 1 + positive_weight_increment_step * np.linalg.norm(region_1_index - region_2_index, ord=2) # L2 norm
                        negative_weight = 1 - negative_weight_increment_step * np.linalg.norm(region_1_index - region_2_index, ord=2) # L2 norm

                        if pos.__contains__(region_2):     # 其它区域的positive pairs，将会乘上权重
                            pos_per_region.append(pos[region_2] * positive_weight)

                        if neg.__contains__(region_2):     # 其它区域的negative pairs，将会乘上权重
                            for neg_cls_2 in neg[region_2]:
                                neg_per_region.append(neg[region_2][neg_cls_2] * negative_weight)

                    if pos_per_region != [] and neg_per_region != []:    # 否则会报错，stack不能对empty list操作
                        pos_set_cl = torch.stack(pos_per_region, dim=0).cuda(non_blocking=True)       # 第一行tensor就是用于query的positive pair头，剩下都是所有乘上权重后的positive pairs
                        neg_set_cl = torch.stack(neg_per_region, dim=0).cuda(non_blocking=True)       # 所有乘上权重后的negative pairs
                        
                        cl_per_region = contrastive_loss_criterion(pos_set_cl, neg_set_cl, temperature)
                        
                        loss.append(cl_per_region)
                
if loss != []:
        loss = torch.stack(loss, dim=0).cuda(non_blocking=True)  
        loss = torch.mean(loss)    
else:
    loss = torch.tensor([0]).cuda(non_blocking=True)
print(loss)  

0 0
0 1
1 0
1 1
0 0
0 1
1 0
1 1
tensor(4.1016, device='cuda:0', grad_fn=<MeanBackward0>)


### Early Stage Annotation Method

In [24]:
from core.models.region_spliter import RegionSplit_CentroidCal
global_batch_centroids_src = RegionSplit_CentroidCal(
    predict=torch.softmax(src_out, dim=1),
    label=src_label,
    is_source=True,
    numparts_h=1, 
    numparts_w=1
)
batch_centroids_tgt = RegionSplit_CentroidCal(
    predict=torch.softmax(tgt_out, dim=1),
    label=tgt_label,
    is_source=False,
    numparts_h=2, 
    numparts_w=4,
    tgt_centroids_base_ratio=0.9,
    tgt_out=tgt_out
)

#### Similarity Score Measure

In [26]:
from torch import unsqueeze
import torch.nn.functional as F

h = 640
w = 1280
numparts_h = 2  #此处应和上面生成的batch_centroids_tgt一致
numparts_w = 4  #此处应和上面生成的batch_centroids_tgt一致
parts_h = int(h / numparts_h)
parts_w = int(w / numparts_w)

tgt_predict = torch.softmax(tgt_out, dim=1)
# print(tgt_predict.size())     # torch.Size([2, 19, 640, 1280])
# print(tgt_label.size())       # torch.Size([2, 640, 1280])
early_anno_score = torch.zeros(tgt_label.size()).cuda()
early_anno_score[:,:,:] = 0

for img_idx in batch_centroids_tgt:
    for region_index in batch_centroids_tgt[img_idx]:
        i, j = int(region_index.split('_')[0]), int(region_index.split('_')[1])
        
        if [i,j] == [range(numparts_h)[-1], range(numparts_w)[-1]]:
            rg_id = [i*parts_h, h-1, j*parts_w, w-1]
        else:
            rg_id = [i*parts_h, (i+1)*parts_h-1, j*parts_w, (j+1)*parts_w-1]
        for cls in batch_centroids_tgt[img_idx][region_index]:
            if global_batch_centroids_src[img_idx]['0_0'].__contains__(cls) and global_batch_centroids_src[1-img_idx]['0_0'].__contains__(cls):
                src_cls_prototype = (global_batch_centroids_src[img_idx]['0_0'][cls] + global_batch_centroids_src[1-img_idx]['0_0'][cls]) / len(global_batch_centroids_src)
            elif global_batch_centroids_src[img_idx]['0_0'].__contains__(cls):
                src_cls_prototype = global_batch_centroids_src[img_idx]['0_0'][cls]
            elif global_batch_centroids_src[1-img_idx]['0_0'].__contains__(cls):
                src_cls_prototype = global_batch_centroids_src[1-img_idx]['0_0'][cls]
            else:
                continue
            tgt_cls_prototype = batch_centroids_tgt[img_idx][region_index][cls]
            cross_cls_unsimilarity = torch.tensor([1]).cuda() - F.cosine_similarity(src_cls_prototype, tgt_cls_prototype, dim=0)
            cls_mask = tgt_label[img_idx, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]].eq(cls)
            intra_cls_similarity = F.cosine_similarity(tgt_predict[img_idx, :, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]], tgt_cls_prototype.unsqueeze(dim=1).unsqueeze(dim=2), dim=0)
            score_mask = cls_mask * cross_cls_unsimilarity * intra_cls_similarity
            early_anno_score[img_idx, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]] += score_mask


        

# print(similarity_score)

#### Uncertainty Score Measure

In [27]:
from core.active.early_anno_score import EarlyAnnoScore
src_predict = torch.softmax(src_out, dim=1)
batch_centroids_tgt=RegionSplit_CentroidCal(predict=tgt_predict, 
                                                        label=tgt_label, 
                                                        is_source=False,
                                                        numparts_h=2, 
                                                        numparts_w=4,
                                                        tgt_centroids_base_ratio=0.9,
                                                        tgt_out=tgt_out)
global_batch_centroids_src=RegionSplit_CentroidCal(predict=src_predict, 
                                                        label=src_label, 
                                                        is_source=True,
                                                        numparts_h=1, 
                                                        numparts_w=1)
score = EarlyAnnoScore(batch_centroids_tgt=batch_centroids_tgt,
                        global_batch_centroids_src=global_batch_centroids_src,
                        tgt_label=tgt_label,
                        tgt_predict=tgt_predict,
                        numparts_h=2,
                        numparts_w=4
                    )

In [28]:
origin_size

NameError: name 'origin_size' is not defined

In [29]:
print(tgt_label.unsqueeze(dim=0).size())
print(tgt_label.float().size())
tgt_label_interpolation = F.interpolate(tgt_label.float().unsqueeze(dim=0), size=(1024, 2048), mode='bilinear', align_corners=True)

torch.Size([1, 2, 640, 1280])
torch.Size([2, 640, 1280])


In [31]:
values, indices = torch.max(score, dim=0)
print(values.size(), indices.size())

torch.Size([640, 1280]) torch.Size([640, 1280])


In [34]:
import math
import torch

import numpy as np
import torch.nn.functional as F

from PIL import Image
from tqdm import tqdm
from core.active.floating_region import FloatingRegionScore
from core.active.spatial_purity import SpatialPurity
from core.active.early_anno_score import EarlyAnnoScore

now_iteration=10000
max_iter = 50000

feature_extractor.eval()
classifier.eval()


# floating_region_score = FloatingRegionScore(in_channels=cfg.MODEL.NUM_CLASSES, size=2 * cfg.ACTIVE.RADIUS_K + 1).cuda()
# per_region_pixels = (2 * cfg.ACTIVE.RADIUS_K + 1) ** 2
# active_radius = cfg.ACTIVE.RADIUS_K
# mask_radius = cfg.ACTIVE.RADIUS_K * 2
active_ratio = cfg.ACTIVE.RATIO / len(cfg.ACTIVE.SELECT_ITER)

flag = 1
cross_entropy_computation = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
with torch.no_grad():
    for (src_data, tgt_data) in tqdm(zip(src_train_loader, tgt_epoch_loader)):

        src_input, src_label = src_data['img'], src_data['label']
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True)

        src_size = src_input.shape[-2:]
        src_out = classifier(feature_extractor(src_input), size=src_size)
        src_predict = torch.softmax(src_out, dim=1)

        tgt_input, path2mask = tgt_data['img'], tgt_data['path_to_mask']    # tgt_input: torch.Size([1, 3, 640, 1280])
        origin_mask, origin_label = tgt_data['origin_mask'], tgt_data['origin_label']       # origin_mask, origin_label: torch.Size([1, 1024, 2048])
        origin_size = tgt_data['size']
        active_indicator = tgt_data['active']
        selected_indicator = tgt_data['selected']
        path2indicator = tgt_data['path_to_indicator']

        tgt_input = tgt_input.cuda(non_blocking=True)

        tgt_size = tgt_input.shape[-2:]
        tgt_feat = feature_extractor(tgt_input)
        tgt_out = classifier(tgt_feat, size=tgt_size)       # tgt_out: torch.Size([1, 19, 640, 1280])
        tgt_predict = torch.softmax(tgt_out, dim=1)
        tgt_label = torch.argmax(tgt_predict[:,:,:,:],dim=1)    # tgt_label: torch.Size([1, 640, 1280])

        for i in range(len(origin_mask)):       # 一个batch内的图片数量  # origin_mask, origin_label: torch.Size([1, 1024, 2048])
            active_mask = origin_mask[i].cuda(non_blocking=True)
            ground_truth = origin_label[i].cuda(non_blocking=True)
            size = (origin_size[i][0], origin_size[i][1])       # size: tensor(1024), tensor(2048)
            num_pixel_cur = size[0] * size[1]
            active = active_indicator[i]        # torch.Size([1024, 2048])，最开始都是False
            selected = selected_indicator[i]

            output = tgt_out[i:i + 1, :, :, :]
            output = F.interpolate(output, size=size, mode='bilinear', align_corners=True)
            # score, purity, entropy = floating_region_score(output, now_iteration=now_iteration, cfg=cfg)

            tgt_label_interpolation = F.interpolate(tgt_label.float().unsqueeze(dim=0), size=size, mode='bilinear', align_corners=True)
            tgt_label_interpolation = tgt_label_interpolation.squeeze(dim=0).long()     # torch.Size([1, 1024, 2048]) 
            tgt_predict_interpolation = F.interpolate(tgt_predict, size=size, mode='bilinear', align_corners=True)  # torch.Size([1, 19, 1024, 2048])

            batch_centroids_tgt=RegionSplit_CentroidCal(predict=tgt_predict_interpolation, 
                                                        label=tgt_label_interpolation, 
                                                        is_source=False,
                                                        numparts_h=2, 
                                                        numparts_w=4,
                                                        tgt_centroids_base_ratio=0.9,
                                                        tgt_out=output) # !!!
            global_batch_centroids_src=RegionSplit_CentroidCal(predict=src_predict, 
                                                                    label=src_label, 
                                                                    is_source=True,
                                                                    numparts_h=1, 
                                                                    numparts_w=1)
            
       
            similarity_score = EarlyAnnoScore(batch_centroids_tgt=batch_centroids_tgt,
                                    global_batch_centroids_src=global_batch_centroids_src,
                                    tgt_label=tgt_label_interpolation,
                                    tgt_predict=tgt_predict_interpolation,
                                    numparts_h=2,
                                    numparts_w=4)

            # similarity_score: torch.Size([1, 1024, 2048])
            similarity_score = similarity_score.squeeze(dim=0) # similarity_score: torch.Size([1024, 2048])
            similarity_score[active] = 0.0     # 把上一轮已经标注过的pixel给置0

            # uncertainty_score: torch.Size([1, 1024, 2048])
            uncertainty_score = cross_entropy_computation(output, tgt_label_interpolation)
            uncertainty_score = uncertainty_score.squeeze(dim=0)

            active_budget = math.ceil(num_pixel_cur * active_ratio)    # 将要actively selected pixel的数量
            similarity_budget = math.ceil(active_budget * ((now_iteration - max_iter) / max_iter) ** 2)
            uncertainty_budget = active_budget - similarity_budget

            for pixel in range(active_budget):
                if pixel < similarity_budget:      # similarity annotation
                    values, indices_h = torch.max(similarity_score, dim=0)
                else:                                        # uncertainty annotation
                    values, indices_h = torch.max(uncertainty_score, dim=0)
                _, indices_w = torch.max(values, dim=0)
                w = indices_w.item()
                h = indices_h[w].item()

                # mask out
                similarity_score[h,w] = 0.0
                uncertainty_score[h,w] = 0.0
                active[h,w] = True
                selected[h,w] = True
                # active sampling
                active_mask[h,w] = ground_truth[h,w]
                

            active_mask = Image.fromarray(np.array(active_mask.cpu().numpy(), dtype=np.uint8))
            active_mask.save(path2mask[i])
            indicator = {
                'active': active,
                'selected': selected
            }
            torch.save(indicator, path2indicator[i])

feature_extractor.train()
classifier.train()


0it [00:00, ?it/s]

torch.Size([1024, 2048])


1it [00:04,  4.24s/it]

torch.Size([1024, 2048])


2it [00:08,  4.04s/it]

torch.Size([1024, 2048])


2it [00:09,  4.90s/it]


KeyboardInterrupt: 