In [3]:
import argparse
import build_index
import logging
import os
import random
import warnings

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.parallel
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from semilearn.algorithms import get_algorithm, name2alg
from semilearn.core.utils import ALGORITHMS
from semilearn.core.utils import (
    TBLog,
    count_parameters,
    get_logger,
    get_net_builder,
    get_port,
    over_write_args_from_file,
    send_model_cuda,
)
from semilearn.imb_algorithms import get_imb_algorithm, name2imbalg
from train import get_config

# import clarabel
# import qpsolvers

import sys
import types

from typing import Iterator, Optional
import torch
from torch.utils.data.dataloader import _BaseDataLoaderIter
from torch.utils.data import Dataset, _DatasetKind
from torch.utils.data.distributed import DistributedSampler
from operator import itemgetter
import torch.distributed as dist
from scipy.sparse import csc_matrix
from build_index import *
from build_index import confidence_convergence
import build_index
from typing import Iterable

import contextlib
import numpy as np
from inspect import signature
from collections import OrderedDict
from sklearn.metrics import accuracy_score, balanced_accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, top_k_accuracy_score

import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

from semilearn.core.hooks import Hook, get_priority, CheckpointHook, TimerHook, LoggingHook, DistSamplerSeedHook, ParamUpdateHook, EvaluationHook, EMAHook, WANDBHook, AimHook
from semilearn.core.utils import get_dataset, get_data_loader, get_optimizer, get_cosine_schedule_with_warmup, Bn_Controller
from semilearn.core.criterions import CELoss, ConsistencyLoss
from tqdm import tqdm

import importlib
importlib.reload(build_index)

sys.argv=['inference.ipynb','--c','config/classic_cv/fixmatch/fixmatch_cifar100_400_0.yaml']

def get_config():
    from semilearn.algorithms.utils import str2bool

    parser = argparse.ArgumentParser(description="Semi-Supervised Learning (USB)")

    """
    Saving & loading of the model.
    """
    parser.add_argument("--save_dir", type=str, default="./saved_models")
    parser.add_argument("-sn", "--save_name", type=str, default="fixmatch")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--load_path", type=str)
    parser.add_argument("-o", "--overwrite", action="store_true", default=True)
    parser.add_argument(
        "--use_tensorboard",
        action="store_true",
        help="Use tensorboard to plot and save curves",
    )
    parser.add_argument(
        "--use_wandb", action="store_true", help="Use wandb to plot and save curves"
    )
    parser.add_argument(
        "--use_aim", action="store_true", help="Use aim to plot and save curves"
    )

    """
    Training Configuration of FixMatch
    """
    parser.add_argument("--epoch", type=int, default=1)
    parser.add_argument(
        "--num_train_iter",
        type=int,
        default=20,
        help="total number of training iterations",
    )
    parser.add_argument(
        "--num_warmup_iter", type=int, default=0, help="cosine linear warmup iterations"
    )
    parser.add_argument(
        "--num_eval_iter", type=int, default=10, help="evaluation frequency"
    )
    parser.add_argument("--num_log_iter", type=int, default=5, help="logging frequency")
    parser.add_argument("-nl", "--num_labels", type=int, default=400)
    parser.add_argument("-bsz", "--batch_size", type=int, default=8)
    parser.add_argument(
        "--uratio",
        type=int,
        default=1,
        help="the ratio of unlabeled data to labeled data in each mini-batch",
    )
    parser.add_argument(
        "--eval_batch_size",
        type=int,
        default=16,
        help="batch size of evaluation data loader (it does not affect the accuracy)",
    )
    parser.add_argument(
        "--ema_m", type=float, default=0.999, help="ema momentum for eval_model"
    )
    parser.add_argument("--ulb_loss_ratio", type=float, default=1.0)

    """
    Optimizer configurations
    """
    parser.add_argument("--optim", type=str, default="SGD")
    parser.add_argument("--lr", type=float, default=3e-2)
    parser.add_argument("--momentum", type=float, default=0.9)
    parser.add_argument("--weight_decay", type=float, default=5e-4)
    parser.add_argument(
        "--layer_decay",
        type=float,
        default=1.0,
        help="layer-wise learning rate decay, default to 1.0 which means no layer "
        "decay",
    )

    """
    Backbone Net Configurations
    """
    parser.add_argument("--net", type=str, default="wrn_28_2")
    parser.add_argument("--net_from_name", type=str2bool, default=False)
    parser.add_argument("--use_pretrain", default=False, type=str2bool)
    parser.add_argument("--pretrain_path", default="", type=str)

    """
    Algorithms Configurations
    """

    ## core algorithm setting
    parser.add_argument(
        "-alg", "--algorithm", type=str, default="fixmatch", help="ssl algorithm"
    )
    parser.add_argument(
        "--use_cat", type=str2bool, default=True, help="use cat operation in algorithms"
    )
    parser.add_argument(
        "--amp",
        type=str2bool,
        default=False,
        help="use mixed precision training or not",
    )
    parser.add_argument("--clip_grad", type=float, default=0)

    ## imbalance algorithm setting
    parser.add_argument(
        "-imb_alg",
        "--imb_algorithm",
        type=str,
        default=None,
        help="imbalance ssl algorithm",
    )

    """
    Data Configurations
    """

    ## standard setting configurations
    parser.add_argument("--data_dir", type=str, default="./data")
    parser.add_argument("-ds", "--dataset", type=str, default="cifar10")
    parser.add_argument("-nc", "--num_classes", type=int, default=10)
    parser.add_argument("--train_sampler", type=str, default="RandomSampler")
    parser.add_argument("--num_workers", type=int, default=1)
    parser.add_argument(
        "--include_lb_to_ulb",
        type=str2bool,
        default="True",
        help="flag of including labeled data into unlabeled data, default to True",
    )

    ## imbalanced setting arguments
    parser.add_argument(
        "--lb_imb_ratio",
        type=int,
        default=1,
        help="imbalance ratio of labeled data, default to 1",
    )
    parser.add_argument(
        "--ulb_imb_ratio",
        type=int,
        default=1,
        help="imbalance ratio of unlabeled data, default to 1",
    )
    parser.add_argument(
        "--ulb_num_labels",
        type=int,
        default=None,
        help="number of labels for unlabeled data, used for determining the maximum "
        "number of labels in imbalanced setting",
    )

    ## cv dataset arguments
    parser.add_argument("--img_size", type=int, default=32)
    parser.add_argument("--crop_ratio", type=float, default=0.875)

    ## nlp dataset arguments
    parser.add_argument("--max_length", type=int, default=512)

    ## speech dataset algorithms
    parser.add_argument("--max_length_seconds", type=float, default=4.0)
    parser.add_argument("--sample_rate", type=int, default=16000)

    """
    multi-GPUs & Distributed Training
    """

    ## args for distributed training (from https://github.com/pytorch/examples/blob/master/imagenet/main.py)  # noqa: E501
    parser.add_argument(
        "--world-size",
        default=1,
        type=int,
        help="number of nodes for distributed training",
    )
    parser.add_argument(
        "--rank", default=0, type=int, help="**node rank** for distributed training"
    )
    parser.add_argument(
        "-du",
        "--dist-url",
        default="tcp://127.0.0.1:11111",
        type=str,
        help="url used to set up distributed training",
    )
    parser.add_argument(
        "--dist-backend", default="nccl", type=str, help="distributed backend"
    )
    parser.add_argument(
        "--seed", default=1, type=int, help="seed for initializing training. "
    )
    parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.")
    parser.add_argument(
        "--multiprocessing-distributed",
        type=str2bool,
        default=False,
        help="Use multi-processing distributed training to launch "
        "N processes per node, which has N GPUs. This is the "
        "fastest way to use PyTorch for either single node or "
        "multi node data parallel training",
    )
    # config file
    parser.add_argument("--c", type=str, default="")
    
    parser.add_argument("--use-prefetcher", action="store_true", default=False)

    # add algorithm specific parameters
    args = parser.parse_args()
    over_write_args_from_file(args, args.c)
    for argument in name2alg[args.algorithm].get_argument():
        parser.add_argument(
            argument.name,
            type=argument.type,
            default=argument.default,
            help=argument.help,
        )

    # add imbalanced algorithm specific parameters
    args = parser.parse_args()
    over_write_args_from_file(args, args.c)
    if args.imb_algorithm is not None:
        for argument in name2imbalg[args.imb_algorithm].get_argument():
            parser.add_argument(
                argument.name,
                type=argument.type,
                default=argument.default,
                help=argument.help,
            )
    args = parser.parse_args('--c config/classic_cv/fixmatch/fixmatch_cifar10_40_0.yaml'.split())
    over_write_args_from_file(args, args.c)
    return args

args = get_config()
port = get_port()
args.dist_url = "tcp://127.0.0.1:" + str(port)
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
ngpus_per_node = torch.cuda.device_count()
args.distributed = False
args.gpu=0

_net_builder = get_net_builder(args.net, args.net_from_name)

model = ALGORITHMS['fixmatch']( # name2alg[args.algorithm](
            args=args,
            net_builder=_net_builder
        )

model.model = send_model_cuda(args, model.model)
model.ema_model = send_model_cuda(args, model.ema_model)
# checkpoint = torch.load('latest_model.pth', map_location='cpu')


def load_model(self, load_path):
    """
    load model and specified parameters for resume
    """
    checkpoint = torch.load(load_path, map_location='cpu')
    if not self.distributed and next(iter(checkpoint['model'].keys())).startswith('module.'):
        checkpoint['model'] = {k[7:]:v for k,v in checkpoint['model'].items()}
    if not self.distributed and next(iter(checkpoint['ema_model'].keys())).startswith('module.'):
        checkpoint['ema_model'] = {k[7:]:v for k,v in checkpoint['ema_model'].items()}
    self.model.load_state_dict(checkpoint['model'])
    self.ema_model.load_state_dict(checkpoint['ema_model'])
    self.loss_scaler.load_state_dict(checkpoint['loss_scaler'])
    self.it = checkpoint['it']
    self.start_epoch = checkpoint['epoch']
    self.epoch = self.start_epoch
    self.best_it = checkpoint['best_it']
    self.best_eval_acc = checkpoint['best_eval_acc']
    self.optimizer.load_state_dict(checkpoint['optimizer'])
    if self.scheduler is not None and 'scheduler' in checkpoint:
        self.scheduler.load_state_dict(checkpoint['scheduler'])
    self.print_fn('Model loaded')
    return checkpoint

model.load_model = types.MethodType(load_model, model)
model.load_model('continual_fixmatch_cifar10_add500_to_500_1400/model_best.pth')
model.gpu=0

transform_val = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225],)
    ])
cifar10_dataset = torchvision.datasets.CIFAR10('./data',train=True, download=True,transform=transform_val)
cifar10_dataloader = DataLoader(cifar10_dataset, batch_size=100, shuffle=False)

ann_index = build_index.Singlemodal_index(dim=128,n=50000,submodular_k=16,num_classes=10)

cifar10_dataloader = DataLoader(cifar10_dataset, batch_size=100, shuffle=False)
model.cifar_train_loader = cifar10_dataloader

def build_ann_index(self, eval_dest='train_ulb', out_key='logits', return_logits=False):
    """
    evaluation function
    """
    self.model.eval()
    # self.ema.apply_shadow()

    # eval_loader = self.loader_dict[eval_dest]
    data_loader = self.cifar_train_loader
    total_loss = 0.0
    total_num = 0.0
    y_true = []
    y_pred = []
    y_probs = []
    y_logits = []
    confidence_list = []
    with torch.no_grad():
        for x,y in tqdm(data_loader):
            if isinstance(x, dict):
                x = {k: v.cuda(self.gpu) for k, v in x.items()}
            else:
                x = x.cuda(self.gpu)
            y = y.cuda(self.gpu)

            num_batch = y.shape[0]
            total_num += num_batch

            # out = self.model(x)
            out = self.ema_model(x)
            
            logits = out[out_key]
            prob = torch.softmax(logits, dim=-1)
            feat = out['feat'].detach().cpu()
            pred = torch.max(logits, dim=-1)[1].cpu()
            conf = (prob.max(dim=-1)[0]-0.1)/0.9
            
            for f,label,confid in zip(feat, pred.detach().cpu(),conf.detach().cpu()):
                ann_index.add_item(build_index.DataPoint(None,f,label,confid))
            loss = F.cross_entropy(logits, y, reduction='mean', ignore_index=-1)
            y_true.extend(y.cpu().tolist())
            y_pred.extend(pred.tolist())
            y_logits.append(logits.cpu().numpy())
            y_probs.extend(prob.cpu().tolist())
            total_loss += loss.item() * num_batch
            
            
            confidence_list.extend(conf.cpu().tolist())
    
    return confidence_list, y_true, y_pred


model.build_ann_index = types.MethodType(build_ann_index,model)

print('building ann index')
conf, y_t, y_p = model.build_ann_index()
print('ann_index built')

# print(len(ann_index.data)

/bin/sh: 1: netstat: not found


Using downloaded and verified file: ./data/cifar10/cifar-10-python.tar.gz
Extracting ./data/cifar10/cifar-10-python.tar.gz to ./data/cifar10
lb count: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
ulb count: [5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000]
Files already downloaded and verified
unlabeled data number: 50000, labeled data number 40
Create train and test data loaders
[!] data loader keys: dict_keys(['train_lb', 'train_ulb', 'eval'])
Create optimizer and scheduler
Model loaded
Files already downloaded and verified
building ann index


100%|██████████| 500/500 [00:19<00:00, 25.26it/s]

ann_index built





In [4]:
def reannotate_gain(relabel_confidence, use_ann_index=False,num_classes=10):
    #try directly using confidence instead of making it into entropy, so that the cosine space is more reasonable
    gain_list = np.zeros(50000)
    c_list = np.zeros(50000)
    p_list = np.zeros(50000)
    for idx in tqdm(range(50000)):
        if ann_index.data[idx].confidence<relabel_confidence:
            if use_ann_index and ann_index is not None:
                preds_dataset = ann_index.knn_pred(ann_index.data[idx], k=8, skip_one=True)
                preds_model = (ann_index.data[idx].label, ann_index.data[idx].confidence)
                merged_p,c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                p_list[idx] = merged_p

            else:
                c = ann_index.data[idx].confidence
                p_list[idx] = ann_index.data[idx].label

            gain_list[idx] = max(0, relabel_confidence-c)
            c_list[idx] = c
        else:
            c_list[idx] = 1
            p_list[idx] = ann_index.data[idx].label
            continue
    # here return positive gains in ablation
    return gain_list, c_list, p_list

In [5]:
expected_confidence=1
full_gain,c_list, p_list = reannotate_gain(expected_confidence,True)
# print('non zero gain number',len(np.nonzero(gain)[0]))
# print(sorted(gain[np.nonzero(gain)[0]].tolist()))
non_zero_gain_idx = np.nonzero(full_gain)[0]
# emb_list = np.array([ann_index.data[idx].I_feature for idx in non_zero_gain_idx])


100%|██████████| 50000/50000 [00:13<00:00, 3842.79it/s]


In [19]:
ann_index.data[1].I_feature

tensor([ 1.0638e+00,  3.7323e+00,  8.2429e-01,  1.9195e+00,  5.2162e+00,
         2.3690e+00,  7.9272e-01,  2.9340e+00,  6.4577e+00,  1.0876e+00,
         5.3472e+00,  5.3533e+00,  2.4846e+00,  4.8310e-01,  6.4177e-01,
        -1.9481e-01,  8.7279e-01,  3.3311e+00,  6.7676e+00,  4.9254e+00,
         5.3077e-01,  5.9769e+00, -1.9307e-01,  1.4010e+00,  4.0692e+00,
         5.0716e+00,  3.2406e+00,  3.5962e+00,  5.9761e+00, -6.4964e-02,
        -1.0025e-01,  2.1499e+00,  2.2794e-01,  4.3380e-01,  5.5643e-01,
         5.8197e+00,  3.0632e+00,  2.5042e+00,  1.5730e+00,  2.2729e+00,
         2.8092e+00,  5.7527e+00,  4.9005e+00, -2.3075e-01,  3.1336e-01,
         9.3917e-01,  4.0123e+00,  1.8686e-01,  9.9656e-02,  1.3687e-01,
         4.4088e-01,  1.2722e+00,  4.4434e+00,  1.8519e-01,  1.9006e+00,
         1.4828e-01,  2.6330e-01,  7.4683e-01, -1.8756e-01,  2.0648e-03,
         4.6652e-01, -1.4679e-02,  4.1223e+00,  7.6769e-02,  1.6504e-01,
         5.1360e+00, -1.2434e-01,  5.1920e-02, -2.8

In [28]:
emb_list = np.stack([data.I_feature for data in ann_index.data])

# np.random.choice baseline 随机选择

In [8]:
selected_samples = [41044,20892,20660,9228,21975,42229,4271,28845,18464, 22901,
  33517, 1726,26194, 49944, 14070, 16655, 31077, 8224,
  1949, 32060, 45898, 2738, 31227, 11017, 33096, 47157,
  34363, 41449, 28579, 21458, 35040, 22122, 44931, 13857,
  29046, 18920, 4799, 44263, 31903, 46107, 9959, 41442, 
  49523, 29231,  8548, 29083, 38689,  3849, 46568,
  40878, 11766, 15412,  4259,  2803, 32791,  6082, 22287,  5725,
  24817,  4231, 29442, 30134, 19606, 14942, 15330, 31440,  6207,
  49077, 46513, 39905, 30715,  9258, 16038, 21485, 20140,  3587,
  29755,  1499, 27685, 39228, 36052, 16914, 48245, 10971, 10512,
  13803, 44581, 38286, 23389, 44788, 32732, 23343, 22460, 10242,
  33312, 10950,  8485,  5996, 36392, 27392,  3837, 19135, 20325,
  41041,  7853,  1801, 45819, 10857,  9842, 24662,  7865, 45931,
  12153,  2130, 36083,  2787, 10875, 46297,  7997,  5749, 32685,
  38916, 37621, 29795,   602, 40366, 44782, 43366, 42862, 12508,
  13434,  6037, 10502, 17404,  3763,   901, 17712, 23044,  9820,
  11166,  8802, 21841, 33428, 14919, 18167, 35318, 23233, 45854,
  25118, 11699, 19154, 35265, 14034,  6319,  9008,  2457, 35593,
  19130, 24565, 15949, 24842, 24628, 22893, 27136, 20139, 47160,
  3989, 39318, 38517, 26081, 10542, 37613, 18790, 46923, 21951,
  16522, 15257, 40719, 29971, 24130, 38684, 18617, 42401, 47475,
  10495, 12689, 37567, 41550, 25504, 23859, 27259,  5347, 12156,
  12687, 34461, 35151, 12966,  3932, 35635, 49544, 49614, 28315,
  16733, 23413, 24100, 23012, 10115, 35044, 20920,  8395,
  17919, 25428, 36692, 42365, 37983, 43915, 35584, 39211, 24859,
  2387, 24174, 18450,  1124, 14161,   712, 13905, 44516,    17,
    717, 35109, 23261, 29265, 31161, 12147, 21487, 13825, 12821,
  34197, 29087, 25207, 34428, 26890, 36511,  9841, 14820, 34345,
  25787, 46899, 49953, 37514, 49492, 22720,  3546, 23788, 26008,
  43546,  7099, 36425, 16157, 45619, 43576,  8080,  1793, 26559,
  43730, 32558, 11873, 45746,  6324,  1282, 20514, 30735, 21926,
  7223, 32322, 19857, 21787, 41892, 32818, 45330,  4319, 25997,
  5108, 46541, 32717, 12993, 36738, 40846,  3679, 23128, 34754,
  36645, 37038, 25934, 33632, 13307, 13687,  1705,  5581, 24779,
  34218,  8760, 11528, 33463, 36518,  5744, 37767, 36870, 37369,
  27118,  5147, 46707,  4025, 34626, 36554, 31823, 11730, 30660,
  13618, 13673,  5867, 39220, 27540, 25311, 29196, 13733, 21211,
  45467, 47194, 33714, 36138, 21525, 39898, 44532, 27173,  8173,
  14289, 30836, 14244, 46245, 14929, 41478, 37310, 13012, 11663,
  41610, 15224, 14171, 26979, 20117, 31887, 42219, 21533, 47591,
  27560, 14838, 28834, 31030, 36292, 38445, 47933, 36403, 34895,
  19738, 49117, 43793, 26911, 22453, 15916,  2867, 29451, 15411,
  27144,   283, 39718, 46089, 12332, 29446, 40954, 44029, 10849,
  18222,  8918, 38528, 44717, 29198, 19712, 19649, 41377, 46475,
  6343,  7207, 47325, 44572, 18824, 27167, 26907, 35921, 33202,
  9803, 40659,  4220, 46650,  3454, 14614, 32216, 41804, 46570,
  26860,  1630, 34189, 47534, 34015, 47632, 44802, 48796,  7533,
  49666, 18449, 29025,  9790,  1744, 31337, 32956, 10482,  9798,
  38971, 19868, 10518, 22218, 24337, 40973, 41805, 30507, 26823,
  21753, 42999, 27753, 30602, 21739, 40389, 14692, 13089,  3738,
  3125, 10720, 19637,  3902, 48878, 27743, 41488, 32826, 22923,
  42881, 36905, 17003,  8347, 32055, 38025, 47892]

In [None]:
full_gain[selected_samples]=0
num = 500
return_list = np.random.choice(list(range(50000)), num,replace=False, p=full_gain/sum(full_gain),replace=False)
print(return_list)

# Dynamic Selection Process with Neighbour-Updated

In [18]:
# ann作为初始集：
import time
count = 0
labeled_set = set(selected_samples)
return_list = set()
gain = full_gain.copy()
sum_gain = sum(gain)
start = time.time()
while count<500 and sum_gain>1:
    # print(sum(gain),sum_gain)
    idx = np.random.choice(50000,p=gain/sum(gain))
    sum_gain -= gain[idx]
    gain[idx] = 0
    if idx in return_list or idx in labeled_set: 
        continue
    
    relabel_y = y_t[idx]
    ann_index.data[idx].label = relabel_y
    ann_index.data[idx].confidence = 1
    return_list.add(idx)
    
    I_near_labels, I_near_distances = ann_index.k_nearest_neighbour_I(ann_index.data[idx], 8, skip_one=True)
    selected_ids = I_near_labels[I_near_distances<=0.1]
    sim = 1-I_near_distances[I_near_distances<=0.1]
    classes = np.array([ann_index.data[idx].label for idx in selected_ids])
    # confidences = np.array([ann_index.data[idx].confidence for idx in selected_ids])
    
    for idx_neigh,s, cls in zip(selected_ids,sim,classes):
        if cls==relabel_y:
            if s>0.85:
                preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                new_gain = max(0, 1-new_c)
                sum_gain += (new_gain-gain[idx_neigh])
                gain[idx_neigh] = new_gain
        else:
            if s>0.85:
                preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                new_gain = max(gain[idx_neigh], 1-new_c)
                sum_gain += (new_gain-gain[idx_neigh])
                gain[idx_neigh] = new_gain
                # print('found conflicting neighbour!')
    count+=1
end = time.time()
print(end-start)

1.727950096130371


In [11]:
print(return_list)

{34818, 8196, 16392, 34835, 2069, 22552, 40989, 24605, 2078, 30753, 26660, 38, 6184, 28712, 22576, 24627, 36922, 4157, 45118, 49218, 26692, 24645, 34889, 20558, 85, 41056, 32864, 4198, 39021, 39024, 20593, 34928, 2162, 30838, 30839, 22648, 20601, 10361, 41083, 39037, 24704, 30848, 24708, 16519, 141, 32923, 49307, 28835, 39092, 12469, 39100, 41148, 39104, 45251, 39112, 22729, 22731, 20690, 28883, 35026, 47322, 8410, 232, 47349, 33016, 37117, 30976, 22785, 35072, 259, 49419, 12563, 16662, 49440, 45348, 24877, 18739, 16694, 35129, 10554, 43328, 12616, 33099, 49484, 6490, 10587, 24924, 4456, 12649, 29040, 20849, 49521, 20850, 10620, 33152, 47492, 39301, 14724, 27021, 401, 39325, 37278, 22943, 39329, 4519, 25000, 31144, 14774, 14775, 39355, 35266, 4547, 10694, 39366, 33223, 14800, 39378, 41427, 22996, 23003, 41436, 16859, 49632, 33250, 27107, 43495, 35311, 45577, 33313, 8738, 12837, 45606, 18986, 10798, 49710, 39474, 33333, 25142, 39477, 12853, 19007, 14918, 6729, 23114, 31308, 599, 6747, 4

## Dynamic Selection with larger batch step

In [21]:
# ann作为初始集：
import time
selection_batchsize = 100
count = 0
labeled_set = set(selected_samples)
return_list = set()
gain = full_gain.copy()
sum_gain = sum(gain)
start = time.time()
target_count = 500
while count<target_count and sum_gain>1:
    
    # print(sum(gain),sum_gain)
    
    idxs = np.random.choice(50000,min(selection_batchsize,target_count-count),p=gain/sum(gain),replace=False)
    keep = []
    for idx in idxs:
        sum_gain -= gain[idx]
        gain[idx] = 0
        if idx in return_list or idx in labeled_set: 
            continue
        else:
            keep.append(idx)
            
    
    for idx in keep:
        relabel_y = y_t[idx]
        ann_index.data[idx].label = relabel_y
        ann_index.data[idx].confidence = 1
        return_list.add(idx)

        I_near_labels, I_near_distances = ann_index.k_nearest_neighbour_I(ann_index.data[idx], 8, skip_one=True)
        selected_ids = I_near_labels[I_near_distances<=0.1]
        sim = 1-I_near_distances[I_near_distances<=0.1]
        classes = np.array([ann_index.data[idx].label for idx in selected_ids])
        # confidences = np.array([ann_index.data[idx].confidence for idx in selected_ids])

        for idx_neigh,s, cls in zip(selected_ids,sim,classes):
            if cls==relabel_y:
                if s>0.85:
                    preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                    preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                    new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                    new_gain = max(0, 1-new_c)
                    sum_gain += (new_gain-gain[idx_neigh])
                    gain[idx_neigh] = new_gain
            else:
                if s>0.85:
                    preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                    preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                    new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                    new_gain = max(gain[idx_neigh], 1-new_c)
                    sum_gain += (new_gain-gain[idx_neigh])
                    gain[idx_neigh] = new_gain
                    # print('found conflicting neighbour!')
    count+=len(keep)
end = time.time()
print(end-start)

0.809117317199707


In [14]:
print(return_list)

{34818, 8196, 16392, 2069, 22552, 40989, 2078, 30753, 26660, 38, 28712, 6184, 24627, 36922, 4157, 49218, 34889, 20558, 85, 39021, 39024, 10361, 20601, 39037, 30848, 24708, 16519, 141, 24723, 32923, 49307, 28835, 39092, 12469, 41148, 39100, 45251, 39112, 22729, 22731, 20690, 28883, 8410, 47322, 47349, 30976, 259, 49419, 12563, 49440, 45348, 24877, 18739, 16694, 35129, 10554, 12616, 33099, 6490, 10587, 24924, 4456, 12649, 20849, 49521, 33152, 47492, 39301, 27021, 39325, 37278, 22943, 39329, 4519, 25000, 14774, 14775, 35266, 10694, 33223, 39378, 41427, 22996, 41436, 27107, 45606, 18986, 10798, 49710, 33333, 25142, 12853, 19007, 14918, 6729, 23114, 6747, 47710, 35429, 25191, 37485, 41583, 43635, 14965, 6782, 640, 23174, 2695, 25252, 10916, 39591, 12970, 15020, 686, 33459, 8887, 6853, 43720, 41681, 37586, 21208, 2776, 43742, 8931, 29422, 47860, 17141, 31495, 47880, 41737, 19210, 41740, 804, 31527, 9009, 35645, 31550, 6979, 9035, 4940, 25420, 2907, 33638, 27505, 13177, 23417, 45956, 5000, 91

## Further add gain of distance to labeled and high confidence labels

In [None]:
average_k = 4 # hyp to control neighbour number to estimate the distance gain
labeled_set = set(selected_samples)
labeled_index = build_index.Singlemodal_index(dim=128,n=50000,submodular_k=8,num_classes=10)
conf = np.array(conf)
high_conf_set = np.where(conf>0.95)[0]
cleaner_set = set(high_conf_set).union(labeled_set)
for id in cleaner_set:
    labeled_index.add_item(build_index.DataPoint(None,emb_list[id],y_t[id],conf[id]))
dis_gain = np.zeros(50000)
for id in (set(range(50000))-cleaner_set):
    l,dis = labeled_index.k_nearest_neighbour_I(emb_list[id],average_k)
    dis_gain[id] = np.mean(dis)
dis_gain = dis_gain/(np.max(dis_gain)+1e-9) # normlize the impact of this gain

In [None]:
# ann作为初始集：
lamb = 1 # the hyp to control balance of distance gain and confidence gain. Tune between 0.0 and 1
count = 0
return_list = set()
gain = full_gain+lamb * dis_gain
while count<500 and sum(gain)>1:
    
    idx = np.random.choice(list(range(50000)),p=gain/sum(gain),replace=False)
    gain[idx] = 0
    
    if idx in return_list or idx in labeled_set: continue
    
    relabel_y = y_t[idx]
    ann_index.data[idx].label = relabel_y
    ann_index.data[idx].confidence = 1
    return_list.add(idx)
    
    I_near_labels, I_near_distances = ann_index.k_nearest_neighbour_I(ann_index.data[idx], 8, skip_one=True)
    selected_ids = I_near_labels[I_near_distances<=0.1]
    sim = 1-I_near_distances[I_near_distances<=0.1]
    classes = np.array([ann_index.data[idx].label for idx in selected_ids])
    # confidences = np.array([ann_index.data[idx].confidence for idx in selected_ids])
    
    for idx_neigh,s, cls in zip(selected_ids,sim,classes):
        if cls==relabel_y:
            if s>0.85:
                preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                gain[idx_neigh] = max(0, 1-new_c)
        else:
            if s>0.85:
                preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                gain[idx_neigh] = max(gain[idx_neigh], 1-new_c)
                # print('found conflicting neighbour!')
    count+=1

## New version of adding gain of distance to labeled and high confidence labels (For extending semi-supervised unlabeled data)

Equation should be: dis_to_labeled + dis_to_selected_unlabeled. that is to say, once we get a batch of midpoints, then we uniformly expand the boundarys. One way is to add selected_unlabeled into distance index and estimate the gain.

Another term: for samples higher than 0.5 but lower than 0.9 acc, they are pesudo labeled actually. We may need these samples for middle stage learning.

In [12]:
emb_list[0]

-0.016977593

In [30]:
average_k = 4 # hyp to control neighbour number to estimate the distance gain
labeled_set = set(selected_samples) # change it to 10000
labeled_index = build_index.Singlemodal_index(dim=128,n=50000,submodular_k=8,num_classes=10)
conf = np.array(conf)
high_conf_set = np.where(conf>0.9)[0]
cleaner_set = set(high_conf_set).union(labeled_set)
for id in cleaner_set:
    labeled_index.add_item(build_index.DataPoint(None,emb_list[id],None,0)) # 扩展集数据的embeding和对应index注意处理，可以拼接起来
dis_gain = np.zeros(50000)
for id in (set(range(50000))-cleaner_set):
    l,dis = labeled_index.k_nearest_neighbour_I(emb_list[id],average_k)
    dis_gain[id] = np.mean(dis)
dis_gain = dis_gain/(np.max(dis_gain)+1e-9) # normlize the impact of this gain

In [34]:
import time
# ann作为初始集：
lamb = 1 # the hyp to control balance of distance gain and confidence gain. For extended dataset, 
         # we do not need to tune this as previous experiments show that lamb=1 is more stable for semi-supervised setting.
gamma = 1 # set 0 or 1, lets see whether to turn off this
count = 0
return_list = set()
gain = full_gain.copy()
selection_batchsize = 100
### For semi-supervised learning with threshold 0.5, this line emphasize those semi-labeled data. 
gain[gain>0.5] = 0 
### Try using this line and without this line. NOTE: For extending supervised data, comment it out.

gain = gamma * gain + lamb * dis_gain
sum_gain = sum(gain)
start = time.time()
target_count = 500
while count<target_count and sum_gain>1:
    # print(sum(gain),sum_gain)
    idxs = np.random.choice(50000,min(selection_batchsize,target_count-count),p=gain/sum(gain),replace=False)
    keep = []
    for idx in idxs:
        sum_gain -= gain[idx]
        gain[idx] = 0
        if idx in return_list or idx in labeled_set: 
            continue
        else:
            keep.append(idx)
            
    
    for idx in keep:
        relabel_y = y_t[idx]
        ann_index.data[idx].label = relabel_y
        ann_index.data[idx].confidence = 1
        return_list.add(idx)

        I_near_labels, I_near_distances = ann_index.k_nearest_neighbour_I(ann_index.data[idx], 8, skip_one=True)
        selected_ids = I_near_labels[I_near_distances<=0.15]
        sim = 1-I_near_distances[I_near_distances<=0.15]
        # classes = np.array([ann_index.data[idx].label for idx in selected_ids])
        # confidences = np.array([ann_index.data[idx].confidence for idx in selected_ids])

        labeled_index.add_item(build_index.DataPoint(None,emb_list[id],None,0))
        
        for idx_neigh,s in zip(selected_ids,sim):
            if s>=0.85:
                # preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                # preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                # new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                l,dis = labeled_index.k_nearest_neighbour_I(emb_list[id],average_k)
                new_dis_gain = np.mean(dis)

                new_gain = max(0, gamma * full_gain[idx_neigh] + lamb * new_dis_gain)

                sum_gain += (new_gain-gain[idx_neigh])
                gain[idx_neigh] = new_gain

    count+=len(keep)
end = time.time()
print(end-start)

1.4113731384277344


## Also dynamic rechecking the bayesian gain

In [None]:
## Also dynamic rechecking the bayesian gain

import time
# ann作为初始集：
lamb = 1 # the hyp to control balance of distance gain and confidence gain. For extended dataset, 
         # we do not need to tune this as previous experiments show that lamb=1 is more stable for semi-supervised setting.
gamma = 1 # set 0 or 1, lets see whether to turn off this
count = 0
return_list = set()
gain = full_gain.copy()
selection_batchsize = 100
### For semi-supervised learning with threshold 0.5, this line emphasize those semi-labeled data. 
gain[gain>0.5] = 0 
### Try using this line and without this line. NOTE: For extending supervised data, comment it out.

gain = gamma * gain + lamb * dis_gain
sum_gain = sum(gain)
start = time.time()
target_count = 500
while count<target_count:
    # print(sum(gain),sum_gain)
    idxs = np.random.choice(50000,min(selection_batchsize,target_count-count),p=gain/sum(gain),replace=False)
    keep = []
    for idx in idxs:
        sum_gain -= gain[idx]
        gain[idx] = 0
        if idx in return_list or idx in labeled_set: 
            continue
        else:
            keep.append(idx)
            
    
    for idx in keep:
        relabel_y = y_t[idx]
        ann_index.data[idx].label = relabel_y
        ann_index.data[idx].confidence = 1
        return_list.add(idx)

        I_near_labels, I_near_distances = ann_index.k_nearest_neighbour_I(ann_index.data[idx], 8, skip_one=True)
        selected_ids = I_near_labels[I_near_distances<=0.15]
        sim = 1-I_near_distances[I_near_distances<=0.15]
        # classes = np.array([ann_index.data[idx].label for idx in selected_ids])
        # confidences = np.array([ann_index.data[idx].confidence for idx in selected_ids])

        labeled_index.add_item(build_index.DataPoint(None,emb_list[id],None,0))
        
        for idx_neigh,s in zip(selected_ids,sim):
            if s>=0.85:
                preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                gain_bayesian = max(0,1-new_c)

                l,dis = labeled_index.k_nearest_neighbour_I(emb_list[id],average_k)
                new_dis_gain = np.mean(dis)

                new_gain = max(0, gamma * gain_bayesian + lamb * new_dis_gain)

                sum_gain += (new_gain-gain[idx_neigh])
                gain[idx_neigh] = new_gain

    count+=len(keep)
end = time.time()
print(end-start)