# Info-Coevolution
In this notebook, we will go through the usage of Info-Coevolution in classification tasks. 

Info-Coevolution is using HNSW as ANN search engine, while other options are also possible.

## Loading

In [None]:
import build_index # implementation of main code in build_index (data structure and Bayesian Fusion)
from build_index import confidence_convergence
import numpy as np
import torch
import types
from tqdm import tqdm
# import your other dependencies

dataset = None #<TODO:load_your_dataset_here> # the mixed labeled and unlabeled data
data_loader = None #<TODO:init_your_dataloader_here> 
model = None #<TODO:load_your_model_here. Model should be trained with some initial data (like 1%~10%).>
embedding_dim = None #<TODO:set your model embedding dim here>
data_len = len(dataset)
device = None #<TODO: set your device here>
num_classes = 10

ann_index = build_index.Singlemodal_index(dim=embedding_dim,n=data_len,submodular_k=16,num_classes=num_classes)

## Generate Embeddings
**Note: choose one of (a) and (b) depending on your model and data loading.**

(a) If you want to embed your code to those codebase where model is wrapped in some class:

In [None]:
def build_ann_index(self):
    """
    Inference on all data, and build up the ANN index with sample embeddings and 
    initial labels/pseudo-labels with confidence values.
    """
    self.model.eval()
    data_loader = self.data_loader # dataloader with both 
    # total_loss = 0.0
    # total_num = 0.0
    y_true = []
    y_pred = []
    y_probs = []
    y_logits = []
    confidence_list = []
    with torch.no_grad():
        for x,y in tqdm(data_loader):
            if isinstance(x, dict):
                x = {k: v.to(device) for k, v in x.items()}
            else:
                x = x.to(device)
            y = y.to(device)

            num_batch = y.shape[0]
            total_num += num_batch

            out = self.model(x) # Here in this sample, out is a dic with {'feat':embedding, 'logits', logits}
            logits = out['logits'] # TODO: let your model output logits somewhere and use here
            prob = torch.softmax(logits, dim=-1)
            feat = out['feat'].detach().cpu()
            pred = torch.max(logits, dim=-1)[1].cpu()
            conf = (prob.max(dim=-1)[0]-1./num_classes)/(1-1./num_classes)
            
            for f,label,confid in zip(feat, pred.detach().cpu(),conf.detach().cpu()):
                ann_index.add_item(build_index.DataPoint(None,f,label,confid))
            # loss = F.cross_entropy(logits, y, reduction='mean', ignore_index=-1) # print total_loss/total_num to debug
            y_true.extend(y.cpu().tolist())
            y_pred.extend(pred.tolist())
            y_logits.append(logits.cpu().numpy())
            y_probs.extend(prob.cpu().tolist())
            # total_loss += loss.item() * num_batch    ## print total_loss/total_num to debug
            confidence_list.extend(conf.cpu().tolist())
    
    return confidence_list, y_true, y_pred

model.build_ann_index = types.MethodType(build_ann_index,model)
print('building ann index')
conf, y_t, y_p = model.build_ann_index()
print('ann_index built')


1. (b) **Or** if your model is not wrapped and can be directly used in script:

In [None]:
model.eval()
total_loss = 0.0
total_num = 0.0
y_true = []
y_pred = []
y_probs = []
y_logits = []
confidence_list = []
with torch.no_grad():
    for x,y in tqdm(data_loader):
        if isinstance(x, dict):
            x = {k: v.to(device) for k, v in x.items()}
        else:
            x = x.to(device)
        y = y.to(device)

        num_batch = y.shape[0]
        total_num += num_batch

        out = model(x) # Here in this sample, out is a dic with {'feat':embedding, 'logits', logits}
        logits = out['logits'] # TODO: let your model output logits somewhere and use here
        prob = torch.softmax(logits, dim=-1)
        feat = out['feat'].detach().cpu()
        pred = torch.max(logits, dim=-1)[1].cpu()
        conf = (prob.max(dim=-1)[0]-1./num_classes)/(1-1./num_classes)
        
        for f,label,confid in zip(feat, pred.detach().cpu(),conf.detach().cpu()):
            ann_index.add_item(build_index.DataPoint(None,f,label,confid))
        # loss = F.cross_entropy(logits, y, reduction='mean', ignore_index=-1)
        y_true.extend(y.cpu().tolist())
        y_pred.extend(pred.tolist())
        y_logits.append(logits.cpu().numpy())
        y_probs.extend(prob.cpu().tolist())
        # total_loss += loss.item() * num_batch
        confidence_list.extend(conf.cpu().tolist())
conf, y_t, y_p = confidence_list, y_true, y_pred

## Calculate Gain

In [None]:

def reannotate_gain(relabel_confidence, use_ann_index=False):
    #Try directly using confidence instead of making it into entropy, so that the cosine space is more reasonable
    gain_list = np.zeros(data_len)
    c_list = np.zeros(data_len)
    p_list = np.zeros(data_len)
    for idx in tqdm(range(data_len)):
        if ann_index.data[idx].confidence<relabel_confidence:
            if use_ann_index and ann_index is not None:
                preds_dataset = ann_index.knn_pred(ann_index.data[idx], k=8, skip_one=True)
                preds_model = (ann_index.data[idx].label, ann_index.data[idx].confidence)
                merged_p,c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                p_list[idx] = merged_p

            else:
                c = ann_index.data[idx].confidence
                p_list[idx] = ann_index.data[idx].label

            gain_list[idx] = max(0, relabel_confidence-c)
            c_list[idx] = c
        else:
            c_list[idx] = 1
            p_list[idx] = ann_index.data[idx].label
            continue
    # here return positive gains in ablation
    return gain_list, c_list, p_list

expected_confidence=1
full_gain,c_list, p_list = reannotate_gain(expected_confidence,True)
# print('non zero gain number',len(np.nonzero(gain)[0]))  # debugging print
# print(sorted(gain[np.nonzero(gain)[0]].tolist()))   # debugging print
non_zero_gain_idx = np.nonzero(full_gain)[0]

## Online Selection with Dynamic Rechecking
We have single label mode and batch mode. In our in-house data setting, batch mode with a larger batch (10k samples per step) works better than small batch (100 and 1k).

In [None]:
selected_samples = None #<TODO: put your labeled sample index here>
full_gain[selected_samples]=0
target_number = 128117 #<TODO: put your target annotation number here>

### Option 1: Single label mode

In [None]:
import time
count = 0
labeled_set = set(selected_samples)
return_list = set()
gain = full_gain.copy()
sum_gain = sum(gain)
start = time.time()
stop_criterion = 1 # use this to control automatic stop. If estimated remaining total confidence gain < 1, we stop.
while count<target_number and sum_gain>stop_criterion:
    # print(sum(gain),sum_gain)
    idx = np.random.choice(data_len,p=gain/sum(gain))
    sum_gain -= gain[idx]
    gain[idx] = 0
    if idx in return_list or idx in labeled_set: 
        continue
    
    relabel_y = y_t[idx] # TODO: Check your logic: here we use true label for research simulation/verification. 
                         # For real usage, substitute api or generate the sample indices for annotation.
    
    ann_index.data[idx].label = relabel_y
    ann_index.data[idx].confidence = 1
    return_list.add(idx)
    
    # Dynamic rechecking
    I_near_labels, I_near_distances = ann_index.k_nearest_neighbour_I(ann_index.data[idx], 8, skip_one=True)
    selected_ids = I_near_labels[I_near_distances<=0.1]
    sim = 1-I_near_distances[I_near_distances<=0.1]
    classes = np.array([ann_index.data[idx].label for idx in selected_ids])
    # confidences = np.array([ann_index.data[idx].confidence for idx in selected_ids])
    
    for idx_neigh,s, cls in zip(selected_ids,sim,classes):
        if cls==relabel_y:
            if s>0.85:
                preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                new_gain = max(0, 1-new_c)
                sum_gain += (new_gain-gain[idx_neigh])
                gain[idx_neigh] = new_gain
        else:
            if s>0.85:
                preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                new_gain = max(gain[idx_neigh], 1-new_c)
                sum_gain += (new_gain-gain[idx_neigh])
                gain[idx_neigh] = new_gain
                # print('found conflicting neighbour!')
    count+=1
end = time.time()
print(end-start)

### Option 2: Batch Labelling Mode

In [None]:
# ann作为初始集：
import time
selection_batchsize = 100 # <TODO: Adjust according to your data.>
count = 0
labeled_set = set(selected_samples)
return_list = set()
gain = full_gain.copy()
sum_gain = sum(gain)
stop_criterion = 1 # use this to control automatic stop. If estimated remaining total confidence gain < 1, we stop.
start = time.time()
while count<target_number and sum_gain>stop_criterion:
    
    # print(sum(gain),sum_gain)
    
    idxs = np.random.choice(50000,min(selection_batchsize,target_number-count),p=gain/sum(gain),replace=False)
    keep = []
    for idx in idxs:
        sum_gain -= gain[idx]
        gain[idx] = 0
        if idx in return_list or idx in labeled_set: 
            continue
        else:
            keep.append(idx)
            
    
    for idx in keep:
        relabel_y = y_t[idx]
        ann_index.data[idx].label = relabel_y
        ann_index.data[idx].confidence = 1
        return_list.add(idx)

        I_near_labels, I_near_distances = ann_index.k_nearest_neighbour_I(ann_index.data[idx], 8, skip_one=True)
        selected_ids = I_near_labels[I_near_distances<=0.1]
        sim = 1-I_near_distances[I_near_distances<=0.1]
        classes = np.array([ann_index.data[idx].label for idx in selected_ids])
        # confidences = np.array([ann_index.data[idx].confidence for idx in selected_ids])

        for idx_neigh,s, cls in zip(selected_ids,sim,classes):
            if cls==relabel_y:
                if s>0.85:
                    preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                    preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                    new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                    new_gain = max(0, 1-new_c)
                    sum_gain += (new_gain-gain[idx_neigh])
                    gain[idx_neigh] = new_gain
            else:
                if s>0.85:
                    preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                    preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                    new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                    new_gain = max(gain[idx_neigh], 1-new_c)
                    sum_gain += (new_gain-gain[idx_neigh])
                    gain[idx_neigh] = new_gain
                    # print('found conflicting neighbour!')
    count+=len(keep)
end = time.time()
print(end-start)

### For extending semi-supervised unlabeled data
We can use this part of code when the target number is not achieved and we still want more samples, or when we want to estimate semi-supervised data for annotation.

Equation is: gain=dis_to_labeled + dis_to_selected_high_conf_unlabeled. that is to say, once we get a batch of midpoints, then we uniformly expand the boundarys. One way is to add selected_unlabeled into distance index and estimate the gain.

Another point: for samples with confidence higher than 0.5 but lower than 0.9 acc, they are pesudo labeled actually. We may need these samples for middle stage learning.

In [None]:
average_k = 4 # hyp to control neighbour number to estimate the distance gain
emb_list = np.array([ann_index.data[idx].I_feature for idx in non_zero_gain_idx])
labeled_set = set(selected_samples) # change it to 10000
labeled_index = build_index.Singlemodal_index(dim=embedding_dim,n=data_len,submodular_k=16,num_classes=num_classes)
conf = np.array(conf)
high_conf_set = np.where(conf>0.9)[0]
cleaner_set = set(high_conf_set).union(labeled_set)
for id in cleaner_set:
    labeled_index.add_item(build_index.DataPoint(None,emb_list[id],None,0)) # 扩展集数据的embeding和对应index注意处理，可以拼接起来
dis_gain = np.zeros(data_len)
for id in (set(range(data_len))-cleaner_set):
    l,dis = labeled_index.k_nearest_neighbour_I(emb_list[id],average_k)
    dis_gain[id] = np.mean(dis)
dis_gain = dis_gain/(np.max(dis_gain)+1e-9) # normlize the impact of this gain

In [None]:
import time
# ann作为初始集：
lamb = 1 # the hyp to control balance of distance gain and confidence gain. For extended dataset, 
         # we do not need to tune this as previous experiments show that lamb=1 is more stable for semi-supervised setting.
gamma = 1 # in our tested setting, both being 1 is the best for semi-supervised data.
count = 0
return_list = set()
gain = full_gain.copy()
selection_batchsize = 100
### For semi-supervised learning with threshold 0.5, this line emphasize those semi-labeled data. 
gain[gain>0.5] = 0 
### Try using this line and without this line. NOTE: For extending supervised data, comment it out.

gain = gamma * gain + lamb * dis_gain
sum_gain = sum(gain)
start = time.time()
target_count = 500
while count<target_count and sum_gain>1:
    # print(sum(gain),sum_gain)
    idxs = np.random.choice(50000,min(selection_batchsize,target_count-count),p=gain/sum(gain),replace=False)
    keep = []
    for idx in idxs:
        sum_gain -= gain[idx]
        gain[idx] = 0
        if idx in return_list or idx in labeled_set: 
            continue
        else:
            keep.append(idx)
    
    for idx in keep:
        relabel_y = y_t[idx] # TODO: check your logic. If doing annotation selection instead of selection, 
                             # send the keep list with deduplication for annotation.
        ann_index.data[idx].label = relabel_y
        ann_index.data[idx].confidence = 1
        return_list.add(idx)

        I_near_labels, I_near_distances = ann_index.k_nearest_neighbour_I(ann_index.data[idx], 8, skip_one=True)
        selected_ids = I_near_labels[I_near_distances<=0.15]
        sim = 1-I_near_distances[I_near_distances<=0.15]
        # classes = np.array([ann_index.data[idx].label for idx in selected_ids])
        # confidences = np.array([ann_index.data[idx].confidence for idx in selected_ids])

        labeled_index.add_item(build_index.DataPoint(None,emb_list[id],None,0))
        
        for idx_neigh,s in zip(selected_ids,sim):
            if s>=0.85:
                # preds_dataset = ann_index.knn_pred(ann_index.data[idx_neigh], k=8, skip_one=True)
                # preds_model = ann_index.data[idx_neigh].label,ann_index.data[idx_neigh].confidence
                # new_p,new_c = confidence_convergence(preds_dataset,preds_model,conf_decay=True)
                l,dis = labeled_index.k_nearest_neighbour_I(emb_list[id],average_k)
                new_dis_gain = np.mean(dis)

                new_gain = max(0, gamma * full_gain[idx_neigh] + lamb * new_dis_gain)

                sum_gain += (new_gain-gain[idx_neigh])
                gain[idx_neigh] = new_gain

    count+=len(keep)
end = time.time()
print(end-start)

# return_list contains the added samples to train. 

## Save and Load
The index is pickle serilizable, you can use pickle to save and load the index for next time usage, if not updating the model.

In [None]:
import pickle

# save
with open('my_index.pkl', 'wb') as file:
    pickle.dump(ann_index, file)

#load
with open('my_index.pkl', 'rb') as file:
    ann_index = pickle.load(file)