In [5]:
import argparse
import os
import datetime
import logging
import time

import torch
import torch.nn as nn
import torch.utils
import torch.distributed
from torch.utils.data import DataLoader
import multiprocessing

import numpy as np

from core.configs import cfg
from core.datasets import build_dataset
from core.models import build_feature_extractor, build_classifier
from core.solver import adjust_learning_rate
from core.utils.misc import mkdir
from core.utils.logger import setup_logger
from core.utils.metric_logger import MetricLogger
from core.active.build import PixelSelection, RegionSelection
from core.datasets.dataset_path_catalog import DatasetCatalog
from core.loss.negative_learning_loss import NegativeLearningLoss
from core.loss.local_consistent_loss import LocalConsistentLoss
from core.utils.utils import set_random_seed
from torch.utils.tensorboard import SummaryWriter

import warnings
warnings.filterwarnings('ignore')

In [6]:
os.system("nvidia-smi")

Sat Oct 15 08:12:31 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.07    Driver Version: 515.48.07    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-PCI...  On   | 00000000:C3:00.0 Off |                    0 |
| N/A   33C    P0    33W / 250W |      2MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

0

In [7]:
torch.cuda.is_available()

True

In [8]:
parser = argparse.ArgumentParser(description="Active Domain Adaptive Semantic Segmentation Training")
parser.add_argument("-cfg",
                    "--config-file",
                    default="",
                    metavar="FILE",
                    help="path to config file",
                    type=str)
parser.add_argument("--proctitle",
                    type=str,
                    default="RCL-AAA",
                    help="allow a process to change its title",)
parser.add_argument(
    "opts",
    help="Modify config options using the command-line",
    default=None,
    nargs=argparse.REMAINDER
)

# args = parser.parse_args()

args = parser.parse_args(args=['-cfg', 'configs/gtav/deeplabv3plus_r101_RA.yaml','--proctitle', 'RCL-AAA'])
args.opts = ['OUTPUT_DIR', 'results/v3plus_gtav_ra_5.0_precent', 'DEBUG', '0']

if args.opts is not None:
    args.opts[-1] = args.opts[-1].strip('\r\n')

torch.backends.cudnn.benchmark = True

cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()

output_dir = cfg.OUTPUT_DIR
if output_dir:
    mkdir(output_dir)

logger = setup_logger("RCL-AAA", output_dir, 0)
logger.info(args)

logger.info("Loaded configuration file {}".format(args.config_file))
logger.info("Running with config:\n{}".format(cfg))

logger.info('Initializing Cityscapes label mask...')

set_random_seed(cfg.SEED)

2022-10-15 08:12:32,767 RCL-AAA INFO: Namespace(config_file='configs/gtav/deeplabv3plus_r101_RA.yaml', opts=['OUTPUT_DIR', 'results/v3plus_gtav_ra_5.0_precent', 'DEBUG', '0'], proctitle='RCL-AAA')
2022-10-15 08:12:32,768 RCL-AAA INFO: Loaded configuration file configs/gtav/deeplabv3plus_r101_RA.yaml
2022-10-15 08:12:32,768 RCL-AAA INFO: Running with config:
ACTIVE:
  NAME: RCL-AAA
  PIXELS: 40
  RADIUS_K: 1
  RATIO: 0.05
  SELECT_ITER: [10000, 12000, 14000, 16000, 18000]
  SETTING: RA
DATASETS:
  SOURCE_TRAIN: gtav_train
  TARGET_TRAIN: cityscapes_train
  TEST: cityscapes_val
DEBUG: 0
INPUT:
  IGNORE_LABEL: 255
  INPUT_SCALES_TRAIN: (1.0, 1.0)
  INPUT_SIZE_TEST: (1280, 640)
  PIXEL_MEAN: [0.485, 0.456, 0.406]
  PIXEL_STD: [0.229, 0.224, 0.225]
  SOURCE_INPUT_SIZE_TRAIN: (1280, 720)
  TARGET_INPUT_SIZE_TRAIN: (1280, 640)
  TO_BGR255: False
MODEL:
  DEVICE: cuda
  FREEZE_BN: True
  NAME: deeplabv3plus_resnet101
  NUM_CLASSES: 19
  WEIGHTS: https://download.pytorch.org/models/resnet101-5d

In [12]:
print("Here is {} CPU, {} GPU.".format(multiprocessing.cpu_count(), torch.cuda.device_count()))
logger = logging.getLogger("RCL-AAA.trainer")
# tb_writer = SummaryWriter('./{}_{}_tensorboard_log_{}_confidence'.format(cfg.DATASETS.SOURCE_TRAIN.split('_')[0], cfg.DATASETS.TARGET_TRAIN.split('_')[0], cfg.CONFIDENCE.WEIGHT))
# print('Tensorboard writer log has been created at {}'.format('./{}_{}_tensorboard_log_{}_confidence'.format(cfg.DATASETS.SOURCE_TRAIN.split('_')[0], cfg.DATASETS.TARGET_TRAIN.split('_')[0], cfg.CONFIDENCE.WEIGHT)))

# create network
device = torch.device(cfg.MODEL.DEVICE)
feature_extractor = build_feature_extractor(cfg)
#feature_extractor = nn.DataParallel(feature_extractor)
feature_extractor.to(device)

classifier = build_classifier(cfg)
#classifier = nn.DataParallel(classifier)
classifier.to(device)
print()

Here is 256 CPU, 1 GPU.
load checkpoint from http path: https://download.pytorch.org/models/resnet101-5d3b4d8f.pth



In [13]:
# init optimizer
optimizer_fea = torch.optim.SGD(feature_extractor.parameters(), lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)
optimizer_fea.zero_grad()

optimizer_cls = torch.optim.SGD(classifier.parameters(), lr=cfg.SOLVER.BASE_LR * 10, momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)
optimizer_cls.zero_grad()

iteration = 0

In [14]:
# # load checkpoint
# if cfg.resume:
#     logger.info("Loading checkpoint from {}".format(cfg.resume))
#     checkpoint = torch.load(cfg.OUTPUT_DIR + '/' + cfg.resume, map_location=torch.device('cpu'))
#     iteration = checkpoint['iteration']
#     feature_extractor.load_state_dict(checkpoint['feature_extractor'])
#     optimizer_fea.load_state_dict(checkpoint['optimizer_fea'])
#     classifier.load_state_dict(checkpoint['classifier'])
#     optimizer_cls.load_state_dict(checkpoint['optimizer_cls'])
# # feature_extractor = nn.DataParallel(feature_extractor)      # modified by CZC
# # classifier = nn.DataParallel(classifier)            # modified by CZC

In [15]:
# init mask for cityscape
# DatasetCatalog.initMask(cfg)

In [17]:
src_train_data = build_dataset(cfg, mode='train', is_source=True)
tgt_train_data = build_dataset(cfg, mode='train', is_source=False)
tgt_epoch_data = build_dataset(cfg, mode='active', is_source=False, epochwise=True)

src_train_loader = DataLoader(src_train_data, batch_size=2, shuffle=True, num_workers=4,
                              pin_memory=False, drop_last=True)
tgt_train_loader = DataLoader(tgt_train_data, batch_size=2, shuffle=True, num_workers=4,
                              pin_memory=False, drop_last=True)
tgt_epoch_loader = DataLoader(tgt_epoch_data, batch_size=1, shuffle=False, num_workers=4,
                              pin_memory=False, drop_last=True)

Compose(
    <core.datasets.transform.Resize object at 0x150f14c99730>
    <core.datasets.transform.ToTensor object at 0x150f14c99a90>
    <core.datasets.transform.Normalize object at 0x150f14c99970>
)
Compose(
    <core.datasets.transform.Resize object at 0x150fd8161c10>
    <core.datasets.transform.ToTensor object at 0x150e82d9e1f0>
    <core.datasets.transform.Normalize object at 0x150e82d9e280>
)
Compose(
    <core.datasets.transform.Resize object at 0x150e82d9eb80>
    <core.datasets.transform.ToTensor object at 0x150f14c99610>
    <core.datasets.transform.Normalize object at 0x150f14c99940>
)


In [18]:
torch.set_printoptions(threshold=np.inf)    # added by czc
# src_train_data[0]

In [19]:
print(src_train_data[0]['img'].size())
print(src_train_data[0]['label'].size())

torch.Size([3, 720, 1280])
torch.Size([720, 1280])


In [20]:
src_input, src_label = src_train_data[0]['img'], src_train_data[0]['label']
src_input = src_input.cuda(non_blocking=True)
src_label = src_label.cuda(non_blocking=True)

In [21]:
# init loss
sup_criterion = nn.CrossEntropyLoss(ignore_index=255)
# negative_criterion = NegativeLearningLoss(threshold=cfg.SOLVER.NEGATIVE_THRESHOLD)
# local_consistent_loss = LocalConsistentLoss(cfg.MODEL.NUM_CLASSES, cfg.SOLVER.LCR_TYPE).cuda()


start_warmup_time = time.time()
end = time.time()
max_iters = cfg.SOLVER.MAX_ITER
warmup_iters = 10000
meters = MetricLogger(delimiter="  ")

logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>")
feature_extractor.train()
classifier.train()
active_round = 1

2022-10-15 08:15:37,055 RCL-AAA.trainer INFO: >>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>


## Warm Up

In [22]:
iteration=0
for batch_index, (src_data, tgt_data) in enumerate(zip(src_train_loader, tgt_train_loader)):       
    data_time = time.time() - end

    current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD, cfg.SOLVER.BASE_LR, iteration, max_iters,
                                      power=cfg.SOLVER.LR_POWER)
    # tb_writer.add_scalar(tag="lr", scalar_value=current_lr, global_step=iteration)      # added by czc
    for index in range(len(optimizer_fea.param_groups)):
        optimizer_fea.param_groups[index]['lr'] = current_lr
    for index in range(len(optimizer_cls.param_groups)):
        optimizer_cls.param_groups[index]['lr'] = current_lr * 10

    optimizer_fea.zero_grad()
    optimizer_cls.zero_grad()

    src_input, src_label = src_data['img'], src_data['label']
    src_input = src_input.cuda(non_blocking=True)
    src_label = src_label.cuda(non_blocking=True)
    # print(src_input.size())
    # print(src_label.size())
    # print(src_label.size()[0])

    tgt_input, tgt_mask = tgt_data['img'], tgt_data['mask']
    tgt_input = tgt_input.cuda(non_blocking=True)
    tgt_mask = tgt_mask.cuda(non_blocking=True)


    src_size = src_input.shape[-2:]
    # print(src_size)
    src_out = classifier(feature_extractor(src_input), size=src_size)
    
    tgt_size = tgt_input.shape[-2:]
    tgt_out = classifier(feature_extractor(tgt_input), size=tgt_size)
    predict = torch.softmax(tgt_out, dim=1)

    # target active supervision loss
    if torch.sum((tgt_mask != 255)) != 0:  # target has labeled pixels
        loss_sup_tgt = sup_criterion(tgt_out, tgt_mask)
        # meters.update(loss_sup_tgt=loss_sup_tgt.item())



    # predict = torch.softmax(src_out, dim=1)

    iteration += 1
    if iteration == 1:
        break

In [23]:
sup_criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
# predict[0,:,0:10,0:10]

In [24]:
tgt_label = torch.argmax(predict[:,:,0:5,0:5],dim=1)
print(tgt_label.size())

torch.Size([2, 5, 5])


In [25]:
tgt_label = torch.argmax(predict[0,:,0:5,0:5],dim=0)
print(tgt_label)
print(tgt_label.size())
# print(tgt_out[0,0:3,0:2,0:2])
loss_sup_tgt = sup_criterion(tgt_out[0,0:19,0:5,0:5].permute(1,0,2), tgt_label)
print(loss_sup_tgt)
print(torch.topk(loss_sup_tgt, k=3, dim=1, largest=False))
print(torch.min(loss_sup_tgt))
print(torch.max(loss_sup_tgt))

tensor([[13, 13, 13, 10, 10],
        [13, 13, 13, 13, 10],
        [13, 13, 13, 13, 13],
        [ 7, 13, 13, 13, 13],
        [ 7,  7, 13, 13, 13]], device='cuda:0')
torch.Size([5, 5])
tensor([[2.6342, 2.5844, 2.5380, 2.4616, 2.3411],
        [2.6260, 2.5670, 2.5114, 2.4594, 2.3631],
        [2.6190, 2.5511, 2.4869, 2.4265, 2.3700],
        [2.6060, 2.5367, 2.4644, 2.3965, 2.3330],
        [2.5586, 2.5100, 2.4440, 2.3693, 2.2997]], device='cuda:0',
       grad_fn=<ViewBackward>)
torch.return_types.topk(
values=tensor([[2.3411, 2.4616, 2.5380],
        [2.3631, 2.4594, 2.5114],
        [2.3700, 2.4265, 2.4869],
        [2.3330, 2.3965, 2.4644],
        [2.2997, 2.3693, 2.4440]], device='cuda:0', grad_fn=<TopkBackward>),
indices=tensor([[4, 3, 2],
        [4, 3, 2],
        [4, 3, 2],
        [4, 3, 2],
        [4, 3, 2]], device='cuda:0'))
tensor(2.2997, device='cuda:0', grad_fn=<MinBackward1>)
tensor(2.6342, device='cuda:0', grad_fn=<MaxBackward1>)


In [26]:
from collections import Counter

from cv2 import threshold
import math
h = 720
w = 1280
numparts_h = 2
numparts_w = 4
parts_h = int(h / numparts_h)
parts_w = int(w / numparts_w)
batch_size = 2
batch_centroids = {}        
is_source = False
tgt_centroids_base_ratio = 0.9999

if is_source:
    predict = torch.softmax(src_out, dim=1)
else:
    predict = torch.softmax(tgt_out, dim=1)
# batch_centroids:{
#                  'img_idx': {
#                               {}_{}': {
#                                         classID: {centroid}, classID: {centroid}, classID: {centroid}, ...
cross_entropy_computation = nn.CrossEntropyLoss(ignore_index=255, reduction='none')

import time
start = time.perf_counter()

tgt_label = torch.argmax(predict, dim=1)


for k in range(batch_size):
    batch_centroids[k] = {}     # k是batch里面img的序号
    for i in range(numparts_h):
        for j in range(numparts_w):
            batch_centroids[k]['{}_{}'.format(i,j)] = {} 
            
            # Get region coordinates 
            if [i,j] == [range(numparts_h)[-1], range(numparts_w)[-1]]:
                rg_id = [i*parts_h, h-1, j*parts_w, w-1]                 # rg_id: region_index
            else:
                rg_id = [i*parts_h, (i+1)*parts_h-1, j*parts_w, (j+1)*parts_w-1]
            # batch_centroids['img_idx_{}'.format(k)]['region_{}_{}'.format(i,j)]['rg_id'] = rg_id
            
            # Get all class ID in a single region
            if is_source == True:
                classID = dict(Counter(src_label[k, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]].cpu().numpy().flatten()))     # 放到代码文件中得改一下
            else:
                classID = dict(Counter(tgt_label[k, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]].cpu().numpy().flatten()))     # 放到代码文件中得改一下
            
            if classID.__contains__(255): del classID[255]
            # batch_centroids['img_idx_{}'.format(k)]['region_{}_{}'.format(i,j)]['region_class']['classID'] = classID
            
            # Get all predict mean as centroids
            centroids = {}
            for key in classID:
                predict_sum = torch.zeros([1,19], requires_grad=True)
                predict_sum = predict_sum.cuda(non_blocking=True)
                if is_source == True:
                    mask = src_label[k, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]].eq(key)
                else:
                    origin_tgt_mask = tgt_label[k, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]].eq(key)       # 代表是这个类的

                    tgt_ce = cross_entropy_computation(tgt_out[k, :, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]].permute(1,0,2), \
                                                        tgt_label[k, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]])
                    tgt_cls_uncertainty = tgt_ce * origin_tgt_mask      # 取最小的几个值注意可能会取到0
                    
                    unselected_sample_num = math.ceil(origin_tgt_mask.sum().item() * (1 - tgt_centroids_base_ratio))
                    unselected_samples, _ = torch.topk(torch.flatten(tgt_cls_uncertainty), k=unselected_sample_num, dim=-1, largest=True)
                    uncertainty_thres = unselected_samples.min().item()
                    
                    uncertainty_mask = tgt_cls_uncertainty.le(uncertainty_thres)        # 代表uncertainty不会过高的样本

                    mask = origin_tgt_mask * uncertainty_mask
                    
                predict_mask = predict[k, :, rg_id[0]:rg_id[1], rg_id[2]:rg_id[3]] * mask
                centroids[key] = predict_mask.sum(axis=[1,2]) / classID[key]


            batch_centroids[k]['{}_{}'.format(i,j)] = centroids
            
end = time.perf_counter()
print(str(end-start))

0.615081629017368


In [27]:
def contrastive_loss(pos_set, neg_set, temperature):
    assert pos_set.size() != 0, "Positive pairs should not be EMPTY!"
    assert neg_set.size() != 0, "Negative pairs should not be EMPTY!"

    pos_head = torch.index_select(pos_set, 0, torch.tensor([0]).cuda(non_blocking=True))
    pos_pairs = torch.mm(pos_head, pos_set.permute(1,0))
    neg_pairs = torch.mm(pos_head, neg_set.permute(1,0))

    all_pairs = torch.cat([neg_pairs.repeat(pos_pairs.size()[1],1), pos_pairs.permute(1,0)], dim=1)
    all_pairs = torch.exp(all_pairs / temperature)

    exp_aggregation_row = all_pairs.sum(dim=1, keepdim=True)
    frac_row = torch.index_select(all_pairs, 1, torch.tensor([all_pairs.size()[1] - 1]).cuda(non_blocking=True)) / exp_aggregation_row
    log_row = torch.log(frac_row)
    
    if pos_set.size()[0] == 1:
        cl_loss = torch.mean(log_row) * (-1)
    else:
        cl_loss = torch.mean(log_row[1:,:]) * (-1)
    
    return cl_loss


In [28]:
# intra contrastive loss v3

from core.loss.contrastive_loss import ContrastiveLoss

num_classes = 19
positive_weight_increment_step = 0.01
negative_weight_increment_step = 0.01
temperature = 0.07

loss = []

contrastive_loss_criterion = ContrastiveLoss()

for k in batch_centroids:
    # calculate per clas
    for cls in range(num_classes):
        pos, neg = {}, {}
        # pos: {'region_{}_{}': {tensor([...])} }
        # neg: {'region_{}_{}': {cls: tensor([...]), cls: tensor([...]), cls: tensor([...]), ...}}
        for region in batch_centroids[k]:  # 此时的region是 0_0
            neg[region] = {}
            for intra_cls in batch_centroids[k][region]:
                neg[region][intra_cls] = batch_centroids[k][region][intra_cls]    # 只有这样copy tensor，才不至于改变batch_centroids本身
            if batch_centroids[k][region].__contains__(cls):
                pos[region] = batch_centroids[k][region][cls]
                del neg[region][cls]
        pos_region, neg_region = [], []
        pos_region = list(pos.keys()) # positive pairs所在的region
        neg_region = list(neg.keys()) # negative pairs所在的region
        all_region = list(set(pos_region + neg_region)) # 合并两个list，并且删除重复元素
        for region_1 in all_region:
            pos_per_region, neg_per_region = [], []         # 该类以该region为中心的positive和negative pairs
            cl_per_region = []
            region_1_index = np.array([int(region_1.split('_')[0]), int(region_1.split('_')[1])])
            if pos.__contains__(region_1):
                pos_per_region.append(pos[region_1])        # positive pairs的头，即cls在region_1的centroids
            for neg_cls_1 in neg[region_1]:
                neg_per_region.append(neg[region_1][neg_cls_1])   # negatiave pairs的头，即在region_1中除了cls之外的所有centroids
            for region_2 in all_region:                         # 收集其它region的positive pairs和negative pairs
                if region_2 != region_1:
                    region_2_index = np.array([int(region_2.split('_')[0]), int(region_2.split('_')[1])])
                    positive_weight = 1 + positive_weight_increment_step * np.linalg.norm(region_1_index - region_2_index, ord=2) # L2 norm
                    negative_weight = 1 - negative_weight_increment_step * np.linalg.norm(region_1_index - region_2_index, ord=2) # L2 norm
                    if pos.__contains__(region_2):     # 其它区域的positive pairs，将会乘上权重
                        pos_per_region.append(pos[region_2] * positive_weight)
                    # print(pos_per_region)
                    if neg.__contains__(region_2):     # 其它区域的negative pairs，将会乘上权重
                        for neg_cls_2 in neg[region_2]:
                            neg_per_region.append(neg[region_2][neg_cls_2] * negative_weight)
            if pos_per_region != []:    # 否则会报错，stack不能对empty list操作
                pos_set_cl = torch.stack(pos_per_region, dim=0)       # 第一行tensor就是用于query的positive pair头，剩下都是所有乘上权重后的positive pairs
                neg_set_cl = torch.stack(neg_per_region, dim=0)       # 所有乘上权重后的negative pairs
                
                # cl_per_region = contrastive_loss(pos_set_cl, neg_set_cl, temperature)
                cl_per_region = contrastive_loss_criterion(pos_set_cl, neg_set_cl, temperature)
                # print(cl_per_region)
                loss.append(cl_per_region)
            
loss = torch.stack(loss, dim=0)  
loss = torch.mean(loss)
print(loss)  

tensor(4.2321, device='cuda:0', grad_fn=<MeanBackward0>)


In [29]:
for intra_cls in batch_centroids[0]['0_0']:
    print(intra_cls)

13
10
7
2
17
3
16
12
9
15
18
8
0


In [30]:
for k in batch_centroids:
    print(k)

0
1


In [34]:
# inter images contrastive loss v1
from core.loss.contrastive_loss import ContrastiveLoss

num_classes = 19
positive_weight_increment_step = 0.01
negative_weight_increment_step = 0.01
temperature = 0.07

loss = []

contrastive_loss_criterion = ContrastiveLoss()

for k in batch_centroids:
    # calculate per clas
    for cls in range(num_classes):
        pos, neg = {}, {}
        # pos: {'region_{}_{}': {tensor([...])} }
        # neg: {'region_{}_{}': {cls: tensor([...]), cls: tensor([...]), cls: tensor([...]), ...}}
        for region in batch_centroids[1-k]:  # 找到每一个region里该cls对应的其它所有pos所在的region和neg所在的region
            neg[region] = {}
            for intra_cls in batch_centroids[1-k][region]:
                neg[region][intra_cls] = batch_centroids[1-k][region][intra_cls]    # 只有这样copy tensor，才不至于改变batch_centroids本身
            if batch_centroids[1-k][region].__contains__(cls):
                pos[region] = batch_centroids[1-k][region][cls]     # 另一张图片的相同区域的pos_region
                del neg[region][cls]
        pos_region, neg_region = [], []
        pos_region = list(pos.keys()) # positive pairs所在的batch内另一张图片的region
        neg_region = list(neg.keys()) # negative pairs所在的batch内另一张图片region
        all_region = list(set(pos_region + neg_region)) # 合并两个list，并且删除重复元素
        for region_1 in all_region:
            pos_per_region, neg_per_region = [], []         # 该类以该region为中心的positive和negative pairs
            cl_per_region = []
            region_1_index = np.array([int(region_1.split('_')[0]), int(region_1.split('_')[1])])
            if pos.__contains__(region_1):
                pos_per_region.append(pos[region_1])        # positive pairs的头，即cls在region_1的centroids
            for neg_cls_1 in neg[region_1]:
                neg_per_region.append(neg[region_1][neg_cls_1])   # negatiave pairs的头，即在region_1中除了cls之外的所有centroids
            for region_2 in all_region:                         # 收集其它region的positive pairs和negative pairs
                if region_2 != region_1:
                    region_2_index = np.array([int(region_2.split('_')[0]), int(region_2.split('_')[1])])
                    positive_weight = 1 + positive_weight_increment_step * np.linalg.norm(region_1_index - region_2_index, ord=2) # L2 norm
                    negative_weight = 1 - negative_weight_increment_step * np.linalg.norm(region_1_index - region_2_index, ord=2) # L2 norm
                    if pos.__contains__(region_2):     # 其它区域的positive pairs，将会乘上权重
                        pos_per_region.append(pos[region_2] * positive_weight)
                    # print(pos_per_region)
                    if neg.__contains__(region_2):     # 其它区域的negative pairs，将会乘上权重
                        for neg_cls_2 in neg[region_2]:
                            neg_per_region.append(neg[region_2][neg_cls_2] * negative_weight)
            if pos_per_region != []:    # 否则会报错，stack不能对empty list操作
                pos_set_cl = torch.stack(pos_per_region, dim=0)       # 第一行tensor就是用于query的positive pair头，剩下都是所有乘上权重后的positive pairs
                neg_set_cl = torch.stack(neg_per_region, dim=0)       # 所有乘上权重后的negative pairs
                
                # cl_per_region = contrastive_loss(pos_set_cl, neg_set_cl, temperature)
                cl_per_region = contrastive_loss_criterion(pos_set_cl, neg_set_cl, temperature)
                # print(cl_per_region)
                loss.append(cl_per_region)
            
loss = torch.stack(loss, dim=0)  
loss = torch.mean(loss)
print(loss)  

tensor([0.1881, 0.0024, 0.0575, 0.0282, 0.0838, 0.0065, 0.0025, 0.0485, 0.0425,
        0.0318, 0.0143, 0.0296, 0.0172, 0.0925, 0.0184, 0.0873, 0.1335, 0.0928,
        0.0224], device='cuda:0', grad_fn=<DivBackward0>)
tensor([0.2191, 0.0034, 0.0743, 0.0337, 0.0619, 0.0086, 0.0043, 0.0714, 0.0393,
        0.0239, 0.0204, 0.0331, 0.0176, 0.0987, 0.0203, 0.0626, 0.1109, 0.0655,
        0.0309], device='cuda:0', grad_fn=<DivBackward0>)
tensor([0.1976, 0.0017, 0.0611, 0.0296, 0.0846, 0.0062, 0.0023, 0.0670, 0.0317,
        0.0226, 0.0131, 0.0316, 0.0160, 0.1072, 0.0209, 0.0982, 0.0788, 0.1056,
        0.0244], device='cuda:0', grad_fn=<DivBackward0>)
tensor([0.1585, 0.0028, 0.0585, 0.0366, 0.0718, 0.0134, 0.0037, 0.1042, 0.0322,
        0.0237, 0.0129, 0.0434, 0.0156, 0.0882, 0.0170, 0.0716, 0.0974, 0.1290,
        0.0196], device='cuda:0', grad_fn=<DivBackward0>)
tensor([0.1407, 0.0046, 0.0585, 0.0333, 0.0705, 0.0128, 0.0066, 0.0704, 0.0533,
        0.0388, 0.0207, 0.0293, 0.0262, 0.0868, 