In [6]:
import os, warnings
import wandb

import pandas as pd
from fastai.vision.all import *
from sklearn.model_selection import StratifiedGroupKFold

import params # local import
warnings.filterwarnings("ignore")

In [7]:
run = wandb.init(project=params.WANDB_PROJECT, entity=params.ENTITY, job_type="data_split")

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [8]:
# retrieve the latest version of the eda job run
raw_data_artifact = run.use_artifact(f"{params.RAW_DATA_AT}:latest", type="raw_data")
path = Path(raw_data_artifact.download())

[34m[1mwandb[0m: Downloading large artifact bdd_simple_1k:latest, 813.77MB. 4007 files... 
[34m[1mwandb[0m:   4007 of 4007 files downloaded.  
Done. 0:0:58.3


In [9]:
path.ls()

(#5) [Path('artifacts/bdd_simple_1k:v0/LICENSE.txt'),Path('artifacts/bdd_simple_1k:v0/eda_table.table.json'),Path('artifacts/bdd_simple_1k:v0/images'),Path('artifacts/bdd_simple_1k:v0/media'),Path('artifacts/bdd_simple_1k:v0/labels')]

### Spliting the dataset using `StratifiedGroupKFold`

This process is similar to using a regular train-test split or k-fold cross-validation, but with the additional consideration of groups and stratification. `StratifiedGroupKFold` is particularly useful when you have a dataset with a large number of groups and/or imbalanced class distributions. By ensuring that each fold has a representative sample of each group and a similar distribution of target labels, you can reduce the risk of overfitting and obtain more accurate estimates of model performance.

In [10]:
# first lets get the table
eda_table = raw_data_artifact.get("eda_table")

[34m[1mwandb[0m: Downloading large artifact bdd_simple_1k:latest, 813.77MB. 4007 files... 
[34m[1mwandb[0m:   4007 of 4007 files downloaded.  
Done. 0:0:13.7


In [11]:
# group data based on geographical location which is the P1 column
groups = eda_table.get_column("P1")
target_label = eda_table.get_column("bicycle")
fnames = eda_table.get_column("File_name")

In [12]:
df = pd.DataFrame()
df["File_Name"] = fnames
df["fold"] = -1

In [13]:
df.head()

Unnamed: 0,File_Name,fold
0,0027eed2-09c90000,-1
1,0027eed2-09c90001,-1
2,00aad4a0-ee8135fe,-1
3,00d79c0a-23befe54,-1
4,00e69ee0-9656df95,-1


In [14]:
cv = StratifiedGroupKFold(n_splits=10) # using 10 folds
for i, (_, test_idx) in enumerate(cv.split(fnames, target_label, groups)):
    df.loc[test_idx, ["fold"]] = i

In [15]:
df.head()

Unnamed: 0,File_Name,fold
0,0027eed2-09c90000,4
1,0027eed2-09c90001,4
2,00aad4a0-ee8135fe,5
3,00d79c0a-23befe54,6
4,00e69ee0-9656df95,7


In [16]:
# make 80% training data, 10% validation and 10% for testing
df["Stage"] = "train"
df.loc[df.fold == 0, ["Stage"]] = "test"
df.loc[df.fold == 1, ["Stage"]] = "valid"
del df["fold"]
df.Stage.value_counts()

train    800
test     100
valid    100
Name: Stage, dtype: int64

In [17]:
df.to_csv("data_split.csv", index=False)

In [18]:
# save dataset
processed_data_artifact = wandb.Artifact(params.PROCESSED_DATA_AT, type="split_data")

In [19]:
processed_data_artifact.add_file("data_split.csv")
processed_data_artifact.add_dir(path)

[34m[1mwandb[0m: Adding directory to artifact (./artifacts/bdd_simple_1k:v0)... Done. 10.6s


In [20]:
data_split_table = wandb.Table(dataframe=df[["File_Name", "Stage"]])

In [21]:
join_table = wandb.JoinedTable(eda_table, data_split_table, "File_Name")

In [22]:
processed_data_artifact.add(join_table, "eda_table_data_slit")

ArtifactManifestEntry(path='eda_table_data_slit.joined-table.json', digest='WbEb8a/+8SosXC5YYEjRgw==', size=127, local_path='/root/.local/share/wandb/artifacts/staging/tmpm45ts21h')

In [23]:
wandb.log_artifact(processed_data_artifact)
wandb.finish()