Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch cos_sim for community_detection? #1023

Closed
mmaybeno opened this issue Jun 23, 2021 · 13 comments
Closed

Batch cos_sim for community_detection? #1023

mmaybeno opened this issue Jun 23, 2021 · 13 comments

Comments

@mmaybeno
Copy link

I've been experimenting with the community_detection method but noticed I quickly get OOM errors if I use too large of embeddings.

Seeing how it uses cos_sim to computed all the embedding distances, do you think it would make sense to have the option for batching? I believe you will find other bottlenecks when iterating over the entries, but at least it will complete on larger embeddings.

@nreimers
Copy link
Member

Hi @mmaybeno
@yjernite has created this batched version. I sadly did not yet have time to review and test it, but hope I can do it soon and integrate it into sentence transformers.

import math
import torch
from time import time
from tqdm import tqdm
def community_detection(embeddings, threshold=0.75, min_community_size=1, init_max_size=1000):
    top_val = torch.Tensor(0, init_max_size)
    top_idx = torch.LongTensor(0, init_max_size)
    # Compute cosine similarity scores
    n_batches = math.ceil(len(embeddings) / init_max_size)
    print("computing scores")
    for b in tqdm(range(n_batches)):
        cos_scores = torch.mm(embeddings[b*init_max_size:(b+1)*init_max_size], embeddings.t())
        top_val_large, top_idx_large = cos_scores.topk(k=init_max_size, dim=-1, largest=True)
        top_val = torch.cat([top_val, top_val_large], dim=0)
        top_idx = torch.cat([top_idx, top_idx_large], dim=0)
    print("done computing scores")
    print()
    # Minimum size for a community
    top_k_values = top_val[:,:min_community_size]
    # Filter for rows >= min_threshold
    print("clustering")
    extracted_communities = []
    for i in tqdm(range(len(top_k_values))):
        if top_k_values[i][-1] >= threshold:
            new_cluster = []
            # Only check top k most similar entries
            top_idx_large = top_idx[i].tolist()
            top_val_large = top_val[i].tolist()
            if top_val_large[-1] < threshold:
                for idx, val in zip(top_idx_large, top_val_large):
                    if val < threshold:
                        break
                    new_cluster.append(idx)
            else:
                # Iterate over all entries (slow)
                cos_scores = torch.mv(embeddings, embeddings[i])
                for idx, val in enumerate(cos_scores.tolist()):
                    if val >= threshold:
                        new_cluster.append(idx)
            extracted_communities.append(new_cluster)
    # Largest cluster first
    extracted_communities = sorted(extracted_communities, key=lambda x: len(x), reverse=True)
    # Step 2) Remove overlapping communities
    unique_communities = []
    extracted_ids = set()
    for community in extracted_communities:
        add_cluster = True
        for idx in community:
            if idx in extracted_ids:
                add_cluster = False
                break
        if add_cluster:
            unique_communities.append(community)
            for idx in community:
                extracted_ids.add(idx)
    return unique_communities

@mmaybeno
Copy link
Author

mmaybeno commented Jun 23, 2021

This is great. I'll try it out on my own data to see how it looks. Thank you for the quick reply!

@ScottishFold007
Copy link

This is great. I'll try it out on my own data to see how it looks. Thank you for the quick reply!

Excuse me, did this method work?

@mmaybeno
Copy link
Author

I'm still reviewing the results. What I did notice now though, even if it can create scores in batches, it will quickly run out of memory for larger datasets when clustering. So if the implementation has results that are similar, it still can break on larger data, just in a different section of the code :).

@mmaybeno
Copy link
Author

It didn't work for me, I did find this kaggle post doing a batch process in tensorflow. I tried my best to convert it below. Swapping out the cos_sim for the batch seems to have the same result in clusters (in my own data).

def batch_generator(embeddings, batch_size=1000):
    start = 0
    stop = 0
    while stop < embeddings.shape[0] :
        stop = stop + batch_size
        yield embeddings[start:stop]
        start = stop

def batch_cos_sim(emb_a, emb_b, batch_size=500):
    cos_scores_batch = torch.zeros((emb_a.shape[0], emb_b.shape[0]), device=emb_a.device)
    n_batches = math.ceil(len(emb_a) / batch_size)
    for i, i_emb in enumerate(tqdm((batch_generator(emb_a, batch_size=batch_size)), total=n_batches)):
        for j, j_emb in enumerate(batch_generator(emb_b, batch_size=batch_size)):

            similarity = torch.sum(i_emb.unsqueeze(1) * j_emb, axis=-1)
            similarity /= torch.linalg.norm(i_emb.unsqueeze(1), axis=-1) * torch.linalg.norm(j_emb, axis=-1)
            
            cos_scores_batch[
                i * batch_size:(i + 1) * min(batch_size, i_emb.shape[0]),
                j * batch_size:(j + 1) * min(batch_size, j_emb.shape[0])
            ] = similarity
    return cos_scores_batch
def community_detection(embeddings, threshold=0.75, min_community_size=10, init_max_size=1000):
    """
    Function for Fast Community Detection
    Finds in the embeddings all communities, i.e. embeddings that are close (closer than threshold).
    Returns only communities that are larger than min_community_size. The communities are returned
    in decreasing order. The first element in each list is the central point in the community.
    """

    # Compute cosine similarity scores
    cos_scores = batch_cos_sim(embeddings, embeddings) # <--- only changed piece of code

    # Minimum size for a community
    top_k_values, _ = cos_scores.topk(k=min_community_size, largest=True)

    # Filter for rows >= min_threshold
    extracted_communities = []
    for i in range(len(top_k_values)):
        if top_k_values[i][-1] >= threshold:
            new_cluster = []

            # Only check top k most similar entries
            top_val_large, top_idx_large = cos_scores[i].topk(k=init_max_size, largest=True)
            top_idx_large = top_idx_large.tolist()
            top_val_large = top_val_large.tolist()

            if top_val_large[-1] < threshold:
                for idx, val in zip(top_idx_large, top_val_large):
                    if val < threshold:
                        break

                    new_cluster.append(idx)
            else:
                # Iterate over all entries (slow)
                for idx, val in enumerate(cos_scores[i].tolist()):
                    if val >= threshold:
                        new_cluster.append(idx)

            extracted_communities.append(new_cluster)

    # Largest cluster first
    extracted_communities = sorted(extracted_communities, key=lambda x: len(x), reverse=True)

    # Step 2) Remove overlapping communities
    unique_communities = []
    extracted_ids = set()

    for community in extracted_communities:
        add_cluster = True
        for idx in community:
            if idx in extracted_ids:
                add_cluster = False
                break

        if add_cluster:
            unique_communities.append(community)
            for idx in community:
                extracted_ids.add(idx)

    return unique_communities

There is a slight difference in results doing it in batches. Maybe I'm missing something?

rand_tensor = torch.rand((10000, 300))
np.sum(np.abs(util.cos_sim(rand_tensor, rand_tensor).cpu().numpy()) - np.abs(batch_cos_sim(rand_tensor, rand_tensor).cpu().numpy()))
# Returns non zero value, example: -0.0050705075

@mmaybeno
Copy link
Author

Nevermind, creating a large dataset and you start to run into memory problems again 😂 . The batching works, but it still creates a large n x n matrix, where n is the size of your embedding. So no matter what you will have to find a way to make that smaller.

@mmaybeno
Copy link
Author

With large datasets it may be best to use some kind of ANN to compute k. Going to look into it more over the weekend and see what I can come up with. Problem with that would be it's an added dependency for this type of operation.

@ScottishFold007
Copy link

With large datasets it may be best to use some kind of ANN to compute k. Going to look into it more over the weekend and see what I can come up with. Problem with that would be it's an added dependency for this type of operation.

Maybe you can do dimensionality reduction first and then do matrix multiplication, which might be much better for memory consumption~

@mmaybeno
Copy link
Author

mmaybeno commented Jul 6, 2021

I ended up using an ANN to approximate the distance search. It works, just not as well. I'm specifically using Annoy since it stores the index on disk and overall is efficient with memory. There is probably something I am missing to make the clusters a little better. But for now this is what I have.

import numpy as np
from annoy import AnnoyIndex
from tqdm.notebook import trange, tqdm

def community_detection(
    embeddings, threshold=0.75, min_community_size=10, init_max_size=1000
):
    """
    Function for Fast Community Detection
    Finds in the embeddings all communities, i.e. embeddings that are close (closer than threshold).
    Returns only communities that are larger than min_community_size. The communities are returned
    in decreasing order. The first element in each list is the central point in the community.
    """

    # Compute cosine similarity scores with ANN index
    index = AnnoyIndex(embeddings.shape[1], "angular")
    for i, v in tqdm(
        enumerate(embeddings.cpu().numpy()),
        total=embeddings.shape[0],
        desc="Building ANN",
    ):
        index.add_item(i, v)
    index.build(30)  # configure to change results
    index.save("test.ann")

    ## Minimum size for a community
    top_k_values = []
    for emb_i in tqdm(
        range(embeddings.shape[0]), desc="Finding Minimum Community Size"
    ):
        _, distance = index.get_nns_by_item(
            emb_i, min_community_size, include_distances=True
        )
        top_k_values.append(distance)

    top_k_values = 1 - np.array(top_k_values)

    # Filter for rows >= min_threshold
    extracted_communities = []
    total_entries = embeddings.shape[0]
    for i in tqdm(range(total_entries), desc="Extracting Communities"):
        if top_k_values[i][-1] >= threshold:
            new_cluster = []

            top_idx_large, top_val_large = index.get_nns_by_item(
                i, init_max_size, include_distances=True
            )
            top_val_large = (1 - np.array(top_val_large)).tolist()

            if top_val_large[-1] < threshold:
                for idx, val in zip(top_idx_large, top_val_large):
                    if val < threshold:
                        break

                    new_cluster.append(idx)
            else:
                # Iterate over most entries (slow)
                # Only searches 50% of possible entries or 10000
                # otherwise this can be a very expensive search
                min_most_size = min([int(total_entries * 0.5), 10000])
                idx_large, val_large = index.get_nns_by_item(
                    i, min_most_size, include_distances=True
                )
                val_large = (1 - np.array(val_large)).tolist()
                for idx, val in zip(idx_large, val_large):
                    if val >= threshold:
                        new_cluster.append(idx)

            extracted_communities.append(new_cluster)

    # Largest cluster first
    extracted_communities = sorted(
        extracted_communities, key=lambda x: len(x), reverse=True
    )

    # Step 2) Remove overlapping communities
    unique_communities = []
    extracted_ids = set()

    for community in extracted_communities:
        add_cluster = True
        for idx in community:
            if idx in extracted_ids:
                add_cluster = False
                break

        if add_cluster:
            unique_communities.append(community)
            for idx in community:
                extracted_ids.add(idx)

    return unique_communities

@randomgambit
Copy link

hello there! Very interesting thread. Sorry if this a noob question, but I have a lot of RAM, much more than my GPU card. Is there a way to move the embeddings to the CPU so that I can compute the clusters without hitting OOM errors? Thanks!

@nreimers
Copy link
Member

Yes. When the models return pytorch tensors (return_tensors=True) you can move them to CPU like this:

my_embeddings = my_embeddings.to('cpu')

Computation will then be done on the CPU

@YashGadhia
Copy link

Hi @mmaybeno @yjernite has created this batched version. I sadly did not yet have time to review and test it, but hope I can do it soon and integrate it into sentence transformers.

import math
import torch
from time import time
from tqdm import tqdm
def community_detection(embeddings, threshold=0.75, min_community_size=1, init_max_size=1000):
    top_val = torch.Tensor(0, init_max_size)
    top_idx = torch.LongTensor(0, init_max_size)
    # Compute cosine similarity scores
    n_batches = math.ceil(len(embeddings) / init_max_size)
    print("computing scores")
    for b in tqdm(range(n_batches)):
        cos_scores = torch.mm(embeddings[b*init_max_size:(b+1)*init_max_size], embeddings.t())
        top_val_large, top_idx_large = cos_scores.topk(k=init_max_size, dim=-1, largest=True)
        top_val = torch.cat([top_val, top_val_large], dim=0)
        top_idx = torch.cat([top_idx, top_idx_large], dim=0)
    print("done computing scores")
    print()
    # Minimum size for a community
    top_k_values = top_val[:,:min_community_size]
    # Filter for rows >= min_threshold
    print("clustering")
    extracted_communities = []
    for i in tqdm(range(len(top_k_values))):
        if top_k_values[i][-1] >= threshold:
            new_cluster = []
            # Only check top k most similar entries
            top_idx_large = top_idx[i].tolist()
            top_val_large = top_val[i].tolist()
            if top_val_large[-1] < threshold:
                for idx, val in zip(top_idx_large, top_val_large):
                    if val < threshold:
                        break
                    new_cluster.append(idx)
            else:
                # Iterate over all entries (slow)
                cos_scores = torch.mv(embeddings, embeddings[i])
                for idx, val in enumerate(cos_scores.tolist()):
                    if val >= threshold:
                        new_cluster.append(idx)
            extracted_communities.append(new_cluster)
    # Largest cluster first
    extracted_communities = sorted(extracted_communities, key=lambda x: len(x), reverse=True)
    # Step 2) Remove overlapping communities
    unique_communities = []
    extracted_ids = set()
    for community in extracted_communities:
        add_cluster = True
        for idx in community:
            if idx in extracted_ids:
                add_cluster = False
                break
        if add_cluster:
            unique_communities.append(community)
            for idx in community:
                extracted_ids.add(idx)
    return unique_communities

Hi @nreimers
I also faced OOM errors while doing community detection on a large dataset & hence created a batched version where one can specify the number of batches in which the embeddings are to be processed.

def community_detection(embeddings, threshold=0.75, min_community_size=10, init_max_size=1000, n_batches=1):
    """
    Function for Fast Community Detection

    Finds in the embeddings all communities, i.e. embeddings that are close (closer than threshold).

    Returns only communities that are larger than min_community_size. The communities are returned
    in decreasing order. The first element in each list is the central point in the community.

    n_batches is the number of batches in which the embeddings are processed
    Increasing it would lead to longer run time but less memory usage
    Default: 1, increase if running out of memory
    """

    #Find batch size
    batch_size = math.floor(len(embeddings)/n_batches)
    
    # Maximum size for community
    init_max_size = min(init_max_size, len(embeddings))

    extracted_communities = []

    for batch in range(n_batches):

        # Compute cosine similarity scores for batch
        if batch != n_batches-1:
            cos_scores = cos_sim(embeddings[batch*batch_size:(batch+1)*batch_size,:], embeddings)
        else:
            cos_scores = cos_sim(embeddings[batch*batch_size:len(embeddings),:], embeddings)
            
        # Minimum size for a community
        top_k_values, _ = cos_scores.topk(k=min_community_size, largest=True)

        # Filter for rows >= min_threshold
        for i in range(len(top_k_values)):
            if top_k_values[i][-1] >= threshold:
                new_cluster = []
                
                # Only check top k most similar entries
                top_val_large, top_idx_large = cos_scores[i].topk(k=init_max_size, largest=True)
                top_idx_large = top_idx_large.tolist()
                top_val_large = top_val_large.tolist()

                if top_val_large[-1] < threshold:
                    for idx, val in zip(top_idx_large, top_val_large):
                        if val < threshold:
                            break

                        new_cluster.append(idx)
                else:
                    # Iterate over all entries (slow)
                    for idx, val in enumerate(cos_scores[i].tolist()):
                        if val >= threshold:
                            new_cluster.append(idx)
                            
                extracted_communities.append(new_cluster)
                
        #Delete tensors once the batch is processed & free up memory
        del cos_scores
        del top_k_values
        gc.collect()
        torch.cuda.empty_cache()

    # Largest cluster first
    extracted_communities = sorted(extracted_communities, key=lambda x: len(x), reverse=True)

    # Step 2) Remove overlapping communities
    unique_communities = []
    extracted_ids = set()

    for community in extracted_communities:
        add_cluster = True
        for idx in community:
            if idx in extracted_ids:
                add_cluster = False
                break

        if add_cluster:
            unique_communities.append(community)
            for idx in community:
                extracted_ids.add(idx)

    return unique_communities

I have tested this & created a PR #1603 for the same.
Hope this resolves the issue for everyone else too!

@nreimers
Copy link
Member

Fixed with the most recent commit, will be part of version 2.2.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants