In [None]:
import argparse
import builtins
import math
import os
import random
import shutil
import time
import warnings
from tqdm import tqdm
import numpy as np
import faiss

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models


import pcl.loader
import pcl.dense_builder

In [3]:
model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

In [4]:
class Args:
    start_epoch = 0
    epochs = 100
    warmup_epoch = 5
    num_cluster = [10, 20, 30]
    low_dim = 128
    gpu = 0
    exp_dir = './test'
    num_cluster_global = "10,20,30"
    num_cluster_dense = "5,15,20"
    arch = "resnet50"
    pcl_r = 20
    moco_m = 0.999
    temperature = 0.2
    mlp = True
    weight_decay = 1e-4
    lr = 0.03
    momentum = 0.999
    

args = Args()

In [5]:
args.num_cluster_global = args.num_cluster_global.split(',')
args.num_cluster_dense = args.num_cluster_dense.split(',')

os.makedirs(args.exp_dir, exist_ok=True)

In [6]:
model = pcl.dense_builder.DenseCL(
    models.__dict__["resnet50"],
    head=None,
    dim=args.low_dim,
    r=args.pcl_r, 
    m=args.moco_m, 
    loss_lambda=args.temperature, 
    mlp=args.mlp
)

In [7]:
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    
optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

In [8]:
cudnn.benchmark = True

In [9]:
args.data = '/home/siyi/DefectDetection/CL_for_Real/PCLv1/data/hg'
traindir = os.path.join(args.data, 'train')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

In [35]:
args.aug_plus = True
if args.aug_plus:
    # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
    augmentation = [
        transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
        transforms.RandomApply(
            [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8  # not strengthened
        ),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([pcl.loader.GaussianBlur([0.1, 2.0])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]
else:
    # MoCo v1's aug: same as InstDisc https://arxiv.org/abs/1805.01978
    augmentation = [
        transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
        transforms.RandomGrayscale(p=0.2),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]

# center-crop augmentation
eval_augmentation = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]
)

train_dataset = pcl.loader.ImageFolderInstance(
    traindir, pcl.loader.TwoCropsTransform(transforms.Compose(augmentation))
)
eval_dataset = pcl.loader.ImageFolderInstance(traindir, eval_augmentation)

In [36]:
train_sampler = None
eval_sampler = None
args.batch_size = 32
args.workers = 1

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=(train_sampler is None),
    num_workers=args.workers,
    pin_memory=True,
    sampler=train_sampler,
    drop_last=True,
)

# dataloader for center-cropped images, use larger batch size to increase speed
eval_loader = torch.utils.data.DataLoader(
    eval_dataset,
    batch_size=args.batch_size * 5,
    shuffle=False,
    sampler=eval_sampler,
    num_workers=args.workers,
    pin_memory=True,
)

In [37]:
cluster_result = {"cluster_result_dense": None, "cluster_result_global": None}

In [38]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

In [39]:
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":.4e")
acc_inst = AverageMeter("Acc@Inst", ":6.2f")
acc_proto_global = AverageMeter("Acc@Proto_global", ":6.2f")
acc_proto_dense = AverageMeter("Acc@Proto_dense", ":6.2f")

progress = ProgressMeter(
    len(train_loader),
    [batch_time, data_time, losses, acc_inst, acc_proto_global, acc_proto_dense],
    prefix="Epoch: [{}]".format(0),
)

In [40]:
end = time.time()

In [41]:
images, index = next(iter(train_loader))

In [70]:
# result = model.forward(im_q=images[0], im_k=images[1],
#                       cluster_global=cluster_result["cluster_result_global"],
#                       cluster_dense=cluster_result["cluster_result_dense"],
#                       index=index)

im_q=images[0]
im_k=images[1]
cluster_global=None
cluster_dense=None

In [71]:
im_q = im_q.contiguous()
im_k = im_k.contiguous()

In [72]:
q_b = model.encoderq_features(im_q)  # backbone features

In [73]:
q, q_grid, q2 = model.encoder_q[1](q_b)

In [74]:
q_b = q_b.view(q_b.size(0), q_b.size(1), -1)

In [75]:
q = nn.functional.normalize(q, dim=1)   # global
q2 = nn.functional.normalize(q2, dim=1) # dense
q_grid = nn.functional.normalize(q_grid, dim=1)
q_b = nn.functional.normalize(q_b, dim=1)

In [76]:
model._momentum_update_key_encoder()

In [77]:
k_b = model.encoderk_features(im_k)

In [78]:
k, k_grid, k2 = model.encoder_k[1](k_b)  # keys: NxC; NxCxS^2
k_b = k_b.view(k_b.size(0), k_b.size(1), -1)

k = nn.functional.normalize(k, dim=1)   # global
k2 = nn.functional.normalize(k2, dim=1)  # dense
k_grid = nn.functional.normalize(k_grid, dim=1)
k_b = nn.functional.normalize(k_b, dim=1)

In [79]:
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)  # 正样本

In [80]:
l_neg = torch.einsum('nc,ck->nk', [q, model.queue.clone().detach()]) # 负样本

In [81]:
# 选取r个原型副样本
l_neg.shape

torch.Size([32, 20])

In [82]:
backbone_sim_matrix = torch.matmul(q_b.permute(0, 2, 1), k_b)

In [83]:
densecl_sim_q = backbone_sim_matrix.max(dim=2)[1]   # NxS^2

In [84]:
l_pos_dense = densecl_sim_q.view(-1).unsqueeze(-1)  # NS^2x1

In [85]:
q_grid = q_grid.permute(0, 2, 1)
q_grid = q_grid.reshape(-1, q_grid.size(2))
l_neg_dense = torch.einsum('nc,ck->nk', [q_grid,model.queue2.clone().detach()])

In [86]:
logits_global = torch.cat([l_pos, l_neg], dim=1)  # Nx(1+K)
logits_dense = torch.cat([l_pos_dense, l_neg_dense], dim=1)
# apply temperature
logits_global /= 0.2
logits_dense /= 0.2

In [87]:
labels_global = torch.zeros(logits_global.shape[0], dtype=torch.long)

In [91]:
labels_dense = torch.zeros(logits_dense.shape[0], dtype=torch.long)

In [94]:
l_pos_dense.shape   # 一个正样本 pcl_r个负原型

torch.Size([1568, 1])

In [96]:
logits_dense.shape

torch.Size([1568, 21])