#Predicting Olympic medals with Deep Learning

- It's almost Olympics time, and we're here to predict medals!
- Let's start out by installing and importing everything we need:

In [5]:
!pip install optuna  > /dev/null 2>&1

In [6]:
!pip install pandas numpy scikit-learn torch transformers  > /dev/null 2>&1

In [7]:
!pip install wikipedia-api > /dev/null 2>&1

In [8]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, mean_squared_error

import torch
import torch.nn as nn
from transformers import pipeline
from torch.utils.data import DataLoader, TensorDataset

import datetime
import wikipediaapi

If you have a GPU, this is your best friend:

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


Let's hard-code the participating countries. Each country has a 3-letter code called its NOC. It will be super handy today.

In [131]:
# List of tuples mapping countries to their NOCs
participating_countries = [
    # Africa
    ("Algeria", "ALG"),
    ("Angola", "ANG"),
    ("Benin", "BEN"),
    ("Botswana", "BOT"),
    ("Burkina Faso", "BUR"),
    ("Burundi", "BDI"),
    ("Cameroon", "CMR"),
    ("Cabo Verde", "CPV"),
    ("Central African Republic", "CAF"),
    ("Chad", "CHA"),
    ("Comoros", "COM"),
    ("Congo", "CGO"),
    ("Democratic Republic of the Congo", "COD"),
    ("Côte d’Ivoire", "CIV"),
    ("Djibouti", "DJI"),
    ("Egypt", "EGY"),
    ("Eritrea", "ERI"),
    ("Eswatini", "SWZ"),
    ("Ethiopia", "ETH"),
    ("Gabon", "GAB"),
    ("Gambia", "GAM"),
    ("Ghana", "GHA"),
    ("Guinea", "GUI"),
    ("Guinea-Bissau", "GBS"),
    ("Equatorial Guinea", "GEQ"),
    ("Kenya", "KEN"),
    ("Lesotho", "LES"),
    ("Liberia", "LBR"),
    ("Libya", "LBA"),
    ("Madagascar", "MAD"),
    ("Malawi", "MAW"),
    ("Mali", "MLI"),
    ("Morocco", "MAR"),
    ("Mauritius", "MRI"),
    ("Mauritania", "MTN"),
    ("Mozambique", "MOZ"),
    ("Namibia", "NAM"),
    ("Niger", "NIG"),
    ("Nigeria", "NGR"),
    ("Uganda", "UGA"),
    ("Rwanda", "RWA"),
    ("Sao Tome and Principe", "STP"),
    ("Senegal", "SEN"),
    ("Seychelles", "SEY"),
    ("Sierra Leone", "SLE"),
    ("Somalia", "SOM"),
    ("South Africa", "RSA"),
    ("South Sudan", "SSD"),
    ("Sudan", "SUD"),
    ("United Republic of Tanzania", "TAN"),
    ("Togo", "TOG"),
    ("Tunisia", "TUN"),
    ("Zambia", "ZAM"),
    ("Zimbabwe", "ZIM"),

    # The Americas
    ("Antigua and Barbuda", "ANT"),
    ("Argentina", "ARG"),
    ("Aruba", "ARU"),
    ("Bahamas", "BAH"),
    ("Barbados", "BAR"),
    ("Belize", "BIZ"),
    ("Bermuda", "BER"),
    ("Bolivia", "BOL"),
    ("Brazil", "BRA"),
    ("Cayman Islands", "CAY"),
    ("Canada", "CAN"),
    ("Chile", "CHI"),
    ("Colombia", "COL"),
    ("Costa Rica", "CRC"),
    ("Cuba", "CUB"),
    ("Dominican Republic", "DOM"),
    ("Dominica", "DMA"),
    ("El Salvador", "ESA"),
    ("Ecuador", "ECU"),
    ("Grenada", "GRN"),
    ("Guatemala", "GUA"),
    ("Guyana", "GUY"),
    ("Haiti", "HAI"),
    ("Honduras", "HON"),
    ("Jamaica", "JAM"),
    ("Mexico", "MEX"),
    ("Nicaragua", "NCA"),
    ("Panama", "PAN"),
    ("Paraguay", "PAR"),
    ("Peru", "PER"),
    ("Puerto Rico", "PUR"),
    ("Saint Kitts and Nevis", "SKN"),
    ("Saint Lucia", "LCA"),
    ("St. Vincent and the Grenadines", "VIN"),
    ("Suriname", "SUR"),
    ("Trinidad and Tobago", "TTO"),
    ("United States", "USA"),
    ("Uruguay", "URU"),
    ("Venezuela", "VEN"),
    ("Virgin Islands, British", "IVB"),
    ("United States Virgin Islands", "ISV"),

    # Asia
    ("Afghanistan", "AFG"),
    ("Bahrain", "BRN"),
    ("Bangladesh", "BAN"),
    ("Bhutan", "BHU"),
    ("Brunei Darussalam", "BRU"),
    ("Cambodia", "CAM"),
    ("China", "CHN"),
    ("Republic of Korea", "KOR"),
    ("Hong Kong, China", "HKG"),
    ("India", "IND"),
    ("Indonesia", "INA"),
    ("Islamic Republic of Iran", "IRI"),
    ("Iraq", "IRQ"),
    ("Japan", "JPN"),
    ("Jordan", "JOR"),
    ("Kazakhstan", "KAZ"),
    ("Kyrgyzstan", "KGZ"),
    ("Kuwait", "KUW"),
    ("Lao People’s Democratic Republic", "LAO"),
    ("Lebanon", "LBN"),
    ("Malaysia", "MAS"),
    ("Maldives", "MDV"),
    ("Mongolia", "MGL"),
    ("Myanmar", "MYA"),
    ("Nepal", "NEP"),
    ("Oman", "OMA"),
    ("Pakistan", "PAK"),
    ("Palestine", "PLE"),
    ("Philippines", "PHI"),
    ("Qatar", "QAT"),
    ("Democratic People’s Republic of Korea", "PRK"),
    ("Saudi Arabia", "KSA"),
    ("Singapore", "SGP"),
    ("Sri Lanka", "SRI"),
    ("Syrian Arab Republic", "SYR"),
    ("Tajikistan", "TJK"),
    ("Chinese Taipei", "TPE"),
    ("Thailand", "THA"),
    ("East Timor", "TLS"),
    ("Turkmenistan", "TKM"),
    ("United Arab Emirates", "UAE"),
    ("Uzbekistan", "UZB"),
    ("Vietnam", "VIE"),
    ("Yemen", "YEM"),

    # Europe
    ("Albania", "ALB"),
    ("Andorra", "AND"),
    ("Armenia", "ARM"),
    ("Austria", "AUT"),
    ("Azerbaijan", "AZE"),
    ("Belgium", "BEL"),
    ("Bosnia and Herzegovina", "BIH"),
    ("Bulgaria", "BUL"),
    ("Cyprus", "CYP"),
    ("Croatia", "CRO"),
    ("Czechia", "CZE"),
    ("Denmark", "DEN"),
    ("Spain", "ESP"),
    ("Estonia", "EST"),
    ("Finland", "FIN"),
    ("France", "FRA"),
    ("Georgia", "GEO"),
    ("Germany", "GER"),
    ("Great Britain", "GBR"),
    ("Greece", "GRE"),
    ("Hungary", "HUN"),
    ("Ireland", "IRL"),
    ("Iceland", "ISL"),
    ("Israel", "ISR"),
    ("Italy", "ITA"),
    ("Kosovo", "KOS"),
    ("Latvia", "LAT"),
    ("Liechtenstein", "LIE"),
    ("Lithuania", "LTU"),
    ("Luxembourg", "LIE"),
    ("North Macedonia", "MKD"),
    ("Malta", "MLT"),
    ("Republic of Moldova", "MDA"),
    ("Monaco", "MON"),
    ("Montenegro", "MNE"),
    ("Netherlands", "NED"),
    ("Norway", "NOR"),
    ("Poland", "POL"),
    ("Portugal", "POR"),
    ("Romania", "ROU"),
    ("San Marino", "SMR"),
    ("Serbia", "SRB"),
    ("Slovakia", "SVK"),
    ("Slovenia", "SLO"),
    ("Sweden", "SWE"),
    ("Switzerland", "SUI"),
    ("Türkiye", "TUR"),
    ("Ukraine", "UKR"),

    # Oceania
    ("American Samoa", "ASA"),
    ("Australia", "AUS"),
    ("Cook Islands", "COK"),
    ("Fiji", "FIJ"),
    ("Guam", "GUM"),
    ("Kiribati", "KIR"),
    ("Marshall Islands", "MHL"),
    ("Federated States of Micronesia", "FSM"),
    ("Nauru", "NRU"),
    ("New Zealand", "NZL"),
    ("Palau", "PLW"),
    ("Papua New Guinea", "PNG"),
    ("Solomon Islands", "SOL"),
    ("Samoa", "SAM"),
    ("Tonga", "TGA"),
    ("Tuvalu", "TUV"),
    ("Vanuatu", "VAN")
]


Sorry *insert name of your elementary school teacher here*, we'll be using Wikipedia!

In [11]:
#Initialize Wikipedia API
wiki_wiki = wikipediaapi.Wikipedia(user_agent='My_Olympic_Data_Bot/1.0 (randomemail@gmail.com)')

#Initialize transformer pipeline for text feature extraction
nlp = pipeline("feature-extraction", model="bert-base-uncased", device=device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Transformers are super cool. Let's see what a transformer can fetch and tell us about different countries from Wikipedia.

In [133]:
# Function to extract text features from Wikipedia
def get_country_features(country_name):
    page = wiki_wiki.page(country_name)

    if not page.exists():
        print(f"Wikipedia page not found for: {country_name}")
        return None

    # Extract summary text
    summary_text = page.summary

    # Truncate the summary text to fit within the BERT model's limit
    # BERT models typically have a maximum sequence length of 512 tokens
    tokens = summary_text.split()
    tokens = tokens[:512] # Truncate to 512 tokens

    # Use the transformer model to extract features
    features = nlp(summary_text, truncation=True, max_length=512 )

    # Aggregating features (e.g., mean across the tokens)
    aggregated_features = [sum(x)/len(x) for x in zip(*features[0])]

    return aggregated_features

In [134]:
# Function to compile features for all countries
def compile_country_features():
    all_country_features = {}

    for country_name, NOC in participating_countries:
        features = get_country_features(country_name)
        if features:
            all_country_features[NOC] = features

    return all_country_features

# Compile features for all countries
country_features = compile_country_features()

# Display sample features
print(f"Sample features for United States: {country_features.get('United States')}")

Sample features for United States: None


This is why we need NOC's. So we can use our data easily and in a uniform way. Look:

U S A! U S A!

In [135]:
print(f"Sample features for United States: {country_features.get('USA')}")

Sample features for United States: [-0.4611632887737187, -0.05057181595475413, 0.04607897073651657, 0.03584708038806639, 0.07485528242796136, -0.037889812905973486, 0.32460331383856555, 0.6609784091178881, 0.016430025981662766, -0.03134424225658705, 0.01804722530619074, -0.3637100045598345, -0.3478598260585386, 0.746809978840929, -0.13590981190054663, 0.36191020858905176, 0.39534079708084846, 0.08246361629213084, -0.31315066141041825, 0.58237973106543, 0.24851076140134865, 0.1570819512462549, 0.01223808385054781, 0.8492528961123753, 0.6276254809317834, -0.027949743085798673, -0.1848762339005816, -0.27075870790940826, -0.08797866366830931, -0.1051676446016927, 0.512074131146619, -0.2959192017474379, -0.27408951300367335, -0.17718606423136407, 0.2642809900868315, 0.0018256630521022998, -0.08326344722354406, -0.24185672438306938, 0.03404435755987478, 0.18204859425156883, -0.6288830170960864, -0.41536303806148567, 0.010657959245286008, 0.10857912525960955, 0.006136733687640117, -0.09405646

You should be getting long vectors of numbers, which are our bert-generated features. If you're getting None here, something is wrong.

In [136]:
print(f"Sample features for China: {country_features.get('CHN')}")

Sample features for China: [-0.4225372159318539, 0.07738997079513865, -0.00876026387300044, -0.0453335340087051, 0.0705637495798328, 0.06537318900478795, 0.20705597874439263, 0.8367145173560857, 0.07901886195443808, -0.027523922080376906, 0.0548064801831174, -0.41264560758099833, -0.3421829218887069, 0.5686385280141977, -0.1718181520982398, 0.48052678947721006, 0.3179122218880366, 0.05206499707946932, -0.2997113193325731, 0.600626012282305, 0.17527564164268838, 0.06800104731212286, 0.035996212052737064, 0.9558984380591937, 0.592128022322413, 0.006404183533845753, -0.07678748904572785, 0.016648983480337165, -0.11963628074613553, -0.2725862703767632, 0.400021027551702, -0.11248992903676935, -0.3268878680675016, -0.20129200645212109, 0.12896272074476656, 0.05852384559238999, -0.023480092269551278, -0.11404357923265707, -0.0850943872908374, 0.08804373230987039, -0.7671799780866877, -0.3456515371608475, -0.01934311230934327, 0.28416899558988007, 0.12581931660329815, -0.20934594558373476, 0.

Ok, now let's look at some data and start preprocessing it. Paris2024 is a Summer Olympics, so we'll only be looking at Summer editions.

In [143]:
historical_medal_tally_path = 'Olympic_Games_Medal_Tally.csv'
paris_2024_medal_tally_path = 'medals_total.csv'

# Load the CSV files into dataframes
historical_medal_tally = pd.read_csv(historical_medal_tally_path)
historical_medal_tally.rename(columns={'country_noc': 'country_code'}, inplace=True)

# Filter the data for Summer Olympics only
historical_medal_tally['edition'] = historical_medal_tally['edition'].str.replace(' Olympics', '')
historical_medal_tally[['Year', 'Season']] = historical_medal_tally['edition'].str.split(' ', expand=True)
historical_medal_tally = historical_medal_tally.drop(columns=['year'])
historical_medal_tally_summer = historical_medal_tally[historical_medal_tally['Season'] == 'Summer']
paris_2024_medal_tally = pd.read_csv(paris_2024_medal_tally_path)

participating_countries_df = pd.DataFrame(participating_countries, columns=['country', 'country_code'])


historical_medal_tally_summer


Unnamed: 0,edition,edition_id,country,country_code,gold,silver,bronze,total,Year,Season
0,1896 Summer,1,United States,USA,11,7,2,20,1896,Summer
1,1896 Summer,1,Greece,GRE,10,18,19,47,1896,Summer
2,1896 Summer,1,Germany,GER,6,5,2,13,1896,Summer
3,1896 Summer,1,France,FRA,5,4,2,11,1896,Summer
4,1896 Summer,1,Great Britain,GBR,2,3,2,7,1896,Summer
...,...,...,...,...,...,...,...,...,...,...
1338,2020 Summer,61,Ghana,GHA,0,0,1,1,2020,Summer
1339,2020 Summer,61,Grenada,GRN,0,0,1,1,2020,Summer
1340,2020 Summer,61,Kuwait,KUW,0,0,1,1,2020,Summer
1341,2020 Summer,61,Republic of Moldova,MDA,0,0,1,1,2020,Summer


Not all countries have won medals yet, and yet we can't completely disregard them- there might be surprises!

In [138]:
# Function to add missing countries with 0 medals
def add_missing_countries(medal_tally_df, left_on, right_on):
    # Merge the medal tally with participating countries
    merged_df = pd.merge(participating_countries_df, medal_tally_df, left_on=left_on, right_on=right_on, how='left')

    # Fill NaN values (for countries with no medals) with 0
    merged_df.fillna(-1, inplace=True)

    return merged_df

We'll try two test cases: First we'll filter out countries which haven't won medals recently. This is mainly in order to disregard former countries like East Germany, Soviet Russia and such, which were IMMENSE Olympic powerhouses but no longer exist. We don't want them biasing our data...

In [144]:
# Test case 1: country filtering - predicting only for countries which have won medals in recent Olympics
historical_medal_tally = add_missing_countries(historical_medal_tally_summer, ['country_code'], ['country_code'])
paris_2024_medal_tally = add_missing_countries(paris_2024_medal_tally, ['country_code'], ['country_code'])
historical_medal_tally = historical_medal_tally.drop(columns=['country_y']).rename(columns={'country_x': 'Country'})


historical_medal_tally['Year'] = historical_medal_tally['Year'].astype(int) # Convert 'year' column to integer type
historical_medal_tally

Unnamed: 0,Country,country_code,edition,edition_id,gold,silver,bronze,total,Year,Season
0,Algeria,ALG,1984 Summer,21.0,0.0,0.0,2.0,2.0,1984,Summer
1,Algeria,ALG,1992 Summer,23.0,1.0,0.0,1.0,2.0,1992,Summer
2,Algeria,ALG,1996 Summer,24.0,2.0,0.0,1.0,3.0,1996,Summer
3,Algeria,ALG,2000 Summer,25.0,1.0,1.0,3.0,5.0,2000,Summer
4,Algeria,ALG,2008 Summer,53.0,0.0,1.0,1.0,2.0,2008,Summer
...,...,...,...,...,...,...,...,...,...,...
1322,Solomon Islands,SOL,-1,-1.0,-1.0,-1.0,-1.0,-1.0,-1,-1
1323,Samoa,SAM,2008 Summer,53.0,0.0,1.0,0.0,1.0,2008,Summer
1324,Tonga,TGA,1996 Summer,24.0,0.0,1.0,0.0,1.0,1996,Summer
1325,Tuvalu,TUV,-1,-1.0,-1.0,-1.0,-1.0,-1.0,-1,-1


Our second case doesn't do that, but rather looks at all the countries participating in Paris2024.

In [145]:
## Test case 2: without country filtering - predicting for all participating countries
historical_medal_tally_case_2 = historical_medal_tally
paris_2024_medal_tally_case_2 = paris_2024_medal_tally
historical_medal_tally_case_2

Unnamed: 0,Country,country_code,edition,edition_id,gold,silver,bronze,total,Year,Season
0,Algeria,ALG,1984 Summer,21.0,0.0,0.0,2.0,2.0,1984,Summer
1,Algeria,ALG,1992 Summer,23.0,1.0,0.0,1.0,2.0,1992,Summer
2,Algeria,ALG,1996 Summer,24.0,2.0,0.0,1.0,3.0,1996,Summer
3,Algeria,ALG,2000 Summer,25.0,1.0,1.0,3.0,5.0,2000,Summer
4,Algeria,ALG,2008 Summer,53.0,0.0,1.0,1.0,2.0,2008,Summer
...,...,...,...,...,...,...,...,...,...,...
1322,Solomon Islands,SOL,-1,-1.0,-1.0,-1.0,-1.0,-1.0,-1,-1
1323,Samoa,SAM,2008 Summer,53.0,0.0,1.0,0.0,1.0,2008,Summer
1324,Tonga,TGA,1996 Summer,24.0,0.0,1.0,0.0,1.0,1996,Summer
1325,Tuvalu,TUV,-1,-1.0,-1.0,-1.0,-1.0,-1.0,-1,-1


In [146]:

# Test case 1: Keep only countries which have won at least 8 medals since Sydney 2000 for relevance
historical_medal_tally_filtered = historical_medal_tally[historical_medal_tally['Year'] >= 2000]
medals_by_country = historical_medal_tally_filtered.groupby('country_code')['total'].sum()
countries_with_at_least_8_medals = medals_by_country[medals_by_country >= 8].index
countries_with_at_least_8_medals_df = pd.DataFrame(countries_with_at_least_8_medals)
medals = pd.DataFrame(medals_by_country)
historical_medal_tally_refined = historical_medal_tally[historical_medal_tally['country_code'].isin(countries_with_at_least_8_medals)]
paris_2024_medal_tally_refined = paris_2024_medal_tally[paris_2024_medal_tally['country_code'].isin(countries_with_at_least_8_medals)]
historical_medal_tally_df = pd.DataFrame(historical_medal_tally_refined)
paris_2024_medal_tally_df = pd.DataFrame(paris_2024_medal_tally_refined)

#Test case 2:
historical_medal_tally_case_2_df = pd.DataFrame(historical_medal_tally_case_2)
paris_2024_medal_tally_case_2_df = pd.DataFrame(paris_2024_medal_tally_case_2)

Here are the countries which won medals recently(at least 8 since Sydney2000)

In [147]:
countries_with_at_least_8_medals_df

Unnamed: 0,country_code
0,ALG
1,ARG
2,ARM
3,AUS
4,AUT
...,...
65,TUR
66,UKR
67,USA
68,UZB


Feel free to view the dataframes

In [148]:
historical_medal_tally_df

Unnamed: 0,Country,country_code,edition,edition_id,gold,silver,bronze,total,Year,Season
0,Algeria,ALG,1984 Summer,21.0,0.0,0.0,2.0,2.0,1984,Summer
1,Algeria,ALG,1992 Summer,23.0,1.0,0.0,1.0,2.0,1992,Summer
2,Algeria,ALG,1996 Summer,24.0,2.0,0.0,1.0,3.0,1996,Summer
3,Algeria,ALG,2000 Summer,25.0,1.0,1.0,3.0,5.0,2000,Summer
4,Algeria,ALG,2008 Summer,53.0,0.0,1.0,1.0,2.0,2008,Summer
...,...,...,...,...,...,...,...,...,...,...
1315,New Zealand,NZL,2004 Summer,26.0,3.0,2.0,0.0,5.0,2004,Summer
1316,New Zealand,NZL,2008 Summer,53.0,3.0,2.0,4.0,9.0,2008,Summer
1317,New Zealand,NZL,2012 Summer,54.0,6.0,2.0,5.0,13.0,2012,Summer
1318,New Zealand,NZL,2016 Summer,59.0,4.0,9.0,5.0,18.0,2016,Summer


In [149]:
paris_2024_medal_tally_df

Unnamed: 0,country,country_code,Gold Medal,Silver Medal,Bronze Medal,Total
0,Algeria,ALG,2.0,0.0,1.0,3.0
15,Egypt,EGY,1.0,1.0,1.0,3.0
18,Ethiopia,ETH,1.0,3.0,0.0,4.0
25,Kenya,KEN,4.0,2.0,5.0,11.0
32,Morocco,MAR,1.0,0.0,1.0,2.0
...,...,...,...,...,...,...
184,Switzerland,SUI,1.0,2.0,5.0,8.0
185,Türkiye,TUR,0.0,3.0,5.0,8.0
186,Ukraine,UKR,3.0,5.0,4.0,12.0
188,Australia,AUS,18.0,19.0,16.0,53.0


We promise, this is as much data processing as we'll be doing today. Pretty straightforward

In [154]:
# Preprocessing steps to standardize country codes and names for merging

# Standardize column names for easier merging
historical_medal_tally_df.rename(columns={'total': 'Total', 'gold': 'Gold', 'silver': 'Silver', 'bronze': 'Bronze'}, inplace=True)
historical_medal_tally_case_2_df.rename(columns={'total': 'Total', 'gold': 'Gold', 'silver': 'Silver', 'bronze': 'Bronze'}, inplace=True)
paris_2024_medal_tally_df.rename(columns={'Gold Medal': 'Gold', 'Silver Medal': 'Silver', 'Bronze Medal': 'Bronze'}, inplace=True)
paris_2024_medal_tally_case_2_df.rename(columns={'Gold Medal': 'Gold', 'Silver Medal': 'Silver', 'Bronze Medal': 'Bronze'}, inplace=True)
historical_medal_tally_df



Unnamed: 0,Country,country_code,edition,edition_id,Gold,Silver,Bronze,Total,Year,Season
0,Algeria,ALG,1984 Summer,21.0,0.0,0.0,2.0,2.0,1984,Summer
1,Algeria,ALG,1992 Summer,23.0,1.0,0.0,1.0,2.0,1992,Summer
2,Algeria,ALG,1996 Summer,24.0,2.0,0.0,1.0,3.0,1996,Summer
3,Algeria,ALG,2000 Summer,25.0,1.0,1.0,3.0,5.0,2000,Summer
4,Algeria,ALG,2008 Summer,53.0,0.0,1.0,1.0,2.0,2008,Summer
...,...,...,...,...,...,...,...,...,...,...
1315,New Zealand,NZL,2004 Summer,26.0,3.0,2.0,0.0,5.0,2004,Summer
1316,New Zealand,NZL,2008 Summer,53.0,3.0,2.0,4.0,9.0,2008,Summer
1317,New Zealand,NZL,2012 Summer,54.0,6.0,2.0,5.0,13.0,2012,Summer
1318,New Zealand,NZL,2016 Summer,59.0,4.0,9.0,5.0,18.0,2016,Summer


Let's only keep the NOC, year and medal tallies. EVerything else isn't so important.

In [155]:
# Extract relevant columns for merging: ['country_code', 'Total']
historical_medal_tally_relevant = historical_medal_tally_df[['country_code', 'Year', 'Total', 'Gold', 'Silver', 'Bronze']]
historical_medal_tally_relevant_case_2 = historical_medal_tally_case_2_df[['country_code', 'Year', 'Total', 'Gold', 'Silver', 'Bronze']]

In [156]:
historical_medal_tally_relevant

Unnamed: 0,country_code,Year,Total,Gold,Silver,Bronze
0,ALG,1984,2.0,0.0,0.0,2.0
1,ALG,1992,2.0,1.0,0.0,1.0
2,ALG,1996,3.0,2.0,0.0,1.0
3,ALG,2000,5.0,1.0,1.0,3.0
4,ALG,2008,2.0,0.0,1.0,1.0
...,...,...,...,...,...,...
1315,NZL,2004,5.0,3.0,2.0,0.0
1316,NZL,2008,9.0,3.0,2.0,4.0
1317,NZL,2012,13.0,6.0,2.0,5.0
1318,NZL,2016,18.0,4.0,9.0,5.0


In [157]:
historical_medal_tally_relevant_case_2

Unnamed: 0,country_code,Year,Total,Gold,Silver,Bronze
0,ALG,1984,2.0,0.0,0.0,2.0
1,ALG,1992,2.0,1.0,0.0,1.0
2,ALG,1996,3.0,2.0,0.0,1.0
3,ALG,2000,5.0,1.0,1.0,3.0
4,ALG,2008,2.0,0.0,1.0,1.0
...,...,...,...,...,...,...
1322,SOL,-1,-1.0,-1.0,-1.0,-1.0
1323,SAM,2008,1.0,0.0,1.0,0.0
1324,TGA,1996,1.0,0.0,1.0,0.0
1325,TUV,-1,-1.0,-1.0,-1.0,-1.0


A quick look at the countries before we do the learning. We can see that the second test case has many more countries because we didn't filter anything out.

In [158]:
paris_2024_medal_tally_relevant = paris_2024_medal_tally_df[['country_code', 'Total', 'Gold', 'Silver', 'Bronze']]
paris_2024_medal_tally_relevant_case_2 = paris_2024_medal_tally_case_2_df[['country_code', 'Total', 'Gold', 'Silver', 'Bronze']]
paris_2024_unique_codes = paris_2024_medal_tally_relevant['country_code'].unique()
historical_medal_tally_unique_codes = historical_medal_tally_relevant['country_code'].unique()
paris_2024_unique_codes_case_2 = paris_2024_medal_tally_relevant_case_2['country_code'].unique()
historical_medal_tally_unique_codes_case_2 = historical_medal_tally_relevant_case_2['country_code'].unique()


paris_2024_unique_codes, paris_2024_unique_codes_case_2

(array(['ALG', 'EGY', 'ETH', 'KEN', 'MAR', 'NGR', 'RSA', 'TUN', 'ARG',
        'BAH', 'BRA', 'CAN', 'COL', 'CUB', 'DOM', 'JAM', 'MEX', 'TTO',
        'USA', 'VEN', 'CHN', 'KOR', 'HKG', 'IND', 'INA', 'IRI', 'JPN',
        'KAZ', 'MAS', 'MGL', 'PRK', 'TPE', 'THA', 'UZB', 'ARM', 'AUT',
        'AZE', 'BEL', 'BUL', 'CRO', 'CZE', 'DEN', 'ESP', 'EST', 'FIN',
        'FRA', 'GEO', 'GER', 'GBR', 'GRE', 'HUN', 'IRL', 'ISR', 'ITA',
        'LAT', 'LTU', 'NED', 'NOR', 'POL', 'POR', 'ROU', 'SRB', 'SVK',
        'SLO', 'SWE', 'SUI', 'TUR', 'UKR', 'AUS', 'NZL'], dtype=object),
 array(['ALG', 'ANG', 'BEN', 'BOT', 'BUR', 'BDI', 'CMR', 'CPV', 'CAF',
        'CHA', 'COM', 'CGO', 'COD', 'CIV', 'DJI', 'EGY', 'ERI', 'SWZ',
        'ETH', 'GAB', 'GAM', 'GHA', 'GUI', 'GBS', 'GEQ', 'KEN', 'LES',
        'LBR', 'LBA', 'MAD', 'MAW', 'MLI', 'MAR', 'MRI', 'MTN', 'MOZ',
        'NAM', 'NIG', 'NGR', 'UGA', 'RWA', 'STP', 'SEN', 'SEY', 'SLE',
        'SOM', 'RSA', 'SSD', 'SUD', 'TAN', 'TOG', 'TUN', 'ZAM', 'ZIM',
    

Remember, this is sequential, so just run everything one by one. The DF views are mainly for debugging. Can you find China?

In [164]:
# Preparing to merge BERT features with medal tally data
# We already have the country_names from BERT feature extraction which corresponds to the country names
# We will match these with country codes from the medal tally data

# First, let's create a DataFrame for the BERT features with corresponding country codes
bert_feature_df = pd.DataFrame.from_dict(country_features, orient='index')
num_features = len(next(iter(country_features.values())))
bert_feature_df.columns = [f'feature_{i}' for i in range(num_features)]
bert_feature_df['country_code'] = bert_feature_df.index

bert_feature_df.reset_index(drop=True, inplace=True)

# Drop rows with missing country codes (countries that are not in the medal tally data)
bert_feature_df.dropna(subset=['country_code'], inplace=True)

# Now, let's merge the BERT features with both historical and Paris 2024 medal data
# Case 1
historical_data_with_features = pd.merge(historical_medal_tally_relevant, bert_feature_df, on='country_code', how='inner')
paris_2024_data_with_features = pd.merge(paris_2024_medal_tally_relevant, bert_feature_df, on='country_code', how='inner')
# Case 2
historical_data_with_features_case_2 = pd.merge(historical_medal_tally_relevant_case_2, bert_feature_df, on='country_code', how='inner')
paris_2024_data_with_features_case_2 = pd.merge(paris_2024_medal_tally_relevant_case_2, bert_feature_df, on='country_code', how='inner')

historical_data_with_features.shape, paris_2024_data_with_features.shape, historical_data_with_features_case_2.shape, paris_2024_data_with_features_case_2.shape


((1081, 774), (70, 773), (1327, 774), (204, 773))

In [165]:
historical_data_with_features = pd.DataFrame(historical_data_with_features)
historical_data_with_features

Unnamed: 0,country_code,Year,Total,Gold,Silver,Bronze,feature_0,feature_1,feature_2,feature_3,...,feature_758,feature_759,feature_760,feature_761,feature_762,feature_763,feature_764,feature_765,feature_766,feature_767
0,ALG,1984,2.0,0.0,0.0,2.0,-0.481400,0.065026,-0.187783,-0.258769,...,-0.210599,0.171756,0.168706,-0.478434,-0.021412,-0.673022,0.176665,-0.294029,0.007613,0.317239
1,ALG,1992,2.0,1.0,0.0,1.0,-0.481400,0.065026,-0.187783,-0.258769,...,-0.210599,0.171756,0.168706,-0.478434,-0.021412,-0.673022,0.176665,-0.294029,0.007613,0.317239
2,ALG,1996,3.0,2.0,0.0,1.0,-0.481400,0.065026,-0.187783,-0.258769,...,-0.210599,0.171756,0.168706,-0.478434,-0.021412,-0.673022,0.176665,-0.294029,0.007613,0.317239
3,ALG,2000,5.0,1.0,1.0,3.0,-0.481400,0.065026,-0.187783,-0.258769,...,-0.210599,0.171756,0.168706,-0.478434,-0.021412,-0.673022,0.176665,-0.294029,0.007613,0.317239
4,ALG,2008,2.0,0.0,1.0,1.0,-0.481400,0.065026,-0.187783,-0.258769,...,-0.210599,0.171756,0.168706,-0.478434,-0.021412,-0.673022,0.176665,-0.294029,0.007613,0.317239
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1076,NZL,2004,5.0,3.0,2.0,0.0,-0.230367,0.046630,0.008161,-0.178960,...,0.035039,-0.017088,0.203934,-0.502458,-0.058443,-0.757425,0.182451,-0.143095,-0.097160,0.334275
1077,NZL,2008,9.0,3.0,2.0,4.0,-0.230367,0.046630,0.008161,-0.178960,...,0.035039,-0.017088,0.203934,-0.502458,-0.058443,-0.757425,0.182451,-0.143095,-0.097160,0.334275
1078,NZL,2012,13.0,6.0,2.0,5.0,-0.230367,0.046630,0.008161,-0.178960,...,0.035039,-0.017088,0.203934,-0.502458,-0.058443,-0.757425,0.182451,-0.143095,-0.097160,0.334275
1079,NZL,2016,18.0,4.0,9.0,5.0,-0.230367,0.046630,0.008161,-0.178960,...,0.035039,-0.017088,0.203934,-0.502458,-0.058443,-0.757425,0.182451,-0.143095,-0.097160,0.334275




In [166]:
historical_data_with_features_case_2 = pd.DataFrame(historical_data_with_features_case_2)
historical_data_with_features_case_2

Unnamed: 0,country_code,Year,Total,Gold,Silver,Bronze,feature_0,feature_1,feature_2,feature_3,...,feature_758,feature_759,feature_760,feature_761,feature_762,feature_763,feature_764,feature_765,feature_766,feature_767
0,ALG,1984,2.0,0.0,0.0,2.0,-0.481400,0.065026,-0.187783,-0.258769,...,-0.210599,0.171756,0.168706,-0.478434,-0.021412,-0.673022,0.176665,-0.294029,0.007613,0.317239
1,ALG,1992,2.0,1.0,0.0,1.0,-0.481400,0.065026,-0.187783,-0.258769,...,-0.210599,0.171756,0.168706,-0.478434,-0.021412,-0.673022,0.176665,-0.294029,0.007613,0.317239
2,ALG,1996,3.0,2.0,0.0,1.0,-0.481400,0.065026,-0.187783,-0.258769,...,-0.210599,0.171756,0.168706,-0.478434,-0.021412,-0.673022,0.176665,-0.294029,0.007613,0.317239
3,ALG,2000,5.0,1.0,1.0,3.0,-0.481400,0.065026,-0.187783,-0.258769,...,-0.210599,0.171756,0.168706,-0.478434,-0.021412,-0.673022,0.176665,-0.294029,0.007613,0.317239
4,ALG,2008,2.0,0.0,1.0,1.0,-0.481400,0.065026,-0.187783,-0.258769,...,-0.210599,0.171756,0.168706,-0.478434,-0.021412,-0.673022,0.176665,-0.294029,0.007613,0.317239
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1322,SOL,-1,-1.0,-1.0,-1.0,-1.0,-0.216952,-0.029954,0.053713,-0.061065,...,-0.068429,0.028338,0.259332,-0.282562,-0.125553,-0.452995,-0.001912,-0.099491,-0.017073,0.296681
1323,SAM,2008,1.0,0.0,1.0,0.0,-0.107072,0.095541,0.039241,-0.190647,...,-0.075343,0.154960,0.312662,-0.548999,-0.194812,-0.554261,0.070452,-0.052134,-0.047149,0.415280
1324,TGA,1996,1.0,0.0,1.0,0.0,-0.183574,-0.019720,0.021747,-0.132610,...,-0.144812,0.119026,0.315878,-0.415945,-0.104811,-0.592223,0.093978,-0.224056,-0.108927,0.425377
1325,TUV,-1,-1.0,-1.0,-1.0,-1.0,-0.128275,0.064639,0.122194,-0.044439,...,0.025434,0.020160,0.245556,-0.472722,-0.151195,-0.494881,0.044703,-0.111421,0.022991,0.363630


In [167]:
paris_2024_data_with_features = pd.DataFrame(paris_2024_data_with_features)
paris_2024_data_with_features

Unnamed: 0,country_code,Total,Gold,Silver,Bronze,feature_0,feature_1,feature_2,feature_3,feature_4,...,feature_758,feature_759,feature_760,feature_761,feature_762,feature_763,feature_764,feature_765,feature_766,feature_767
0,ALG,3.0,2.0,0.0,1.0,-0.481400,0.065026,-0.187783,-0.258769,0.137929,...,-0.210599,0.171756,0.168706,-0.478434,-0.021412,-0.673022,0.176665,-0.294029,0.007613,0.317239
1,EGY,3.0,1.0,1.0,1.0,-0.465102,0.146367,-0.131519,-0.321259,0.009605,...,-0.100259,0.216268,0.106795,-0.342460,0.065206,-0.768027,0.079232,-0.363963,0.063853,0.235727
2,ETH,4.0,1.0,3.0,0.0,-0.409535,0.137565,-0.043246,-0.139794,0.046535,...,-0.037561,0.126538,0.145211,-0.145815,-0.052923,-0.547403,0.042519,-0.345376,0.040223,0.112593
3,KEN,11.0,4.0,2.0,5.0,-0.429986,-0.040282,-0.015930,-0.184996,0.135444,...,-0.102464,-0.126728,0.123888,-0.325692,-0.017184,-0.521751,0.115374,-0.270096,0.059250,0.215744
4,MAR,2.0,1.0,0.0,1.0,-0.481812,0.031398,-0.029591,-0.244370,0.209804,...,-0.199334,0.214556,0.214557,-0.366538,-0.042957,-0.733670,0.098782,-0.338699,-0.027667,0.296554
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
65,SUI,8.0,1.0,2.0,5.0,-0.526929,0.023256,0.113287,-0.192864,0.065293,...,0.011658,0.150495,0.219954,-0.333502,0.104974,-0.592063,0.154276,-0.173147,0.064206,0.257630
66,TUR,8.0,0.0,3.0,5.0,-0.390875,0.143982,-0.185502,-0.152473,0.117819,...,-0.102327,0.038741,0.163614,-0.324636,-0.084870,-0.576549,0.056370,-0.379389,0.005728,0.274972
67,UKR,12.0,3.0,5.0,4.0,-0.414747,-0.070224,0.015344,-0.154040,0.086212,...,-0.179419,0.191441,0.151555,-0.249487,-0.094757,-0.643935,0.174749,-0.401643,-0.095841,0.292428
68,AUS,53.0,18.0,19.0,16.0,-0.399308,0.016281,0.165232,-0.058390,0.004677,...,-0.096388,0.015516,0.090935,-0.336500,0.089895,-0.675746,0.157005,-0.149723,-0.009941,0.211812


In [168]:
paris_2024_data_with_features_case_2 = pd.DataFrame(paris_2024_data_with_features_case_2)
paris_2024_data_with_features_case_2

Unnamed: 0,country_code,Total,Gold,Silver,Bronze,feature_0,feature_1,feature_2,feature_3,feature_4,...,feature_758,feature_759,feature_760,feature_761,feature_762,feature_763,feature_764,feature_765,feature_766,feature_767
0,ALG,3.0,2.0,0.0,1.0,-0.481400,0.065026,-0.187783,-0.258769,0.137929,...,-0.210599,0.171756,0.168706,-0.478434,-0.021412,-0.673022,0.176665,-0.294029,0.007613,0.317239
1,ANG,-1.0,-1.0,-1.0,-1.0,-0.531567,-0.169313,-0.094242,-0.033409,0.296487,...,-0.097449,0.125326,0.197422,-0.428187,-0.046271,-0.636142,-0.062633,-0.070587,-0.054669,0.242843
2,BEN,-1.0,-1.0,-1.0,-1.0,-0.489794,0.028737,-0.049838,-0.189908,0.169756,...,-0.141605,0.104256,0.049683,-0.457860,0.043841,-0.658490,0.022212,-0.060198,0.025601,0.235664
3,BOT,2.0,1.0,1.0,0.0,-0.466590,-0.096590,0.113659,-0.024595,0.127837,...,-0.101390,0.209467,0.119671,-0.498387,0.091250,-0.544316,-0.032646,-0.149083,0.001064,0.215812
4,BUR,-1.0,-1.0,-1.0,-1.0,-0.431025,-0.038477,-0.113337,-0.156375,0.116831,...,-0.242285,0.227782,0.051940,-0.384393,0.074634,-0.731360,-0.084426,-0.226577,-0.174809,0.141097
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
199,SOL,-1.0,-1.0,-1.0,-1.0,-0.216952,-0.029954,0.053713,-0.061065,0.143249,...,-0.068429,0.028338,0.259332,-0.282562,-0.125553,-0.452995,-0.001912,-0.099491,-0.017073,0.296681
200,SAM,-1.0,-1.0,-1.0,-1.0,-0.107072,0.095541,0.039241,-0.190647,0.231200,...,-0.075343,0.154960,0.312662,-0.548999,-0.194812,-0.554261,0.070452,-0.052134,-0.047149,0.415280
201,TGA,-1.0,-1.0,-1.0,-1.0,-0.183574,-0.019720,0.021747,-0.132610,0.101105,...,-0.144812,0.119026,0.315878,-0.415945,-0.104811,-0.592223,0.093978,-0.224056,-0.108927,0.425377
202,TUV,-1.0,-1.0,-1.0,-1.0,-0.128275,0.064639,0.122194,-0.044439,0.235125,...,0.025434,0.020160,0.245556,-0.472722,-0.151195,-0.494881,0.044703,-0.111421,0.022991,0.363630


#THE MAIN EVENT

Ok, we made it so far.
Let's scale the data and it into train and validation sets. We'll use the ACTUAL Paris results as a test set(should be ready before our presentation). How cool is that?

In [170]:
scaler1 = StandardScaler()
scaler2 = StandardScaler()
# Splitting features and target (Total medals)
X_train = historical_data_with_features.drop(columns=['Total', 'Gold', 'Silver', 'Bronze', 'country_code'])
X_train_2 = historical_data_with_features_case_2.drop(columns=['Total', 'Gold', 'Silver', 'Bronze', 'country_code'])

X_train = X_train.apply(pd.to_numeric, errors='coerce')  # Convert non-numeric values to NaN
X_train_2 = X_train_2.apply(pd.to_numeric, errors='coerce')  # Convert non-numeric values to NaN

X_train_scaled = scaler1.fit_transform(X_train)
X_train_tensor_scaled = torch.tensor(X_train_scaled, dtype=torch.float32).to(device)

X_train_scaled_2 = scaler2.fit_transform(X_train_2)
X_train_tensor_scaled_2 = torch.tensor(X_train_scaled_2, dtype=torch.float32).to(device)

y_train = historical_data_with_features[['Total', 'Gold', 'Silver', 'Bronze']]
y_train_2 = historical_data_with_features_case_2[['Total', 'Gold', 'Silver', 'Bronze']]

y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32).to(device)
y_train_tensor_2 = torch.tensor(y_train_2.values, dtype=torch.float32).to(device)

paris_2024_data_with_features_sorted = paris_2024_data_with_features.sort_values(by='Gold', ascending=False)
paris_2024_data_with_features_sorted_2 = paris_2024_data_with_features_case_2.sort_values(by='Gold', ascending=False)

X_test = paris_2024_data_with_features_sorted.drop(columns=['Total', 'Gold', 'Silver', 'Bronze', 'country_code'])
X_test_2 = paris_2024_data_with_features_sorted_2.drop(columns=['Total', 'Gold', 'Silver', 'Bronze', 'country_code'])
test_countries = paris_2024_data_with_features_sorted['country_code']
test_countries_2 = paris_2024_data_with_features_sorted_2['country_code']

# Add a constant 'Year' column to X_test with value 2024
X_test.insert(0, 'Year', 2024)
X_test_2.insert(0, 'Year', 2024)
#Scale and transform
X_test_scaled = scaler1.transform(X_test)  # Use the same scaler fitted on training data
X_test_tensor_scaled = torch.tensor(X_test_scaled, dtype=torch.float32).to(device)

X_test_scaled_2 = scaler2.transform(X_test_2)  # Use the same scaler fitted on training data
X_test_tensor_scaled_2 = torch.tensor(X_test_scaled_2, dtype=torch.float32).to(device)

y_test = paris_2024_data_with_features_sorted[['Total', 'Gold', 'Silver', 'Bronze']]
y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32).to(device)

y_test_2 = paris_2024_data_with_features_sorted_2[['Total', 'Gold', 'Silver', 'Bronze']]
y_test_tensor_2 = torch.tensor(y_test_2.values, dtype=torch.float32).to(device)






#TRAINING

We'll use Optuna in order to tuna Hyperparameters and find the best combo. Then we'll train our network and see the results. It's as simple as that.

In [171]:
import optuna

# Splitting train and validation data
X_train_sub, X_val, y_train_sub, y_val = train_test_split(X_train_tensor_scaled, y_train_tensor, test_size=0.2, random_state=42)

def objective(trial):
    # Define the hyperparameter search space
    n_units_1 = trial.suggest_int('n_units_1', 256, 1024)
    n_units_2 = trial.suggest_int('n_units_2', 128, 512)
    n_units_3 = trial.suggest_int('n_units_3', 64, 256)
    dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3)

    # Building the model with trial hyperparameters
    model = nn.Sequential(
        nn.Linear(X_train_sub.shape[1], n_units_1),
        nn.ReLU(),
        nn.Dropout(dropout_rate),
        nn.Linear(n_units_1, n_units_2),
        nn.ReLU(),
        nn.Dropout(dropout_rate),
        nn.Linear(n_units_2, n_units_3),
        nn.ReLU(),
        nn.Linear(n_units_3, 4)  # Output layer for regression
    ).to(device)  # Move the model to the GPU

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model
    for epoch in range(20):
        optimizer.zero_grad()
        outputs = model(X_train_sub).to(device)
        loss = criterion(outputs.squeeze(), y_train_sub)
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")
        loss.backward()
        optimizer.step()

    # Evaluate the model on validation data
    with torch.no_grad():
        val_outputs = model(X_val).to(device)
        val_outputs = torch.round(val_outputs)  # Round the predictions to the nearest integer
        val_mae = {}
        val_accuracy = {}
        tolerance = 0.2
        tolerance_biases = [7.0, 3.0, 3.0, 3.0]
        for i, medal_type in enumerate(['Total', 'Gold', 'Silver', 'Bronze']):
          mae = torch.mean(torch.abs(val_outputs[:,i] - y_val[:,i])).item()
          val_mae[medal_type] = mae
          correct = torch.sum(torch.abs(val_outputs[:,i] - y_val[:,i]) <= torch.tensor(tolerance_biases[i]).to(device)*y_val[:,i])
          accuracy = correct.item() / len(y_val) * 100
          val_accuracy[medal_type] = accuracy
          val_loss = criterion(val_outputs[:,i], y_val[:,i])
        print(f"Validation Loss: {val_loss} ; Validation Accuracy: {accuracy}%")

    return sum(val_mae.values())  # Return the validation MAE as a Python numbe

# Create an Optuna study and optimize the objective function
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=50)

# Get the best hyperparameters
best_params = study.best_params
print("Best hyperparameters: ", best_params)

# Build and evaluate the model with the best hyperparameters
n_units_1 = best_params['n_units_1']
n_units_2 = best_params['n_units_2']
n_units_3 = best_params['n_units_3']
dropout_rate = best_params['dropout_rate']
learning_rate = best_params['learning_rate']

    # Building the model with trial hyperparameters
best_model = nn.Sequential(
        nn.Linear(X_train_tensor_scaled.shape[1], n_units_1),
        nn.ReLU(),
        nn.Dropout(dropout_rate),
        nn.Linear(n_units_1, n_units_2),
        nn.ReLU(),
        nn.Dropout(dropout_rate),
        nn.Linear(n_units_2, n_units_3),
        nn.ReLU(),
        nn.Linear(n_units_3, 4)  # Output layer for regression
    ).to(device)  # Move the model to the GPU

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(best_model.parameters(), lr=learning_rate)

# Train with the entire training data and evaluate on the test set
for epoch in range(50):
    optimizer.zero_grad()
    outputs = best_model(X_train_tensor_scaled).to(device)
    loss = criterion(outputs.squeeze(), y_train_tensor)
    loss.backward()
    optimizer.step()

    # Evaluate the model on test data
with torch.no_grad():
    test_mae = {}
    test_accuracy = {}
    test_outputs = best_model(X_test_tensor_scaled).to(device)
    test_outputs = torch.round(test_outputs)  # Round the predictions to the nearest integer
    tolerance_biases = [7.0, 3.0, 3.0, 3.0]
    for i, medal_type in enumerate(['Total', 'Gold', 'Silver', 'Bronze']):
      test_loss = criterion(test_outputs[:,i], y_test_tensor[:,i])
      tolerance = 0.25
      correct = torch.sum(torch.abs(test_outputs[:,i] - y_test_tensor[:,i]) <= torch.max(tolerance*y_test_tensor[:,i], torch.tensor(tolerance_biases[i]).to(device)))
      accuracy = correct.item() / len(y_test) * 100
      test_accuracy[medal_type] = accuracy
      mae = torch.mean(torch.abs(test_outputs[:,i] - y_test_tensor[:,i])).item()  # Calculate MAE
      test_mae[medal_type] = mae


print(f"Test MAE: {test_mae}")
print(f"Test Loss: {test_loss} ; Test Accuracy: {test_accuracy}%")


[I 2024-08-18 14:11:25,594] A new study created in memory with name: no-name-a045d54a-b3b0-4bb2-9a1c-d746e19859c7
[I 2024-08-18 14:11:25,702] Trial 0 finished with value: 17.52534532546997 and parameters: {'n_units_1': 1024, 'n_units_2': 154, 'n_units_3': 96, 'dropout_rate': 0.31363767518388097, 'learning_rate': 0.000356703928938053}. Best is trial 0 with value: 17.52534532546997.
[I 2024-08-18 14:11:25,795] Trial 1 finished with value: 21.184332370758057 and parameters: {'n_units_1': 670, 'n_units_2': 145, 'n_units_3': 90, 'dropout_rate': 0.290451126189584, 'learning_rate': 0.0006197938672130309}. Best is trial 0 with value: 17.52534532546997.


Epoch 1, Loss: 212.74293518066406
Epoch 2, Loss: 211.5515899658203
Epoch 3, Loss: 210.41656494140625
Epoch 4, Loss: 209.0385284423828
Epoch 5, Loss: 207.35662841796875
Epoch 6, Loss: 205.5083770751953
Epoch 7, Loss: 203.1262969970703
Epoch 8, Loss: 200.60855102539062
Epoch 9, Loss: 197.6058349609375
Epoch 10, Loss: 193.79930114746094
Epoch 11, Loss: 189.73095703125
Epoch 12, Loss: 185.3426971435547
Epoch 13, Loss: 179.6175994873047
Epoch 14, Loss: 173.4788360595703
Epoch 15, Loss: 167.0694122314453
Epoch 16, Loss: 159.98927307128906
Epoch 17, Loss: 152.24635314941406
Epoch 18, Loss: 143.54420471191406
Epoch 19, Loss: 135.39930725097656
Epoch 20, Loss: 124.47371673583984
Validation Loss: 20.80645179748535 ; Validation Accuracy: 79.72350230414746%
Epoch 1, Loss: 212.831787109375
Epoch 2, Loss: 211.3584747314453
Epoch 3, Loss: 209.78379821777344
Epoch 4, Loss: 207.9241180419922
Epoch 5, Loss: 205.386474609375
Epoch 6, Loss: 202.551513671875
Epoch 7, Loss: 198.55148315429688
Epoch 8, Loss:

[I 2024-08-18 14:11:25,917] Trial 2 finished with value: 20.165898323059082 and parameters: {'n_units_1': 859, 'n_units_2': 330, 'n_units_3': 132, 'dropout_rate': 0.45539826332309385, 'learning_rate': 0.00043399830091555577}. Best is trial 0 with value: 17.52534532546997.
[I 2024-08-18 14:11:25,995] Trial 3 finished with value: 15.552995681762695 and parameters: {'n_units_1': 260, 'n_units_2': 464, 'n_units_3': 120, 'dropout_rate': 0.3134118086960192, 'learning_rate': 0.0007916742584223209}. Best is trial 3 with value: 15.552995681762695.


Epoch 1, Loss: 212.49464416503906
Epoch 2, Loss: 210.7615966796875
Epoch 3, Loss: 208.64370727539062
Epoch 4, Loss: 206.13973999023438
Epoch 5, Loss: 202.53887939453125
Epoch 6, Loss: 198.3521728515625
Epoch 7, Loss: 193.32557678222656
Epoch 8, Loss: 186.17115783691406
Epoch 9, Loss: 178.5065460205078
Epoch 10, Loss: 169.0978546142578
Epoch 11, Loss: 158.4095458984375
Epoch 12, Loss: 146.06109619140625
Epoch 13, Loss: 133.8997802734375
Epoch 14, Loss: 120.94416809082031
Epoch 15, Loss: 110.91838836669922
Epoch 16, Loss: 102.48157501220703
Epoch 17, Loss: 98.06615447998047
Epoch 18, Loss: 91.72135162353516
Epoch 19, Loss: 89.33110809326172
Epoch 20, Loss: 85.55243682861328
Validation Loss: 23.161291122436523 ; Validation Accuracy: 72.81105990783409%
Epoch 1, Loss: 212.8304443359375
Epoch 2, Loss: 211.1222686767578
Epoch 3, Loss: 209.0164031982422
Epoch 4, Loss: 206.12779235839844
Epoch 5, Loss: 202.03750610351562
Epoch 6, Loss: 196.3264617919922
Epoch 7, Loss: 189.36045837402344
Epoch 8

[I 2024-08-18 14:11:26,079] Trial 4 finished with value: 18.470045804977417 and parameters: {'n_units_1': 266, 'n_units_2': 350, 'n_units_3': 236, 'dropout_rate': 0.34925086395673954, 'learning_rate': 0.00027374801554431056}. Best is trial 3 with value: 15.552995681762695.
[I 2024-08-18 14:11:26,189] Trial 5 finished with value: 18.82488441467285 and parameters: {'n_units_1': 715, 'n_units_2': 280, 'n_units_3': 193, 'dropout_rate': 0.2937665311594404, 'learning_rate': 0.0002093529223397296}. Best is trial 3 with value: 15.552995681762695.


Epoch 4, Loss: 210.72410583496094
Epoch 5, Loss: 209.9215850830078
Epoch 6, Loss: 209.0436248779297
Epoch 7, Loss: 207.99794006347656
Epoch 8, Loss: 206.8518524169922
Epoch 9, Loss: 205.61154174804688
Epoch 10, Loss: 203.95352172851562
Epoch 11, Loss: 202.18064880371094
Epoch 12, Loss: 199.91738891601562
Epoch 13, Loss: 197.8220977783203
Epoch 14, Loss: 195.0838623046875
Epoch 15, Loss: 192.14169311523438
Epoch 16, Loss: 188.3900604248047
Epoch 17, Loss: 184.9020538330078
Epoch 18, Loss: 181.1773681640625
Epoch 19, Loss: 176.43650817871094
Epoch 20, Loss: 170.66090393066406
Validation Loss: 31.428571701049805 ; Validation Accuracy: 83.41013824884793%
Epoch 1, Loss: 212.91366577148438
Epoch 2, Loss: 212.16134643554688
Epoch 3, Loss: 211.5099639892578
Epoch 4, Loss: 210.82974243164062
Epoch 5, Loss: 210.03048706054688
Epoch 6, Loss: 209.1823272705078
Epoch 7, Loss: 208.2123565673828
Epoch 8, Loss: 207.1808319091797
Epoch 9, Loss: 205.96522521972656
Epoch 10, Loss: 204.45289611816406
Epoc

[I 2024-08-18 14:11:26,280] Trial 6 finished with value: 21.66820240020752 and parameters: {'n_units_1': 974, 'n_units_2': 128, 'n_units_3': 144, 'dropout_rate': 0.13332668190093688, 'learning_rate': 0.00014091060294237922}. Best is trial 3 with value: 15.552995681762695.
[I 2024-08-18 14:11:26,364] Trial 7 finished with value: 17.972350120544434 and parameters: {'n_units_1': 604, 'n_units_2': 359, 'n_units_3': 149, 'dropout_rate': 0.22130326053825, 'learning_rate': 0.00026705339321250125}. Best is trial 3 with value: 15.552995681762695.


Epoch 6, Loss: 210.6103515625
Epoch 7, Loss: 210.06500244140625
Epoch 8, Loss: 209.5217742919922
Epoch 9, Loss: 208.89321899414062
Epoch 10, Loss: 208.1859130859375
Epoch 11, Loss: 207.51173400878906
Epoch 12, Loss: 206.71627807617188
Epoch 13, Loss: 205.9357147216797
Epoch 14, Loss: 205.04400634765625
Epoch 15, Loss: 204.05047607421875
Epoch 16, Loss: 202.96824645996094
Epoch 17, Loss: 201.72402954101562
Epoch 18, Loss: 200.48590087890625
Epoch 19, Loss: 199.09921264648438
Epoch 20, Loss: 197.55223083496094
Validation Loss: 45.77880096435547 ; Validation Accuracy: 100.0%
Epoch 1, Loss: 213.21652221679688
Epoch 2, Loss: 212.29605102539062
Epoch 3, Loss: 211.45562744140625
Epoch 4, Loss: 210.5421905517578
Epoch 5, Loss: 209.5197296142578
Epoch 6, Loss: 208.44744873046875
Epoch 7, Loss: 207.12049865722656
Epoch 8, Loss: 205.67002868652344
Epoch 9, Loss: 203.82635498046875
Epoch 10, Loss: 201.84024047851562
Epoch 11, Loss: 199.2354278564453
Epoch 12, Loss: 196.43862915039062
Epoch 13, Los

[I 2024-08-18 14:11:26,458] Trial 8 finished with value: 18.728111028671265 and parameters: {'n_units_1': 900, 'n_units_2': 409, 'n_units_3': 102, 'dropout_rate': 0.35450860023266795, 'learning_rate': 0.00030042085828682077}. Best is trial 3 with value: 15.552995681762695.
[I 2024-08-18 14:11:26,536] Trial 9 finished with value: 18.75115180015564 and parameters: {'n_units_1': 848, 'n_units_2': 240, 'n_units_3': 192, 'dropout_rate': 0.3027885238164379, 'learning_rate': 0.00036601188128717325}. Best is trial 3 with value: 15.552995681762695.



Epoch 14, Loss: 158.71531677246094
Epoch 15, Loss: 150.3770751953125
Epoch 16, Loss: 140.43316650390625
Epoch 17, Loss: 131.1533966064453
Epoch 18, Loss: 122.45682525634766
Epoch 19, Loss: 113.03240966796875
Epoch 20, Loss: 105.4325180053711
Validation Loss: 20.354839324951172 ; Validation Accuracy: 73.73271889400922%
Epoch 1, Loss: 212.8175506591797
Epoch 2, Loss: 211.6623992919922
Epoch 3, Loss: 210.3639678955078
Epoch 4, Loss: 208.86827087402344
Epoch 5, Loss: 206.93243408203125
Epoch 6, Loss: 204.71592712402344
Epoch 7, Loss: 201.9177703857422
Epoch 8, Loss: 198.57960510253906
Epoch 9, Loss: 194.54806518554688
Epoch 10, Loss: 189.8844451904297
Epoch 11, Loss: 184.20741271972656
Epoch 12, Loss: 177.6982421875
Epoch 13, Loss: 169.78733825683594
Epoch 14, Loss: 161.18150329589844
Epoch 15, Loss: 152.2978515625
Epoch 16, Loss: 140.50851440429688
Epoch 17, Loss: 129.31886291503906
Epoch 18, Loss: 120.5545883178711
Epoch 19, Loss: 110.96122741699219
Epoch 20, Loss: 100.96894073486328
Va

[I 2024-08-18 14:11:26,703] Trial 10 finished with value: 14.396313428878784 and parameters: {'n_units_1': 284, 'n_units_2': 508, 'n_units_3': 70, 'dropout_rate': 0.4956254424604487, 'learning_rate': 0.0009751946713824293}. Best is trial 10 with value: 14.396313428878784.
[I 2024-08-18 14:11:26,811] Trial 11 finished with value: 14.981566667556763 and parameters: {'n_units_1': 273, 'n_units_2': 510, 'n_units_3': 64, 'dropout_rate': 0.49264037806369326, 'learning_rate': 0.0009974956742627166}. Best is trial 10 with value: 14.396313428878784.


Epoch 5, Loss: 198.6174774169922
Epoch 6, Loss: 190.2308349609375
Epoch 7, Loss: 179.95436096191406
Epoch 8, Loss: 165.3687744140625
Epoch 9, Loss: 147.95089721679688
Epoch 10, Loss: 131.27330017089844
Epoch 11, Loss: 113.33067321777344
Epoch 12, Loss: 101.97590637207031
Epoch 13, Loss: 96.15216064453125
Epoch 14, Loss: 94.64784240722656
Epoch 15, Loss: 86.38350677490234
Epoch 16, Loss: 81.0654296875
Epoch 17, Loss: 74.96206665039062
Epoch 18, Loss: 67.19020080566406
Epoch 19, Loss: 64.80589294433594
Epoch 20, Loss: 65.50674438476562
Validation Loss: 15.036866188049316 ; Validation Accuracy: 81.5668202764977%
Epoch 1, Loss: 211.8452911376953
Epoch 2, Loss: 209.80880737304688
Epoch 3, Loss: 206.6742706298828
Epoch 4, Loss: 202.1737823486328
Epoch 5, Loss: 195.60446166992188
Epoch 6, Loss: 187.07228088378906
Epoch 7, Loss: 175.70547485351562
Epoch 8, Loss: 162.02548217773438
Epoch 9, Loss: 145.55416870117188
Epoch 10, Loss: 127.78980255126953
Epoch 11, Loss: 113.07624816894531
Epoch 12, 

[I 2024-08-18 14:11:26,922] Trial 12 finished with value: 14.456221103668213 and parameters: {'n_units_1': 412, 'n_units_2': 512, 'n_units_3': 74, 'dropout_rate': 0.4851129079510874, 'learning_rate': 0.0009980551171937303}. Best is trial 10 with value: 14.396313428878784.
[I 2024-08-18 14:11:27,037] Trial 13 finished with value: 14.147465467453003 and parameters: {'n_units_1': 419, 'n_units_2': 510, 'n_units_3': 70, 'dropout_rate': 0.42415673608036986, 'learning_rate': 0.0009971962434612467}. Best is trial 13 with value: 14.147465467453003.


Epoch 1, Loss: 212.7864990234375
Epoch 2, Loss: 209.89077758789062
Epoch 3, Loss: 205.7225341796875
Epoch 4, Loss: 199.5088653564453
Epoch 5, Loss: 190.092529296875
Epoch 6, Loss: 177.69969177246094
Epoch 7, Loss: 163.1193084716797
Epoch 8, Loss: 142.34609985351562
Epoch 9, Loss: 125.08699035644531
Epoch 10, Loss: 106.91529083251953
Epoch 11, Loss: 98.28841400146484
Epoch 12, Loss: 90.62796783447266
Epoch 13, Loss: 87.57866668701172
Epoch 14, Loss: 77.8576431274414
Epoch 15, Loss: 70.66056060791016
Epoch 16, Loss: 68.98995208740234
Epoch 17, Loss: 64.70460510253906
Epoch 18, Loss: 63.25703430175781
Epoch 19, Loss: 65.17350006103516
Epoch 20, Loss: 65.02239227294922
Validation Loss: 29.838708877563477 ; Validation Accuracy: 92.62672811059907%
Epoch 1, Loss: 211.88299560546875
Epoch 2, Loss: 209.50889587402344
Epoch 3, Loss: 206.5425567626953
Epoch 4, Loss: 201.71412658691406
Epoch 5, Loss: 193.9304962158203
Epoch 6, Loss: 183.4215850830078
Epoch 7, Loss: 168.59881591796875
Epoch 8, Loss

[I 2024-08-18 14:11:27,161] Trial 14 finished with value: 14.400922060012817 and parameters: {'n_units_1': 435, 'n_units_2': 420, 'n_units_3': 184, 'dropout_rate': 0.42073630652977745, 'learning_rate': 0.0008188824823263361}. Best is trial 13 with value: 14.147465467453003.
[I 2024-08-18 14:11:27,266] Trial 15 finished with value: 14.359447240829468 and parameters: {'n_units_1': 429, 'n_units_2': 446, 'n_units_3': 108, 'dropout_rate': 0.41229976061533635, 'learning_rate': 0.0008234176078282421}. Best is trial 13 with value: 14.147465467453003.


Epoch 1, Loss: 212.47146606445312
Epoch 2, Loss: 209.3510284423828
Epoch 3, Loss: 205.09674072265625
Epoch 4, Loss: 199.1723175048828
Epoch 5, Loss: 189.81121826171875
Epoch 6, Loss: 177.67723083496094
Epoch 7, Loss: 161.59532165527344
Epoch 8, Loss: 142.09512329101562
Epoch 9, Loss: 121.51049041748047
Epoch 10, Loss: 99.6280746459961
Epoch 11, Loss: 90.21898651123047
Epoch 12, Loss: 92.01500701904297
Epoch 13, Loss: 89.20384979248047
Epoch 14, Loss: 85.0226058959961
Epoch 15, Loss: 75.92076110839844
Epoch 16, Loss: 64.5692367553711
Epoch 17, Loss: 62.83222961425781
Epoch 18, Loss: 60.14883041381836
Epoch 19, Loss: 61.04107666015625
Epoch 20, Loss: 62.91022872924805
Validation Loss: 18.502304077148438 ; Validation Accuracy: 84.33179723502305%
Epoch 1, Loss: 213.1354217529297
Epoch 2, Loss: 210.62344360351562
Epoch 3, Loss: 207.53721618652344
Epoch 4, Loss: 202.88516235351562
Epoch 5, Loss: 196.41543579101562
Epoch 6, Loss: 187.54833984375
Epoch 7, Loss: 176.28817749023438
Epoch 8, Loss

[I 2024-08-18 14:11:27,355] Trial 16 finished with value: 15.470046281814575 and parameters: {'n_units_1': 509, 'n_units_2': 430, 'n_units_3': 114, 'dropout_rate': 0.4051124253948556, 'learning_rate': 0.0006172805280082667}. Best is trial 13 with value: 14.147465467453003.
[I 2024-08-18 14:11:27,443] Trial 17 finished with value: 14.359447002410889 and parameters: {'n_units_1': 390, 'n_units_2': 458, 'n_units_3': 165, 'dropout_rate': 0.3997599004008595, 'learning_rate': 0.0008342733789077604}. Best is trial 13 with value: 14.147465467453003.


Epoch 1, Loss: 211.93243408203125
Epoch 2, Loss: 209.64479064941406
Epoch 3, Loss: 206.78286743164062
Epoch 4, Loss: 202.9533233642578
Epoch 5, Loss: 198.08688354492188
Epoch 6, Loss: 190.9298553466797
Epoch 7, Loss: 182.5912628173828
Epoch 8, Loss: 171.91111755371094
Epoch 9, Loss: 158.53964233398438
Epoch 10, Loss: 143.58853149414062
Epoch 11, Loss: 129.0955352783203
Epoch 12, Loss: 114.25471496582031
Epoch 13, Loss: 101.8437271118164
Epoch 14, Loss: 95.41665649414062
Epoch 15, Loss: 92.45362854003906
Epoch 16, Loss: 88.85808563232422
Epoch 17, Loss: 85.64212036132812
Epoch 18, Loss: 82.98191833496094
Epoch 19, Loss: 75.98548889160156
Epoch 20, Loss: 68.65899658203125
Validation Loss: 14.737327575683594 ; Validation Accuracy: 79.26267281105991%
Epoch 1, Loss: 211.9543914794922
Epoch 2, Loss: 209.737548828125
Epoch 3, Loss: 206.76902770996094
Epoch 4, Loss: 202.34214782714844
Epoch 5, Loss: 196.10861206054688
Epoch 6, Loss: 186.6477508544922
Epoch 7, Loss: 174.98355102539062
Epoch 8, 

[I 2024-08-18 14:11:27,528] Trial 18 finished with value: 14.414746284484863 and parameters: {'n_units_1': 547, 'n_units_2': 374, 'n_units_3': 172, 'dropout_rate': 0.3811977937085277, 'learning_rate': 0.0006907444090303945}. Best is trial 13 with value: 14.147465467453003.
[I 2024-08-18 14:11:27,624] Trial 19 finished with value: 13.95852541923523 and parameters: {'n_units_1': 361, 'n_units_2': 470, 'n_units_3': 226, 'dropout_rate': 0.24803673638495743, 'learning_rate': 0.0009007061982876949}. Best is trial 19 with value: 13.95852541923523.


Epoch 14, Loss: 91.29462432861328
Epoch 15, Loss: 89.72115325927734
Epoch 16, Loss: 85.94625091552734
Epoch 17, Loss: 76.44585418701172
Epoch 18, Loss: 71.62676239013672
Epoch 19, Loss: 66.03553771972656
Epoch 20, Loss: 62.76532745361328
Validation Loss: 15.617511749267578 ; Validation Accuracy: 79.26267281105991%
Epoch 1, Loss: 211.74087524414062
Epoch 2, Loss: 208.18251037597656
Epoch 3, Loss: 203.1285858154297
Epoch 4, Loss: 195.16522216796875
Epoch 5, Loss: 183.93194580078125
Epoch 6, Loss: 167.89569091796875
Epoch 7, Loss: 147.00851440429688
Epoch 8, Loss: 125.0538101196289
Epoch 9, Loss: 108.6663589477539
Epoch 10, Loss: 101.48119354248047
Epoch 11, Loss: 95.0735855102539
Epoch 12, Loss: 86.80359649658203
Epoch 13, Loss: 78.03977966308594
Epoch 14, Loss: 70.05602264404297
Epoch 15, Loss: 66.22843170166016
Epoch 16, Loss: 63.31341552734375
Epoch 17, Loss: 64.55101013183594
Epoch 18, Loss: 62.613189697265625
Epoch 19, Loss: 62.45875930786133
Epoch 20, Loss: 62.39643096923828
Valida

[I 2024-08-18 14:11:27,712] Trial 20 finished with value: 14.39170527458191 and parameters: {'n_units_1': 345, 'n_units_2': 204, 'n_units_3': 250, 'dropout_rate': 0.2306984055002, 'learning_rate': 0.0009078406313141639}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:27,796] Trial 21 finished with value: 14.253456115722656 and parameters: {'n_units_1': 362, 'n_units_2': 472, 'n_units_3': 215, 'dropout_rate': 0.23116022238771608, 'learning_rate': 0.0008734715975781915}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:27,882] Trial 22 finished with value: 14.511520862579346 and parameters: {'n_units_1': 509, 'n_units_2': 473, 'n_units_3': 217, 'dropout_rate': 0.2276095641758032, 'learning_rate': 0.0007068569914760426}. Best is trial 19 with value: 13.95852541923523.


Epoch 1, Loss: 212.77655029296875
Epoch 2, Loss: 210.36383056640625
Epoch 3, Loss: 207.27252197265625
Epoch 4, Loss: 202.43385314941406
Epoch 5, Loss: 195.21958923339844
Epoch 6, Loss: 184.52987670898438
Epoch 7, Loss: 170.32574462890625
Epoch 8, Loss: 152.1372833251953
Epoch 9, Loss: 131.84521484375
Epoch 10, Loss: 111.8714828491211
Epoch 11, Loss: 98.34375
Epoch 12, Loss: 96.39141845703125
Epoch 13, Loss: 91.94165802001953
Epoch 14, Loss: 86.46856689453125
Epoch 15, Loss: 78.55831146240234
Epoch 16, Loss: 70.07406616210938
Epoch 17, Loss: 64.00228118896484
Epoch 18, Loss: 62.46696090698242
Epoch 19, Loss: 62.70053482055664
Epoch 20, Loss: 63.00360107421875
Validation Loss: 18.97235107421875 ; Validation Accuracy: 85.71428571428571%
Epoch 1, Loss: 212.67984008789062
Epoch 2, Loss: 210.2757110595703
Epoch 3, Loss: 207.246337890625
Epoch 4, Loss: 202.76023864746094
Epoch 5, Loss: 196.4838409423828
Epoch 6, Loss: 187.1421661376953
Epoch 7, Loss: 175.06101989746094
Epoch 8, Loss: 159.5839

[I 2024-08-18 14:11:27,967] Trial 23 finished with value: 14.216590166091919 and parameters: {'n_units_1': 381, 'n_units_2': 396, 'n_units_3': 216, 'dropout_rate': 0.18277685713060826, 'learning_rate': 0.0009054003926146037}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:28,051] Trial 24 finished with value: 14.25806450843811 and parameters: {'n_units_1': 473, 'n_units_2': 399, 'n_units_3': 213, 'dropout_rate': 0.15024126897578838, 'learning_rate': 0.0007329864145870157}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:28,133] Trial 25 finished with value: 14.414746284484863 and parameters: {'n_units_1': 336, 'n_units_2': 386, 'n_units_3': 235, 'dropout_rate': 0.18529106799587672, 'learning_rate': 0.0009123650417369608}. Best is trial 19 with value: 13.95852541923523.


Epoch 15, Loss: 75.45838165283203
Epoch 16, Loss: 64.63739013671875
Epoch 17, Loss: 62.33631896972656
Epoch 18, Loss: 61.70127487182617
Epoch 19, Loss: 62.51118850708008
Epoch 20, Loss: 62.94084548950195
Validation Loss: 15.921658515930176 ; Validation Accuracy: 84.33179723502305%
Epoch 1, Loss: 212.5330810546875
Epoch 2, Loss: 210.4488525390625
Epoch 3, Loss: 207.88433837890625
Epoch 4, Loss: 204.10250854492188
Epoch 5, Loss: 198.6437530517578
Epoch 6, Loss: 190.7854766845703
Epoch 7, Loss: 180.08616638183594
Epoch 8, Loss: 166.26446533203125
Epoch 9, Loss: 149.28253173828125
Epoch 10, Loss: 130.2224884033203
Epoch 11, Loss: 111.5500259399414
Epoch 12, Loss: 101.26497650146484
Epoch 13, Loss: 95.07647705078125
Epoch 14, Loss: 91.806396484375
Epoch 15, Loss: 87.4474105834961
Epoch 16, Loss: 79.76441192626953
Epoch 17, Loss: 72.09469604492188
Epoch 18, Loss: 66.18570709228516
Epoch 19, Loss: 62.90017318725586
Epoch 20, Loss: 62.97138977050781
Validation Loss: 16.244239807128906 ; Valida

[I 2024-08-18 14:11:28,232] Trial 26 finished with value: 23.686635971069336 and parameters: {'n_units_1': 575, 'n_units_2': 304, 'n_units_3': 253, 'dropout_rate': 0.2611885582891183, 'learning_rate': 2.8751075472526974e-05}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:28,326] Trial 27 finished with value: 17.479262351989746 and parameters: {'n_units_1': 749, 'n_units_2': 481, 'n_units_3': 203, 'dropout_rate': 0.1658455521396557, 'learning_rate': 0.0005082000950475097}. Best is trial 19 with value: 13.95852541923523.


Epoch 1, Loss: 212.76356506347656
Epoch 2, Loss: 212.63502502441406
Epoch 3, Loss: 212.5174560546875
Epoch 4, Loss: 212.449951171875
Epoch 5, Loss: 212.3651580810547
Epoch 6, Loss: 212.25303649902344
Epoch 7, Loss: 212.15829467773438
Epoch 8, Loss: 212.05142211914062
Epoch 9, Loss: 211.9498748779297
Epoch 10, Loss: 211.86378479003906
Epoch 11, Loss: 211.77542114257812
Epoch 12, Loss: 211.66738891601562
Epoch 13, Loss: 211.54318237304688
Epoch 14, Loss: 211.44821166992188
Epoch 15, Loss: 211.3372802734375
Epoch 16, Loss: 211.25192260742188
Epoch 17, Loss: 211.15455627441406
Epoch 18, Loss: 211.023193359375
Epoch 19, Loss: 210.94332885742188
Epoch 20, Loss: 210.81210327148438
Validation Loss: 45.77880096435547 ; Validation Accuracy: 100.0%
Epoch 1, Loss: 213.2725372314453
Epoch 2, Loss: 211.14419555664062
Epoch 3, Loss: 209.00643920898438
Epoch 4, Loss: 206.18162536621094
Epoch 5, Loss: 202.48846435546875
Epoch 6, Loss: 197.52804565429688
Epoch 7, Loss: 190.88247680664062
Epoch 8, Loss: 

[I 2024-08-18 14:11:28,412] Trial 28 finished with value: 14.419354677200317 and parameters: {'n_units_1': 327, 'n_units_2': 436, 'n_units_3': 236, 'dropout_rate': 0.10463856215996875, 'learning_rate': 0.0009164447646238859}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:28,499] Trial 29 finished with value: 15.027649402618408 and parameters: {'n_units_1': 469, 'n_units_2': 497, 'n_units_3': 180, 'dropout_rate': 0.1973422282195312, 'learning_rate': 0.000589362942755398}. Best is trial 19 with value: 13.95852541923523.


Epoch 11, Loss: 95.07563781738281
Epoch 12, Loss: 91.51986694335938
Epoch 13, Loss: 85.21248626708984
Epoch 14, Loss: 77.05103302001953
Epoch 15, Loss: 68.71134948730469
Epoch 16, Loss: 62.95893859863281
Epoch 17, Loss: 62.05506896972656
Epoch 18, Loss: 63.22092819213867
Epoch 19, Loss: 64.88758850097656
Epoch 20, Loss: 63.980369567871094
Validation Loss: 20.156681060791016 ; Validation Accuracy: 88.47926267281106%
Epoch 1, Loss: 212.1226348876953
Epoch 2, Loss: 209.67430114746094
Epoch 3, Loss: 206.78565979003906
Epoch 4, Loss: 202.73379516601562
Epoch 5, Loss: 197.20164489746094
Epoch 6, Loss: 190.28759765625
Epoch 7, Loss: 181.02027893066406
Epoch 8, Loss: 169.2366180419922
Epoch 9, Loss: 155.38055419921875
Epoch 10, Loss: 139.4013214111328
Epoch 11, Loss: 123.17325592041016
Epoch 12, Loss: 107.85258483886719
Epoch 13, Loss: 97.28031158447266
Epoch 14, Loss: 90.16173553466797
Epoch 15, Loss: 87.61430358886719
Epoch 16, Loss: 83.75940704345703
Epoch 17, Loss: 82.30138397216797
Epoch 

[I 2024-08-18 14:11:28,599] Trial 30 finished with value: 13.972350120544434 and parameters: {'n_units_1': 634, 'n_units_2': 438, 'n_units_3': 86, 'dropout_rate': 0.26346453001986725, 'learning_rate': 0.0007611564478193606}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:28,698] Trial 31 finished with value: 14.354838609695435 and parameters: {'n_units_1': 638, 'n_units_2': 443, 'n_units_3': 88, 'dropout_rate': 0.26415214111949037, 'learning_rate': 0.000762290677076859}. Best is trial 19 with value: 13.95852541923523.


Epoch 20, Loss: 61.9443244934082
Validation Loss: 16.71889305114746 ; Validation Accuracy: 83.87096774193549%
Epoch 1, Loss: 212.2660675048828
Epoch 2, Loss: 209.82135009765625
Epoch 3, Loss: 206.66815185546875
Epoch 4, Loss: 202.1748809814453
Epoch 5, Loss: 195.66065979003906
Epoch 6, Loss: 186.5078125
Epoch 7, Loss: 174.76121520996094
Epoch 8, Loss: 159.8308868408203
Epoch 9, Loss: 142.45046997070312
Epoch 10, Loss: 124.3428955078125
Epoch 11, Loss: 107.42793273925781
Epoch 12, Loss: 97.99378204345703
Epoch 13, Loss: 91.38274383544922
Epoch 14, Loss: 87.30152893066406
Epoch 15, Loss: 85.58343505859375
Epoch 16, Loss: 80.10243225097656
Epoch 17, Loss: 71.03823852539062
Epoch 18, Loss: 67.77578735351562
Epoch 19, Loss: 62.98238754272461
Epoch 20, Loss: 59.780155181884766
Validation Loss: 15.967741966247559 ; Validation Accuracy: 80.64516129032258%
Epoch 1, Loss: 213.038818359375
Epoch 2, Loss: 209.84414672851562
Epoch 3, Loss: 205.3153076171875
Epoch 4, Loss: 198.3917694091797
Epoch 5,

[I 2024-08-18 14:11:28,798] Trial 32 finished with value: 14.322580575942993 and parameters: {'n_units_1': 766, 'n_units_2': 485, 'n_units_3': 94, 'dropout_rate': 0.2632214640701826, 'learning_rate': 0.0009142428519767143}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:28,894] Trial 33 finished with value: 13.995391845703125 and parameters: {'n_units_1': 381, 'n_units_2': 411, 'n_units_3': 83, 'dropout_rate': 0.19857683573354426, 'learning_rate': 0.0009508698408811372}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:28,983] Trial 34 finished with value: 14.465437889099121 and parameters: {'n_units_1': 473, 'n_units_2': 424, 'n_units_3': 81, 'dropout_rate': 0.34035318004100423, 'learning_rate': 0.0009574614074849303}. Best is trial 19 with value: 13.95852541923523.


Validation Loss: 15.304147720336914 ; Validation Accuracy: 81.10599078341014%
Epoch 1, Loss: 212.87986755371094
Epoch 2, Loss: 210.6635284423828
Epoch 3, Loss: 207.8802490234375
Epoch 4, Loss: 203.6399688720703
Epoch 5, Loss: 197.59335327148438
Epoch 6, Loss: 188.67662048339844
Epoch 7, Loss: 176.84136962890625
Epoch 8, Loss: 162.1090850830078
Epoch 9, Loss: 144.2797393798828
Epoch 10, Loss: 127.0758285522461
Epoch 11, Loss: 111.16792297363281
Epoch 12, Loss: 102.28707122802734
Epoch 13, Loss: 97.51448822021484
Epoch 14, Loss: 91.7049331665039
Epoch 15, Loss: 85.75067901611328
Epoch 16, Loss: 74.76663208007812
Epoch 17, Loss: 71.19005584716797
Epoch 18, Loss: 66.81936645507812
Epoch 19, Loss: 62.22528839111328
Epoch 20, Loss: 60.367774963378906
Validation Loss: 15.99078369140625 ; Validation Accuracy: 82.94930875576037%
Epoch 1, Loss: 212.58212280273438
Epoch 2, Loss: 209.2946319580078
Epoch 3, Loss: 204.76268005371094
Epoch 4, Loss: 198.20953369140625
Epoch 5, Loss: 188.41525268554688

[I 2024-08-18 14:11:29,080] Trial 35 finished with value: 14.626728296279907 and parameters: {'n_units_1': 669, 'n_units_2': 453, 'n_units_3': 77, 'dropout_rate': 0.2808948833867958, 'learning_rate': 0.0008506142904341496}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:29,168] Trial 36 finished with value: 18.903226137161255 and parameters: {'n_units_1': 302, 'n_units_2': 330, 'n_units_3': 127, 'dropout_rate': 0.3259536385790911, 'learning_rate': 0.0006652575639045725}. Best is trial 19 with value: 13.95852541923523.


Epoch 1, Loss: 213.46609497070312
Epoch 2, Loss: 210.45278930664062
Epoch 3, Loss: 206.43507385253906
Epoch 4, Loss: 200.3511962890625
Epoch 5, Loss: 191.7523956298828
Epoch 6, Loss: 179.49710083007812
Epoch 7, Loss: 163.42552185058594
Epoch 8, Loss: 145.11085510253906
Epoch 9, Loss: 126.09058380126953
Epoch 10, Loss: 110.93995666503906
Epoch 11, Loss: 100.24018096923828
Epoch 12, Loss: 95.251953125
Epoch 13, Loss: 88.54349517822266
Epoch 14, Loss: 80.84398651123047
Epoch 15, Loss: 73.46001434326172
Epoch 16, Loss: 66.38574981689453
Epoch 17, Loss: 62.81071853637695
Epoch 18, Loss: 62.59130859375
Epoch 19, Loss: 62.2843017578125
Epoch 20, Loss: 64.98912048339844
Validation Loss: 20.85714340209961 ; Validation Accuracy: 88.47926267281106%
Epoch 1, Loss: 212.55751037597656
Epoch 2, Loss: 211.08189392089844
Epoch 3, Loss: 209.39352416992188
Epoch 4, Loss: 207.10861206054688
Epoch 5, Loss: 204.2917022705078
Epoch 6, Loss: 200.48855590820312
Epoch 7, Loss: 195.2782440185547
Epoch 8, Loss: 1

[I 2024-08-18 14:11:29,271] Trial 37 finished with value: 14.373272180557251 and parameters: {'n_units_1': 522, 'n_units_2': 484, 'n_units_3': 92, 'dropout_rate': 0.20744047414022457, 'learning_rate': 0.0007682831190117444}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:29,362] Trial 38 finished with value: 17.76497745513916 and parameters: {'n_units_1': 596, 'n_units_2': 345, 'n_units_3': 130, 'dropout_rate': 0.24370591681251974, 'learning_rate': 0.0005589630544623311}. Best is trial 19 with value: 13.95852541923523.


Epoch 5, Loss: 194.01129150390625
Epoch 6, Loss: 184.14205932617188
Epoch 7, Loss: 171.11024475097656
Epoch 8, Loss: 154.53985595703125
Epoch 9, Loss: 137.4119415283203
Epoch 10, Loss: 118.70339965820312
Epoch 11, Loss: 105.80501556396484
Epoch 12, Loss: 95.50927734375
Epoch 13, Loss: 90.43806457519531
Epoch 14, Loss: 86.30628204345703
Epoch 15, Loss: 80.556640625
Epoch 16, Loss: 73.05962371826172
Epoch 17, Loss: 67.86811065673828
Epoch 18, Loss: 63.522499084472656
Epoch 19, Loss: 61.99660873413086
Epoch 20, Loss: 62.33322525024414
Validation Loss: 19.474655151367188 ; Validation Accuracy: 77.88018433179722%
Epoch 1, Loss: 212.36375427246094
Epoch 2, Loss: 210.5669708251953
Epoch 3, Loss: 208.5631103515625
Epoch 4, Loss: 205.77235412597656
Epoch 5, Loss: 202.1884765625
Epoch 6, Loss: 197.28848266601562
Epoch 7, Loss: 191.09146118164062
Epoch 8, Loss: 182.72457885742188
Epoch 9, Loss: 173.29185485839844
Epoch 10, Loss: 161.75778198242188
Epoch 11, Loss: 148.11837768554688
Epoch 12, Loss

[I 2024-08-18 14:11:29,453] Trial 39 finished with value: 14.705068826675415 and parameters: {'n_units_1': 415, 'n_units_2': 293, 'n_units_3': 153, 'dropout_rate': 0.4531952742788031, 'learning_rate': 0.0009502737749043862}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:29,539] Trial 40 finished with value: 14.161290407180786 and parameters: {'n_units_1': 307, 'n_units_2': 370, 'n_units_3': 105, 'dropout_rate': 0.3723608968143012, 'learning_rate': 0.0008676023054400408}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:29,631] Trial 41 finished with value: 15.456221342086792 and parameters: {'n_units_1': 331, 'n_units_2': 371, 'n_units_3': 84, 'dropout_rate': 0.3761882144916038, 'learning_rate': 0.0007902162001629932}. Best is trial 19 with value: 13.95852541923523.


Epoch 14, Loss: 87.6595230102539
Epoch 15, Loss: 78.92527770996094
Epoch 16, Loss: 68.77085876464844
Epoch 17, Loss: 65.6250991821289
Epoch 18, Loss: 63.84963607788086
Epoch 19, Loss: 62.07233428955078
Epoch 20, Loss: 64.58714294433594
Validation Loss: 19.502304077148438 ; Validation Accuracy: 84.33179723502305%
Epoch 1, Loss: 211.78662109375
Epoch 2, Loss: 209.58250427246094
Epoch 3, Loss: 206.53663635253906
Epoch 4, Loss: 202.53085327148438
Epoch 5, Loss: 196.6744384765625
Epoch 6, Loss: 188.68902587890625
Epoch 7, Loss: 178.4856719970703
Epoch 8, Loss: 165.96951293945312
Epoch 9, Loss: 150.73611450195312
Epoch 10, Loss: 132.972412109375
Epoch 11, Loss: 117.30387115478516
Epoch 12, Loss: 102.29467010498047
Epoch 13, Loss: 94.24044036865234
Epoch 14, Loss: 93.49393463134766
Epoch 15, Loss: 88.52960205078125
Epoch 16, Loss: 86.19097137451172
Epoch 17, Loss: 77.34986877441406
Epoch 18, Loss: 71.33174133300781
Epoch 19, Loss: 67.5628433227539
Epoch 20, Loss: 65.7533950805664
Validation L

[I 2024-08-18 14:11:29,743] Trial 42 finished with value: 14.414746284484863 and parameters: {'n_units_1': 1007, 'n_units_2': 411, 'n_units_3': 102, 'dropout_rate': 0.28311732592849154, 'learning_rate': 0.0008668789586693002}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:29,829] Trial 43 finished with value: 14.110599279403687 and parameters: {'n_units_1': 313, 'n_units_2': 463, 'n_units_3': 66, 'dropout_rate': 0.33133434793925187, 'learning_rate': 0.0009653669627000918}. Best is trial 19 with value: 13.95852541923523.


Epoch 1, Loss: 213.06173706054688
Epoch 2, Loss: 209.3726806640625
Epoch 3, Loss: 203.64772033691406
Epoch 4, Loss: 194.6932830810547
Epoch 5, Loss: 181.96188354492188
Epoch 6, Loss: 164.05099487304688
Epoch 7, Loss: 142.60699462890625
Epoch 8, Loss: 120.15232849121094
Epoch 9, Loss: 103.2945785522461
Epoch 10, Loss: 96.66576385498047
Epoch 11, Loss: 93.20664978027344
Epoch 12, Loss: 86.53985595703125
Epoch 13, Loss: 78.77870178222656
Epoch 14, Loss: 70.5789566040039
Epoch 15, Loss: 62.803367614746094
Epoch 16, Loss: 59.883174896240234
Epoch 17, Loss: 62.97063446044922
Epoch 18, Loss: 62.83709716796875
Epoch 19, Loss: 63.5906867980957
Epoch 20, Loss: 63.55143356323242
Validation Loss: 17.562211990356445 ; Validation Accuracy: 84.33179723502305%
Epoch 1, Loss: 211.9345703125
Epoch 2, Loss: 209.77926635742188
Epoch 3, Loss: 206.5873565673828
Epoch 4, Loss: 201.83465576171875
Epoch 5, Loss: 195.01023864746094
Epoch 6, Loss: 185.94406127929688
Epoch 7, Loss: 174.21258544921875
Epoch 8, Los

[I 2024-08-18 14:11:29,917] Trial 44 finished with value: 14.152074098587036 and parameters: {'n_units_1': 376, 'n_units_2': 460, 'n_units_3': 67, 'dropout_rate': 0.31579702435443735, 'learning_rate': 0.0009555159904684474}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:30,017] Trial 45 finished with value: 14.18433165550232 and parameters: {'n_units_1': 454, 'n_units_2': 495, 'n_units_3': 64, 'dropout_rate': 0.44623537947396574, 'learning_rate': 0.0009995584398604458}. Best is trial 19 with value: 13.95852541923523.


Epoch 7, Loss: 169.35446166992188
Epoch 8, Loss: 153.0353240966797
Epoch 9, Loss: 134.2012176513672
Epoch 10, Loss: 115.50457763671875
Epoch 11, Loss: 100.2251968383789
Epoch 12, Loss: 88.90670013427734
Epoch 13, Loss: 88.96238708496094
Epoch 14, Loss: 85.37958526611328
Epoch 15, Loss: 81.547119140625
Epoch 16, Loss: 74.68029022216797
Epoch 17, Loss: 68.10894012451172
Epoch 18, Loss: 62.708770751953125
Epoch 19, Loss: 61.3343391418457
Epoch 20, Loss: 61.830596923828125
Validation Loss: 15.193548202514648 ; Validation Accuracy: 81.5668202764977%
Epoch 1, Loss: 212.3166961669922
Epoch 2, Loss: 209.57101440429688
Epoch 3, Loss: 205.60214233398438
Epoch 4, Loss: 199.8649139404297
Epoch 5, Loss: 191.2379150390625
Epoch 6, Loss: 179.7279510498047
Epoch 7, Loss: 165.10191345214844
Epoch 8, Loss: 147.38357543945312
Epoch 9, Loss: 127.49555206298828
Epoch 10, Loss: 108.59539794921875
Epoch 11, Loss: 98.74280548095703
Epoch 12, Loss: 96.89393615722656
Epoch 13, Loss: 90.89569091796875
Epoch 14, 

[I 2024-08-18 14:11:30,112] Trial 46 finished with value: 17.96774196624756 and parameters: {'n_units_1': 268, 'n_units_2': 465, 'n_units_3': 75, 'dropout_rate': 0.2522982894335175, 'learning_rate': 0.000399122925374557}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:30,230] Trial 47 finished with value: 14.423963069915771 and parameters: {'n_units_1': 870, 'n_units_2': 438, 'n_units_3': 86, 'dropout_rate': 0.2990249153009773, 'learning_rate': 0.0009358447116853913}. Best is trial 19 with value: 13.95852541923523.


Epoch 12, Loss: 199.7741241455078
Epoch 13, Loss: 197.1249237060547
Epoch 14, Loss: 193.96115112304688
Epoch 15, Loss: 189.72994995117188
Epoch 16, Loss: 185.41921997070312
Epoch 17, Loss: 180.1576690673828
Epoch 18, Loss: 174.61964416503906
Epoch 19, Loss: 167.62710571289062
Epoch 20, Loss: 161.41265869140625
Validation Loss: 27.44239616394043 ; Validation Accuracy: 83.41013824884793%
Epoch 1, Loss: 213.1959686279297
Epoch 2, Loss: 209.2243194580078
Epoch 3, Loss: 203.6632537841797
Epoch 4, Loss: 194.86940002441406
Epoch 5, Loss: 182.1730499267578
Epoch 6, Loss: 164.5141143798828
Epoch 7, Loss: 143.46266174316406
Epoch 8, Loss: 122.51454162597656
Epoch 9, Loss: 105.34613800048828
Epoch 10, Loss: 97.87371063232422
Epoch 11, Loss: 90.57041931152344
Epoch 12, Loss: 82.23995208740234
Epoch 13, Loss: 73.57503509521484
Epoch 14, Loss: 67.61985778808594
Epoch 15, Loss: 61.54829406738281
Epoch 16, Loss: 59.84855270385742
Epoch 17, Loss: 61.65462112426758
Epoch 18, Loss: 64.342529296875
Epoch 

[I 2024-08-18 14:11:30,342] Trial 48 finished with value: 14.281105995178223 and parameters: {'n_units_1': 790, 'n_units_2': 496, 'n_units_3': 119, 'dropout_rate': 0.47545115541130756, 'learning_rate': 0.0007966718872323274}. Best is trial 19 with value: 13.95852541923523.
[I 2024-08-18 14:11:30,431] Trial 49 finished with value: 14.0506911277771 and parameters: {'n_units_1': 404, 'n_units_2': 410, 'n_units_3': 100, 'dropout_rate': 0.2138963707028985, 'learning_rate': 0.0008778996500271026}. Best is trial 19 with value: 13.95852541923523.


Epoch 6, Loss: 162.1790008544922
Epoch 7, Loss: 141.11546325683594
Epoch 8, Loss: 118.0960922241211
Epoch 9, Loss: 102.34577941894531
Epoch 10, Loss: 94.1323013305664
Epoch 11, Loss: 91.29647827148438
Epoch 12, Loss: 91.6722183227539
Epoch 13, Loss: 80.49636840820312
Epoch 14, Loss: 75.15314483642578
Epoch 15, Loss: 63.76881790161133
Epoch 16, Loss: 62.77628707885742
Epoch 17, Loss: 62.520408630371094
Epoch 18, Loss: 61.8804817199707
Epoch 19, Loss: 61.99791717529297
Epoch 20, Loss: 62.967918395996094
Validation Loss: 18.594470977783203 ; Validation Accuracy: 86.17511520737328%
Epoch 1, Loss: 211.54888916015625
Epoch 2, Loss: 209.1127471923828
Epoch 3, Loss: 205.9221649169922
Epoch 4, Loss: 201.35009765625
Epoch 5, Loss: 194.5211944580078
Epoch 6, Loss: 184.9001007080078
Epoch 7, Loss: 172.1671600341797
Epoch 8, Loss: 156.40187072753906
Epoch 9, Loss: 137.91390991210938
Epoch 10, Loss: 117.22349548339844
Epoch 11, Loss: 101.85609436035156
Epoch 12, Loss: 92.280029296875
Epoch 13, Loss:

Let's see what the results are. We already have predictions for each medal separately and total medals in the previous cell. Pretty accurate, right? Wait till you see when happens in the second test case.
P.S. if we take the separate medals and sum them up, we'll do a bit better:

In [182]:
medal_manual_sum = torch.sum(test_outputs[:,1:], dim=1)
sum_correct = torch.sum(torch.abs(medal_manual_sum - y_test_tensor[:,0]) <= torch.max(tolerance*y_test_tensor[:,0], torch.tensor(7.0).to(device)))
sum_accuracy = sum_correct.item() / len(y_test) * 100

print(f"Test Accuracy for Manual Medal Total: {sum_accuracy}%")

Test Accuracy for Manual Medal Total: 77.14285714285715%


No let's see the predicted table!

In [181]:
test_countries = np.array(test_countries)  # Convert to a numpy array if it's not already

# Create a DataFrame from the test predictions
predictions_df = pd.DataFrame({
    'Country': test_countries,
    'Predicted Total': test_outputs[:, 0].cpu().numpy(),
    'Predicted Gold': test_outputs[:, 1].cpu().numpy(),
    'Predicted Silver': test_outputs[:, 2].cpu().numpy(),
    'Predicted Bronze': test_outputs[:, 3].cpu().numpy(),
    'Total Predicted': medal_manual_sum.cpu().numpy()
})

predictions_df

Unnamed: 0,Country,Predicted Total,Predicted Gold,Predicted Silver,Predicted Bronze,Total Predicted
0,USA,91.0,38.0,29.0,30.0,97.0
1,CHN,65.0,24.0,21.0,21.0,66.0
2,JPN,20.0,7.0,7.0,7.0,21.0
3,AUS,21.0,6.0,7.0,7.0,20.0
4,FRA,30.0,10.0,10.0,11.0,31.0
...,...,...,...,...,...,...
65,VEN,2.0,0.0,0.0,1.0,1.0
66,FIN,12.0,3.0,4.0,5.0,12.0
67,EST,4.0,1.0,1.0,1.0,3.0
68,LAT,3.0,1.0,1.0,1.0,3.0


And for reference, the actual table(added at the end of the Olympics):

In [174]:
actual_df = pd.DataFrame({
    'Country': test_countries,
    'Actual Total': y_test['Total'].values,
    'Actual Gold': y_test['Gold'].values,
    'Actual Silver': y_test['Silver'].values,
    'Actual Bronze': y_test['Bronze'].values
})

actual_df

Unnamed: 0,Country,Actual Total,Actual Gold,Actual Silver,Actual Bronze
0,USA,126.0,40.0,44.0,42.0
1,CHN,91.0,40.0,27.0,24.0
2,JPN,45.0,20.0,12.0,13.0
3,AUS,53.0,18.0,19.0,16.0
4,FRA,64.0,16.0,26.0,22.0
...,...,...,...,...,...
65,VEN,-1.0,-1.0,-1.0,-1.0
66,FIN,-1.0,-1.0,-1.0,-1.0
67,EST,-1.0,-1.0,-1.0,-1.0
68,LAT,-1.0,-1.0,-1.0,-1.0


Ok, now for the second test case. Here we didn't filter anything out, just looked at all the participating countries in Paris.

In [175]:
# Test case 2:

# Splitting train and validation data
X_train_sub_2, X_val_2, y_train_sub_2, y_val_2 = train_test_split(X_train_tensor_scaled_2, y_train_tensor_2, test_size=0.2, random_state=42)

def objective(trial):
    # Define the hyperparameter search space
    n_units_1 = trial.suggest_int('n_units_1', 256, 1024)
    n_units_2 = trial.suggest_int('n_units_2', 128, 512)
    n_units_3 = trial.suggest_int('n_units_3', 64, 256)
    dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3)

    # Building the model with trial hyperparameters
    model = nn.Sequential(
        nn.Linear(X_train_sub.shape[1], n_units_1),
        nn.ReLU(),
        nn.Dropout(dropout_rate),
        nn.Linear(n_units_1, n_units_2),
        nn.ReLU(),
        nn.Dropout(dropout_rate),
        nn.Linear(n_units_2, n_units_3),
        nn.ReLU(),
        nn.Linear(n_units_3, 4)  # Output layer for regression
    ).to(device)  # Move the model to the GPU

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model
    for epoch in range(20):
        optimizer.zero_grad()
        outputs = model(X_train_sub_2).to(device)
        loss = criterion(outputs.squeeze(), y_train_sub_2)
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")
        loss.backward()
        optimizer.step()

    # Evaluate the model on validation data
    with torch.no_grad():
        val_outputs = model(X_val_2).to(device)
        val_outputs = torch.round(val_outputs)  # Round the predictions to the nearest integer
        val_mae = {}
        val_accuracy = {}
        tolerance = 0.2
        tolerance_biases = [7.0, 3.0, 3.0, 3.0]
        for i, medal_type in enumerate(['Total', 'Gold', 'Silver', 'Bronze']):
          mae = torch.mean(torch.abs(val_outputs[:,i] - y_val_2[:,i])).item()
          val_mae[medal_type] = mae
          correct = torch.sum(torch.abs(val_outputs[:,i] - y_val_2[:,i]) <= torch.tensor(tolerance_biases[i]).to(device)*y_val_2[:,i])
          accuracy = correct.item() / len(y_val_2) * 100
          val_accuracy[medal_type] = accuracy
          val_loss = criterion(val_outputs[:,i], y_val_2[:,i])
        print(f"Validation Loss: {val_loss} ; Validation Accuracy: {accuracy}%")

    return sum(val_mae.values())  # Return the validation MAE as a Python numbe

# Create an Optuna study and optimize the objective function
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=50)

# Get the best hyperparameters
best_params = study.best_params
print("Best hyperparameters: ", best_params)

# Build and evaluate the model with the best hyperparameters
n_units_1 = best_params['n_units_1']
n_units_2 = best_params['n_units_2']
n_units_3 = best_params['n_units_3']
dropout_rate = best_params['dropout_rate']
learning_rate = best_params['learning_rate']

    # Building the model with trial hyperparameters
best_model = nn.Sequential(
        nn.Linear(X_train_tensor_scaled_2.shape[1], n_units_1),
        nn.ReLU(),
        nn.Dropout(dropout_rate),
        nn.Linear(n_units_1, n_units_2),
        nn.ReLU(),
        nn.Dropout(dropout_rate),
        nn.Linear(n_units_2, n_units_3),
        nn.ReLU(),
        nn.Linear(n_units_3, 4)  # Output layer for regression
    ).to(device)  # Move the model to the GPU

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(best_model.parameters(), lr=learning_rate)

# Train with the entire training data and evaluate on the test set
for epoch in range(50):
    optimizer.zero_grad()
    outputs = best_model(X_train_tensor_scaled_2).to(device)
    loss = criterion(outputs.squeeze(), y_train_tensor_2)
    loss.backward()
    optimizer.step()

    # Evaluate the model on test data
with torch.no_grad():
    test_mae = {}
    test_accuracy = {}
    test_outputs_2 = best_model(X_test_tensor_scaled_2).to(device)
    test_outputs_2 = torch.round(test_outputs_2)  # Round the predictions to the nearest integer
    tolerance_biases = [7.0, 3.0, 3.0, 3.0]
    for i, medal_type in enumerate(['Total', 'Gold', 'Silver', 'Bronze']):
      test_loss = criterion(test_outputs_2[:,i], y_test_tensor_2[:,i])
      tolerance = 0.25
      correct = torch.sum(torch.abs(test_outputs_2[:,i] - y_test_tensor_2[:,i]) <= torch.max(tolerance*y_test_tensor_2[:,i], torch.tensor(tolerance_biases[i]).to(device)))
      accuracy = correct.item() / len(y_test_2) * 100
      test_accuracy[medal_type] = accuracy
      mae = torch.mean(torch.abs(test_outputs_2[:,i] - y_test_tensor_2[:,i])).item()  # Calculate MAE
      test_mae[medal_type] = mae


print(f"Test MAE: {test_mae}")
print(f"Test Loss: {test_loss} ; Test Accuracy: {test_accuracy}%")


[I 2024-08-18 14:12:12,477] A new study created in memory with name: no-name-c13fad0a-d6c5-46c7-b4a4-d52b954762ed
[I 2024-08-18 14:12:12,580] Trial 0 finished with value: 16.421052932739258 and parameters: {'n_units_1': 1015, 'n_units_2': 140, 'n_units_3': 101, 'dropout_rate': 0.36670868475270235, 'learning_rate': 0.0006684051592370293}. Best is trial 0 with value: 16.421052932739258.
[I 2024-08-18 14:12:12,663] Trial 1 finished with value: 15.736842393875122 and parameters: {'n_units_1': 317, 'n_units_2': 418, 'n_units_3': 229, 'dropout_rate': 0.4962210913959211, 'learning_rate': 0.0006479387719418275}. Best is trial 1 with value: 15.736842393875122.


Epoch 1, Loss: 147.78903198242188
Epoch 2, Loss: 146.39125061035156
Epoch 3, Loss: 144.77560424804688
Epoch 4, Loss: 142.39364624023438
Epoch 5, Loss: 139.35635375976562
Epoch 6, Loss: 135.76986694335938
Epoch 7, Loss: 130.5080108642578
Epoch 8, Loss: 124.65779876708984
Epoch 9, Loss: 116.93080139160156
Epoch 10, Loss: 109.33895874023438
Epoch 11, Loss: 99.15070343017578
Epoch 12, Loss: 89.80205535888672
Epoch 13, Loss: 81.64229583740234
Epoch 14, Loss: 71.92452239990234
Epoch 15, Loss: 65.83824157714844
Epoch 16, Loss: 62.65933609008789
Epoch 17, Loss: 62.86573791503906
Epoch 18, Loss: 59.81515884399414
Epoch 19, Loss: 59.43532180786133
Epoch 20, Loss: 53.21694564819336
Validation Loss: 25.169174194335938 ; Validation Accuracy: 69.54887218045113%
Epoch 1, Loss: 148.0503692626953
Epoch 2, Loss: 146.51731872558594
Epoch 3, Loss: 144.69802856445312
Epoch 4, Loss: 142.26890563964844
Epoch 5, Loss: 139.1923828125
Epoch 6, Loss: 134.78089904785156
Epoch 7, Loss: 128.78787231445312
Epoch 8, 

[I 2024-08-18 14:12:12,767] Trial 2 finished with value: 14.838346242904663 and parameters: {'n_units_1': 687, 'n_units_2': 442, 'n_units_3': 174, 'dropout_rate': 0.23276395158032703, 'learning_rate': 0.0007289771309258667}. Best is trial 2 with value: 14.838346242904663.
[I 2024-08-18 14:12:12,855] Trial 3 finished with value: 18.552631855010986 and parameters: {'n_units_1': 807, 'n_units_2': 382, 'n_units_3': 105, 'dropout_rate': 0.468257031217563, 'learning_rate': 0.0002948451024404624}. Best is trial 2 with value: 14.838346242904663.


Epoch 5, Loss: 129.93296813964844
Epoch 6, Loss: 120.19876861572266
Epoch 7, Loss: 107.9389419555664
Epoch 8, Loss: 93.4400634765625
Epoch 9, Loss: 79.29341125488281
Epoch 10, Loss: 70.45770263671875
Epoch 11, Loss: 66.07178497314453
Epoch 12, Loss: 64.57382202148438
Epoch 13, Loss: 61.35271453857422
Epoch 14, Loss: 55.069480895996094
Epoch 15, Loss: 49.28738784790039
Epoch 16, Loss: 45.09371566772461
Epoch 17, Loss: 42.62403869628906
Epoch 18, Loss: 42.96051788330078
Epoch 19, Loss: 42.59315872192383
Epoch 20, Loss: 43.59547424316406
Validation Loss: 35.962406158447266 ; Validation Accuracy: 83.83458646616542%
Epoch 1, Loss: 147.12730407714844
Epoch 2, Loss: 146.41366577148438
Epoch 3, Loss: 145.4665985107422
Epoch 4, Loss: 144.45498657226562
Epoch 5, Loss: 143.0809326171875
Epoch 6, Loss: 141.7294921875
Epoch 7, Loss: 139.9495849609375
Epoch 8, Loss: 137.9005889892578
Epoch 9, Loss: 135.71836853027344
Epoch 10, Loss: 132.6471405029297
Epoch 11, Loss: 129.40948486328125
Epoch 12, Loss

[I 2024-08-18 14:12:12,951] Trial 4 finished with value: 16.00000023841858 and parameters: {'n_units_1': 655, 'n_units_2': 428, 'n_units_3': 152, 'dropout_rate': 0.36858282066503445, 'learning_rate': 0.0005540531514784397}. Best is trial 2 with value: 14.838346242904663.
[I 2024-08-18 14:12:13,026] Trial 5 finished with value: 24.36842179298401 and parameters: {'n_units_1': 560, 'n_units_2': 201, 'n_units_3': 178, 'dropout_rate': 0.44656311304742746, 'learning_rate': 6.89951708234466e-05}. Best is trial 2 with value: 14.838346242904663.


Epoch 10, Loss: 103.21979522705078
Epoch 11, Loss: 92.45537567138672
Epoch 12, Loss: 81.32451629638672
Epoch 13, Loss: 72.10151672363281
Epoch 14, Loss: 67.771240234375
Epoch 15, Loss: 64.0843505859375
Epoch 16, Loss: 62.90137481689453
Epoch 17, Loss: 59.95355987548828
Epoch 18, Loss: 55.7463264465332
Epoch 19, Loss: 52.640193939208984
Epoch 20, Loss: 49.11627960205078
Validation Loss: 30.819549560546875 ; Validation Accuracy: 74.06015037593986%
Epoch 1, Loss: 148.00930786132812
Epoch 2, Loss: 147.8647003173828
Epoch 3, Loss: 147.77774047851562
Epoch 4, Loss: 147.65660095214844
Epoch 5, Loss: 147.56886291503906
Epoch 6, Loss: 147.41969299316406
Epoch 7, Loss: 147.3175048828125
Epoch 8, Loss: 147.1971893310547
Epoch 9, Loss: 147.07666015625
Epoch 10, Loss: 146.9386749267578
Epoch 11, Loss: 146.80918884277344
Epoch 12, Loss: 146.69305419921875
Epoch 13, Loss: 146.57327270507812
Epoch 14, Loss: 146.45025634765625
Epoch 15, Loss: 146.2923126220703
Epoch 16, Loss: 146.15469360351562
Epoch 1

[I 2024-08-18 14:12:13,112] Trial 6 finished with value: 24.383459329605103 and parameters: {'n_units_1': 674, 'n_units_2': 319, 'n_units_3': 86, 'dropout_rate': 0.26747830971057696, 'learning_rate': 1.3952876457417355e-05}. Best is trial 2 with value: 14.838346242904663.
[I 2024-08-18 14:12:13,185] Trial 7 finished with value: 19.14661717414856 and parameters: {'n_units_1': 680, 'n_units_2': 205, 'n_units_3': 147, 'dropout_rate': 0.1923419164158986, 'learning_rate': 0.00024608188899460866}. Best is trial 2 with value: 14.838346242904663.
[I 2024-08-18 14:12:13,266] Trial 8 finished with value: 15.1015043258667 and parameters: {'n_units_1': 505, 'n_units_2': 400, 'n_units_3': 103, 'dropout_rate': 0.2987338633194476, 'learning_rate': 0.0007277757276345989}. Best is trial 2 with value: 14.838346242904663.


Epoch 19, Loss: 146.6913299560547
Epoch 20, Loss: 146.67047119140625
Validation Loss: 73.86090087890625 ; Validation Accuracy: 95.48872180451127%
Epoch 1, Loss: 147.61363220214844
Epoch 2, Loss: 146.9125213623047
Epoch 3, Loss: 146.1700897216797
Epoch 4, Loss: 145.4260711669922
Epoch 5, Loss: 144.5520782470703
Epoch 6, Loss: 143.6094970703125
Epoch 7, Loss: 142.48678588867188
Epoch 8, Loss: 141.25303649902344
Epoch 9, Loss: 139.92474365234375
Epoch 10, Loss: 138.33070373535156
Epoch 11, Loss: 136.4904022216797
Epoch 12, Loss: 134.30413818359375
Epoch 13, Loss: 132.11419677734375
Epoch 14, Loss: 129.91229248046875
Epoch 15, Loss: 126.47628784179688
Epoch 16, Loss: 123.39689636230469
Epoch 17, Loss: 120.17086791992188
Epoch 18, Loss: 116.25092315673828
Epoch 19, Loss: 112.21837615966797
Epoch 20, Loss: 107.31991577148438
Validation Loss: 52.763160705566406 ; Validation Accuracy: 76.31578947368422%
Epoch 1, Loss: 148.32150268554688
Epoch 2, Loss: 146.8542022705078
Epoch 3, Loss: 145.10664

[I 2024-08-18 14:12:13,350] Trial 9 finished with value: 15.071428775787354 and parameters: {'n_units_1': 694, 'n_units_2': 310, 'n_units_3': 123, 'dropout_rate': 0.22057998513385935, 'learning_rate': 0.0008492813681009352}. Best is trial 2 with value: 14.838346242904663.
[I 2024-08-18 14:12:13,480] Trial 10 finished with value: 14.161654710769653 and parameters: {'n_units_1': 894, 'n_units_2': 502, 'n_units_3': 205, 'dropout_rate': 0.1086127423744821, 'learning_rate': 0.0009627442684438241}. Best is trial 10 with value: 14.161654710769653.


Epoch 8, Loss: 102.36227416992188
Epoch 9, Loss: 88.23419952392578
Epoch 10, Loss: 77.51371002197266
Epoch 11, Loss: 70.79444122314453
Epoch 12, Loss: 65.03842163085938
Epoch 13, Loss: 63.39413070678711
Epoch 14, Loss: 58.2054443359375
Epoch 15, Loss: 52.89054870605469
Epoch 16, Loss: 48.67753601074219
Epoch 17, Loss: 45.32958984375
Epoch 18, Loss: 43.143470764160156
Epoch 19, Loss: 43.873138427734375
Epoch 20, Loss: 44.72441101074219
Validation Loss: 44.203006744384766 ; Validation Accuracy: 89.09774436090225%
Epoch 1, Loss: 148.04486083984375
Epoch 2, Loss: 144.18435668945312
Epoch 3, Loss: 138.193359375
Epoch 4, Loss: 128.14669799804688
Epoch 5, Loss: 113.05709075927734
Epoch 6, Loss: 94.75874328613281
Epoch 7, Loss: 79.53266143798828
Epoch 8, Loss: 72.30135345458984
Epoch 9, Loss: 65.9863510131836
Epoch 10, Loss: 57.28475570678711
Epoch 11, Loss: 50.49626541137695
Epoch 12, Loss: 46.090614318847656
Epoch 13, Loss: 44.9447021484375
Epoch 14, Loss: 45.2288818359375
Epoch 15, Loss: 45

[I 2024-08-18 14:12:13,620] Trial 11 finished with value: 13.966165781021118 and parameters: {'n_units_1': 926, 'n_units_2': 510, 'n_units_3': 209, 'dropout_rate': 0.10129696624762358, 'learning_rate': 0.0009883728249526539}. Best is trial 11 with value: 13.966165781021118.
[I 2024-08-18 14:12:13,726] Trial 12 finished with value: 13.951128244400024 and parameters: {'n_units_1': 982, 'n_units_2': 511, 'n_units_3': 229, 'dropout_rate': 0.11490857372775848, 'learning_rate': 0.0009932300437282496}. Best is trial 12 with value: 13.951128244400024.


Epoch 1, Loss: 148.01119995117188
Epoch 2, Loss: 144.196533203125
Epoch 3, Loss: 137.85215759277344
Epoch 4, Loss: 126.66635131835938
Epoch 5, Loss: 110.37275695800781
Epoch 6, Loss: 90.89007568359375
Epoch 7, Loss: 77.81260681152344
Epoch 8, Loss: 72.1558609008789
Epoch 9, Loss: 65.13439178466797
Epoch 10, Loss: 55.55926513671875
Epoch 11, Loss: 48.808876037597656
Epoch 12, Loss: 44.98732376098633
Epoch 13, Loss: 44.865013122558594
Epoch 14, Loss: 44.68840408325195
Epoch 15, Loss: 43.747039794921875
Epoch 16, Loss: 42.490234375
Epoch 17, Loss: 41.402591705322266
Epoch 18, Loss: 40.050010681152344
Epoch 19, Loss: 39.98873519897461
Epoch 20, Loss: 40.629302978515625
Validation Loss: 24.7593994140625 ; Validation Accuracy: 81.57894736842105%
Epoch 1, Loss: 147.4447784423828
Epoch 2, Loss: 142.53549194335938
Epoch 3, Loss: 134.13790893554688
Epoch 4, Loss: 120.06047821044922
Epoch 5, Loss: 100.4213638305664
Epoch 6, Loss: 81.02799224853516
Epoch 7, Loss: 73.39625549316406
Epoch 8, Loss: 6

[I 2024-08-18 14:12:13,837] Trial 13 finished with value: 14.000000715255737 and parameters: {'n_units_1': 1021, 'n_units_2': 511, 'n_units_3': 253, 'dropout_rate': 0.10153152686357197, 'learning_rate': 0.0009688707615924732}. Best is trial 12 with value: 13.951128244400024.
[I 2024-08-18 14:12:13,933] Trial 14 finished with value: 13.872180461883545 and parameters: {'n_units_1': 882, 'n_units_2': 466, 'n_units_3': 214, 'dropout_rate': 0.1623735011060553, 'learning_rate': 0.0008547259487968325}. Best is trial 14 with value: 13.872180461883545.


Epoch 1, Loss: 148.07760620117188
Epoch 2, Loss: 143.86936950683594
Epoch 3, Loss: 136.9318084716797
Epoch 4, Loss: 125.12273406982422
Epoch 5, Loss: 107.60895538330078
Epoch 6, Loss: 87.43134307861328
Epoch 7, Loss: 73.66793060302734
Epoch 8, Loss: 70.01732635498047
Epoch 9, Loss: 65.24364471435547
Epoch 10, Loss: 55.43312454223633
Epoch 11, Loss: 47.64844512939453
Epoch 12, Loss: 43.91887664794922
Epoch 13, Loss: 43.76187515258789
Epoch 14, Loss: 43.9118537902832
Epoch 15, Loss: 43.46816635131836
Epoch 16, Loss: 42.09058380126953
Epoch 17, Loss: 41.05534362792969
Epoch 18, Loss: 39.66781234741211
Epoch 19, Loss: 39.847511291503906
Epoch 20, Loss: 40.08685302734375
Validation Loss: 24.827068328857422 ; Validation Accuracy: 81.57894736842105%
Epoch 1, Loss: 147.57424926757812
Epoch 2, Loss: 144.23960876464844
Epoch 3, Loss: 139.3333282470703
Epoch 4, Loss: 131.34339904785156
Epoch 5, Loss: 119.48899841308594
Epoch 6, Loss: 103.98470306396484
Epoch 7, Loss: 87.44599151611328
Epoch 8, Lo

[I 2024-08-18 14:12:14,035] Trial 15 finished with value: 13.879699468612671 and parameters: {'n_units_1': 825, 'n_units_2': 464, 'n_units_3': 250, 'dropout_rate': 0.16708600908754037, 'learning_rate': 0.0008363140350683673}. Best is trial 14 with value: 13.872180461883545.
[I 2024-08-18 14:12:14,127] Trial 16 finished with value: 14.244361162185669 and parameters: {'n_units_1': 816, 'n_units_2': 361, 'n_units_3': 251, 'dropout_rate': 0.17279362801881132, 'learning_rate': 0.0008105953719741311}. Best is trial 14 with value: 13.872180461883545.


Epoch 2, Loss: 145.45884704589844
Epoch 3, Loss: 141.09742736816406
Epoch 4, Loss: 134.09579467773438
Epoch 5, Loss: 123.42449951171875
Epoch 6, Loss: 108.70048522949219
Epoch 7, Loss: 91.63978576660156
Epoch 8, Loss: 77.1233901977539
Epoch 9, Loss: 71.18567657470703
Epoch 10, Loss: 67.80721282958984
Epoch 11, Loss: 62.82144546508789
Epoch 12, Loss: 55.21339797973633
Epoch 13, Loss: 47.97222900390625
Epoch 14, Loss: 44.72732162475586
Epoch 15, Loss: 43.93876647949219
Epoch 16, Loss: 44.0439338684082
Epoch 17, Loss: 43.931190490722656
Epoch 18, Loss: 42.69289779663086
Epoch 19, Loss: 41.258445739746094
Epoch 20, Loss: 40.684383392333984
Validation Loss: 28.654136657714844 ; Validation Accuracy: 83.0827067669173%
Epoch 1, Loss: 148.22837829589844
Epoch 2, Loss: 145.79385375976562
Epoch 3, Loss: 142.39039611816406
Epoch 4, Loss: 137.10293579101562
Epoch 5, Loss: 129.0554656982422
Epoch 6, Loss: 117.70999908447266
Epoch 7, Loss: 103.77977752685547
Epoch 8, Loss: 89.05138397216797
Epoch 9, 

[I 2024-08-18 14:12:14,228] Trial 17 finished with value: 16.28195548057556 and parameters: {'n_units_1': 816, 'n_units_2': 464, 'n_units_3': 202, 'dropout_rate': 0.16749029207942095, 'learning_rate': 0.0004715445095841147}. Best is trial 14 with value: 13.872180461883545.
[I 2024-08-18 14:12:14,319] Trial 18 finished with value: 17.477444171905518 and parameters: {'n_units_1': 884, 'n_units_2': 274, 'n_units_3': 230, 'dropout_rate': 0.15685663212232048, 'learning_rate': 0.0004876826575576539}. Best is trial 14 with value: 13.872180461883545.


Epoch 3, Loss: 144.44076538085938
Epoch 4, Loss: 142.08953857421875
Epoch 5, Loss: 138.9835662841797
Epoch 6, Loss: 134.824951171875
Epoch 7, Loss: 129.56468200683594
Epoch 8, Loss: 122.87124633789062
Epoch 9, Loss: 115.06188201904297
Epoch 10, Loss: 105.64018249511719
Epoch 11, Loss: 95.59222412109375
Epoch 12, Loss: 85.95671081542969
Epoch 13, Loss: 77.93692779541016
Epoch 14, Loss: 72.2890625
Epoch 15, Loss: 68.27706146240234
Epoch 16, Loss: 65.44750213623047
Epoch 17, Loss: 61.68050765991211
Epoch 18, Loss: 57.63917541503906
Epoch 19, Loss: 53.949729919433594
Epoch 20, Loss: 49.48959732055664
Validation Loss: 29.2030086517334 ; Validation Accuracy: 74.43609022556392%
Epoch 1, Loss: 147.9691619873047
Epoch 2, Loss: 146.67181396484375
Epoch 3, Loss: 145.22691345214844
Epoch 4, Loss: 143.2511444091797
Epoch 5, Loss: 140.8135986328125
Epoch 6, Loss: 137.47103881835938
Epoch 7, Loss: 133.19772338867188
Epoch 8, Loss: 127.8003921508789
Epoch 9, Loss: 121.49998474121094
Epoch 10, Loss: 11

[I 2024-08-18 14:12:14,408] Trial 19 finished with value: 14.327068090438843 and parameters: {'n_units_1': 764, 'n_units_2': 361, 'n_units_3': 192, 'dropout_rate': 0.2548161107531169, 'learning_rate': 0.000858069637028237}. Best is trial 14 with value: 13.872180461883545.
[I 2024-08-18 14:12:14,497] Trial 20 finished with value: 20.27819585800171 and parameters: {'n_units_1': 455, 'n_units_2': 457, 'n_units_3': 256, 'dropout_rate': 0.31865295676357847, 'learning_rate': 0.00034656664832386446}. Best is trial 14 with value: 13.872180461883545.


Epoch 12, Loss: 57.916175842285156
Epoch 13, Loss: 53.199092864990234
Epoch 14, Loss: 46.84600830078125
Epoch 15, Loss: 45.287662506103516
Epoch 16, Loss: 43.06166076660156
Epoch 17, Loss: 44.42790603637695
Epoch 18, Loss: 43.514034271240234
Epoch 19, Loss: 43.91461181640625
Epoch 20, Loss: 41.96297073364258
Validation Loss: 26.676692962646484 ; Validation Accuracy: 81.203007518797%
Epoch 1, Loss: 147.66647338867188
Epoch 2, Loss: 146.65562438964844
Epoch 3, Loss: 145.61166381835938
Epoch 4, Loss: 144.29188537597656
Epoch 5, Loss: 142.71697998046875
Epoch 6, Loss: 140.716796875
Epoch 7, Loss: 138.3569793701172
Epoch 8, Loss: 135.5038604736328
Epoch 9, Loss: 131.64523315429688
Epoch 10, Loss: 127.57972717285156
Epoch 11, Loss: 122.19180297851562
Epoch 12, Loss: 116.42692565917969
Epoch 13, Loss: 109.67227172851562
Epoch 14, Loss: 102.40560150146484
Epoch 15, Loss: 94.78589630126953
Epoch 16, Loss: 87.06889343261719
Epoch 17, Loss: 78.76969909667969
Epoch 18, Loss: 73.0162582397461
Epoch

[I 2024-08-18 14:12:14,608] Trial 21 finished with value: 14.082707166671753 and parameters: {'n_units_1': 949, 'n_units_2': 474, 'n_units_3': 229, 'dropout_rate': 0.14373058429673635, 'learning_rate': 0.0008966922977092587}. Best is trial 14 with value: 13.872180461883545.
[I 2024-08-18 14:12:14,710] Trial 22 finished with value: 13.939850091934204 and parameters: {'n_units_1': 951, 'n_units_2': 478, 'n_units_3': 233, 'dropout_rate': 0.1375171263952551, 'learning_rate': 0.0007673567006505344}. Best is trial 14 with value: 13.872180461883545.


Epoch 16, Loss: 42.86384582519531
Epoch 17, Loss: 41.19762420654297
Epoch 18, Loss: 40.191856384277344
Epoch 19, Loss: 40.070274353027344
Epoch 20, Loss: 40.701236724853516
Validation Loss: 25.12782096862793 ; Validation Accuracy: 81.57894736842105%
Epoch 1, Loss: 147.6602783203125
Epoch 2, Loss: 144.49720764160156
Epoch 3, Loss: 139.87364196777344
Epoch 4, Loss: 132.5473175048828
Epoch 5, Loss: 121.46294403076172
Epoch 6, Loss: 107.02953338623047
Epoch 7, Loss: 90.87438201904297
Epoch 8, Loss: 76.83445739746094
Epoch 9, Loss: 71.06444549560547
Epoch 10, Loss: 67.8106918334961
Epoch 11, Loss: 62.279911041259766
Epoch 12, Loss: 55.56439208984375
Epoch 13, Loss: 47.798118591308594
Epoch 14, Loss: 44.535621643066406
Epoch 15, Loss: 43.458221435546875
Epoch 16, Loss: 43.57900619506836
Epoch 17, Loss: 43.37825393676758
Epoch 18, Loss: 43.165706634521484
Epoch 19, Loss: 41.78892135620117
Epoch 20, Loss: 40.763858795166016
Validation Loss: 29.037593841552734 ; Validation Accuracy: 83.45864661

[I 2024-08-18 14:12:14,811] Trial 23 finished with value: 13.845865249633789 and parameters: {'n_units_1': 886, 'n_units_2': 473, 'n_units_3': 218, 'dropout_rate': 0.2056877773141831, 'learning_rate': 0.0007662057725372583}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:14,905] Trial 24 finished with value: 15.180451154708862 and parameters: {'n_units_1': 767, 'n_units_2': 406, 'n_units_3': 213, 'dropout_rate': 0.20710981358994485, 'learning_rate': 0.0006161707931511348}. Best is trial 23 with value: 13.845865249633789.


Epoch 16, Loss: 44.45021438598633
Epoch 17, Loss: 44.60921096801758
Epoch 18, Loss: 43.14231491088867
Epoch 19, Loss: 42.070526123046875
Epoch 20, Loss: 40.50347137451172
Validation Loss: 25.796993255615234 ; Validation Accuracy: 81.57894736842105%
Epoch 1, Loss: 147.98109436035156
Epoch 2, Loss: 146.18423461914062
Epoch 3, Loss: 144.09825134277344
Epoch 4, Loss: 141.0268096923828
Epoch 5, Loss: 136.7238006591797
Epoch 6, Loss: 130.838623046875
Epoch 7, Loss: 123.00880432128906
Epoch 8, Loss: 113.08414459228516
Epoch 9, Loss: 101.25338745117188
Epoch 10, Loss: 89.56535339355469
Epoch 11, Loss: 79.40513610839844
Epoch 12, Loss: 72.43618774414062
Epoch 13, Loss: 68.70188903808594
Epoch 14, Loss: 65.40442657470703
Epoch 15, Loss: 60.42800521850586
Epoch 16, Loss: 56.27895736694336
Epoch 17, Loss: 52.116424560546875
Epoch 18, Loss: 46.774810791015625
Epoch 19, Loss: 44.71263122558594
Epoch 20, Loss: 44.2366943359375
Validation Loss: 37.34586715698242 ; Validation Accuracy: 81.203007518797%

[I 2024-08-18 14:12:15,005] Trial 25 finished with value: 14.624060153961182 and parameters: {'n_units_1': 872, 'n_units_2': 437, 'n_units_3': 64, 'dropout_rate': 0.1910955153860648, 'learning_rate': 0.0008985179557822178}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:15,111] Trial 26 finished with value: 14.229323625564575 and parameters: {'n_units_1': 753, 'n_units_2': 479, 'n_units_3': 184, 'dropout_rate': 0.23535956650302461, 'learning_rate': 0.0007901857288784087}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:15,200] Trial 27 finished with value: 16.165413856506348 and parameters: {'n_units_1': 600, 'n_units_2': 362, 'n_units_3': 242, 'dropout_rate': 0.28278759695735645, 'learning_rate': 0.0005621039694565741}. Best is trial 23 with value: 13.845865249633789.


Validation Loss: 33.462406158447266 ; Validation Accuracy: 81.95488721804512%
Epoch 1, Loss: 147.8393096923828
Epoch 2, Loss: 145.08282470703125
Epoch 3, Loss: 141.0802459716797
Epoch 4, Loss: 134.81182861328125
Epoch 5, Loss: 125.73936462402344
Epoch 6, Loss: 113.61502838134766
Epoch 7, Loss: 98.67695617675781
Epoch 8, Loss: 82.9126968383789
Epoch 9, Loss: 72.18383026123047
Epoch 10, Loss: 68.8981704711914
Epoch 11, Loss: 66.15701293945312
Epoch 12, Loss: 61.998252868652344
Epoch 13, Loss: 55.769287109375
Epoch 14, Loss: 49.451663970947266
Epoch 15, Loss: 45.4793815612793
Epoch 16, Loss: 43.831748962402344
Epoch 17, Loss: 44.282684326171875
Epoch 18, Loss: 44.1623649597168
Epoch 19, Loss: 43.26701736450195
Epoch 20, Loss: 42.757667541503906
Validation Loss: 28.62782096862793 ; Validation Accuracy: 80.45112781954887%
Epoch 1, Loss: 148.49923706054688
Epoch 2, Loss: 147.0184783935547
Epoch 3, Loss: 145.4735565185547
Epoch 4, Loss: 143.3728485107422
Epoch 5, Loss: 140.5697784423828
Epoch

[I 2024-08-18 14:12:15,293] Trial 28 finished with value: 15.007519245147705 and parameters: {'n_units_1': 844, 'n_units_2': 286, 'n_units_3': 168, 'dropout_rate': 0.3377381901695693, 'learning_rate': 0.0006902703579235714}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:15,376] Trial 29 finished with value: 15.515037775039673 and parameters: {'n_units_1': 268, 'n_units_2': 140, 'n_units_3': 218, 'dropout_rate': 0.39044882403731623, 'learning_rate': 0.0009126021509710084}. Best is trial 23 with value: 13.845865249633789.


Epoch 1, Loss: 147.91940307617188
Epoch 2, Loss: 146.1460418701172
Epoch 3, Loss: 143.97433471679688
Epoch 4, Loss: 140.7979736328125
Epoch 5, Loss: 136.18475341796875
Epoch 6, Loss: 130.29214477539062
Epoch 7, Loss: 122.1753921508789
Epoch 8, Loss: 112.27299499511719
Epoch 9, Loss: 100.80813598632812
Epoch 10, Loss: 87.7558364868164
Epoch 11, Loss: 76.91508483886719
Epoch 12, Loss: 68.23738098144531
Epoch 13, Loss: 63.730411529541016
Epoch 14, Loss: 63.886436462402344
Epoch 15, Loss: 62.76936340332031
Epoch 16, Loss: 59.11530303955078
Epoch 17, Loss: 52.52214050292969
Epoch 18, Loss: 47.34000778198242
Epoch 19, Loss: 46.0941047668457
Epoch 20, Loss: 43.020408630371094
Validation Loss: 26.154136657714844 ; Validation Accuracy: 74.43609022556392%
Epoch 1, Loss: 147.69296264648438
Epoch 2, Loss: 146.23646545410156
Epoch 3, Loss: 144.43203735351562
Epoch 4, Loss: 141.89077758789062
Epoch 5, Loss: 138.5266571044922
Epoch 6, Loss: 134.20016479492188
Epoch 7, Loss: 128.40341186523438
Epoch 8

[I 2024-08-18 14:12:15,467] Trial 30 finished with value: 14.19172978401184 and parameters: {'n_units_1': 741, 'n_units_2': 388, 'n_units_3': 194, 'dropout_rate': 0.18657599479999676, 'learning_rate': 0.000822757613494526}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:15,574] Trial 31 finished with value: 13.969925165176392 and parameters: {'n_units_1': 961, 'n_units_2': 471, 'n_units_3': 239, 'dropout_rate': 0.13790742354430632, 'learning_rate': 0.0007567854166880308}. Best is trial 23 with value: 13.845865249633789.


Epoch 12, Loss: 57.17771911621094
Epoch 13, Loss: 49.845481872558594
Epoch 14, Loss: 45.35295104980469
Epoch 15, Loss: 43.45259475708008
Epoch 16, Loss: 42.93147659301758
Epoch 17, Loss: 43.762840270996094
Epoch 18, Loss: 43.75815200805664
Epoch 19, Loss: 42.6025505065918
Epoch 20, Loss: 41.88019943237305
Validation Loss: 26.894737243652344 ; Validation Accuracy: 81.95488721804512%
Epoch 1, Loss: 147.6494140625
Epoch 2, Loss: 144.2989501953125
Epoch 3, Loss: 139.49581909179688
Epoch 4, Loss: 131.95301818847656
Epoch 5, Loss: 120.9319839477539
Epoch 6, Loss: 106.29332733154297
Epoch 7, Loss: 90.12740325927734
Epoch 8, Loss: 75.99080657958984
Epoch 9, Loss: 70.35255432128906
Epoch 10, Loss: 68.65121459960938
Epoch 11, Loss: 63.283180236816406
Epoch 12, Loss: 55.290740966796875
Epoch 13, Loss: 48.218875885009766
Epoch 14, Loss: 44.36857223510742
Epoch 15, Loss: 43.045955657958984
Epoch 16, Loss: 43.500911712646484
Epoch 17, Loss: 43.24843215942383
Epoch 18, Loss: 43.18359375
Epoch 19, Los

[I 2024-08-18 14:12:15,685] Trial 32 finished with value: 14.417293548583984 and parameters: {'n_units_1': 921, 'n_units_2': 483, 'n_units_3': 223, 'dropout_rate': 0.13365243492976286, 'learning_rate': 0.0006638157178215063}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:15,793] Trial 33 finished with value: 14.082706928253174 and parameters: {'n_units_1': 987, 'n_units_2': 445, 'n_units_3': 242, 'dropout_rate': 0.15533315591939342, 'learning_rate': 0.0007603663885963025}. Best is trial 23 with value: 13.845865249633789.


Epoch 9, Loss: 75.11825561523438
Epoch 10, Loss: 70.31021881103516
Epoch 11, Loss: 67.73579406738281
Epoch 12, Loss: 62.97469711303711
Epoch 13, Loss: 55.84856033325195
Epoch 14, Loss: 49.740638732910156
Epoch 15, Loss: 46.053524017333984
Epoch 16, Loss: 44.071617126464844
Epoch 17, Loss: 43.606971740722656
Epoch 18, Loss: 43.764827728271484
Epoch 19, Loss: 42.96919631958008
Epoch 20, Loss: 41.72493362426758
Validation Loss: 31.40225601196289 ; Validation Accuracy: 83.0827067669173%
Epoch 1, Loss: 148.28213500976562
Epoch 2, Loss: 145.55198669433594
Epoch 3, Loss: 141.834716796875
Epoch 4, Loss: 135.99205017089844
Epoch 5, Loss: 127.21246337890625
Epoch 6, Loss: 114.5943832397461
Epoch 7, Loss: 99.6552734375
Epoch 8, Loss: 83.15141296386719
Epoch 9, Loss: 71.52828979492188
Epoch 10, Loss: 67.5406265258789
Epoch 11, Loss: 67.05863952636719
Epoch 12, Loss: 62.24535369873047
Epoch 13, Loss: 53.668575286865234
Epoch 14, Loss: 47.070411682128906
Epoch 15, Loss: 43.81159210205078
Epoch 16, L

[I 2024-08-18 14:12:15,889] Trial 34 finished with value: 15.090226173400879 and parameters: {'n_units_1': 854, 'n_units_2': 413, 'n_units_3': 236, 'dropout_rate': 0.22176020725355516, 'learning_rate': 0.0006102940330161676}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:15,988] Trial 35 finished with value: 14.161654472351074 and parameters: {'n_units_1': 915, 'n_units_2': 436, 'n_units_3': 219, 'dropout_rate': 0.12463216028534006, 'learning_rate': 0.0007184504823382783}. Best is trial 23 with value: 13.845865249633789.


Epoch 7, Loss: 111.34059143066406
Epoch 8, Loss: 98.80656433105469
Epoch 9, Loss: 85.65757751464844
Epoch 10, Loss: 75.6310043334961
Epoch 11, Loss: 70.60355377197266
Epoch 12, Loss: 66.73993682861328
Epoch 13, Loss: 63.14509582519531
Epoch 14, Loss: 58.466766357421875
Epoch 15, Loss: 53.40532302856445
Epoch 16, Loss: 47.207218170166016
Epoch 17, Loss: 44.783470153808594
Epoch 18, Loss: 44.359580993652344
Epoch 19, Loss: 45.0853385925293
Epoch 20, Loss: 43.728424072265625
Validation Loss: 36.5 ; Validation Accuracy: 82.33082706766918%
Epoch 1, Loss: 147.99269104003906
Epoch 2, Loss: 144.97557067871094
Epoch 3, Loss: 140.62762451171875
Epoch 4, Loss: 133.86106872558594
Epoch 5, Loss: 124.23063659667969
Epoch 6, Loss: 111.04853820800781
Epoch 7, Loss: 95.53446197509766
Epoch 8, Loss: 81.06634521484375
Epoch 9, Loss: 71.87696075439453
Epoch 10, Loss: 69.02584838867188
Epoch 11, Loss: 66.44496154785156
Epoch 12, Loss: 59.856143951416016
Epoch 13, Loss: 52.806121826171875
Epoch 14, Loss: 47

[I 2024-08-18 14:12:16,102] Trial 36 finished with value: 17.206766843795776 and parameters: {'n_units_1': 989, 'n_units_2': 487, 'n_units_3': 200, 'dropout_rate': 0.17404104021489653, 'learning_rate': 0.00041401922408266264}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:16,199] Trial 37 finished with value: 13.939850330352783 and parameters: {'n_units_1': 805, 'n_units_2': 449, 'n_units_3': 246, 'dropout_rate': 0.24760039630142094, 'learning_rate': 0.0009178915984088165}. Best is trial 23 with value: 13.845865249633789.


Epoch 6, Loss: 135.411376953125
Epoch 7, Loss: 130.26930236816406
Epoch 8, Loss: 123.87596893310547
Epoch 9, Loss: 116.21074676513672
Epoch 10, Loss: 107.28915405273438
Epoch 11, Loss: 97.40648651123047
Epoch 12, Loss: 87.18228912353516
Epoch 13, Loss: 78.17741394042969
Epoch 14, Loss: 70.69834899902344
Epoch 15, Loss: 66.16133117675781
Epoch 16, Loss: 64.32281494140625
Epoch 17, Loss: 62.68649673461914
Epoch 18, Loss: 59.702606201171875
Epoch 19, Loss: 56.95219802856445
Epoch 20, Loss: 53.10696029663086
Validation Loss: 29.642858505249023 ; Validation Accuracy: 72.93233082706767%
Epoch 1, Loss: 147.37930297851562
Epoch 2, Loss: 143.4989471435547
Epoch 3, Loss: 137.46380615234375
Epoch 4, Loss: 127.70063781738281
Epoch 5, Loss: 113.56109619140625
Epoch 6, Loss: 95.99002838134766
Epoch 7, Loss: 79.3160400390625
Epoch 8, Loss: 73.39997863769531
Epoch 9, Loss: 68.85797882080078
Epoch 10, Loss: 61.467857360839844
Epoch 11, Loss: 53.330501556396484
Epoch 12, Loss: 47.52616882324219
Epoch 13

[I 2024-08-18 14:12:16,288] Trial 38 finished with value: 21.26691770553589 and parameters: {'n_units_1': 408, 'n_units_2': 490, 'n_units_3': 159, 'dropout_rate': 0.19735252163653177, 'learning_rate': 0.00018408663631715963}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:16,391] Trial 39 finished with value: 14.33458662033081 and parameters: {'n_units_1': 945, 'n_units_2': 422, 'n_units_3': 122, 'dropout_rate': 0.1478856636086748, 'learning_rate': 0.0007866324882441009}. Best is trial 23 with value: 13.845865249633789.


Epoch 7, Loss: 144.45298767089844
Epoch 8, Loss: 143.82113647460938
Epoch 9, Loss: 143.1598663330078
Epoch 10, Loss: 142.26336669921875
Epoch 11, Loss: 141.35353088378906
Epoch 12, Loss: 140.38792419433594
Epoch 13, Loss: 139.2038116455078
Epoch 14, Loss: 137.94021606445312
Epoch 15, Loss: 136.4810333251953
Epoch 16, Loss: 134.97122192382812
Epoch 17, Loss: 133.1448516845703
Epoch 18, Loss: 131.02479553222656
Epoch 19, Loss: 129.0832977294922
Epoch 20, Loss: 126.41742706298828
Validation Loss: 64.26692199707031 ; Validation Accuracy: 81.95488721804512%
Epoch 1, Loss: 147.78858947753906
Epoch 2, Loss: 145.25653076171875
Epoch 3, Loss: 141.5924072265625
Epoch 4, Loss: 135.914794921875
Epoch 5, Loss: 127.78500366210938
Epoch 6, Loss: 116.72920227050781
Epoch 7, Loss: 102.65408325195312
Epoch 8, Loss: 88.23297882080078
Epoch 9, Loss: 76.26192474365234
Epoch 10, Loss: 69.88578796386719
Epoch 11, Loss: 66.34884643554688
Epoch 12, Loss: 62.53811264038086
Epoch 13, Loss: 56.324981689453125
Epo

[I 2024-08-18 14:12:16,487] Trial 40 finished with value: 15.199248313903809 and parameters: {'n_units_1': 713, 'n_units_2': 459, 'n_units_3': 180, 'dropout_rate': 0.20646367487332173, 'learning_rate': 0.0005623029454729281}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:16,587] Trial 41 finished with value: 13.996240854263306 and parameters: {'n_units_1': 789, 'n_units_2': 451, 'n_units_3': 246, 'dropout_rate': 0.24224588928892812, 'learning_rate': 0.0009109052522976318}. Best is trial 23 with value: 13.845865249633789.


Epoch 10, Loss: 86.25177764892578
Epoch 11, Loss: 77.46328735351562
Epoch 12, Loss: 72.93048858642578
Epoch 13, Loss: 68.7155532836914
Epoch 14, Loss: 65.0386962890625
Epoch 15, Loss: 59.409873962402344
Epoch 16, Loss: 55.19928741455078
Epoch 17, Loss: 50.18381881713867
Epoch 18, Loss: 45.94685363769531
Epoch 19, Loss: 44.335357666015625
Epoch 20, Loss: 43.395809173583984
Validation Loss: 34.66541290283203 ; Validation Accuracy: 79.69924812030075%
Epoch 1, Loss: 148.18618774414062
Epoch 2, Loss: 145.19088745117188
Epoch 3, Loss: 140.64999389648438
Epoch 4, Loss: 133.19688415527344
Epoch 5, Loss: 122.11161041259766
Epoch 6, Loss: 106.98046875
Epoch 7, Loss: 88.8171157836914
Epoch 8, Loss: 75.54662322998047
Epoch 9, Loss: 71.00923919677734
Epoch 10, Loss: 68.34019470214844
Epoch 11, Loss: 62.70254135131836
Epoch 12, Loss: 53.202510833740234
Epoch 13, Loss: 46.7927131652832
Epoch 14, Loss: 44.59934616088867
Epoch 15, Loss: 44.705039978027344
Epoch 16, Loss: 44.19007110595703
Epoch 17, Los

[I 2024-08-18 14:12:16,694] Trial 42 finished with value: 13.966165781021118 and parameters: {'n_units_1': 825, 'n_units_2': 428, 'n_units_3': 236, 'dropout_rate': 0.2622910131583069, 'learning_rate': 0.0008748291461518355}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:16,785] Trial 43 finished with value: 13.939850091934204 and parameters: {'n_units_1': 627, 'n_units_2': 497, 'n_units_3': 226, 'dropout_rate': 0.22168679873283684, 'learning_rate': 0.0009329407755920355}. Best is trial 23 with value: 13.845865249633789.


Epoch 9, Loss: 70.26602935791016
Epoch 10, Loss: 68.02181243896484
Epoch 11, Loss: 61.73361587524414
Epoch 12, Loss: 52.67353820800781
Epoch 13, Loss: 46.68519592285156
Epoch 14, Loss: 43.026397705078125
Epoch 15, Loss: 43.62514114379883
Epoch 16, Loss: 44.33431625366211
Epoch 17, Loss: 44.32585144042969
Epoch 18, Loss: 41.65413284301758
Epoch 19, Loss: 40.884605407714844
Epoch 20, Loss: 40.66390609741211
Validation Loss: 26.469924926757812 ; Validation Accuracy: 82.33082706766918%
Epoch 1, Loss: 147.52438354492188
Epoch 2, Loss: 144.29005432128906
Epoch 3, Loss: 139.2900848388672
Epoch 4, Loss: 131.13449096679688
Epoch 5, Loss: 118.56920623779297
Epoch 6, Loss: 101.89459991455078
Epoch 7, Loss: 84.23908996582031
Epoch 8, Loss: 73.80195617675781
Epoch 9, Loss: 68.93741607666016
Epoch 10, Loss: 64.82024383544922
Epoch 11, Loss: 57.25575256347656
Epoch 12, Loss: 49.08807373046875
Epoch 13, Loss: 44.96323013305664
Epoch 14, Loss: 44.12272262573242
Epoch 15, Loss: 44.2748908996582
Epoch 16

[I 2024-08-18 14:12:16,883] Trial 44 finished with value: 14.007519245147705 and parameters: {'n_units_1': 571, 'n_units_2': 495, 'n_units_3': 223, 'dropout_rate': 0.22057750206888815, 'learning_rate': 0.0008353075746292328}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:16,979] Trial 45 finished with value: 14.763158321380615 and parameters: {'n_units_1': 629, 'n_units_2': 389, 'n_units_3': 216, 'dropout_rate': 0.1764565352016194, 'learning_rate': 0.0006961499706518196}. Best is trial 23 with value: 13.845865249633789.


Epoch 15, Loss: 45.09840393066406
Epoch 16, Loss: 44.35466384887695
Epoch 17, Loss: 44.611812591552734
Epoch 18, Loss: 43.60784912109375
Epoch 19, Loss: 44.47015380859375
Epoch 20, Loss: 41.98258590698242
Validation Loss: 28.699249267578125 ; Validation Accuracy: 82.33082706766918%
Epoch 1, Loss: 147.97027587890625
Epoch 2, Loss: 146.10096740722656
Epoch 3, Loss: 143.8246307373047
Epoch 4, Loss: 140.3936309814453
Epoch 5, Loss: 135.39614868164062
Epoch 6, Loss: 128.43687438964844
Epoch 7, Loss: 119.142822265625
Epoch 8, Loss: 107.5685806274414
Epoch 9, Loss: 95.24125671386719
Epoch 10, Loss: 82.04314422607422
Epoch 11, Loss: 73.01713562011719
Epoch 12, Loss: 68.30540466308594
Epoch 13, Loss: 65.17474365234375
Epoch 14, Loss: 62.564002990722656
Epoch 15, Loss: 57.14765930175781
Epoch 16, Loss: 52.100154876708984
Epoch 17, Loss: 47.37710952758789
Epoch 18, Loss: 44.305809020996094
Epoch 19, Loss: 44.25358581542969
Epoch 20, Loss: 43.563114166259766
Validation Loss: 33.09022521972656 ; Va

[I 2024-08-18 14:12:17,086] Trial 46 finished with value: 13.992481708526611 and parameters: {'n_units_1': 650, 'n_units_2': 496, 'n_units_3': 190, 'dropout_rate': 0.12731359438677237, 'learning_rate': 0.0009475235349649983}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:17,184] Trial 47 finished with value: 14.89097785949707 and parameters: {'n_units_1': 525, 'n_units_2': 337, 'n_units_3': 210, 'dropout_rate': 0.4256116224261678, 'learning_rate': 0.0007503017970822055}. Best is trial 23 with value: 13.845865249633789.


Epoch 18, Loss: 40.84159851074219
Epoch 19, Loss: 39.783077239990234
Epoch 20, Loss: 38.6664924621582
Validation Loss: 26.330827713012695 ; Validation Accuracy: 81.57894736842105%
Epoch 1, Loss: 148.5465850830078
Epoch 2, Loss: 146.63880920410156
Epoch 3, Loss: 144.54747009277344
Epoch 4, Loss: 141.28372192382812
Epoch 5, Loss: 137.03524780273438
Epoch 6, Loss: 130.82058715820312
Epoch 7, Loss: 122.3909912109375
Epoch 8, Loss: 111.75435638427734
Epoch 9, Loss: 100.4947738647461
Epoch 10, Loss: 86.3582763671875
Epoch 11, Loss: 74.31959533691406
Epoch 12, Loss: 68.49998474121094
Epoch 13, Loss: 65.14684295654297
Epoch 14, Loss: 63.55904006958008
Epoch 15, Loss: 59.59033966064453
Epoch 16, Loss: 55.96901321411133
Epoch 17, Loss: 51.7414665222168
Epoch 18, Loss: 48.082733154296875
Epoch 19, Loss: 45.26488494873047
Epoch 20, Loss: 43.849212646484375
Validation Loss: 28.060150146484375 ; Validation Accuracy: 77.81954887218046%
Epoch 1, Loss: 147.76905822753906
Epoch 2, Loss: 145.504028320312

[I 2024-08-18 14:12:17,285] Trial 48 finished with value: 14.665414094924927 and parameters: {'n_units_1': 716, 'n_units_2': 478, 'n_units_3': 229, 'dropout_rate': 0.2765806665825521, 'learning_rate': 0.0006338931012202939}. Best is trial 23 with value: 13.845865249633789.
[I 2024-08-18 14:12:17,382] Trial 49 finished with value: 14.778195858001709 and parameters: {'n_units_1': 890, 'n_units_2': 227, 'n_units_3': 200, 'dropout_rate': 0.16112352979243888, 'learning_rate': 0.0008508074629990699}. Best is trial 23 with value: 13.845865249633789.


Epoch 19, Loss: 44.405921936035156
Epoch 20, Loss: 43.14606857299805
Validation Loss: 33.582706451416016 ; Validation Accuracy: 80.82706766917293%
Epoch 1, Loss: 148.00987243652344
Epoch 2, Loss: 145.9523468017578
Epoch 3, Loss: 143.05621337890625
Epoch 4, Loss: 138.46690368652344
Epoch 5, Loss: 131.88783264160156
Epoch 6, Loss: 122.79413604736328
Epoch 7, Loss: 110.95260620117188
Epoch 8, Loss: 97.650390625
Epoch 9, Loss: 85.18917846679688
Epoch 10, Loss: 76.16691589355469
Epoch 11, Loss: 70.82804107666016
Epoch 12, Loss: 65.870361328125
Epoch 13, Loss: 60.05363845825195
Epoch 14, Loss: 55.115089416503906
Epoch 15, Loss: 49.256690979003906
Epoch 16, Loss: 46.62143325805664
Epoch 17, Loss: 44.803955078125
Epoch 18, Loss: 44.13697052001953
Epoch 19, Loss: 44.5294303894043
Epoch 20, Loss: 43.32481384277344
Validation Loss: 36.97744369506836 ; Validation Accuracy: 84.21052631578947%
Best hyperparameters:  {'n_units_1': 886, 'n_units_2': 473, 'n_units_3': 218, 'dropout_rate': 0.20568777731

Wow, that's crazy! Let's sum them up:

In [178]:
medal_manual_sum_2 = torch.sum(test_outputs_2[:,1:], dim=1)
sum_correct = torch.sum(torch.abs(medal_manual_sum_2 - y_test_tensor_2[:,0]) <= torch.max(tolerance*y_test_tensor_2[:,0], torch.tensor(7.0).to(device)))
sum_accuracy = sum_correct.item() / len(y_test_2) * 100

print(f"Test Accuracy for Manual Medal Total in Test Case 2: {sum_accuracy}%")

Test Accuracy for Manual Medal Total in Test Case 2: 91.66666666666666%


And for the medal table...

In [180]:
test_countries_2 = np.array(test_countries_2)  # Convert to a numpy array if it's not already

# Create a DataFrame from the test predictions
predictions_case_2_df = pd.DataFrame({
    'Country': test_countries_2,
    'Predicted Total': test_outputs_2[:, 0].cpu().numpy(),
    'Predicted Gold': test_outputs_2[:, 1].cpu().numpy(),
    'Predicted Silver': test_outputs_2[:, 2].cpu().numpy(),
    'Predicted Bronze': test_outputs_2[:, 3].cpu().numpy(),
    'Total Predicted': medal_manual_sum_2.cpu().numpy()
})

predictions_case_2_df

Unnamed: 0,Country,Predicted Total,Predicted Gold,Predicted Silver,Predicted Bronze,Total Predicted
0,CHN,68.0,26.0,21.0,22.0,69.0
1,USA,97.0,42.0,29.0,28.0,99.0
2,JPN,22.0,9.0,7.0,8.0,24.0
3,AUS,21.0,8.0,7.0,8.0,23.0
4,FRA,27.0,9.0,10.0,10.0,29.0
...,...,...,...,...,...,...
199,CHA,1.0,0.0,0.0,0.0,0.0
200,COM,0.0,0.0,0.0,0.0,0.0
201,GUM,0.0,0.0,0.0,0.0,0.0
202,SAM,1.0,0.0,0.0,0.0,0.0


Compare to the actual results:

In [111]:
actual_df_2 = pd.DataFrame({
    'Country': test_countries_2,
    'Actual Total': y_test_2['Total'].values,
    'Actual Gold': y_test_2['Gold'].values,
    'Actual Silver': y_test_2['Silver'].values,
    'Actual Bronze': y_test_2['Bronze'].values
})

actual_df_2

Unnamed: 0,Country,Actual Total,Actual Gold,Actual Silver,Actual Bronze
0,CHN,91.0,40.0,27.0,24.0
1,USA,126.0,40.0,44.0,42.0
2,JPN,45.0,20.0,12.0,13.0
3,AUS,53.0,18.0,19.0,16.0
4,FRA,64.0,16.0,26.0,22.0
...,...,...,...,...,...
201,FSM,0.0,0.0,0.0,0.0
202,NRU,0.0,0.0,0.0,0.0
203,PLW,0.0,0.0,0.0,0.0
204,COL,4.0,0.0,3.0,1.0


And that's it for today. See you in 2028(or 2026 but then you'll have to go up and swap "summer" out for "winter")!