In [None]:
import os
import numpy as np
import pandas as pd
import json
import itertools
import pickle
import matplotlib.pyplot as plt
import torch
from sentence_transformers import SentenceTransformer, util

## Load data

In [None]:
# Local data source path
text_source_path = '' 
var_source_path = ''

In [None]:
text = pd.read_csv(os.path.join(text_source_path, 'g2022_adid_text.csv.gz'))
var = pd.read_csv(os.path.join(var_source_path, 'g2022_adid_var.csv.gz'), usecols=['ad_id', 'wmp_creative_id'])

In [None]:
df = var.merge(text, on='ad_id', how='left').drop_duplicates()

In [None]:
unique = df[['wmp_creative_id']].drop_duplicates(subset=['wmp_creative_id'], keep='last').reset_index(drop=True)

In [None]:
unique.to_csv('../input_data/unique_creative_id_index_mapping.csv', index=False)

In [None]:
'''
Mapping between corpus embedding index and creative ID 
'''
unique = pd.read_csv('../input_data/unique_creative_id_index_mapping.csv')

## Load original model  (Skip this section to load the reordered embeddings)

In [None]:
corpus_embeddings0 = torch.load('../model/corpus_embedding_google2022_unique_lite.pt', map_location=torch.device('cpu'))

In [None]:
corpus_embeddings0.shape

## reorder corpus embedding indices to match cid ordering

**This would make the pairwise similarity computation much faster**

(Because pandas indexing is slow)

In [None]:
unique.loc[:, 'cid_index'] = unique.wmp_creative_id.apply(lambda x: int(x.lstrip('cid_')))

In [None]:
len(unique.cid_index.unique())

In [None]:
reordered = torch.zeros(corpus_embeddings.shape[0], corpus_embeddings.shape[1], dtype=corpus_embeddings.dtype)

In [None]:
assert reordered.shape == corpus_embeddings.shape

In [None]:
indices = unique.cid_index.tolist()

indices = torch.tensor([indices for i in range(corpus_embeddings.shape[1])])

In [None]:
indices = indices.T

In [None]:
torch.save(reordered, 'model/corpus_embedding_google2022_unique_lite_reordered.pt') 

## Import the reordered embeddings

In [None]:
'''
This model should be loaded for google data analysis: calculation performs much faster
'''
corpus_embeddings = torch.load('model/corpus_embedding_google2022_unique_lite_reordered.pt')