In [1]:
import numpy as np
import pandas as pd

In [3]:
beer_types = pd.read_csv('../data/beer_labels_panel_train.csv', index_col=0)[['tasting_category_fine']]
beer_types

Unnamed: 0,tasting_category_fine
6,Amber
210,Amber
32,Amber
158,Amber
137,Amber
...,...
172,Wheat
197,Wheat
152,Wheat
22,Wheat


In [4]:
type_counts = beer_types.groupby(by=['tasting_category_fine']).value_counts().sort_values()
type_counts

tasting_category_fine
Faro                   2
Scotch                 3
Brut                   3
Flanders old brown     3
Brett/cofermented      4
West Flanders ale      5
Christmas              6
Fruitbeer              6
Low/No alcohol         6
Saison                 7
Dubbel                 7
Kriek                  7
Amber                  8
Pils/Lager             8
Wheat                  8
Brown                  8
Lambic                 8
Stout/Porter           9
Hoppy                 12
Strong ale            13
Tripel                20
Blond                 22
Name: count, dtype: int64

In [6]:
base = ['Blond','Tripel','Strong ale','Hoppy','Dubbel']
light = ['Wheat','Pils/Lager']
dark = ['Brown','Stout/Porter','Amber']
kriek = ['Kriek','Fruitbeer']
sour = ['Lambic','Flanders old brown','Brett/cofermented','Faro']
season = ['Saison','Christmas']
rest = ['Scotch','West Flanders ale','Low/No alcohol','Brut']

print(type_counts[base].sum())
print(type_counts[light].sum())
print(type_counts[dark].sum())
print(type_counts[kriek].sum())
print(type_counts[sour].sum())
print(type_counts[season].sum())
print(type_counts[rest].sum())

74
16
25
13
17
13
17


In [7]:
type_mapping = {
    'base': ['Blond','Tripel','Strong ale','Hoppy','Dubbel'],
    'light': ['Wheat','Pils/Lager'],
    'dark': ['Brown','Stout/Porter','Amber'],
    'kriek': ['Kriek','Fruitbeer'],
    'sour': ['Lambic','Flanders old brown','Brett/cofermented','Faro'],
    'season': ['Saison','Christmas'],
    'rest': ['Scotch','West Flanders ale','Low/No alcohol','Brut']}

inv_type_mapping = {}

# Iterate through each key-value pair in the original dictionary
for key, value_list in type_mapping.items():
    for item in value_list:
        inv_type_mapping[item] = key

# inv_type_mapping

In [37]:
beer_types['broad_types'] = beer_types['tasting_category_fine'].apply(lambda type: inv_type_mapping[type])
beer_types.reset_index(inplace=True)
beer_types

Unnamed: 0,index,tasting_category_fine,broad_types
0,6,Amber,dark
1,210,Amber,dark
2,32,Amber,dark
3,158,Amber,dark
4,137,Amber,dark
...,...,...,...
170,172,Wheat,light
171,197,Wheat,light
172,152,Wheat,light
173,22,Wheat,light


In [39]:
base_idx = beer_types[beer_types['broad_types'] == 'base'].index
light_idx = beer_types[beer_types['broad_types'] == 'light'].index
dark_idx = beer_types[beer_types['broad_types'] == 'dark'].index
kriek_idx = beer_types[beer_types['broad_types'] == 'kriek'].index
sour_idx = beer_types[beer_types['broad_types'] == 'sour'].index
season_idx = beer_types[beer_types['broad_types'] == 'season'].index
rest_idx = beer_types[beer_types['broad_types'] == 'rest'].index

indices = [base_idx, light_idx, dark_idx, kriek_idx, sour_idx, season_idx, rest_idx]

In [40]:
FOLDS = 5

val_indices = []
for i in range(FOLDS):
    val_indices.append([])
    for idx in indices:
        val_indices[-1].extend(idx[i::FOLDS])

val_indices[-1] = [val_indices[0].pop(0), val_indices[0].pop(0), val_indices[0].pop(0), *val_indices[-1]]
val_indices[-2] = [val_indices[1].pop(0), val_indices[1].pop(0), *val_indices[-2]]

In [42]:
val_indices

[[23,
  28,
  54,
  70,
  75,
  80,
  133,
  138,
  143,
  148,
  153,
  158,
  102,
  107,
  169,
  174,
  0,
  5,
  36,
  41,
  124,
  63,
  68,
  85,
  30,
  59,
  89,
  94,
  45,
  50,
  114,
  42,
  98,
  118,
  165],
 [19,
  24,
  29,
  55,
  71,
  76,
  129,
  134,
  139,
  144,
  149,
  154,
  159,
  103,
  108,
  170,
  1,
  6,
  37,
  120,
  125,
  64,
  81,
  86,
  31,
  60,
  90,
  95,
  46,
  110,
  115,
  43,
  99,
  119,
  166],
 [10,
  15,
  20,
  25,
  51,
  56,
  72,
  77,
  130,
  135,
  140,
  145,
  150,
  155,
  160,
  104,
  109,
  171,
  2,
  7,
  38,
  121,
  126,
  65,
  82,
  87,
  32,
  61,
  91,
  47,
  111,
  116,
  44,
  100,
  162],
 [9,
  14,
  11,
  16,
  21,
  26,
  52,
  57,
  73,
  78,
  131,
  136,
  141,
  146,
  151,
  156,
  161,
  105,
  167,
  172,
  3,
  34,
  39,
  122,
  127,
  66,
  83,
  33,
  62,
  92,
  48,
  112,
  96,
  101,
  163],
 [8,
  13,
  18,
  12,
  17,
  22,
  27,
  53,
  69,
  74,
  79,
  132,
  137,
  142,
  147,
  152,
  1

In [63]:
val_indices = np.sort([[23,28,54,70,75,80,133,138,143,148,153,158,102,107,169,174,0,5,36,41,124,63,68,85,30,59,89,94,45,50,114,42,98,118,165],
 [19,24,29,55,71,76,129,134,139,144,149,154,159,103,108,170,1,6,37,120,125,64,81,86,31,60,90,95,46,110,115,43,99,119,166],
 [10,15,20,25,51,56,72,77,130,135,140,145,150,155,160,104,109,171,2,7,38,121,126,65,82,87,32,61,91,47,111,116,44,100,162],
 [9,14,11,16,21,26,52,57,73,78,131,136,141,146,151,156,161,105,167,172,3,34,39,122,127,66,83,33,62,92,48,112,96,101,163],
 [8,13,18,12,17,22,27,53,69,74,79,132,137,142,147,152,157,106,168,173,164,35,40,123,128,67,84,58,88,93,49,113,97,117,4]
 ])

In [64]:
val_indices[-1]

array([  4,   8,  12,  13,  17,  18,  22,  27,  35,  40,  49,  53,  58,
        67,  69,  74,  79,  84,  88,  93,  97, 106, 113, 117, 123, 128,
       132, 137, 142, 147, 152, 157, 164, 168, 173])

In [71]:
train_indices = []
for i in range(FOLDS):
    train_indices.append(np.setdiff1d(np.arange(175), val_indices[i]))

train_indices = np.sort(train_indices)

In [72]:
train_indices

array([[  1,   2,   3,   4,   6,   7,   8,   9,  10,  11,  12,  13,  14,
         15,  16,  17,  18,  19,  20,  21,  22,  24,  25,  26,  27,  29,
         31,  32,  33,  34,  35,  37,  38,  39,  40,  43,  44,  46,  47,
         48,  49,  51,  52,  53,  55,  56,  57,  58,  60,  61,  62,  64,
         65,  66,  67,  69,  71,  72,  73,  74,  76,  77,  78,  79,  81,
         82,  83,  84,  86,  87,  88,  90,  91,  92,  93,  95,  96,  97,
         99, 100, 101, 103, 104, 105, 106, 108, 109, 110, 111, 112, 113,
        115, 116, 117, 119, 120, 121, 122, 123, 125, 126, 127, 128, 129,
        130, 131, 132, 134, 135, 136, 137, 139, 140, 141, 142, 144, 145,
        146, 147, 149, 150, 151, 152, 154, 155, 156, 157, 159, 160, 161,
        162, 163, 164, 166, 167, 168, 170, 171, 172, 173],
       [  0,   2,   3,   4,   5,   7,   8,   9,  10,  11,  12,  13,  14,
         15,  16,  17,  18,  20,  21,  22,  23,  25,  26,  27,  28,  30,
         32,  33,  34,  35,  36,  38,  39,  40,  41,  42,  44,  4