In [3]:
import os
from typing import Optional, Literal

import optuna
import json
import pandas as pd
import numpy as np
import plotly.io as pio
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from deepdiff import DeepDiff
import itertools
import torch
from typing_extensions import LiteralString

from settings.config import *
from src.commons.exp_config import ExpConfig

pio.templates.default = "plotly"

idx = pd.IndexSlice

In [12]:
exp_dir = os.path.join(EXPERIMENTS_PATH, "basic")
exp_name = "resnets_training_BM_F1_INGS"
use_val_data = False  # if True, use validation data, otherwise use train data

#### Collection of metrics data

In [6]:
# loading data from local saves [general code]
TRIAL_NAME_PREFIX = "trial_"

trials_metrics_array = []
for trial_id in range(0, 3):  #3 for the study {10k}_sel, otherwise 4
    trial_dir = os.path.join(exp_dir, exp_name, f"{TRIAL_NAME_PREFIX}{trial_id}")

    metrics = pd.read_csv(os.path.join(trial_dir, "metrics.csv"))
    
    #remove params columns
    metrics_columns = [col for col in metrics.columns if "val" in col] + [col for col in metrics.columns if "train" in col] + ["epoch", "step"] 
    metrics = metrics[metrics_columns]
    
    metrics = metrics.dropna(how="all", axis=0, subset=[c for c in metrics.columns if c not in ['step']])  # remove NaN rows (used as separators)
    metrics = metrics.ffill().bfill()  # fill NaN values with previous values
    metrics = metrics.groupby("epoch").last()  # take last value of each epoch
    metrics.index = metrics.index.astype(int)
    
    trials_metrics_array.append(metrics.drop(columns=["step"]))

In [78]:
full_metrics = pd.concat({i: trial for i, trial in enumerate(trials_metrics_array)}, axis=0)
# full_metrics['val_loss'] = full_metrics['val_loss'].fillna(110)  # for last study's bug
full_metrics.index = full_metrics.index.rename(["trial", "epoch"])
full_metrics.to_csv(os.path.join(exp_dir, exp_name, "full_metrics.csv"))

#### extraction of f1 labels metrics and related statistics

In [86]:
trials_labels_stats, trials_f1 = [], []

set_src = "val" if use_val_data else "train"
f1_columns = [col for col in full_metrics.columns if ("f1" in col and set_src in col)]

for trial_metrics in trials_metrics_array:
    trial_metrics = trial_metrics[f1_columns]
    trial_metrics.columns = [col.replace(f"{set_src}_f1_label_", "") for col in trial_metrics.columns]
    trial_metrics.columns = trial_metrics.columns.values.astype(int)
    trial_metrics.columns.name = "label"
    trial_metrics = trial_metrics.sort_index(axis=1)
    
    trials_f1.append(trial_metrics)
    trials_labels_stats.append(trial_metrics.describe().loc[["mean", "std", "max"]])


In [87]:
full_f1 = pd.concat({i: trial for i, trial in enumerate(trials_f1)}, axis=0)
full_f1.index.names = ["trial", "epoch"]
full_f1.columns.name = "label"
full_f1.to_csv(os.path.join(exp_dir, exp_name, f"full_f1_{set_src}.csv"))

In [88]:
full_labels_stats = pd.concat({i: trial for i, trial in enumerate(trials_labels_stats)}, axis=0)
full_labels_stats.index.names = ["trial", "stat"]
full_labels_stats.columns.name = "label"
full_labels_stats.to_csv(os.path.join(exp_dir, exp_name, "full_labels_stats.csv"))

ValueError: No objects to concatenate

In [28]:
full_labels_stats

Unnamed: 0_level_0,label,0,1,2,3,4,5,6,7,8,9,...,173,174,175,176,177,178,179,180,181,182
trial,stat,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
0,mean,9.208702e-07,5e-06,0.0,0.0,3e-05,0.0,9.508171e-07,9.065931e-07,0.0,0.128403,...,0.0,0.005086,0.0,0.008948,9.065931e-07,0.017114,0.0,0.058297,2.8e-05,0.0
0,std,5.824095e-06,3.4e-05,0.0,0.0,0.00014,0.0,6.013496e-06,5.733798e-06,0.0,0.111027,...,0.0,0.007471,0.0,0.014031,5.733798e-06,0.023531,0.0,0.055745,0.000151,0.0
0,max,3.683481e-05,0.000217,0.0,0.0,0.00078,0.0,3.803269e-05,3.626373e-05,0.0,0.359696,...,0.0,0.029713,0.0,0.048984,3.626373e-05,0.079242,0.0,0.18221,0.000936,0.0
1,mean,0.0,3.8e-05,0.0,2e-06,4.1e-05,0.0,0.0,2.762951e-06,0.0,0.168982,...,4e-06,0.009098,4e-06,0.008542,7.184934e-06,0.028568,0.0,0.065021,0.002474,0.0
1,std,0.0,0.000192,0.0,1.2e-05,0.000124,0.0,0.0,1.747443e-05,0.0,0.098793,...,2.3e-05,0.009212,2.4e-05,0.010135,4.544151e-05,0.029798,0.0,0.043597,0.004357,0.0
1,max,0.0,0.00117,0.0,7.5e-05,0.000426,0.0,0.0,0.000110518,0.0,0.315699,...,0.000143,0.033577,0.000152,0.032993,0.0002873974,0.092447,0.0,0.151003,0.016373,0.0
2,mean,2.0099e-06,1e-05,0.0,6e-06,2.3e-05,9.065931e-07,0.0,0.0,0.0,0.091205,...,0.0,0.001079,4e-06,0.000431,0.0,0.001144,2e-06,0.032718,0.0,0.0
2,std,1.271172e-05,6.4e-05,0.0,3.8e-05,0.000143,5.733798e-06,0.0,0.0,0.0,0.062847,...,0.0,0.00123,2.4e-05,0.000597,0.0,0.002066,1.1e-05,0.020087,0.0,0.0
2,max,8.0396e-05,0.000402,0.0,0.000243,0.000901,3.626373e-05,0.0,0.0,0.0,0.184825,...,0.0,0.003885,0.00015,0.002017,0.0,0.007407,7e-05,0.066379,0.0,0.0
3,mean,0.0,0.0,0.0,0.0,2e-05,0.0,0.0,0.0,1e-06,0.140395,...,0.0,0.007492,6e-06,0.010708,0.0,0.023333,0.0,0.068682,0.000181,0.0


In [29]:
stats = full_labels_stats.index.get_level_values(1).unique().values.tolist()
stats_mean = []
for stat in stats:
    stats_mean.append(full_labels_stats.xs(stat, level="stat").mean(axis=1))

labels_avg_stats = pd.concat({f"avg_labels_{stat_name}": stat for stat_name, stat in zip(stats, stats_mean)}, axis=1)
labels_avg_stats # stats averaged over labels

Unnamed: 0_level_0,avg_labels_mean,avg_labels_std,avg_labels_max
trial,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0.044877,0.021567,0.089498
1,0.051134,0.018524,0.08037
2,0.036833,0.010296,0.052092
3,0.047562,0.023069,0.095194


In [30]:
trial_avg_stats = full_labels_stats.groupby("stat").mean()  # stats averaged over trials
trial_avg_stats

label,0,1,2,3,4,5,6,7,8,9,...,173,174,175,176,177,178,179,180,181,182
stat,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
max,2.93077e-05,0.000447,0.0,7.9e-05,0.000722,9.065931e-06,9.508171e-06,3.669544e-05,1.008194e-05,0.306928,...,3.584849e-05,0.026777,0.000133,0.035851,8.1e-05,0.069515,1.74553e-05,0.154489,0.004884,0.0
mean,7.326925e-07,1.3e-05,0.0,2e-06,2.9e-05,2.266483e-07,2.377043e-07,9.17386e-07,2.520485e-07,0.132246,...,8.962124e-07,0.005689,3e-06,0.007157,2e-06,0.01754,4.363825e-07,0.056179,0.000671,0.0
std,4.633955e-06,7.2e-05,0.0,1.3e-05,0.000132,1.43345e-06,1.503374e-06,5.802058e-06,1.594095e-06,0.097185,...,5.668145e-06,0.007232,2.1e-05,0.010287,1.3e-05,0.021594,2.759925e-06,0.045501,0.001261,0.0


#### Loading label encoder

In [13]:
# laoding
label_encoder_configs = []
for trial_id in range(0, 3): 
    checkpoint_data = torch.load(str(os.path.join(exp_dir, exp_name, f"{TRIAL_NAME_PREFIX}{trial_id}", "best_model.ckpt")), weights_only=False)
    trial_config = ExpConfig.load_from_ckpt_data(checkpoint_data)
    label_encoder_configs.append(trial_config.label_encoder)

In [14]:
# check if they are the same
assert all([DeepDiff(label_encoder_configs[0], labels_encoder_config).to_dict() == {} for labels_encoder_config in label_encoder_configs[1:]])

In [15]:
label_encoder = label_encoder_configs[0]['type'].load_from_config(label_encoder_configs[0])
label_encoder

<src.data_processing.labels_encoders.MultiLabelBinarizerRobust at 0x1a9d1ce0f40>

#### Plotting

In [11]:
full_f1_pretty = full_f1.copy()
full_f1_pretty.columns = [label_encoder.decode_labels([label])[0][0] for label in full_f1.columns]

NameError: name 'full_f1' is not defined

##### Plotting f1 scores for each label over epochs

In [99]:
def make_scatter(trial_data):
    trial_data = trial_data.reset_index(drop=True)
    return go.Scatter(x=trial_data.index, y=trial_data.values, mode="lines", name=f"trial_{trial_data.name}")

n_cols = 6
n_rows = int(np.ceil((len(full_f1.columns) - 1) / n_cols))

fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=full_f1_pretty.columns[:-1], shared_yaxes=True, shared_xaxes=True).update_yaxes(range=[0, 1])
for i, column in enumerate(full_f1_pretty.columns[:-1]):
    row = i // n_cols + 1
    col = i % n_cols + 1
    
    scatters_plot = full_f1_pretty[column].groupby(level=0).apply(lambda trial_data: make_scatter(trial_data)).values
    for scatter_fig in scatters_plot:
        fig.add_trace(scatter_fig, row=row, col=col)

ValueError: 
The 'rows' argument to make_subplots must be an int greater than 0.
    Received value of type <class 'int'>: 0

In [39]:
fig.update_layout(xaxis_title="Epoch", yaxis_title="F1 score", width=2000, height=3000).show(renderer="browser")

In [37]:
full_labels_stats

Unnamed: 0_level_0,label,0,1,2,3,4,5,6,7,8,9,...,173,174,175,176,177,178,179,180,181,182
trial,stat,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
0,mean,9.208702e-07,5e-06,0.0,0.0,3e-05,0.0,9.508171e-07,9.065931e-07,0.0,0.128403,...,0.0,0.005086,0.0,0.008948,9.065931e-07,0.017114,0.0,0.058297,2.8e-05,0.0
0,std,5.824095e-06,3.4e-05,0.0,0.0,0.00014,0.0,6.013496e-06,5.733798e-06,0.0,0.111027,...,0.0,0.007471,0.0,0.014031,5.733798e-06,0.023531,0.0,0.055745,0.000151,0.0
0,max,3.683481e-05,0.000217,0.0,0.0,0.00078,0.0,3.803269e-05,3.626373e-05,0.0,0.359696,...,0.0,0.029713,0.0,0.048984,3.626373e-05,0.079242,0.0,0.18221,0.000936,0.0
1,mean,0.0,3.8e-05,0.0,2e-06,4.1e-05,0.0,0.0,2.762951e-06,0.0,0.168982,...,4e-06,0.009098,4e-06,0.008542,7.184934e-06,0.028568,0.0,0.065021,0.002474,0.0
1,std,0.0,0.000192,0.0,1.2e-05,0.000124,0.0,0.0,1.747443e-05,0.0,0.098793,...,2.3e-05,0.009212,2.4e-05,0.010135,4.544151e-05,0.029798,0.0,0.043597,0.004357,0.0
1,max,0.0,0.00117,0.0,7.5e-05,0.000426,0.0,0.0,0.000110518,0.0,0.315699,...,0.000143,0.033577,0.000152,0.032993,0.0002873974,0.092447,0.0,0.151003,0.016373,0.0
2,mean,2.0099e-06,1e-05,0.0,6e-06,2.3e-05,9.065931e-07,0.0,0.0,0.0,0.091205,...,0.0,0.001079,4e-06,0.000431,0.0,0.001144,2e-06,0.032718,0.0,0.0
2,std,1.271172e-05,6.4e-05,0.0,3.8e-05,0.000143,5.733798e-06,0.0,0.0,0.0,0.062847,...,0.0,0.00123,2.4e-05,0.000597,0.0,0.002066,1.1e-05,0.020087,0.0,0.0
2,max,8.0396e-05,0.000402,0.0,0.000243,0.000901,3.626373e-05,0.0,0.0,0.0,0.184825,...,0.0,0.003885,0.00015,0.002017,0.0,0.007407,7e-05,0.066379,0.0,0.0
3,mean,0.0,0.0,0.0,0.0,2e-05,0.0,0.0,0.0,1e-06,0.140395,...,0.0,0.007492,6e-06,0.010708,0.0,0.023333,0.0,0.068682,0.000181,0.0


##### Boxplot of f1 scores averaged over trials

In [42]:
fig = px.box(trial_avg_stats.T, title="F1 labels scores averaged over different trials", points="all")
fig.show()

##### Boxplot of f1 scores for set trial

In [44]:
trial_id = 2
fig = px.box(full_labels_stats.xs(trial_id, level="trial").T, title=f"F1 scores for trial {trial_id}", points="all")
fig.show()

In [50]:
n_rows, n_cols = 1, 4

box_subplots = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=[f"Trial {trial_id}" for trial_id in range(4)], shared_yaxes=True, shared_xaxes=True).update_yaxes(range=[0, 1])
for trial_id in range(4):
    row = trial_id // n_cols + 1
    col = trial_id % n_cols + 1
    
    trial_stats = full_labels_stats.xs(trial_id, level="trial")
    
    for stat in trial_stats.index:
        box_subplots.add_trace(go.Box(y=trial_stats.loc[stat].values, name=stat), row=row, col=col)
    
box_subplots.update_layout(width=1600, height=800).show(renderer="browser")

### Selection of ingredients

In [51]:
def select_labels(labels_stats_full, labels_stats_avg, trial_id: Optional[int] = 0, stat_criterion: Literal["mean", "std", "max"] = "max",
                  threshold: Literal["avg", "q1", "q2", "q3"] | float = "q3", plot: bool = True):
    stats_src = labels_stats_avg if trial_id is None else labels_stats_full.xs(trial_id, level="trial")
    labels_f1 = stats_src.loc[stat_criterion]
    
    if plot:
        px.box(labels_f1, title=f"F1 scores for labels (criterion: {stat_criterion}) for trial {trial_id}").show()
    
    if isinstance(threshold, str):
        if threshold == "avg":
            threshold = labels_f1.mean()
        else:
            threshold = labels_f1.quantile(0.25 * float(threshold[1]))
    
    labels_selected = labels_f1.loc[labels_f1 >= threshold].index
    return labels_selected.values

def print_selected_labels(labels, label_encoder):
    print(f"Selected labels:\n{labels}\n")
    print(f"Number of selected labels: {len(labels)}")
    print(f"Selected Labels translated: \n{label_encoder.decode_labels([labels])[0]}")
    


In [54]:
plot = True

select_labels_dict = {str(trial_id): select_labels(full_labels_stats, trial_avg_stats, trial_id=trial_id, threshold="q3", plot=plot) for trial_id in [0, 1, 2, 3, None]}

if plot:
    for name, selected_labels in select_labels_dict.items():
        print(f"\n\n\n\nTrial {name}:\n")
        print_selected_labels(selected_labels, label_encoder)
    





Trial 0:

Selected labels:
[  9  15  18  19  25  26  29  30  32  33  40  44  46  48  49  51  56  57
  59  60  61  62  64  69  70  72  88  91  96 104 106 111 113 117 119 136
 142 149 151 152 154 156 164 168 170 180]

Number of selected labels: 46
Selected Labels translated: 
['avocado' 'basil' 'beans' 'beef' 'bread' 'broccoli' 'butter' 'cabbage'
 'cardamom' 'carrot' 'cheese' 'chicken' 'chili' 'chocolate' 'cilantro'
 'cinnamon' 'coriander' 'corn' 'cream' 'cucumber' 'cumin' 'curry' 'egg'
 'flour' 'garam masala' 'garlic' 'lettuce' 'liquor' 'milk' 'oil' 'onion'
 'parsley' 'pasta' 'peas' 'pepper' 'rice' 'salt' 'shrimp' 'soy' 'spinach'
 'strawberries' 'sugar' 'tomato' 'turmeric' 'vanilla' 'yogurt']




Trial 1:

Selected labels:
[  9  15  18  19  24  25  26  29  30  33  40  44  46  48  49  51  56  57
  60  61  64  69  70  72  88  91  96  99 104 106 111 113 117 119 136 142
 149 151 152 154 156 164 168 170 178 180]

Number of selected labels: 46
Selected Labels translated: 
['avocado' 'basi

#### Comparison of selected labels

In [55]:
cmp_setdif_df = pd.DataFrame(index=select_labels_dict.keys(), columns=select_labels_dict.keys(), data=
                          [[len(np.setdiff1d(select_labels_dict[e1], select_labels_dict[e2])) for e2 in select_labels_dict.keys()] for e1 in select_labels_dict.keys()])
cmp_setxor_df = pd.DataFrame(index=select_labels_dict.keys(), columns=select_labels_dict.keys(), data=
                            [[len(np.setxor1d(select_labels_dict[e1], select_labels_dict[e2])) for e2 in select_labels_dict.keys()] for e1 in select_labels_dict.keys()])

In [57]:
cmp_setxor_df

Unnamed: 0,0,1,2,3,None
0.0,0,6,8,6,2
1.0,6,0,8,10,4
2.0,8,8,0,6,6
3.0,6,10,6,0,6
,2,4,6,6,0


In [58]:
common_intersection = np.intersect1d(np.intersect1d(select_labels_dict["0"], select_labels_dict["1"]), np.intersect1d(select_labels_dict["2"], select_labels_dict["3"]))

In [59]:
SELECTED_KEY: Literal["0", "1", "2", "3", "None", "common"] = "common"

if SELECTED_KEY == "common":
    selected_labels = common_intersection
else:
    selected_labels = select_labels_dict[SELECTED_KEY]
    
selected_labels_translated = label_encoder.decode_labels([selected_labels])[0]
print_selected_labels(selected_labels, label_encoder)

Selected labels:
[  9  15  18  19  25  29  30  33  40  44  46  48  49  51  56  57  60  61
  64  69  70  72  88  91  96 104 106 111 113 119 136 142 149 151 154 156
 164 168 170 180]

Number of selected labels: 40
Selected Labels translated: 
['avocado' 'basil' 'beans' 'beef' 'bread' 'butter' 'cabbage' 'carrot'
 'cheese' 'chicken' 'chili' 'chocolate' 'cilantro' 'cinnamon' 'coriander'
 'corn' 'cucumber' 'cumin' 'egg' 'flour' 'garam masala' 'garlic' 'lettuce'
 'liquor' 'milk' 'oil' 'onion' 'parsley' 'pasta' 'pepper' 'rice' 'salt'
 'shrimp' 'soy' 'strawberries' 'sugar' 'tomato' 'turmeric' 'vanilla'
 'yogurt']


In [None]:
trial_avg_stats.loc["max", selected_labels] # max f1 scores for selected labels

#### Plotting f1 scores for selected labels

In [61]:
f1_selected = full_f1.loc[idx[:, :], selected_labels]
f1_selected_pretty = f1_selected.copy()
f1_selected_pretty.columns = [label_encoder.decode_labels([label])[0][0] for label in f1_selected.columns]


n_cols = 6
n_rows = int(np.ceil((len(f1_selected.columns) - 1) / n_cols))

fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=f1_selected_pretty.columns[:-1], shared_yaxes=True, shared_xaxes=True).update_yaxes(range=[0, 1])
for i, column in enumerate(f1_selected_pretty.columns[:-1]):
    row = i // n_cols + 1
    col = i % n_cols + 1
    
    scatters_plot = f1_selected_pretty[column].groupby(level=0).apply(lambda trial_data: make_scatter(trial_data)).values
    for scatter_fig in scatters_plot:
        fig.add_trace(scatter_fig, row=row, col=col)

fig.update_layout(xaxis_title="Epoch", yaxis_title="F1 score", width=2000, height=1500).show(renderer="browser")

## Creation of new metadata file for training over selected labels

In [62]:
ingredients = selected_labels_translated
RECIPE_MIN_INGREDIENTS = 3

In [63]:
path = YUMMLY_PATH
train_recipes_df = pd.DataFrame(json.load(open(os.path.join(path, 'train', METADATA_FILENAME))))
val_recipes_df = pd.DataFrame(json.load(open(os.path.join(path, 'val', METADATA_FILENAME))))
test_recipes_df = pd.DataFrame(json.load(open(os.path.join(path, 'test', METADATA_FILENAME))))

In [64]:
normalize = False

train_count = train_recipes_df['ingredients_ok'].explode().value_counts(normalize=normalize)
val_count = val_recipes_df['ingredients_ok'].explode().value_counts(normalize=normalize)
test_count = test_recipes_df['ingredients_ok'].explode().value_counts(normalize=normalize)

train_count[ingredients].sort_values(ascending=False)

ingredients_ok
salt            33985
oil             30007
pepper          26638
garlic          24494
onion           23404
cheese          16302
sugar           14497
tomato          13333
chicken         12642
egg             12604
butter          11686
liquor          10318
flour            9743
milk             7737
chili            7678
soy              6795
corn             6762
rice             6721
cumin            6590
cilantro         6299
bread            6113
parsley          5384
beans            4873
basil            4621
carrot           4312
vanilla          3836
yogurt           3715
beef             3657
cinnamon         3631
coriander        3238
pasta            2922
turmeric         2841
avocado          2744
cucumber         2455
shrimp           2044
lettuce          1933
garam masala     1833
cabbage          1519
chocolate        1456
strawberries      578
Name: count, dtype: int64 




In [65]:
train_recipes_sel = train_recipes_df.copy()
train_recipes_sel['ingredients_ok'] = train_recipes_sel['ingredients_ok'].apply(lambda x: list(set(x).intersection(ingredients)))
train_recipes_sel = train_recipes_sel[train_recipes_sel['ingredients_ok'].apply(len) >= RECIPE_MIN_INGREDIENTS]
print("Before: ", len(train_recipes_df), "After: ", len(train_recipes_sel))

Before:  54724 After:  50866


In [66]:
val_recipes_sel = val_recipes_df.copy()
val_recipes_sel['ingredients_ok'] = val_recipes_sel['ingredients_ok'].apply(lambda x: list(set(x).intersection(ingredients)))
val_recipes_sel = val_recipes_sel[val_recipes_sel['ingredients_ok'].apply(len) >= RECIPE_MIN_INGREDIENTS]
print("Before: ", len(val_recipes_df), "After: ", len(val_recipes_sel))

Before:  5210 After:  4802


In [67]:
test_recipes_sel = test_recipes_df.copy()
test_recipes_sel['ingredients_ok'] = test_recipes_sel['ingredients_ok'].apply(lambda x: list(set(x).intersection(ingredients)))
test_recipes_sel = test_recipes_sel[test_recipes_sel['ingredients_ok'].apply(len) >= RECIPE_MIN_INGREDIENTS]
print("Before: ", len(test_recipes_df), "After: ", len(test_recipes_sel))

Before:  5212 After:  4854


In [68]:
train_recipes_sel.to_json(os.path.join(path, 'train', "sel_ing_2410_" + METADATA_FILENAME), orient="records", indent=4)
val_recipes_sel.to_json(os.path.join(path, 'val', "sel_ing_2410_" + METADATA_FILENAME), orient="records", indent=4)
test_recipes_sel.to_json(os.path.join(path, 'test', "sel_ing_2410_" + METADATA_FILENAME), orient="records", indent=4)

#### comparing number of ingredient for recipes over train data

In [69]:
pd.DataFrame({
    "selected_data": train_recipes_sel['ingredients_ok'].apply(len).describe(),
    "original_data": train_recipes_df['ingredients_ok'].apply(len).describe()
})

Unnamed: 0,selected_data,original_data
count,50866.0,54724.0
mean,6.791393,9.132739
std,2.606156,3.630393
min,3.0,3.0
25%,5.0,6.0
50%,6.0,9.0
75%,8.0,11.0
max,22.0,42.0
