In [1]:
from sentence_transformers import SentenceTransformer, models
from pathlib import PosixPath
import pandas as pd
import numpy as np
from typing import Iterable

In [2]:
data_dir = PosixPath('../data')
images_dir = data_dir / 'images'
prompts_df_file = data_dir / 'prompts.csv'
submission_df_file = data_dir / 'sample_submission.csv'

## 1 - Load data

In [3]:
prompts_df = pd.read_csv(prompts_df_file)
prompts_df

Unnamed: 0,imgId,prompt
0,20057f34d,hyper realistic photo of very friendly and dys...
1,227ef0887,"ramen carved out of fractal rose ebony, in the..."
2,92e911621,ultrasaurus holding a black bean taco in the w...
3,a4e1c55a9,a thundering retro robot crane inks on parchme...
4,c98f79f71,"portrait painting of a shimmering greek hero, ..."
5,d8edf2e40,an astronaut standing on a engaging white rose...
6,f27825b2c,Kaggle employee Phil at a donut shop ordering ...


In [4]:
submission_df = pd.read_csv(submission_df_file)
submission_df

Unnamed: 0,imgId_eId,val
0,20057f34d_0,0.018848
1,20057f34d_1,0.030190
2,20057f34d_2,0.072792
3,20057f34d_3,-0.000673
4,20057f34d_4,0.016774
...,...,...
2683,f27825b2c_379,0.012124
2684,f27825b2c_380,0.021575
2685,f27825b2c_381,0.030563
2686,f27825b2c_382,0.014047


In [5]:
images = list(images_dir.glob('**/*.png'))

## 2 - Load sentence embedding model

In [6]:
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

## 3 - Calaculate embeddings

In [7]:
prompt_embedding = embedding_model.encode(prompts_df['prompt']).flatten()
embedding_size = prompt_embedding.shape[0] // len(prompts_df)
assert embedding_size == 384

### 3.1 - Check with validation sample

In [8]:
assert np.all(np.isclose(submission_df['val'].values, prompt_embedding, atol=1e-07))

## 4 - Build submission csv

In [9]:
submission_df

Unnamed: 0,imgId_eId,val
0,20057f34d_0,0.018848
1,20057f34d_1,0.030190
2,20057f34d_2,0.072792
3,20057f34d_3,-0.000673
4,20057f34d_4,0.016774
...,...,...
2683,f27825b2c_379,0.012124
2684,f27825b2c_380,0.021575
2685,f27825b2c_381,0.030563
2686,f27825b2c_382,0.014047


In [10]:
submission_df.dtypes

imgId_eId     object
val          float64
dtype: object

In [11]:
submission_df.imgId_eId

0         20057f34d_0
1         20057f34d_1
2         20057f34d_2
3         20057f34d_3
4         20057f34d_4
            ...      
2683    f27825b2c_379
2684    f27825b2c_380
2685    f27825b2c_381
2686    f27825b2c_382
2687    f27825b2c_383
Name: imgId_eId, Length: 2688, dtype: object

In [12]:
output_df = pd.DataFrame(columns=['imgId_eId', 'val'])

In [13]:
prompt_embedding = embedding_model.encode(prompts_df['prompt'])

In [14]:
for prompt_idx, prompt in enumerate(prompts_df.prompt):
    imgId = prompts_df.iloc[prompt_idx].imgId
    embed_vec = prompt_embedding[prompt_idx]

    for embed_idx, val in enumerate(embed_vec):
        row_idx = (embedding_size * prompt_idx) + embed_idx
        output_df.loc[row_idx] = (f'{imgId}_{embed_idx}', val)

In [15]:
assert np.all(output_df.columns == submission_df.columns)

In [16]:
assert output_df.imgId_eId.equals(submission_df.imgId_eId)

In [17]:
assert np.all(np.isclose(submission_df['val'].values, output_df['val'].values, atol=1e-07))

## 5 - Build reusable functions

In [32]:
def embed_prompts(prompts_df: pd.DataFrame) -> np.ndarray:
    return embedding_model.encode(prompts_df['prompt'])

In [40]:
def build_submission(prompts_df: pd.DataFrame) -> pd.DataFrame:
    prompt_embedding = embed_prompts(prompts_df)
    output_df = pd.DataFrame(columns=['imgId_eId', 'val'])
    
    for prompt_idx, _ in enumerate(prompts_df.prompt):
        imgId = prompts_df.iloc[prompt_idx].imgId
        embed_vec = prompt_embedding[prompt_idx]

        for embed_idx, val in enumerate(embed_vec):
            row_idx = (embedding_size * prompt_idx) + embed_idx
            output_df.loc[row_idx] = (f'{imgId}_{embed_idx}', val)
    
    return output_df

In [41]:
output_submission_df = build_submission(prompts_df)

In [42]:
output_submission_df

Unnamed: 0,imgId_eId,val
0,20057f34d_0,0.018848
1,20057f34d_1,0.030190
2,20057f34d_2,0.072792
3,20057f34d_3,-0.000673
4,20057f34d_4,0.016774
...,...,...
2683,f27825b2c_379,0.012124
2684,f27825b2c_380,0.021575
2685,f27825b2c_381,0.030563
2686,f27825b2c_382,0.014047


In [26]:
def compare_validation(submission_df: pd.DataFrame, test_df: pd.DataFrame) -> bool:
    try: 
        column_names = np.all(test_df.columns == submission_df.columns)
        id_column = test_df.imgId_eId.equals(submission_df.imgId_eId)
        embed_values = np.all(np.isclose(submission_df['val'].values, test_df['val'].values, atol=1e-07))
    except:
        return False
    
    return all((column_names, id_column, embed_values))

In [43]:
assert compare_validation(submission_df, output_submission_df) 

In [44]:
assert not compare_validation(submission_df, output_submission_df.head(10))

In [45]:
assert not compare_validation(submission_df, pd.DataFrame())

In [46]:
bad_df = output_submission_df.copy()
bad_df.val = np.random.randn(len(bad_df))
assert not compare_validation(submission_df, bad_df)

In [47]:
bad_df = submission_df.copy()
bad_df.val = np.random.randn(len(bad_df))
assert not compare_validation(output_submission_df, bad_df)