In [1]:
import numpy as np
import pandas as pd

import itertools

import torch
import torch.nn as nn
import torch.optim as optim

from hypll import nn as hnn
from hypll.tensors import TangentTensor
from hypll.optim import RiemannianAdam
from hypll.manifolds.poincare_ball import Curvature, PoincareBall

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [2]:
import util

In [3]:
TRAIN_FILE = '../data/strawberry_samples_big.csv'
VAL_FILE = '../data/strawberry_val_dataset.csv'

data = pd.read_csv(TRAIN_FILE, index_col=0)
val_data = pd.read_csv(VAL_FILE, index_col=0)

data

Unnamed: 0,OVERALL LIKING,TEXTURE LIKING,SWEETNESS INTENSITY,SOURNESS INTENSITY,STRAWBERRY FLAVOR INTENSITY,6915-15-7,77-92-9,50-99-7,57-48-7,57-50-1,...,7786-58-5,15111-96-3,706-14-9,10522-34-6,5881-17-4,128-37-0,40716-66-3,4887-30-3,5454-09-1,2305-05-7
0,0.307068,0.250174,0.276647,0.146214,0.305021,-2.120421,0.171179,0.759180,0.612070,0.314374,...,-0.271610,-0.443110,-0.544167,-0.463164,1.289539,-0.945803,0.122098,-0.127048,0.804970,-0.354773
1,0.307859,0.249023,0.276101,0.147151,0.306364,-2.119978,0.171282,0.759195,0.612046,0.314378,...,-0.276743,-0.334642,-0.545464,-0.568225,1.346308,-0.893509,0.117186,-0.129998,0.800679,-0.374048
2,0.306348,0.248756,0.277115,0.146568,0.306719,-2.120063,0.171283,0.759176,0.612056,0.314430,...,-0.275587,-0.241669,-0.544872,-0.614111,1.283168,-0.997929,0.118824,-0.128236,0.819822,-0.357370
3,0.307694,0.248227,0.277607,0.147135,0.305276,-2.120058,0.171249,0.759163,0.612060,0.314365,...,-0.273644,-0.312912,-0.545357,-0.562826,1.343249,-0.959232,0.118201,-0.126724,0.811123,-0.377366
4,0.308055,0.249070,0.276638,0.145692,0.305707,-2.120678,0.171089,0.759154,0.612069,0.314420,...,-0.273092,-0.211367,-0.545153,-0.589017,1.548846,-0.901830,0.120083,-0.126749,0.807910,-0.384669
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
53995,0.237395,0.249785,0.205025,0.225081,0.258800,-0.222374,0.884429,-0.834332,-0.803436,-0.625512,...,-0.368228,-0.326611,-0.146165,-0.098112,0.032811,-0.408101,0.055410,-0.453151,-0.452553,-0.445544
53996,0.240068,0.250256,0.204477,0.226804,0.257519,-0.221809,0.884286,-0.834295,-0.803437,-0.625499,...,-0.394705,-0.280702,-0.155369,-0.551456,-0.262542,-0.313059,0.052135,-0.480896,-0.518157,-0.462789
53997,0.238440,0.249851,0.204574,0.224211,0.257632,-0.222039,0.884583,-0.834322,-0.803455,-0.625512,...,-0.390017,-0.316657,-0.159865,-0.542431,-0.189396,-0.153318,0.055373,-0.470802,-0.472719,-0.478549
53998,0.238959,0.250376,0.202495,0.225202,0.258081,-0.222588,0.884425,-0.834312,-0.803431,-0.625546,...,-0.372647,-0.229130,-0.157428,-0.625398,-0.242219,-0.471044,0.053125,-0.445828,-0.497520,-0.467094


In [4]:
FEATURE_COLS = data.columns[5:]
LABEL_COLS = data.columns[[0]]
print(FEATURE_COLS)
print(LABEL_COLS)


def get_fold_indices(size, k):
    fold_size = size // k
    rest = size % k
    
    fold_sizes = [fold_size] * k
    
    for i in range(rest):
        fold_sizes[i] += 1

    indices = np.cumsum([fold_sizes])
    
    return list(zip(indices-np.array(fold_sizes), indices))


FOLDS = 3
NUM_SAMPLE_TYPES = len(val_data)
NUM_SAMPLES_PER_TYPE = len(data) // NUM_SAMPLE_TYPES

fold_nums = list(range(FOLDS))
[num*NUM_SAMPLE_TYPES for num in fold_nums]
[(num+1)*NUM_SAMPLE_TYPES for num in fold_nums]

FOLD_INDICIES = get_fold_indices(NUM_SAMPLE_TYPES, FOLDS)

# FOLD_INDICIES = list(zip([num*NUM_SAMPLE_TYPES//FOLDS for num in fold_nums], 
#                          [(num+1)*NUM_SAMPLE_TYPES//FOLDS for num in fold_nums]))

print(FOLD_INDICIES)

ALL_TRAIN_FEATURES = data[FEATURE_COLS].values
ALL_TRAIN_LABELS = data[LABEL_COLS].values
ALL_VAL_FEATURES = val_data[FEATURE_COLS].values
ALL_VAL_LABELS = val_data[LABEL_COLS].values

Index(['6915-15-7', '77-92-9', '50-99-7', '57-48-7', '57-50-1', 'SSC', 'pH',
       'TA', '75-85-4 ', '616-25-1 ', '1629-58-9 ', '96-22-0 ', '110-62-3 ',
       '1534-08-3 ', '105-37-3', '109-60-4 ', '623-42-7 ', '591-78-6 ',
       '108-10-1 ', '1576-87-0 ', '1576-86-9 ', '623-43-8 ', '71-41-0',
       '1576-95-0 ', '556-24-1 ', '589-38-8 ', '105-54-4 ', '66-25-1 ',
       '123-86-4 ', '624-24-8 ', '29674-47-3 ', '96-04-8 ', '638-11-9 ',
       '116-53-0 ', '7452-79-1 ', '6728-26-3 ', '928-95-0 ', '111-27-3 ',
       '123-92-2 ', '624-41-9 ', '110-43-0', '2432-51-1 ', '105-66-8 ',
       '539-82-2 ', '111-71-7 ', '628-63-7 ', '1191-16-8 ', '106-70-7 ',
       '55514-48-2 ', '110-93-0 ', '109-21-7 ', '123-66-0 ', '124-13-0 ',
       '142-92-7 ', '2497-18-9 ', '60415-61-4', '104-76-7 ', ' 2311-46-8 ',
       '109-19-3 ', '2548-87-0 ', '540-18-1 ', '4077-47-8 ', '20664-46-4',
       '821-55-6 ', '5989-33-3 ', '78-70-6 ', '124-19-6 ', '103-09-3',
       '140-11-4 ', '2639-63-6 ', '53398-8

In [5]:
# Define custom PyTorch dataset
class CustomDataset(Dataset):
    def __init__(self, features, labels):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

<h1> Hyperbolic </h1>

In [6]:
# Define your MLP model
class HYP_MLP(nn.Module):
    def __init__(self, input_size, output_size, layer_size, num_hidden_layers, manifold):
        super(HYP_MLP, self).__init__()
        torch.manual_seed(42)
        self.fc_in = hnn.HLinear(input_size, layer_size, manifold=manifold)
        self.relu = hnn.HReLU(manifold=manifold)
        self.hidden_fcs = nn.ModuleList([hnn.HLinear(layer_size, layer_size, manifold=manifold) for _ in range(num_hidden_layers)])
        self.fc_out = hnn.HLinear(layer_size, output_size, manifold=manifold)

    def forward(self, x):
        x = self.fc_in(x)
        x = self.relu(x)
        for fc in self.hidden_fcs:
            x = fc(x)
            x = self.relu(x)
        x = self.fc_out(x)

        return x


# Define training function
def hyp_train_model(model, train_loader, criterion, optimizer, manifold, device):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        tangents = TangentTensor(data=inputs, man_dim=-1, manifold=manifold)
        manifold_inputs = manifold.expmap(tangents)

        outputs = model(manifold_inputs)

        loss = criterion(outputs.tensor, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(train_loader.dataset)

<h1> EUCLIDEAN </h1>

In [7]:
# Define your MLP model
class EUC_MLP(nn.Module):
    def __init__(self, input_size, output_size, layer_size, num_hidden_layers):
        super(EUC_MLP, self).__init__()
        torch.manual_seed(42)
        self.fc_in = nn.Linear(input_size, layer_size)
        self.relu = nn.ReLU()
        self.hidden_fcs = nn.ModuleList([nn.Linear(layer_size, layer_size) for _ in range(num_hidden_layers)])
        self.fc_out = nn.Linear(layer_size, output_size)

    def forward(self, x):
        x = self.fc_in(x)
        x = self.relu(x)
        for fc in self.hidden_fcs:
            x = fc(x)
            x = self.relu(x)
        x = self.fc_out(x)

        return x

# Define training function
def euc_train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(train_loader.dataset)

In [10]:
# param_grid = {
#     'model_type': ['hyp'],
#     'num_hidden_layers': [0,2,8,12,14,16,18,20],
#     'layer_size': [2,8,16,32,48,64,72,80,96,128,256,512],
#     'lr': [0.018,0.02,0.022],
#     'weight_decay': [0.001],
#     'batch_size': [1024],
#     'epochs': [50],
#     'curvature': [-1]
# }

# param_grid = {
#     'model_type': ['euc'],
#     'num_hidden_layers': [0,1,2,3,4,5,8,12],
#     'layer_size': [2,8,16,64,128,192,256,320,448,480,512,544,576],
#     'lr': [0.003,0.004,0.005],
#     'weight_decay': [0.001],
#     'batch_size': [1024],
#     'epochs': [50],
#     'curvature': [-1]
# }


param_grid = {
    'model_type': ['euc'],
    'num_hidden_layers': [0,1,2,4],
    'layer_size': [2,4,8,16,32,64,128,192,256],
    'lr': [0.001,0.002,0.003,0.004,0.005],
    'weight_decay': [0.0005,0.001,0.002,0.005],
    'batch_size': [1024],
    'epochs': [20],
    'curvature': [-1]
}

# param_grid = {
#     'model_type': ['hyp'],
#     'num_hidden_layers': [0,1,2,4,8,16],
#     'layer_size': [2,4,8,16,32,64,128,192,256,448,512],
#     'lr': [0.003,0.004,0.005],
#     'weight_decay': [0.001],
#     'batch_size': [1024],
#     'epochs': [100],
#     'curvature': [-1]
# }


param_combinations = list(itertools.product(*param_grid.values()))
len(param_combinations)

720

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

param_eval_stats = []

for i, params in enumerate(param_combinations):
    print(f'----- Combination {i} -----')
    print(*zip(param_grid.keys(), params))
    model_type, num_hidden_layers, layer_size, lr, weight_decay, batch_size, epochs, curvature = params
    for fold, (fold_start, fold_stop) in enumerate(FOLD_INDICIES):
        print(f'Fold {fold}')

        train_features = ALL_TRAIN_FEATURES[fold_start*NUM_SAMPLES_PER_TYPE:fold_stop*NUM_SAMPLES_PER_TYPE]
        train_labels   =   ALL_TRAIN_LABELS[fold_start*NUM_SAMPLES_PER_TYPE:fold_stop*NUM_SAMPLES_PER_TYPE]
        val_features   = ALL_VAL_FEATURES[fold_start:fold_stop]
        val_labels     =   ALL_VAL_LABELS[fold_start:fold_stop]

        train_dataset = CustomDataset(train_features, train_labels)
        val_dataset = CustomDataset(val_features, val_labels)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)

        if model_type == 'hyp':
            manifold = PoincareBall(c=Curvature(curvature))
        elif model_type == 'euc':
            manifold = None

        if model_type == 'hyp':
            model = HYP_MLP(input_size=len(FEATURE_COLS), 
                            output_size=len(LABEL_COLS), 
                            layer_size=layer_size, 
                            num_hidden_layers=num_hidden_layers, 
                            manifold=manifold).to(device)
        elif model_type == 'euc':
            model = EUC_MLP(input_size=len(FEATURE_COLS), 
                            output_size=len(LABEL_COLS), 
                            layer_size=layer_size, 
                            num_hidden_layers=num_hidden_layers).to(device)
            
        criterion = nn.MSELoss()

        if model_type == 'hyp':
            optimizer = RiemannianAdam(model.parameters(), lr=lr, weight_decay=weight_decay)
        elif model_type == 'euc':
            optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

        eval_stats = {'loss': {'train': [], 'val': []}, 'mae': {'train': [], 'val': []}}

        for epoch in range(epochs):
            if model_type == 'hyp':
                eval_stats['loss']['train'].append(hyp_train_model(model, train_loader, criterion, optimizer, manifold, device))
                eval_stats['loss']['val'].append(util.h_evaluate_loss(model, val_loader, criterion, manifold, device))

                eval_stats['mae']['train'].append(util.h_evaluate_mae(model, train_loader, manifold, device))
                eval_stats['mae']['val'].append(util.h_evaluate_mae(model, val_loader, manifold, device))
            elif model_type == 'euc':
                eval_stats['loss']['train'].append(euc_train_model(model, train_loader, criterion, optimizer, device))
                eval_stats['loss']['val'].append(util.evaluate_loss(model, val_loader, criterion, device))

                eval_stats['mae']['train'].append(util.evaluate_mae(model, train_loader, device))
                eval_stats['mae']['val'].append(util.evaluate_mae(model, val_loader, device))

        print(eval_stats['mae']['val'])
        param_eval_stats.append(eval_stats)

----- Combination 0 -----
('model_type', 'euc') ('num_hidden_layers', 0) ('layer_size', 2) ('lr', 0.001) ('weight_decay', 0.0005) ('batch_size', 1024) ('epochs', 20) ('curvature', -1)
Fold 0
[0.052228004, 0.036528688, 0.026788082, 0.021803955, 0.019764125, 0.018897802, 0.018077768, 0.017363133, 0.017352888, 0.018132111, 0.018634403, 0.019206855, 0.019558374, 0.019793736, 0.019918274, 0.020041797, 0.019879477, 0.019854344, 0.019868592, 0.019907787]
Fold 1
[0.039654113, 0.021625025, 0.015729602, 0.0112873735, 0.00956693, 0.008587331, 0.008068332, 0.0076405373, 0.0073756743, 0.0073051024, 0.0069012586, 0.006807058, 0.0066419262, 0.006491324, 0.0062474664, 0.006282361, 0.006158597, 0.0060064495, 0.005933149, 0.0059978236]
Fold 2
[0.05607973, 0.03133555, 0.022850363, 0.015363898, 0.012160929, 0.010986801, 0.009812906, 0.009027038, 0.0080780275, 0.007727891, 0.007406407, 0.0071703666, 0.007194487, 0.006843256, 0.0070054065, 0.007098525, 0.007242089, 0.007304793, 0.0073233675, 0.0075343866]
-

0.0260: HYP: 25X, EUC: 5X

In [20]:
all = np.array([
# ('model_type', 'hyp') ('num_hidden_layers', 8) ('layer_size', 64) ('lr', 0.003) ('weight_decay', 0.001) ('batch_size', 1024) ('epochs', 100) ('curvature', -1)
[0.04677703, 0.018618498, 0.010535374, 0.007100117, 0.0061865896, 0.005973243, 0.005628023, 0.0054762075, 0.0051581934, 0.0049592247, 0.004717715, 0.0048807836, 0.0048701987, 0.004797911, 0.0048144087, 0.004597679, 0.0046878597, 0.004398643, 0.004355534, 0.004543802, 0.0042595062, 0.004198495, 0.004242834, 0.0040632617, 0.0039877947, 0.0039501158, 0.0038352327, 0.003732875, 0.0035546804, 0.0033946757, 0.0035182545, 0.003295361, 0.0031672353, 0.0031114172, 0.0031201865, 0.002979276, 0.0030205713, 0.0029632838, 0.0029699951, 0.0029723654, 0.0028702104, 0.0028381662, 0.002702257, 0.0027241888, 0.002672461, 0.0025620228, 0.0025798222, 0.0024672789, 0.0024415264, 0.0023717615, 0.0022829175, 0.0022036498, 0.0021737209, 0.002087931, 0.0021041061, 0.0021189535, 0.0021432796, 0.0021690626, 0.0021654086, 0.0021407423, 0.00219814, 0.0022196148, 0.0022234428, 0.002225408, 0.0021610484, 0.002202836, 0.0021669557, 0.0021621054, 0.0021623506, 0.0022410369, 0.0022139987, 0.0022512185, 0.0022102809, 0.0021926379, 0.0020749792, 0.002008866, 0.0019812137, 0.0019312385, 0.0019574927, 0.0019083893, 0.0019208433, 0.0018422902, 0.0018084554, 0.0017081855, 0.0017243864, 0.0017260313, 0.0018034362, 0.0018858438, 0.0019071823, 0.0017769518, 0.0017347211, 0.0017323138, 0.0018189467, 0.0017853106, 0.0018567336, 0.0018203756, 0.0018029072, 0.0017080746, 0.0017808742, 0.0017206403]
,[0.08356461, 0.02149906, 0.009283004, 0.0074421708, 0.0072741425, 0.007290739, 0.0074288794, 0.007281998, 0.0074881464, 0.006869893, 0.0066665956, 0.006451817, 0.0062157842, 0.0060888627, 0.0056805527, 0.005885435, 0.0057026925, 0.0051901545, 0.004747509, 0.0047931727, 0.004572518, 0.004627298, 0.004624834, 0.004273879, 0.0041643754, 0.0040027276, 0.004059454, 0.0040163463, 0.0039096302, 0.003814121, 0.0037092657, 0.0037919905, 0.0037337616, 0.0036378924, 0.0036664184, 0.0037279949, 0.0036499666, 0.0035580455, 0.0034080925, 0.003351772, 0.0031427923, 0.0030609088, 0.0031783464, 0.003083412, 0.002883195, 0.0029096412, 0.0027006227, 0.0028117613, 0.002647775, 0.0025793314, 0.0025287932, 0.0024020497, 0.0024908888, 0.0023787708, 0.0025924344, 0.002451068, 0.002384304, 0.0023615866, 0.0023536251, 0.002247273, 0.0022275085, 0.0021630973, 0.002106077, 0.001953299, 0.001982722, 0.0019366617, 0.001873143, 0.0018741571, 0.0017218036, 0.0017361608, 0.0016202355, 0.0016474931, 0.0016392278, 0.001632452, 0.0015806324, 0.001591728, 0.001525099, 0.0015206983, 0.0015124786, 0.0014465874, 0.0015253044, 0.0014593378, 0.0014586672, 0.0013879562, 0.001379052, 0.0013398842, 0.0014744443, 0.0015375358, 0.0015431833, 0.001571589, 0.0016104695, 0.0016388115, 0.0016874993, 0.0016886741, 0.001779803, 0.0017600341, 0.0018736249, 0.001989142, 0.0017456992, 0.0016758102]
,[0.035657298, 0.013526415, 0.0050405273, 0.0041267094, 0.003996879, 0.0037419365, 0.0033674075, 0.0034828344, 0.003355518, 0.0032310253, 0.0032764275, 0.0034392162, 0.0032533307, 0.0031282066, 0.0031062795, 0.0031782603, 0.0029336463, 0.0028494836, 0.0026871471, 0.0025413367, 0.002416961, 0.0023716008, 0.0022159277, 0.0022319779, 0.0021736945, 0.002152511, 0.0021494115, 0.0019580266, 0.0018720677, 0.0018475321, 0.0017259626, 0.0016356119, 0.0016006439, 0.0015109902, 0.0015459234, 0.0015493283, 0.0015367634, 0.001388635, 0.0013482115, 0.0011403047, 0.001290328, 0.0011458951, 0.0012005816, 0.0012154207, 0.0013124637, 0.0013786049, 0.0014856202, 0.0015458142, 0.00168654, 0.0017209831, 0.0017425683, 0.0016535082, 0.0016635566, 0.0015650905, 0.0016231057, 0.001589497, 0.0015655516, 0.0015309735, 0.0016550529, 0.0016222075, 0.0017240453, 0.0016443738, 0.0015825289, 0.0014039502, 0.0013974574, 0.0014316035, 0.0014572516, 0.00147496, 0.0015225692, 0.0015932992, 0.0016966859, 0.0018569571, 0.0020283568, 0.0021973301, 0.0023844365, 0.0024331543, 0.0023413813, 0.0022235578, 0.0022187962, 0.0022710536, 0.002271057, 0.0022516176, 0.0021912812, 0.0021257268, 0.0019732234, 0.0018844373, 0.0018250106, 0.0017587476, 0.0017236405, 0.0016774742, 0.0016837219, 0.0017393811, 0.0017633521, 0.0017347435, 0.0017200741, 0.0017682835, 0.0017708979, 0.0017415774, 0.0017499211, 0.0017585887]
# ('model_type', 'hyp') ('num_hidden_layers', 8) ('layer_size', 64) ('lr', 0.004) ('weight_decay', 0.001) ('batch_size', 1024) ('epochs', 100) ('curvature', -1)
,[0.031013899, 0.02235403, 0.00593466, 0.0037623958, 0.004337199, 0.0038571656, 0.0037834223, 0.0040920195, 0.0043135085, 0.0043964814, 0.0043538543, 0.0042198356, 0.0040792874, 0.003790819, 0.0038213, 0.0036683986, 0.0036721295, 0.0036005909, 0.0034336853, 0.0035134298, 0.0034816281, 0.0034001777, 0.003411664, 0.0032992298, 0.0033469605, 0.0032564402, 0.0032552024, 0.00323094, 0.003178571, 0.0031005938, 0.0031099957, 0.0030163585, 0.002969942, 0.0028397557, 0.0028705497, 0.002783205, 0.0028483504, 0.002855969, 0.0028419793, 0.0028365387, 0.0027772875, 0.002745851, 0.0026775391, 0.0026913648, 0.002615201, 0.0025401355, 0.0024861312, 0.0024552974, 0.002490513, 0.0025193403, 0.0024751408, 0.0024601072, 0.0024824839, 0.0024479057, 0.0023843886, 0.002358016, 0.0023999389, 0.002397606, 0.002354664, 0.0023475562, 0.0023433103, 0.0023199006, 0.0022696415, 0.002230319, 0.0021657208, 0.002198216, 0.002140387, 0.0021040787, 0.0021030786, 0.002107401, 0.0020822345, 0.0020866576, 0.0020192415, 0.0022047237, 0.002295485, 0.0023640122, 0.0023246722, 0.0023394201, 0.0023251167, 0.002248658, 0.0022380177, 0.0022697432, 0.0022665104, 0.0022792262, 0.0021600623, 0.0020368223, 0.0019289511, 0.0019656958, 0.0020011058, 0.0018928929, 0.0018603918, 0.001917344, 0.0018502218, 0.0018294818, 0.0019071044, 0.00185777, 0.0018927678, 0.0018239957, 0.0018351749, 0.0018468855]
,[0.07274847, 0.022060897, 0.008362493, 0.004456916, 0.004913576, 0.00433994, 0.004481854, 0.004281972, 0.004175639, 0.0038874869, 0.0037494062, 0.0037523375, 0.0037024394, 0.003541938, 0.0034693745, 0.0034053847, 0.0034031314, 0.0033760022, 0.003183797, 0.003036538, 0.0029748024, 0.0028671871, 0.0029697353, 0.0028782578, 0.0028974365, 0.0029211747, 0.0028862944, 0.0028806021, 0.0029309574, 0.0027892804, 0.0027597446, 0.0027940844, 0.002664407, 0.002614172, 0.0026435927, 0.0024867877, 0.0026036492, 0.0025625858, 0.002509527, 0.0024927738, 0.0024489304, 0.002348927, 0.002363306, 0.0022871038, 0.002224481, 0.0022956124, 0.0022537385, 0.002356694, 0.002288887, 0.0022602975, 0.0022574281, 0.0022352943, 0.0022686596, 0.0022380443, 0.00230252, 0.0023189362, 0.002298313, 0.0022746043, 0.0021567468, 0.0021178615, 0.0021579554, 0.002069426, 0.0021059206, 0.0020492051, 0.002115873, 0.002058049, 0.0020591477, 0.002079655, 0.001977591, 0.0019595465, 0.001922042, 0.0019307642, 0.00197323, 0.0020493949, 0.0022186015, 0.0020194948, 0.0018644365, 0.0019459551, 0.0019145558, 0.0018497804, 0.0018704914, 0.0019013749, 0.0019591649, 0.0019729105, 0.002047422, 0.0018220884, 0.0015648422, 0.0016015023, 0.0014657171, 0.0013248292, 0.0012016818, 0.001197622, 0.0012287886, 0.0012468142, 0.0012945277, 0.0012356051, 0.0013643213, 0.0014216139, 0.0015320422, 0.0015380109]
,[0.037858136, 0.014453037, 0.0075853104, 0.0055604526, 0.005054415, 0.0050522313, 0.004805681, 0.004832993, 0.004949738, 0.0044152383, 0.004371496, 0.004262622, 0.0038097103, 0.0037040014, 0.0034562398, 0.003272583, 0.0030483182, 0.0030017644, 0.0029582407, 0.0027945016, 0.0026755524, 0.002614138, 0.002397822, 0.002315431, 0.002209277, 0.0019910643, 0.0018690726, 0.0017958606, 0.001681745, 0.0015390044, 0.0014238043, 0.0013490112, 0.0013202801, 0.0011244946, 0.0011578095, 0.0010538979, 0.0010659802, 0.0011508175, 0.0011365977, 0.0012360737, 0.0012422842, 0.0012664042, 0.0013195731, 0.0013609023, 0.0013837474, 0.0014284013, 0.001358216, 0.001382283, 0.0014195136, 0.0013978125, 0.0014575555, 0.0015351069, 0.001559231, 0.0016053435, 0.0016753698, 0.0017248815, 0.0017215535, 0.0015243093, 0.0014358842, 0.0015581367, 0.0018187803, 0.0020003791, 0.002158704, 0.0022960296, 0.0024137076, 0.0024782212, 0.0023417026, 0.0021738997, 0.00216231, 0.0022001166, 0.0023434716, 0.002485585, 0.0026076569, 0.0026542991, 0.0026957716, 0.0024149832, 0.002207746, 0.002406371, 0.0025015895, 0.002459979, 0.0024244753, 0.002386151, 0.002354756, 0.0022548577, 0.0021926586, 0.0021449956, 0.002086555, 0.002056827, 0.001962714, 0.001949148, 0.0019469013, 0.0018671858, 0.0018071309, 0.0017952679, 0.001812251, 0.0017222762, 0.0017419507, 0.0016605167, 0.0015584694, 0.0015166054]
# ('model_type', 'hyp') ('num_hidden_layers', 8) ('layer_size', 64) ('lr', 0.005) ('weight_decay', 0.001) ('batch_size', 1024) ('epochs', 100) ('curvature', -1)
,[0.05707022, 0.0170104, 0.008115383, 0.004587623, 0.0025968435, 0.0026396157, 0.0026909485, 0.0025240656, 0.002522952, 0.0024107487, 0.0025100186, 0.002545749, 0.0024475663, 0.0022888582, 0.0023508104, 0.002363306, 0.0023683251, 0.0023821974, 0.0022455778, 0.002243681, 0.0022060524, 0.0021830285, 0.002211064, 0.0021680726, 0.0021713881, 0.0021793535, 0.002181069, 0.0021573603, 0.002196759, 0.0021679716, 0.0021831219, 0.002171822, 0.0021531987, 0.0020969277, 0.0020918557, 0.002007929, 0.0020392637, 0.00205678, 0.0020832336, 0.0021162108, 0.0021355997, 0.0021412158, 0.002212723, 0.002307722, 0.0022067958, 0.0020735478, 0.002095206, 0.0020565863, 0.002087583, 0.0022327336, 0.0021548206, 0.0021822643, 0.0022472234, 0.0022595774, 0.0022615641, 0.002283001, 0.002319225, 0.0023426944, 0.002355939, 0.0023820375, 0.0024054828, 0.0024176687, 0.0024138407, 0.002477722, 0.0025124443, 0.0023480323, 0.002115171, 0.002154097, 0.0022996392, 0.002379178, 0.0023232806, 0.0023239686, 0.0023275176, 0.0022921844, 0.0022661702, 0.0023581642, 0.0024637033, 0.0024899824, 0.0025968626, 0.0025113688, 0.0023333041, 0.002191551, 0.0023099566, 0.002336536, 0.0022529918, 0.0023231474, 0.002320965, 0.0023862927, 0.0024389368, 0.0023075044, 0.0021648183, 0.0021647364, 0.0020054579, 0.0020049943, 0.0019780977, 0.002073032, 0.0022498104, 0.002351491, 0.0022593373, 0.002206809]
,[0.046943802, 0.024497153, 0.011660192, 0.0056382366, 0.0040519964, 0.0034491718, 0.003141794, 0.0031130116, 0.0029810353, 0.0027564382, 0.0026348638, 0.002740504, 0.002917288, 0.0030617814, 0.0030296734, 0.0029660587, 0.0028732668, 0.0027873788, 0.002728044, 0.0026660496, 0.0026219685, 0.0024877554, 0.0025011906, 0.0024587554, 0.0025010689, 0.002347788, 0.0022573753, 0.0021645054, 0.0021019902, 0.0020962516, 0.0021438184, 0.0021075804, 0.0020195013, 0.0020002467, 0.0019519478, 0.0019961712, 0.0020615433, 0.0020004297, 0.0018623463, 0.0018568444, 0.0018760446, 0.0017476968, 0.0017384432, 0.0017304594, 0.0016911469, 0.0018207232, 0.0018268765, 0.0019256688, 0.0018890823, 0.0018728102, 0.0019035629, 0.0018463606, 0.0017558667, 0.0016439491, 0.0016536249, 0.0016247233, 0.0016806523, 0.0017595084, 0.0016885069, 0.0017126749, 0.0017488839, 0.0017300711, 0.0017909184, 0.0017612875, 0.0018308767, 0.0017760553, 0.0018184459, 0.0018953963, 0.0017221346, 0.0016961627, 0.0016111359, 0.0016595381, 0.0016767499, 0.0016764245, 0.0016813973, 0.0016683596, 0.0017231471, 0.0017703524, 0.0017000553, 0.0017815033, 0.0018838164, 0.0018237995, 0.0018061367, 0.0016387121, 0.0016859695, 0.0017533087, 0.0018315489, 0.0018374539, 0.0018914664, 0.0019091757, 0.0019105631, 0.0019185924, 0.0019368422, 0.0018704683, 0.0019706604, 0.0019415559, 0.0019532186, 0.0019783361, 0.0019218773, 0.002105275]
,[0.028578699, 0.011534469, 0.005256858, 0.0028151688, 0.0029381546, 0.0033411474, 0.003355095, 0.0030962261, 0.0031844517, 0.0030277984, 0.0028141248, 0.0027114716, 0.0025988694, 0.0024208028, 0.0023375999, 0.0020764347, 0.0019793957, 0.0019932364, 0.0017645326, 0.0017407826, 0.0015263201, 0.00144431, 0.0014218738, 0.0013341936, 0.0013184233, 0.0012902708, 0.0014604868, 0.0014432304, 0.001466299, 0.0014059626, 0.001429975, 0.0014026338, 0.0014984956, 0.0015376045, 0.001444973, 0.0015801763, 0.0016526066, 0.0017747937, 0.0018245719, 0.0019781606, 0.0020979776, 0.0021696133, 0.00221041, 0.0022824018, 0.0023895742, 0.0024278637, 0.002273649, 0.0023708567, 0.0023027253, 0.0022793296, 0.0023096155, 0.0024171066, 0.002535315, 0.0025916505, 0.0024964793, 0.0024224652, 0.0023092744, 0.0022559587, 0.0024639037, 0.0026324573, 0.002739975, 0.0028461318, 0.0028777933, 0.002911162, 0.002922025, 0.0029117067, 0.0028841437, 0.0027664006, 0.0026435878, 0.0025119674, 0.0023905395, 0.0022796085, 0.0021886851, 0.002148721, 0.0020452647, 0.0020099406, 0.0018366823, 0.0018592658, 0.0017544056, 0.0016451196, 0.0015796497, 0.0015978292, 0.001575148, 0.0015180442, 0.0015235064, 0.0014876773, 0.0014646508, 0.0014870622, 0.0014963025, 0.001544483, 0.0015878355, 0.0016049386, 0.0016075703, 0.0016469144, 0.0016509981, 0.0015604885, 0.0015729567, 0.001514852, 0.0014999774, 0.0015628967]
# ('model_type', 'euc') ('num_hidden_layers', 8) ('layer_size', 64) ('lr', 0.003) ('weight_decay', 0.001) ('batch_size', 1024) ('epochs', 100) ('curvature', -1)
,[0.057034075, 0.042183146, 0.042492047, 0.043101948, 0.043212093, 0.043672048, 0.042732336, 0.042827405, 0.04249148, 0.04288186, 0.042814545, 0.04242559, 0.04244757, 0.04273637, 0.042779993, 0.04224177, 0.04368676, 0.04237274, 0.043129884, 0.04290766, 0.04264776, 0.042837515, 0.042466115, 0.04233065, 0.042911362, 0.042853657, 0.042630523, 0.043111667, 0.043013494, 0.04238677, 0.04248667, 0.04269139, 0.042510845, 0.042724606, 0.043342315, 0.04262378, 0.04286913, 0.042817846, 0.042824198, 0.04255264, 0.042880613, 0.04232632, 0.042875335, 0.04322594, 0.042308982, 0.043304037, 0.04295728, 0.0425272, 0.04259866, 0.042782906, 0.04260964, 0.04257339, 0.042403057, 0.042745933, 0.042955115, 0.04276612, 0.04275702, 0.04222646, 0.043044247, 0.042476293, 0.04316179, 0.04286191, 0.042462055, 0.042598087, 0.04232906, 0.042883433, 0.04246187, 0.04315008, 0.042382125, 0.04270847, 0.04303641, 0.042422436, 0.04253439, 0.0426361, 0.042917658, 0.04288116, 0.042668343, 0.04310403, 0.04246805, 0.04284595, 0.042902842, 0.042970523, 0.04291207, 0.042830862, 0.04313605, 0.042183332, 0.04271035, 0.04287268, 0.042397726, 0.042648192, 0.042418193, 0.043142583, 0.042488024, 0.042616006, 0.042989183, 0.04247412, 0.042602938, 0.042595237, 0.042857423, 0.04234736]
,[0.06535492, 0.056407135, 0.05628889, 0.05628889, 0.056386776, 0.056301195, 0.05669786, 0.05640401, 0.056587536, 0.056659963, 0.05629668, 0.056358144, 0.057430916, 0.056711104, 0.05628889, 0.056586996, 0.05668582, 0.056597922, 0.056587026, 0.056368157, 0.05650528, 0.05683238, 0.056870684, 0.05672684, 0.056695912, 0.056487124, 0.05640798, 0.056797095, 0.056750614, 0.056559086, 0.056709077, 0.056489985, 0.05673807, 0.056477707, 0.056533113, 0.05642627, 0.056332946, 0.05640675, 0.056568902, 0.056629114, 0.056303952, 0.056465294, 0.05685328, 0.05628889, 0.056593962, 0.05631251, 0.05628889, 0.05640495, 0.056906804, 0.05700279, 0.056794584, 0.05675459, 0.056707423, 0.05660575, 0.05664489, 0.05628889, 0.056426447, 0.056971423, 0.056725, 0.056702904, 0.056405, 0.056742456, 0.056618385, 0.056352403, 0.057195187, 0.05662726, 0.056861468, 0.05653499, 0.056743994, 0.05672799, 0.056507334, 0.056299515, 0.056736734, 0.05705118, 0.056418046, 0.05636263, 0.056504183, 0.05638883, 0.056603447, 0.056494884, 0.056790326, 0.05636354, 0.056304973, 0.05655394, 0.057024214, 0.05677517, 0.05657461, 0.05662792, 0.05700229, 0.056571867, 0.056749623, 0.05628889, 0.056592345, 0.056627616, 0.05662281, 0.056637414, 0.05642977, 0.05628889, 0.056347992, 0.056716032]
,[0.045954898, 0.031272106, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.030496687, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.030450003, 0.030450003, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.030509956]
# ('model_type', 'euc') ('num_hidden_layers', 8) ('layer_size', 64) ('lr', 0.004) ('weight_decay', 0.001) ('batch_size', 1024) ('epochs', 100) ('curvature', -1)
,[0.05812368, 0.045012254, 0.043070365, 0.043276757, 0.043247964, 0.043752316, 0.042802755, 0.042855315, 0.04259974, 0.042961195, 0.043386456, 0.042825177, 0.042407025, 0.042553734, 0.042496994, 0.04295875, 0.042484645, 0.042585913, 0.042693056, 0.043327104, 0.04264421, 0.042764947, 0.04276885, 0.042249605, 0.04239179, 0.04312956, 0.043305896, 0.04257507, 0.042462517, 0.042531427, 0.04294374, 0.042483594, 0.04276779, 0.042589396, 0.0429072, 0.042983506, 0.042946626, 0.042795066, 0.04310907, 0.042497557, 0.042755708, 0.042595968, 0.042663246, 0.042852543, 0.042658195, 0.042873994, 0.043122336, 0.04293333, 0.04248689, 0.04252468, 0.042571805, 0.042495206, 0.04246741, 0.04249859, 0.042883232, 0.042866167, 0.04286312, 0.04228338, 0.04281269, 0.042454034, 0.04328265, 0.042987105, 0.042484593, 0.04253532, 0.042294107, 0.04267842, 0.04240732, 0.04308173, 0.042497054, 0.04242672, 0.04299976, 0.04253227, 0.042560577, 0.04260447, 0.04286692, 0.04288208, 0.042709455, 0.043076266, 0.04249184, 0.042861745, 0.04281758, 0.043046314, 0.0429652, 0.042801164, 0.04322345, 0.0422341, 0.042928707, 0.042750552, 0.04237203, 0.04259655, 0.04240102, 0.04309641, 0.0425034, 0.04260268, 0.04294383, 0.04247589, 0.042622026, 0.042612854, 0.042899407, 0.042378947]
,[0.06655238, 0.058772463, 0.05644454, 0.056444246, 0.0564312, 0.05634366, 0.056716457, 0.056411162, 0.05653302, 0.056400724, 0.05628889, 0.05628889, 0.056941435, 0.056685396, 0.056355998, 0.05644907, 0.056671925, 0.05667362, 0.056688327, 0.05636664, 0.056417756, 0.056768376, 0.056951746, 0.05680992, 0.05674456, 0.05650602, 0.056439202, 0.056618545, 0.05691008, 0.056666587, 0.056649115, 0.05644544, 0.056598917, 0.05652429, 0.05650246, 0.056459136, 0.0563215, 0.05637547, 0.056519847, 0.056578398, 0.056341514, 0.05635133, 0.05679749, 0.05637375, 0.056474566, 0.05634358, 0.05628889, 0.05635414, 0.05679397, 0.056945667, 0.05683388, 0.05672933, 0.056670945, 0.056577034, 0.056668784, 0.05628889, 0.0563921, 0.056946475, 0.056706626, 0.056614306, 0.056381635, 0.05675108, 0.056605645, 0.056347422, 0.05721636, 0.0566102, 0.05684769, 0.056512926, 0.056766458, 0.056680996, 0.056456342, 0.05628889, 0.056740813, 0.05707029, 0.05628889, 0.056469876, 0.056556515, 0.056376737, 0.056594543, 0.05650654, 0.056791015, 0.056350484, 0.05628889, 0.05649565, 0.05699598, 0.056756854, 0.05646079, 0.056575842, 0.05695786, 0.056566283, 0.056747094, 0.05628909, 0.05658934, 0.05662112, 0.056620784, 0.056655806, 0.05643658, 0.05628889, 0.056312494, 0.056691036]
,[0.048744928, 0.032064382, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.030450003, 0.030450003, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.030482935]
# ('model_type', 'euc') ('num_hidden_layers', 8) ('layer_size', 64) ('lr', 0.005) ('weight_decay', 0.001) ('batch_size', 1024) ('epochs', 100) ('curvature', -1)
,[0.045934983, 0.04218333, 0.043339413, 0.043439914, 0.043472953, 0.04384227, 0.042695202, 0.042750847, 0.04244843, 0.04290328, 0.043019462, 0.04248297, 0.04237309, 0.04269458, 0.04251667, 0.042883184, 0.042488664, 0.04273171, 0.042815086, 0.04320028, 0.04252825, 0.042860933, 0.042650636, 0.042184256, 0.042582564, 0.043183636, 0.04309181, 0.042486265, 0.042707667, 0.04250364, 0.042785257, 0.042433158, 0.042783454, 0.042590566, 0.042959355, 0.043014284, 0.042788453, 0.042744488, 0.04310575, 0.042466875, 0.04278715, 0.042483184, 0.0426123, 0.043086194, 0.04252918, 0.04292157, 0.043198224, 0.042739883, 0.04243131, 0.042649068, 0.04262439, 0.04247942, 0.04244338, 0.04252322, 0.04301052, 0.042832687, 0.042841874, 0.042215787, 0.04282055, 0.042484682, 0.04322633, 0.04287239, 0.042453054, 0.04256362, 0.042315446, 0.042889744, 0.042446032, 0.043135077, 0.042404987, 0.04246416, 0.04303179, 0.042437907, 0.042543136, 0.04260614, 0.042936508, 0.042812362, 0.0427054, 0.043116674, 0.04245567, 0.042826388, 0.042884424, 0.04301002, 0.042889446, 0.042799685, 0.043233365, 0.042183332, 0.042592656, 0.042768966, 0.042394333, 0.042617347, 0.042400204, 0.04308708, 0.04250238, 0.042592667, 0.04292005, 0.042474315, 0.042633463, 0.04263272, 0.04293534, 0.04241329]
,[0.06122889, 0.05635441, 0.056298018, 0.056288883, 0.05629451, 0.05628889, 0.05662763, 0.05639525, 0.05673823, 0.056598917, 0.05633509, 0.056288883, 0.057319622, 0.05669329, 0.05628889, 0.056583617, 0.05664506, 0.056595087, 0.05659433, 0.05636206, 0.056466725, 0.05682116, 0.05689226, 0.0567379, 0.056689553, 0.056490727, 0.056411054, 0.05670635, 0.056862302, 0.05657718, 0.056659367, 0.05647484, 0.056674097, 0.056508727, 0.056512635, 0.056431334, 0.056318972, 0.056383636, 0.056542423, 0.056607723, 0.05631562, 0.0564907, 0.056833427, 0.05628889, 0.05652531, 0.056346323, 0.056288883, 0.05636402, 0.05699794, 0.05683572, 0.05662722, 0.056897894, 0.056615815, 0.056544185, 0.056627646, 0.05628889, 0.0565511, 0.056672215, 0.05656777, 0.056602042, 0.05646356, 0.05680579, 0.056543246, 0.056345925, 0.057300024, 0.056515813, 0.056832068, 0.05647462, 0.056786504, 0.056693807, 0.056426153, 0.05628889, 0.056747634, 0.057177346, 0.05637529, 0.056418605, 0.056546703, 0.05637369, 0.056611735, 0.05652168, 0.056786366, 0.05634435, 0.05628889, 0.05649376, 0.057005566, 0.05676518, 0.05655767, 0.0566109, 0.056960225, 0.056547746, 0.05671777, 0.05628889, 0.056552317, 0.05658217, 0.056608345, 0.056678005, 0.056449745, 0.05628889, 0.05628889, 0.056658823]
,[0.035542764, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.030450003, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.030450003, 0.030450003, 0.030450003, 0.030450003, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.030450003, 0.030450003, 0.030450003, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.03045, 0.03045, 0.030450003, 0.03045, 0.030450003, 0.030450003, 0.030450003, 0.03045]
])

array([0.00170807, 0.00133988, 0.0011403 , 0.001824  , 0.00119762,
       0.0010539 , 0.0019781 , 0.00161114, 0.00129027, 0.04218315,
       0.05628889, 0.03045   , 0.0422341 , 0.05628889, 0.03045   ,
       0.04218333, 0.05628888, 0.03045   ])

In [30]:
hyp = all[:9]
euc = all[9:]

hyp_diffs = []
euc_diffs = []
for i in range(all.shape[1]):
    hyp_diffs.append(np.mean(hyp.min(axis=1) / hyp[:,i]))
    euc_diffs.append(np.mean(euc.min(axis=1) / euc[:,i]))

hyp_diffs = np.array(hyp_diffs)
euc_diffs = np.array(euc_diffs)

In [32]:
euc_diffs

array([0.79499422, 0.97956963, 0.99374658, 0.99143454, 0.99096919,
       0.9880197 , 0.99345998, 0.99456573, 0.99559995, 0.99288789,
       0.99314418, 0.99691124, 0.99270624, 0.99399955, 0.99675901,
       0.99467877, 0.99249438, 0.99519689, 0.99275906, 0.99225147,
       0.99579058, 0.99212499, 0.99304468, 0.99681341, 0.99418414,
       0.99216692, 0.99297041, 0.9934657 , 0.99265593, 0.99601456,
       0.99355058, 0.99630093, 0.9939549 , 0.99533603, 0.99193829,
       0.99388166, 0.99459663, 0.99484828, 0.99220162, 0.99574259,
       0.99508294, 0.99702878, 0.99280742, 0.99322038, 0.99623233,
       0.99328684, 0.99309855, 0.99533919, 0.99402891, 0.99272766,
       0.99413805, 0.99456068, 0.99592176, 0.99526798, 0.99207644,
       0.99516253, 0.99418074, 0.99630535, 0.99240158, 0.99580476,
       0.99135273, 0.99170413, 0.9961409 , 0.99678497, 0.99358998,
       0.99345799, 0.99485609, 0.99158285, 0.99541045, 0.99497175,
       0.9925984 , 0.99790681, 0.99463017, 0.99201856, 0.99408