# InfoGrowth Algorithm

In [None]:
import numpy as np
import hnswlib
from numpy import linalg as LA

def cosine_sim(x,y):
    return np.dot(x,y)/LA.norm(x)/LA.norm(y)

class DataPair:
    def __init__(self, Image_path, Text, I_feature, T_feature):
        self.Image_path=Image_path
        self.Text=Text
        self.I_feature=I_feature
        self.T_feature=T_feature
        self.index = None

    def I_sim(self, point):
        return cosine_sim(self.I_feature, point.I_feature)

    def I_distance(self, points):
        if type(points).__name__=='DataPair':
            return 1.-cosine_sim(self.I_feature, points.I_feature)
        elif type(points) is list:
            return np.array([1.-cosine_sim(self.I_feature, p.I_feature) for p in points])
        else:
            raise TypeError("data should be list or DataPair!")

    def T_distance(self, points):
        if type(points).__name__=='DataPair':
            return 1.-cosine_sim(self.T_feature,points.T_feature)
        elif type(points) is list:
            return np.array([1.-cosine_sim(self.T_feature, p.T_feature) for p in points])
        else:
            raise TypeError("data should be list or DataPair!")


class InfoGrowth:
    # New incoming data should have comparable cleaness as bootstraping set.
    # Current version handle million scale on single machine; 
    # For larger scale, consider using hyperplane (vector space separation) 
    # to further distribute samples across machines;
    # and the HNSW can be further modified or replaced for better concurrency.

    def __init__(self, initial_points=None, n=3000000, keep_seed=False, submodular_k=4):
        self.clusters = None
        self.target_n = n
        self.current_n = 0
        self.keep_seed = keep_seed
        self.dim = 256
        self.submodular_k = submodular_k

        if initial_points: 
            self.data = initial_points.copy()
            self.current_n = len(initial_points)
            if keep_seed:
                self.seeded_num = len(initial_points)
            self.dim = len(initial_points[0].I_feature)
        else:
            self.data = []
        self.submodular_gain = [(1,1)]*len(self.data)

        # initialize HNSW index
        self.I_knn_graph = hnswlib.Index(space='cosine', dim=self.dim)
        self.I_knn_graph.init_index(max_elements=n, ef_construction=100, M=48, allow_replace_deleted = False)
        self.T_knn_graph = hnswlib.Index(space='cosine', dim=self.dim)
        self.T_knn_graph.init_index(max_elements=n, ef_construction=100, M=48, allow_replace_deleted = False)
        self.precluster(initial_points)

        self.I_knn_graph.set_ef(32)
        self.T_knn_graph.set_ef(32)
        self.min_align = 0.4


    def precluster(self, initial_points):
    # Starting from some initial points (the cleaner the better) to do online selection
        if initial_points is None or initial_points==[]: return
        for idx,data in enumerate(self.data):
            data.index = idx

        for idx,data in enumerate(self.data):
            self.submodular_gain[idx] = self.submodular_func(data, True)
            self.I_knn_graph.add_items(data.I_feature, idx)
            self.T_knn_graph.add_items(data.T_feature, idx)
            

    def submodular_func(self, data, skip_one=False):
        if self.I_knn_graph.get_current_count()==0:
            return (1.,1.)
        k = min(self.I_knn_graph.get_current_count(), self.submodular_k)
        
        I_near_labels, I_near_distances = self.k_nearest_neighbour_I(data, k)
        T_near_labels, T_near_distances = self.k_nearest_neighbour_T(data, k)
        return (np.mean(I_near_distances),np.mean(T_near_distances))

    def align_score(self,data):
        if type(data).__name__=='DataPair':
            return cosine_sim(data.I_feature,data.T_feature)
        elif type(data) is list:
            return [self.align_score(x) for x in data]
        else:
            raise TypeError("data should be list or DataPair!")

    def k_nearest_neighbour_I(self, data, k):
        I_near_labels, I_near_distances = self.I_knn_graph.knn_query(data.I_feature, k)
        return I_near_labels, I_near_distances

    def k_nearest_neighbour_T(self,data, k):
        T_near_labels, T_near_distances = self.T_knn_graph.knn_query(data.T_feature, k)
        return T_near_labels, T_near_distances

    def I_to_T_k_nearest(self, data, k):
        T_near_labels, T_near_distances = self.T_knn_graph.knn_query(data.I_feature, k)
        return T_near_labels, T_near_distances

    def T_to_I_k_nearest(self, data, k):
        I_near_labels, I_near_distances = self.I_knn_graph.knn_query(data.T_feature, k)
        return I_near_labels, I_near_distances

    def add_item(self, data):
        data.index = self.current_n
        self.data.append(data)
        self.I_knn_graph.add_items(data.I_feature, self.current_n)
        self.T_knn_graph.add_items(data.T_feature, self.current_n)
        self.current_n+=1

    def replace_item(self, data, index):
        # Not used in current work but provide for future extension on replacing samples
        data_to_rep = self.data[index]
        n_index = data_to_rep.index
        data.index = self.current_n
        self.I_knn_graph.mark_deleted(n_index)
        self.T_knn_graph.mark_deleted(n_index)
        self.I_knn_graph.add_items(data.I_feature, self.current_n, replace_deleted = True)
        self.T_knn_graph.add_items(data.T_feature, self.current_n, replace_deleted = True)
        self.data[index] = data
        self.current_n+=1

    def process_item(self, data: DataPair, recaptioner = None):
        # find near clusters
        # go into nearest clusters to search near neighbour
        # calculate corresponding threshold to decide if try to add or not
        align_score = self.align_score(data)
        if recaptioner and data.Image_path in recaptioner:
            text = recaptioner[data.Image_path]['caption']
            recap_T_feature = recaptioner[data.Image_path]['text_feature']
            recap_align_score = cosine_sim(data.I_feature,recap_T_feature)
            if align_score<0.4 and recap_align_score>=0.4:
                align_score = recap_align_score
                data.Text = text
                data.T_feature = recap_T_feature

        if align_score<self.min_align:
            return

        gain = self.submodular_func(data)

        self.add_item(data)
        self.submodular_gain.append(gain)

    def final_gains(self):
        return self.submodular_gain

In [None]:
def dic_to_DataPairs(d):
    res = []
    covered_index = set()
    for idx in range(len(d['image_path'])):
        if d['image_path'][idx]['image_path']in covered_index:
            continue
        else:
            covered_index.add(d['image_path'][idx]['image_path'])
            res.append(DataPair(d['image_path'][idx]['image_path'],
                                                d['image_path'][idx]['text'],
                                                d['image_feature_array'][idx],
                                                d['text_feature_array'][idx]))
    return res

Note: To extract the blip features, run ```python3 process_features.py``` with setting corresponding file path in configs/feature_processing.yaml. You can also define your own implementation on other data structure and backbone.

In [None]:
import torch
cc3m_dict = torch.load('/data/common/cc3m/blip_features/cc3m_raw_features.pth')
cc3m_datapairs = dic_to_DataPairs(cc3m_dict)
cc3m_lookup = {}
for i,data in enumerate(cc3m_datapairs):
    cc3m_lookup[data.Image_path]=i

To initialize with a clean subset:

In [None]:
import json
clean_5 = json.load(open('top_5_percent_clean_captions.json'))
clean_5_datapairs = [cc3m_datapairs[cc3m_lookup[d['image']]] for d in clean_5]

Load the data to grow

In [None]:
random_10 = json.load(open('random_10_percent.json'))
random_10_datapairs = [cc3m_datapairs[cc3m_lookup[d['image']]] for d in random_10]

We have a full recaption from MiniGPT-4. You can substitute it with other cleaner module. Will provide a 40k one later. For real-time online cleaning that takes longer inference time, consider using batched inputs and clean those filtered samples asynchrously during adding other samples, and add cleaned samples back to queue later when relabeled.

In [None]:
recaptioner = torch.load('mini_text_feature_all.pth')
recaptioner = {recaptioner['images'][i]:{'text_feature':recaptioner['text_feature_array'][i]} for i in range(len(recaptioner['images']))}
mini_captions = json.load(open('cc3m_minigpt4.json'))
for d in mini_captions:
    if d['image'] in recaptioner:
        recaptioner[d['image']]['caption'] = d['caption']

In [None]:
from tqdm import tqdm
import numpy
gd = InfoGrowth(clean_5_datapairs,submodular_k=4)
for data in tqdm(random_10_datapairs):
    gd.process_item(data,recaptioner)

In [None]:
filelist = [{'image':data.Image_path, 'caption':data.Text} for data in gd.data]
weights = gd.submodular_gain


For dynamic two phase training, save the cleaned samples with gains for dataset loading:

In [None]:
w = [(str(x[0]),str(x[1])) for x in weights]
for i in range(len(filelist)):
    filelist[i]['gains'] = w[i]
json.dump(filelist,open('40k_samples_k4.json','w+'))

For static selection, run the following instead:

In [None]:
import numpy as np
gains = np.array([(float(x[0])+float(x[1]))/2. for x in weights])
thr = np.quantile(gains,0.99)
gains=gains/thr
gains[gains>1]=1.
gains[gains<0]=0