# Imports

In [None]:
import os
import wandb
import tempfile
import pandas as pd

from pathlib import Path
from sklearn.model_selection import StratifiedGroupKFold

import project_config as pc

# Retrieve artifact

In [None]:
# Init run
run = wandb.init(project=pc.WANDB_PROJECT, 
                 entity=pc.WANDB_ENTITY, 
                 dir=pc.WANDB_LOCAL_LOGS_PATH,
                 job_type='data_processing')

# Download latest dataset version (if not already downloaded)
dataset_artifact = run.use_artifact(f'{pc.DATASET_ARTIFACT_NAME}:latest')
dataset_dir = pc.WANDB_LOCAL_ARTIFACTS_PATH+Path(dataset_artifact._default_root()).stem
if not os.path.exists(dataset_dir):
	_ = dataset_artifact.download(root=dataset_dir)

# Data processing (filtering, cleaning, etc)

In [None]:
# Read dataframe
df = pd.read_csv(dataset_dir + '/data.csv')

# Filter dataframe
df = df[df['label_breed'] != 'Abyssinian'].reset_index(drop=True)
df

# Data split

In [6]:
# Splits parameters
n_splits = 10
valid_splits_ids = [0, 1]
test_splits_ids = [2]

# Select columns for splits
X = df['file_path'].values
y = df['label_breed'].values
groups = df['group'].values

# Create splits
cv = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=18)
df['split'] = -1
for i, (train_idxs, valid_idxs) in enumerate(cv.split(X, y, groups)):
	df.loc[valid_idxs, ['split']] = i

# Assign splits
df['is_valid'] = df['split'].apply(lambda x: x in valid_splits_ids)
df['is_test'] = df['split'].apply(lambda x: x in test_splits_ids)
df.drop(columns=['split'], inplace=True)

# Separate dataframe
df_train_valid = df[~df['is_test']].drop(columns=['is_test'])
df_test = df[df['is_test']].drop(columns=['is_test', 'is_valid'])

# Logs

In [None]:
# Create new artifact version
new_dataset_artifact = dataset_artifact.new_draft()

# Save dataframes in temporary files and add them to the artifact
with tempfile.TemporaryDirectory() as temp_dir:
	df_train_valid.to_csv(temp_dir+'/dataset.csv', index=False)
	new_dataset_artifact.add_file(temp_dir+'/dataset.csv', 'dataset.csv')
	if len(df_test) > 0:
		df_test.to_csv(temp_dir+'/dataset_test.csv', index=False)
		new_dataset_artifact.add_file(temp_dir+'/dataset_test.csv', 'dataset_test.csv')

# Log artifact and finish run
run.log_artifact(new_dataset_artifact)
run.finish()