# ATP Tennis Analysis

In [1]:
import pandas as pd
import numpy as np
import re
from glob import glob
from tqdm import tqdm

## Data Loading

In [2]:
# Load all data from 2016-2019
years_to_consider = ["2010", "2011", "2012", "2013", "2014", "2015", "2016", "2017", "2018", "2019"]
all_files = ["tennis_atp_data/atp_matches_" + year + ".csv" for year in years_to_consider]

In [3]:
df_all_matches = pd.concat((pd.read_csv(f) for f in all_files), ignore_index=True)

## Data Set Construction

We need 3 different sources of data. For each of the 2 players, we need:
1. All past matches they have played
2. The most recent 3 matches against the same handed opponent on the same court

We also need past H2H data which will be the same for both players and thus does not need to duplicated.

The data should then be stored in data structures with efficient lookup for fast training time.

### Data Cleaning

First, to clean the data:
1. Remove names of tournament and players - these are captured by ids already
2. Remove the match number
3. Remove draw size
4. Remove the IOC (Country of origin)
5. Convert seed to seed score (1/s if s is not NaN, otherwise 0)
6. Convert score to games won for each player
7. Convert entry (take lowercase) into OHE
8. Remove minutes
9. Remove rank
10. Drop best of
11. Drop round
13. Drop tourney level
14. Drop height and age

In [4]:
# 1.
df_all_matches.drop(["tourney_name", "winner_name", "loser_name"], axis = 1, inplace = True)

In [5]:
# 2.
df_all_matches.drop(["match_num"], axis = 1, inplace = True)

In [6]:
# 3.
df_all_matches.drop(["draw_size"], axis = 1, inplace = True)

In [7]:
# 4.
df_all_matches.drop(["winner_ioc", "loser_ioc"], axis = 1, inplace = True)

In [8]:
# 5.

def calculate_seed_score(seed):
    if np.isnan(seed):
        return 0
    return 1/seed

df_all_matches["winner_seed_score"] = df_all_matches["winner_seed"].apply(calculate_seed_score)
df_all_matches["loser_seed_score"] = df_all_matches["loser_seed"].apply(calculate_seed_score)

df_all_matches.drop(["winner_seed", "loser_seed"], axis = 1, inplace = True)

In [9]:
# 6.
def remove_tiebreak_scores(score_str):
    # This removes any parenthetical group like (8), (10), etc.
    cleaned = re.sub(r'\(\d+\)', '', score_str)
    
    # Remove 'RET', 'W/O', 'WO', 'DEF', case-insensitively and any extra whitespace
    cleaned_2 = re.sub(r'\b(RET|W\/O|WO|DEF|Def|Abandoned|and|Played|Default|ABD)\b\.?', '', cleaned, flags=re.IGNORECASE)
    
    # Normalize spaces
    cleaned_3 = ' '.join(cleaned_2.split())
    
    return cleaned_3

def winner_loser_games_won(string_score):
    winner_games = 0
    loser_games = 0
    cleaned_score = remove_tiebreak_scores(string_score)
    sets = cleaned_score.split(" ")
    
    for set in sets:
        if set.startswith("["):
            winner_games += 1
        elif set == "" or set == "Walkover" or set == "":
            pass
        else:
            winner, loser = set.split("-")
            winner_games += float(winner)
            loser_games += float(loser)
    return winner_games, loser_games

df_all_matches[["w_games_won", "l_games_won"]] = df_all_matches["score"].apply(lambda x: pd.Series(winner_loser_games_won(x)))
df_all_matches.drop(["score"], axis = 1, inplace = True)

In [10]:
# 7.

def lower_if_not_nan(entry):
    if type(entry) == float:
        return entry
    return entry.lower()

df_all_matches["winner_entry"] = df_all_matches["winner_entry"].apply(lower_if_not_nan)
df_all_matches["loser_entry"] = df_all_matches["loser_entry"].apply(lower_if_not_nan)
df_all_matches = pd.get_dummies(df_all_matches, columns = ["winner_entry", "loser_entry"], dtype = float)
df_all_matches.drop("loser_entry_s", axis = 1, inplace = True)

In [11]:
# 8.
df_all_matches.drop(["minutes"], axis = 1, inplace = True)

In [12]:
# 9.
df_all_matches.drop(["winner_rank", "loser_rank"], axis = 1, inplace = True)

In [13]:
# 10.
df_all_matches.drop(["best_of"], axis = 1, inplace = True)

In [14]:
# 11.
df_all_matches.drop(["round"], axis = 1, inplace = True)

In [15]:
# 13.
df_all_matches.drop(["tourney_level"], axis = 1, inplace = True)

In [16]:
# 14.
df_all_matches.drop(["winner_age", "winner_ht", "loser_age", "loser_ht"], axis = 1, inplace = True)

In [17]:
# For hand, we will use to create the past X match performance against
# opponents of the same hand then drop that column

In [18]:
df_all_matches.columns

Index(['tourney_id', 'surface', 'tourney_date', 'winner_id', 'winner_hand',
       'loser_id', 'loser_hand', 'w_ace', 'w_df', 'w_svpt', 'w_1stIn',
       'w_1stWon', 'w_2ndWon', 'w_SvGms', 'w_bpSaved', 'w_bpFaced', 'l_ace',
       'l_df', 'l_svpt', 'l_1stIn', 'l_1stWon', 'l_2ndWon', 'l_SvGms',
       'l_bpSaved', 'l_bpFaced', 'winner_rank_points', 'loser_rank_points',
       'winner_seed_score', 'loser_seed_score', 'w_games_won', 'l_games_won',
       'winner_entry_alt', 'winner_entry_ll', 'winner_entry_pr',
       'winner_entry_q', 'winner_entry_se', 'winner_entry_wc',
       'loser_entry_alt', 'loser_entry_ll', 'loser_entry_pr', 'loser_entry_q',
       'loser_entry_se', 'loser_entry_wc'],
      dtype='object')

We can now split the data into train, validation, test - 70%, 15%, 15% chronologically

In [19]:
df_all_matches_train = df_all_matches[:20578]
df_all_matches_validation = df_all_matches[20578:24987]
df_all_matches_test = df_all_matches[24987:]

Fit a normalizer to the training data, then also normalize the validation and test sets with the same normalizer

In [20]:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import MinMaxScaler

In [21]:
# Columns to normalize
columns_to_normalize = ['w_ace', 'w_df', 'w_svpt', 'w_1stIn',
       'w_1stWon', 'w_2ndWon', 'w_SvGms', 'w_bpSaved', 'w_bpFaced', 'l_ace',
       'l_df', 'l_svpt', 'l_1stIn', 'l_1stWon', 'l_2ndWon', 'l_SvGms',
       'l_bpSaved', 'l_bpFaced', 'winner_rank_points', 'loser_rank_points',
       'winner_seed_score', 'loser_seed_score', 'w_games_won', 'l_games_won']

# Create the normalizer (fit only on training set)
normalizer = ColumnTransformer(
    transformers=[
        ('norm', MinMaxScaler(), columns_to_normalize)
    ],
    remainder='passthrough'  # Keep non-normalized columns
)
normalizer.fit(df_all_matches_train)

The format of the columns of the 'remainder' transformer in ColumnTransformer.transformers_ will change in version 1.7 to match the format of the other transformers.
At the moment the remainder columns are stored as indices (of type int). With the same ColumnTransformer configuration, in the future they will be stored as column names (of type str).



In [22]:
X_train_norm = normalizer.transform(df_all_matches_train)
new_columns = columns_to_normalize + [col for col in df_all_matches_train.columns if col not in columns_to_normalize]
df_all_matches_train = pd.DataFrame(X_train_norm, columns=new_columns)

In [23]:
X_validation_norm = normalizer.transform(df_all_matches_validation)
new_columns = columns_to_normalize + [col for col in df_all_matches_validation.columns if col not in columns_to_normalize]
df_all_matches_validation = pd.DataFrame(X_validation_norm, columns=new_columns)

In [24]:
X_test_norm = normalizer.transform(df_all_matches_test)
new_columns = columns_to_normalize + [col for col in df_all_matches_test.columns if col not in columns_to_normalize]
df_all_matches_test = pd.DataFrame(X_test_norm, columns=new_columns)

In [25]:
df_all_matches_train.columns

Index(['w_ace', 'w_df', 'w_svpt', 'w_1stIn', 'w_1stWon', 'w_2ndWon', 'w_SvGms',
       'w_bpSaved', 'w_bpFaced', 'l_ace', 'l_df', 'l_svpt', 'l_1stIn',
       'l_1stWon', 'l_2ndWon', 'l_SvGms', 'l_bpSaved', 'l_bpFaced',
       'winner_rank_points', 'loser_rank_points', 'winner_seed_score',
       'loser_seed_score', 'w_games_won', 'l_games_won', 'tourney_id',
       'surface', 'tourney_date', 'winner_id', 'winner_hand', 'loser_id',
       'loser_hand', 'winner_entry_alt', 'winner_entry_ll', 'winner_entry_pr',
       'winner_entry_q', 'winner_entry_se', 'winner_entry_wc',
       'loser_entry_alt', 'loser_entry_ll', 'loser_entry_pr', 'loser_entry_q',
       'loser_entry_se', 'loser_entry_wc'],
      dtype='object')

## Input Match

Set up a dataset to represent the data for an upcoming match

In [26]:
df_current_match_train = df_all_matches_train.copy()
df_current_match_valid = df_all_matches_validation.copy()
df_current_match_test = df_all_matches_test.copy()

In [27]:
def preprocess_current_match(df_current_match):
    # Drop all columns relating to the match performance itself
    df_current_match.drop(['w_ace', 'w_df',
           'w_svpt', 'w_1stIn', 'w_1stWon', 'w_2ndWon', 'w_SvGms', 'w_bpSaved',
           'w_bpFaced', 'l_ace', 'l_df', 'l_svpt', 'l_1stIn', 'l_1stWon',
           'l_2ndWon', 'l_SvGms', 'l_bpSaved', 'l_bpFaced', 'w_games_won', 'l_games_won'], axis = 1, inplace = True)
    # Rename winner and loser to player A and player B and randomly shuffle so that sometimes
    # player A wins and sometime player B wins
    
    # Generate random mask for shuffling
    mask = np.random.rand(len(df_current_match)) < 0.5
    
    # Define base stat fields
    stat_fields = ['id', 'hand', "seed_score", 
                   "entry_alt", "entry_ll",
                   "entry_pr", "entry_q",
                   "entry_se", "entry_wc",
                   "rank_points"
                  ]
    
    for field in stat_fields:
        df_current_match[f'player_A_{field}'] = np.where(mask, df_current_match[f'winner_{field}'], df_current_match[f'loser_{field}'])
        df_current_match[f'player_B_{field}'] = np.where(mask, df_current_match[f'loser_{field}'], df_current_match[f'winner_{field}'])
    
    # Add new column to indicate winner ('A' or 'B')
    df_current_match['winner_label'] = np.where(mask, 'A', 'B')
    
    # Drop original columns
    cols_to_drop = [f'winner_{f}' for f in stat_fields] + [f'loser_{f}' for f in stat_fields]
    df_current_match = df_current_match.drop(columns=cols_to_drop)
    return df_current_match


In [28]:
df_current_match_train = preprocess_current_match(df_current_match_train)
df_current_match_valid = preprocess_current_match(df_current_match_valid)
df_current_match_test = preprocess_current_match(df_current_match_test)

In [29]:
df_current_match_train.columns

Index(['tourney_id', 'surface', 'tourney_date', 'player_A_id', 'player_B_id',
       'player_A_hand', 'player_B_hand', 'player_A_seed_score',
       'player_B_seed_score', 'player_A_entry_alt', 'player_B_entry_alt',
       'player_A_entry_ll', 'player_B_entry_ll', 'player_A_entry_pr',
       'player_B_entry_pr', 'player_A_entry_q', 'player_B_entry_q',
       'player_A_entry_se', 'player_B_entry_se', 'player_A_entry_wc',
       'player_B_entry_wc', 'player_A_rank_points', 'player_B_rank_points',
       'winner_label'],
      dtype='object')

## All Past Performances

In [30]:
def process_past_performance_df(df, prefix, won):
    columns_to_keep = ['ace', 'df', 'svpt', '1stIn', '1stWon',
       '2ndWon', 'SvGms', 'bpSaved', 'bpFaced', "games_won"]
    df_processed = df[[f"{prefix}{col}" for col in columns_to_keep] + ["tourney_date"]].copy()
    df_processed.rename(columns = lambda x: x.replace(prefix, ""), inplace = True)
    df_processed["won"] = float(won)
    return df_processed

In [31]:
def all_past_performances(player_id: int, tourney_date: int, df):
    df_wins = df[(df['winner_id'] == player_id) & (df['tourney_date'] < tourney_date)].copy()
    df_loses = df[(df['loser_id'] == player_id) & (df['tourney_date'] < tourney_date)].copy()

    df_wins_processed = process_past_performance_df(df_wins, "w_", True)
    df_loses_processed = process_past_performance_df(df_loses, "l_", False)

    df_total = pd.concat([df_wins_processed, df_loses_processed], ignore_index = True)

    # Drop rows with incomplete data
    df_total.dropna(inplace = True)
    return df_total

In [32]:
all_past_performances(104745, 20190920, df_all_matches_train)

Unnamed: 0,ace,df,svpt,1stIn,1stWon,2ndWon,SvGms,bpSaved,bpFaced,games_won,tourney_date,won
0,0.044248,0.0,0.10998,0.099723,0.089041,0.158537,0.1,0.125,0.1,0.130435,20100104,1.0
1,0.026549,0.043478,0.101833,0.088643,0.089041,0.109756,0.088889,0.0,0.0,0.130435,20100104,1.0
2,0.017699,0.0,0.038697,0.033241,0.037671,0.060976,0.044444,0.0,0.0,0.086957,20100104,1.0
3,0.00885,0.0,0.09165,0.077562,0.078767,0.109756,0.088889,0.0,0.033333,0.130435,20100104,1.0
4,0.097345,0.043478,0.211813,0.199446,0.188356,0.219512,0.166667,0.25,0.233333,0.206522,20100118,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...
489,0.048544,0.047619,0.216769,0.25,0.172535,0.108911,0.153846,0.32,0.419355,0.131868,20160808,0.0
490,0.009709,0.238095,0.100204,0.10061,0.073944,0.019802,0.087912,0.12,0.258065,0.043956,20160815,0.0
491,0.009709,0.190476,0.306748,0.271341,0.225352,0.29703,0.252747,0.36,0.451613,0.252747,20160829,0.0
492,0.0,0.095238,0.132924,0.146341,0.088028,0.069307,0.098901,0.4,0.483871,0.065934,20161003,0.0


## Past 5 Performance on Same Court and Handed-Opponent

For each player, we want to filter for the past 5 performances against the same handed opponent on the same surface.

In [33]:
def past_5_performance_same_surface(player_id: int, tourney_date: int, surface: str, opponent_hand: str, df):
    df_wins = df[(df['winner_id'] == player_id) & (df['tourney_date'] < tourney_date)].copy()
    df_loses = df[(df['loser_id'] == player_id) & (df['tourney_date'] < tourney_date)].copy()

    df_wins_same_surface_hand = df_wins[(df_wins["surface"] == surface) & (df_wins["loser_hand"] == opponent_hand)].copy()
    df_loses_same_surface_hand = df_loses[(df_loses["surface"] == surface) & (df_loses["winner_hand"] == opponent_hand)].copy()

    df_wins_processed = process_past_performance_df(df_wins_same_surface_hand, "w_", True)
    df_loses_processed = process_past_performance_df(df_loses_same_surface_hand, "l_", False)

    df_total = pd.concat([df_wins_processed, df_loses_processed], ignore_index = True)

    # Drop rows with incomplete data
    df_total.dropna(inplace = True)

    # Sort by tourney_date
    df_total.sort_values(by='tourney_date', ascending=False, inplace=True)
    return df_total[:5]

In [34]:
past_5_performance_same_surface(104745, 20190920, "Clay", "L", df_all_matches_train)

Unnamed: 0,ace,df,svpt,1stIn,1stWon,2ndWon,SvGms,bpSaved,bpFaced,games_won,tourney_date,won
19,0.0,0.086957,0.14664,0.155125,0.133562,0.109756,0.133333,0.291667,0.3,0.195652,20160523,1.0
18,0.0,0.304348,0.169043,0.182825,0.133562,0.085366,0.122222,0.125,0.233333,0.141304,20150727,1.0
17,0.017699,0.0,0.162933,0.160665,0.126712,0.158537,0.122222,0.291667,0.3,0.163043,20150727,1.0
16,0.0,0.086957,0.077393,0.066482,0.068493,0.109756,0.077778,0.0,0.0,0.130435,20150223,1.0
15,0.017699,0.0,0.087576,0.085873,0.068493,0.097561,0.088889,0.0,0.066667,0.130435,20150216,1.0


## Past H2H Performance

We also want the past H2H performances between the 2 players

In [35]:
def process_h2h(df, winner: str, loser: str):
    performance_columns_to_keep = ['ace', 'df', 'svpt', '1stIn', '1stWon',
       '2ndWon', 'SvGms', 'bpSaved', 'bpFaced', "games_won"]

    df_processed_winner = df[[f"w_{col}" for col in performance_columns_to_keep]].copy()
    df_processed_winner.rename(columns = lambda x: x.replace("w_", winner+"_"), inplace = True)

    df_processed_loser = df[[f"l_{col}" for col in performance_columns_to_keep]].copy()
    df_processed_loser.rename(columns = lambda x: x.replace("l_", loser+"_"), inplace = True)


    df_extra_columns = df[["surface", "tourney_date"]].copy()
    
    df_processed_total = pd.concat([df_processed_winner, df_processed_loser, df_extra_columns], axis = 1)
    df_processed_total["winner"] = float(winner == "A")
    df_processed_total.dropna(inplace = True)
    return df_processed_total

In [36]:
def get_head_to_head(df, player_A_id, player_B_id, tourney_date):
    
    df_player_A_won = df[
        ((df['winner_id'] == player_A_id) & (df['loser_id'] == player_B_id)) &
        (df["tourney_date"] < tourney_date)
        ].copy()
    df_player_A_won_processed = process_h2h(df_player_A_won, "A", "B")
    
    df_player_B_won = df[
        ((df['winner_id'] == player_B_id) & (df['loser_id'] == player_A_id)) &
        (df["tourney_date"] < tourney_date)
        ].copy()
    df_player_B_won_processed = process_h2h(df_player_B_won, "B", "A")

    return pd.concat([df_player_A_won_processed, df_player_B_won_processed])


In [37]:
get_head_to_head(df_all_matches_train, 106233, 133430, 20190920)

Unnamed: 0,A_ace,A_df,A_svpt,A_1stIn,A_1stWon,A_2ndWon,A_SvGms,A_bpSaved,A_bpFaced,A_games_won,...,B_1stIn,B_1stWon,B_2ndWon,B_SvGms,B_bpSaved,B_bpFaced,B_games_won,surface,tourney_date,winner


## Single data point

For a given row of the Input Match data frame, construct a single data point ready for modelling

In [38]:
df_current_match_train.iloc[0]

tourney_id              2010-339
surface                     Hard
tourney_date            20100103
player_A_id               104053
player_B_id               103429
player_A_hand                  R
player_B_hand                  R
player_A_seed_score          1.0
player_B_seed_score          0.0
player_A_entry_alt           0.0
player_B_entry_alt           0.0
player_A_entry_ll            0.0
player_B_entry_ll            0.0
player_A_entry_pr            0.0
player_B_entry_pr            0.0
player_A_entry_q             0.0
player_B_entry_q             0.0
player_A_entry_se            0.0
player_B_entry_se            0.0
player_A_entry_wc            0.0
player_B_entry_wc            0.0
player_A_rank_points    0.260133
player_B_rank_points    0.035223
winner_label                   A
Name: 0, dtype: object

In [39]:
def single_data_point(df_single_row):
    temp = df_single_row
    player_a_past_performance = all_past_performances(temp["player_A_id"], temp["tourney_date"], df_all_matches_train).drop(["tourney_date"], axis = 1)
    player_a_past_5_performance = past_5_performance_same_surface(
        temp["player_A_id"], temp["tourney_date"], temp["surface"], temp["player_B_hand"], df_all_matches_train).drop(["tourney_date"], axis = 1)
    
    player_b_past_performance = all_past_performances(temp["player_B_id"], temp["tourney_date"], df_all_matches_train).drop(["tourney_date"], axis = 1)
    player_b_past_5_performance = past_5_performance_same_surface(
        temp["player_B_id"], temp["tourney_date"], temp["surface"], temp["player_A_hand"], df_all_matches_train).drop(["tourney_date"], axis = 1)

    past_h2h_data = get_head_to_head(
        df_all_matches_train, temp["player_A_id"], temp["player_B_id"], temp["tourney_date"]
    ).drop(["tourney_date", "surface"], axis = 1)

    temp.drop(["tourney_id", "surface", "tourney_date", "player_A_id", "player_B_id", "player_A_hand", "player_B_hand"], inplace = True)
    temp["winner_label"] = float(temp["winner_label"] == "A")

    return [list(temp), player_a_past_performance, player_b_past_performance, player_a_past_5_performance, player_b_past_5_performance, past_h2h_data]

In [40]:
single_data_point(df_current_match_train.iloc[20000])

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  temp.drop(["tourney_id", "surface", "tourney_date", "player_A_id", "player_B_id", "player_A_hand", "player_B_hand"], inplace = True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  temp["winner_label"] = float(temp["winner_label"] == "A")


[[1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.10820697386276476,
  0.007257065313587822,
  1.0],
           ace        df      svpt     1stIn    1stWon    2ndWon     SvGms  \
 0    0.115044  0.086957  0.179226  0.163435  0.150685  0.170732  0.166667   
 1    0.088496  0.173913  0.244399  0.163435  0.143836  0.426829  0.188889   
 2    0.035398   0.26087  0.183299  0.144044  0.150685  0.219512  0.166667   
 3    0.088496  0.173913   0.14664  0.102493  0.113014  0.231707  0.111111   
 4    0.106195  0.173913  0.169043  0.119114  0.116438  0.268293  0.122222   
 ..        ...       ...       ...       ...       ...       ...       ...   
 373  0.097087  0.047619  0.153374  0.161585  0.137324  0.089109  0.120879   
 374  0.058252  0.047619  0.237219   0.22561  0.158451  0.188119  0.153846   
 375  0.058252  0.190476  0.132924  0.115854  0.091549  0.148515   0.10989   
 376  0.116505  0.095238  0.167689   0.14939  0.123239  0.188119  0

# Model Building

In [41]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, Masking, GlobalAveragePooling1D, Concatenate
from tensorflow.keras.losses import BinaryCrossentropy

# === Define input shapes ===
X1 = 11   # Features in time series 1
X2 = 11   # Features in time series 2
dense_input_size = 16

# Inputs (variable-length)
input_ts1 = Input(shape=(None, X1), name='player_a_past_performance')
input_ts2 = Input(shape=(None, X2), name='player_b_past_performance')
input_tab1 = Input(shape=(None, 11), name='player_a_past_5_performance')  # 0–5 rows
input_tab2 = Input(shape=(None, 11), name='player_b_past_5_performance')  # 0–5 rows
input_tab3 = Input(shape=(None, 21), name='h2h_data')  # 0–3 rows
input_dense = Input(shape=(dense_input_size,), name='single_row_16_features')


# === Branch for time series 1 ===
masked_ts1 = Masking()(input_ts1)
lstm_ts1 = LSTM(64)(masked_ts1)

# === Branch for time series 2 ===
masked_ts2 = Masking()(input_ts2)
lstm_ts2 = LSTM(64)(masked_ts2)

# === Branch for tabular data 1 ===
masked_tab1 = Masking()(input_tab1)
dense_tab1 = Dense(32, activation='relu')(masked_tab1)
pooled_tab1 = GlobalAveragePooling1D()(dense_tab1)

# === Branch for tabular data 2 ===
masked_tab2 = Masking()(input_tab2)
dense_tab2 = Dense(32, activation='relu')(masked_tab2)
pooled_tab2 = GlobalAveragePooling1D()(dense_tab2)

# === Branch for tabular data 3 ===
masked_tab3 = Masking()(input_tab3)
dense_tab3 = Dense(32, activation='relu')(masked_tab3)
pooled_tab3 = GlobalAveragePooling1D()(dense_tab3)

dense_branch = Dense(32, activation='relu')(input_dense)

# === Merge all branches ===
merged = Concatenate()([lstm_ts1, lstm_ts2, pooled_tab1, pooled_tab2, pooled_tab3, dense_branch])
dense = Dense(64, activation='relu')(merged)
output = Dense(1, activation='sigmoid')(dense)

model = Model(inputs=[input_ts1, input_ts2, input_tab1, input_tab2, input_tab3, input_dense], outputs=output)
model.compile(optimizer='adam', loss=BinaryCrossentropy(), metrics=['accuracy'])

model.summary()


2025-06-05 18:54:03.080512: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [42]:
def train_model_samplewise(model, train_data, epochs=1):
    for epoch in range(epochs):
        total_loss = 0.0
        sample_count = 0

        print(f"\n=== Epoch {epoch + 1} ===")
        
        for i, sample in tqdm(train_data.iterrows(), total=len(train_data)):
            try:
                dense_input, ts1, ts2, tab1, tab2, tab3 = single_data_point(sample)

                # Skip incomplete inputs
                if tab1.shape[0] == 0 or tab2.shape[0] == 0 or tab3.shape[0] == 0:
                    continue

                # Ensure all inputs are properly shaped float32 arrays
                ts1 = np.expand_dims(np.array(ts1, dtype=np.float32), axis=0)
                ts2 = np.expand_dims(np.array(ts2, dtype=np.float32), axis=0)
                tab1 = np.expand_dims(np.array(tab1, dtype=np.float32), axis=0)
                tab2 = np.expand_dims(np.array(tab2, dtype=np.float32), axis=0)
                tab3 = np.expand_dims(np.array(tab3, dtype=np.float32), axis=0)

                dense_input_array = np.array(dense_input, dtype=np.float32)
                target_value = float(dense_input_array[-1])  # Ensure it's a scalar float

                dense_input_trimmed = np.expand_dims(dense_input_array[:-1], axis=0)
                target = np.array([[target_value]], dtype=np.float32)

                inputs = [ts1, ts2, tab1, tab2, tab3, dense_input_trimmed]
                for input_arr in [ts1, ts2, tab1, tab2, tab3, dense_input_trimmed]:
                    assert not np.isnan(input_arr).any(), "NaN in input"
                    assert not np.isinf(input_arr).any(), "Inf in input"

                # === Train step ===
                loss = model.train_on_batch(inputs, target)

                # # Log raw loss
                # print(f"Sample {i} — Raw loss return: {loss}")

                # If using metrics, loss is a list
                if isinstance(loss, (list, tuple)):
                    loss_value = loss[1]
                else:
                    loss_value = loss

                total_loss += loss_value
                sample_count += 1

                # Optional: print a prediction every 10 samples
                if i % 10 == 0:
                    pred = model.predict(inputs, verbose=0)
                    print(f"  ➤ Predicted: {pred[0][0]:.4f}, Target: {target_value:.1f}, Loss: {loss_value:.4f}")

            except Exception as e:
                print(f"Error on sample {i}: {e}")
                continue

        # if sample_count == 0:
        #     print("⚠️ No valid samples used this epoch.")
        #     continue

        avg_loss = total_loss / sample_count
        print(f"✅ Epoch {epoch + 1} — Trained on {sample_count} samples — Avg Loss: {avg_loss:.4f}")


In [43]:
train_model_samplewise(model, df_current_match_train[-8000:])


=== Epoch 1 ===


  0%|                                          | 8/8000 [00:07<1:41:49,  1.31it/s]



  0%|                                           | 21/8000 [00:08<14:39,  9.07it/s]



  0%|                                           | 23/8000 [00:10<45:07,  2.95it/s]

  ➤ Predicted: 0.4322, Target: 1.0, Loss: 0.5714


  0%|▏                                        | 33/8000 [00:16<1:11:17,  1.86it/s]

  ➤ Predicted: 0.4192, Target: 0.0, Loss: 0.6429


  1%|▎                                        | 63/8000 [00:30<1:02:53,  2.10it/s]

  ➤ Predicted: 0.3804, Target: 1.0, Loss: 0.6129


  1%|▍                                        | 93/8000 [00:44<1:23:23,  1.58it/s]

  ➤ Predicted: 0.3888, Target: 0.0, Loss: 0.6327


  1%|▌                                       | 103/8000 [00:50<1:18:33,  1.68it/s]

  ➤ Predicted: 0.4743, Target: 1.0, Loss: 0.5965


  2%|▌                                       | 123/8000 [01:00<1:13:34,  1.78it/s]

  ➤ Predicted: 0.5516, Target: 1.0, Loss: 0.5942


  2%|▊                                       | 153/8000 [01:13<1:13:13,  1.79it/s]

  ➤ Predicted: 0.5225, Target: 1.0, Loss: 0.5783


  2%|█                                         | 195/8000 [01:23<27:02,  4.81it/s]

  ➤ Predicted: 0.4412, Target: 0.0, Loss: 0.6000


  3%|█                                       | 223/8000 [01:37<1:19:22,  1.63it/s]

  ➤ Predicted: 0.4436, Target: 0.0, Loss: 0.5929


  3%|█▏                                        | 235/8000 [01:43<51:39,  2.51it/s]

  ➤ Predicted: 0.5629, Target: 1.0, Loss: 0.6000


  3%|█▏                                      | 243/8000 [01:48<1:28:07,  1.47it/s]

  ➤ Predicted: 0.4923, Target: 1.0, Loss: 0.6032


  3%|█▍                                        | 263/8000 [01:56<44:31,  2.90it/s]

  ➤ Predicted: 0.4551, Target: 1.0, Loss: 0.6029


  3%|█▍                                        | 273/8000 [01:59<45:00,  2.86it/s]

  ➤ Predicted: 0.4896, Target: 1.0, Loss: 0.6043


  4%|█▍                                        | 283/8000 [02:02<48:25,  2.66it/s]

  ➤ Predicted: 0.5390, Target: 1.0, Loss: 0.6084


  4%|█▋                                        | 313/8000 [02:17<56:55,  2.25it/s]

  ➤ Predicted: 0.6584, Target: 1.0, Loss: 0.6173


  4%|█▋                                      | 333/8000 [02:30<1:36:14,  1.33it/s]

  ➤ Predicted: 0.5245, Target: 0.0, Loss: 0.6257


  4%|█▋                                      | 343/8000 [02:37<1:32:40,  1.38it/s]

  ➤ Predicted: 0.4047, Target: 1.0, Loss: 0.6223


  5%|█▉                                        | 369/8000 [02:42<13:42,  9.28it/s]

  ➤ Predicted: 0.2934, Target: 1.0, Loss: 0.6250


  6%|██▎                                       | 443/8000 [02:51<44:13,  2.85it/s]

  ➤ Predicted: 0.5084, Target: 1.0, Loss: 0.6332


  6%|██▍                                       | 466/8000 [02:55<20:48,  6.03it/s]

  ➤ Predicted: 0.4097, Target: 1.0, Loss: 0.6355


  6%|██▍                                     | 493/8000 [03:04<1:00:26,  2.07it/s]

  ➤ Predicted: 0.5259, Target: 1.0, Loss: 0.6279


  6%|██▋                                       | 505/8000 [03:09<33:37,  3.72it/s]

  ➤ Predicted: 0.5428, Target: 1.0, Loss: 0.6364


  7%|██▋                                       | 523/8000 [03:14<50:18,  2.48it/s]

  ➤ Predicted: 0.7195, Target: 0.0, Loss: 0.6327


  7%|██▋                                     | 533/8000 [03:20<1:22:24,  1.51it/s]

  ➤ Predicted: 0.6588, Target: 1.0, Loss: 0.6368


  7%|██▋                                     | 543/8000 [03:27<1:18:59,  1.57it/s]

  ➤ Predicted: 0.4438, Target: 0.0, Loss: 0.6405


  7%|██▊                                     | 553/8000 [03:34<1:44:40,  1.19it/s]

  ➤ Predicted: 0.4423, Target: 1.0, Loss: 0.6414


  7%|███                                       | 585/8000 [03:43<39:50,  3.10it/s]

  ➤ Predicted: 0.4205, Target: 1.0, Loss: 0.6360


  7%|███                                       | 593/8000 [03:47<54:44,  2.26it/s]

  ➤ Predicted: 0.7128, Target: 1.0, Loss: 0.6391


  8%|███▎                                      | 623/8000 [03:58<57:31,  2.14it/s]

  ➤ Predicted: 0.4630, Target: 0.0, Loss: 0.6380


  8%|███▎                                      | 636/8000 [04:02<26:54,  4.56it/s]

  ➤ Predicted: 0.6077, Target: 1.0, Loss: 0.6408


  8%|███▌                                      | 673/8000 [04:14<50:24,  2.42it/s]

  ➤ Predicted: 0.4807, Target: 1.0, Loss: 0.6364


  9%|███▍                                    | 683/8000 [04:20<1:12:04,  1.69it/s]

  ➤ Predicted: 0.6497, Target: 0.0, Loss: 0.6361


  9%|███▍                                    | 693/8000 [04:26<1:11:19,  1.71it/s]

  ➤ Predicted: 0.5999, Target: 1.0, Loss: 0.6314


  9%|███▌                                    | 713/8000 [04:37<1:21:45,  1.49it/s]

  ➤ Predicted: 0.4541, Target: 0.0, Loss: 0.6227


  9%|███▌                                    | 723/8000 [04:43<1:06:28,  1.82it/s]

  ➤ Predicted: 0.5964, Target: 1.0, Loss: 0.6198


  9%|███▋                                    | 733/8000 [04:49<1:11:10,  1.70it/s]

  ➤ Predicted: 0.5194, Target: 0.0, Loss: 0.6158


  9%|███▉                                      | 755/8000 [05:00<39:00,  3.10it/s]

  ➤ Predicted: 0.4754, Target: 1.0, Loss: 0.6186


 10%|████                                      | 775/8000 [05:10<49:29,  2.43it/s]

  ➤ Predicted: 0.5960, Target: 1.0, Loss: 0.6164


 10%|███▉                                    | 783/8000 [05:15<1:25:04,  1.41it/s]

  ➤ Predicted: 0.6029, Target: 1.0, Loss: 0.6183


 10%|███▉                                    | 793/8000 [05:23<1:36:27,  1.25it/s]

  ➤ Predicted: 0.3514, Target: 0.0, Loss: 0.6194


 10%|████▏                                     | 803/8000 [05:27<57:59,  2.07it/s]

  ➤ Predicted: 0.6186, Target: 1.0, Loss: 0.6208


 10%|████▎                                     | 826/8000 [05:34<33:30,  3.57it/s]

  ➤ Predicted: 0.5017, Target: 1.0, Loss: 0.6132


 11%|████▌                                     | 873/8000 [05:51<51:12,  2.32it/s]

  ➤ Predicted: 0.5719, Target: 1.0, Loss: 0.6053


 11%|████▋                                     | 902/8000 [05:59<22:35,  5.24it/s]

  ➤ Predicted: 0.4311, Target: 0.0, Loss: 0.6076


 11%|████▊                                     | 913/8000 [06:05<50:41,  2.33it/s]

  ➤ Predicted: 0.5676, Target: 0.0, Loss: 0.6061


 12%|████▊                                     | 923/8000 [06:10<59:29,  1.98it/s]

  ➤ Predicted: 0.6346, Target: 1.0, Loss: 0.6083


 12%|████▋                                   | 943/8000 [06:19<1:04:40,  1.82it/s]

  ➤ Predicted: 0.5863, Target: 1.0, Loss: 0.6067


 12%|████▊                                   | 953/8000 [06:26<1:27:52,  1.34it/s]

  ➤ Predicted: 0.5566, Target: 0.0, Loss: 0.6071


 12%|████▊                                   | 973/8000 [06:41<1:29:25,  1.31it/s]

  ➤ Predicted: 0.3380, Target: 1.0, Loss: 0.6093


 12%|████▉                                   | 983/8000 [06:47<1:08:54,  1.70it/s]

  ➤ Predicted: 0.2760, Target: 1.0, Loss: 0.6109


 13%|█████▏                                   | 1003/8000 [06:53<57:03,  2.04it/s]

  ➤ Predicted: 0.4246, Target: 0.0, Loss: 0.6103


 14%|█████▋                                   | 1104/8000 [07:31<37:10,  3.09it/s]

  ➤ Predicted: 0.4618, Target: 0.0, Loss: 0.6167


 14%|█████▊                                   | 1135/8000 [07:46<40:21,  2.84it/s]

  ➤ Predicted: 0.4729, Target: 0.0, Loss: 0.6213


 14%|█████▉                                   | 1152/8000 [07:49<19:59,  5.71it/s]

  ➤ Predicted: 0.6610, Target: 0.0, Loss: 0.6211


 15%|█████▉                                   | 1162/8000 [07:52<21:47,  5.23it/s]

  ➤ Predicted: 0.6846, Target: 1.0, Loss: 0.6207


 15%|██████                                   | 1183/8000 [08:00<53:24,  2.13it/s]

  ➤ Predicted: 0.7520, Target: 1.0, Loss: 0.6219


 15%|██████▏                                  | 1213/8000 [08:12<52:26,  2.16it/s]

  ➤ Predicted: 0.6815, Target: 1.0, Loss: 0.6189


 15%|█████▉                                 | 1223/8000 [08:19<1:26:02,  1.31it/s]

  ➤ Predicted: 0.5736, Target: 1.0, Loss: 0.6172


 16%|██████                                 | 1243/8000 [08:29<1:14:08,  1.52it/s]

  ➤ Predicted: 0.6525, Target: 1.0, Loss: 0.6193


 16%|██████▋                                  | 1316/8000 [08:47<19:44,  5.64it/s]

  ➤ Predicted: 0.5766, Target: 1.0, Loss: 0.6141


 17%|██████▊                                  | 1325/8000 [08:51<36:51,  3.02it/s]

  ➤ Predicted: 0.4129, Target: 0.0, Loss: 0.6134


 17%|██████▉                                  | 1353/8000 [08:55<24:02,  4.61it/s]

  ➤ Predicted: 0.5179, Target: 1.0, Loss: 0.6143


 17%|██████▉                                  | 1365/8000 [09:01<41:41,  2.65it/s]

  ➤ Predicted: 0.5189, Target: 1.0, Loss: 0.6132


 17%|███████                                  | 1372/8000 [09:02<23:18,  4.74it/s]

  ➤ Predicted: 0.7396, Target: 1.0, Loss: 0.6134


 17%|███████                                  | 1383/8000 [09:09<58:27,  1.89it/s]

  ➤ Predicted: 0.4623, Target: 0.0, Loss: 0.6123


 17%|███████▏                                 | 1393/8000 [09:14<56:14,  1.96it/s]

  ➤ Predicted: 0.4126, Target: 1.0, Loss: 0.6129


 18%|███████▏                                 | 1405/8000 [09:19<38:47,  2.83it/s]

  ➤ Predicted: 0.7346, Target: 0.0, Loss: 0.6128


 18%|███████▏                                 | 1413/8000 [09:21<36:07,  3.04it/s]

  ➤ Predicted: 0.4026, Target: 0.0, Loss: 0.6124


 18%|███████▎                                 | 1422/8000 [09:24<33:48,  3.24it/s]

  ➤ Predicted: 0.3959, Target: 1.0, Loss: 0.6117


 18%|███████▎                                 | 1433/8000 [09:30<55:21,  1.98it/s]

  ➤ Predicted: 0.5072, Target: 1.0, Loss: 0.6137


 19%|███████▌                                 | 1485/8000 [09:50<45:30,  2.39it/s]

  ➤ Predicted: 0.3990, Target: 1.0, Loss: 0.6100


 19%|███████▋                                 | 1505/8000 [09:57<43:09,  2.51it/s]

  ➤ Predicted: 0.3244, Target: 0.0, Loss: 0.6088


 19%|███████▊                                 | 1535/8000 [10:05<32:06,  3.36it/s]

  ➤ Predicted: 0.4878, Target: 1.0, Loss: 0.6104


 19%|███████▉                                 | 1543/8000 [10:09<48:09,  2.23it/s]

  ➤ Predicted: 0.6598, Target: 1.0, Loss: 0.6098


 19%|███████▌                               | 1553/8000 [10:16<1:09:53,  1.54it/s]

  ➤ Predicted: 0.4814, Target: 0.0, Loss: 0.6108


 20%|███████▌                               | 1563/8000 [10:23<1:25:22,  1.26it/s]

  ➤ Predicted: 0.5222, Target: 1.0, Loss: 0.6110


 20%|████████                                 | 1575/8000 [10:28<37:56,  2.82it/s]

  ➤ Predicted: 0.4544, Target: 0.0, Loss: 0.6137


 20%|████████                                 | 1583/8000 [10:32<49:56,  2.14it/s]

  ➤ Predicted: 0.8195, Target: 1.0, Loss: 0.6130


 20%|███████▊                               | 1603/8000 [10:45<1:23:31,  1.28it/s]

  ➤ Predicted: 0.6419, Target: 1.0, Loss: 0.6114


 20%|███████▊                               | 1613/8000 [10:54<1:35:11,  1.12it/s]

  ➤ Predicted: 0.3777, Target: 1.0, Loss: 0.6099


 21%|████████▍                                | 1645/8000 [11:12<50:32,  2.10it/s]

  ➤ Predicted: 0.4969, Target: 0.0, Loss: 0.6149


 21%|████████                               | 1653/8000 [11:17<1:08:09,  1.55it/s]

  ➤ Predicted: 0.7181, Target: 1.0, Loss: 0.6166


 21%|████████                               | 1663/8000 [11:26<1:33:57,  1.12it/s]

  ➤ Predicted: 0.5847, Target: 1.0, Loss: 0.6176


 21%|████████▋                                | 1693/8000 [11:38<43:54,  2.39it/s]

  ➤ Predicted: 0.6697, Target: 1.0, Loss: 0.6179


 21%|████████▎                              | 1703/8000 [11:45<1:25:25,  1.23it/s]

  ➤ Predicted: 0.6287, Target: 0.0, Loss: 0.6167


 21%|████████▎                              | 1713/8000 [11:52<1:09:09,  1.52it/s]

  ➤ Predicted: 0.4141, Target: 1.0, Loss: 0.6143


 22%|████████▉                                | 1746/8000 [12:02<20:57,  4.97it/s]

  ➤ Predicted: 0.5584, Target: 1.0, Loss: 0.6142


 22%|████████▉                                | 1752/8000 [12:03<23:37,  4.41it/s]

  ➤ Predicted: 0.2819, Target: 1.0, Loss: 0.6139


 22%|█████████                                | 1775/8000 [12:13<39:18,  2.64it/s]

  ➤ Predicted: 0.3923, Target: 0.0, Loss: 0.6169


 23%|████████▊                              | 1803/8000 [12:28<1:15:09,  1.37it/s]

  ➤ Predicted: 0.3582, Target: 0.0, Loss: 0.6182


 23%|█████████▎                               | 1815/8000 [12:35<45:12,  2.28it/s]

  ➤ Predicted: 0.3786, Target: 1.0, Loss: 0.6178


 23%|█████████▎                               | 1823/8000 [12:39<59:55,  1.72it/s]

  ➤ Predicted: 0.3391, Target: 1.0, Loss: 0.6189


 23%|████████▉                              | 1833/8000 [12:46<1:06:53,  1.54it/s]

  ➤ Predicted: 0.4999, Target: 0.0, Loss: 0.6185


 24%|█████████▋                               | 1881/8000 [12:55<05:08, 19.86it/s]

  ➤ Predicted: 0.6643, Target: 0.0, Loss: 0.6137


 24%|█████████▋                               | 1895/8000 [13:00<22:16,  4.57it/s]

  ➤ Predicted: 0.3511, Target: 0.0, Loss: 0.6150


 24%|█████████▊                               | 1913/8000 [13:04<35:10,  2.88it/s]

  ➤ Predicted: 0.6128, Target: 1.0, Loss: 0.6156


 24%|█████████▊                               | 1923/8000 [13:08<38:02,  2.66it/s]

  ➤ Predicted: 0.4322, Target: 1.0, Loss: 0.6151


 25%|██████████                               | 1965/8000 [13:24<33:42,  2.98it/s]

  ➤ Predicted: 0.2232, Target: 0.0, Loss: 0.6146


 25%|██████████▏                              | 1993/8000 [13:38<56:39,  1.77it/s]

  ➤ Predicted: 0.2274, Target: 0.0, Loss: 0.6172


 25%|█████████▊                             | 2023/8000 [14:00<1:28:42,  1.12it/s]

  ➤ Predicted: 0.5195, Target: 1.0, Loss: 0.6231


 25%|█████████▉                             | 2033/8000 [14:07<1:04:27,  1.54it/s]

  ➤ Predicted: 0.5271, Target: 1.0, Loss: 0.6238


 26%|██████████▍                              | 2043/8000 [14:10<39:24,  2.52it/s]

  ➤ Predicted: 0.6703, Target: 0.0, Loss: 0.6229


 26%|██████████▌                              | 2053/8000 [14:15<50:43,  1.95it/s]

  ➤ Predicted: 0.6103, Target: 1.0, Loss: 0.6228


 26%|██████████▌                              | 2063/8000 [14:21<57:29,  1.72it/s]

  ➤ Predicted: 0.3894, Target: 0.0, Loss: 0.6255


 26%|██████████▌                              | 2073/8000 [14:25<50:22,  1.96it/s]

  ➤ Predicted: 0.7155, Target: 0.0, Loss: 0.6260


 26%|██████████▎                            | 2103/8000 [14:42<1:02:55,  1.56it/s]

  ➤ Predicted: 0.6727, Target: 1.0, Loss: 0.6258


 26%|██████████▊                              | 2116/8000 [14:51<42:40,  2.30it/s]

  ➤ Predicted: 0.3008, Target: 0.0, Loss: 0.6241


 27%|███████████                              | 2165/8000 [15:11<36:49,  2.64it/s]

  ➤ Predicted: 0.4021, Target: 0.0, Loss: 0.6250


 27%|██████████▋                            | 2193/8000 [15:25<1:17:59,  1.24it/s]

  ➤ Predicted: 0.6704, Target: 1.0, Loss: 0.6265


 28%|██████████▋                            | 2203/8000 [15:32<1:13:25,  1.32it/s]

  ➤ Predicted: 0.2855, Target: 0.0, Loss: 0.6263


 28%|███████████▍                             | 2223/8000 [15:42<50:15,  1.92it/s]

  ➤ Predicted: 0.1531, Target: 0.0, Loss: 0.6290


 28%|██████████▉                            | 2243/8000 [15:58<1:24:09,  1.14it/s]

  ➤ Predicted: 0.6789, Target: 0.0, Loss: 0.6267


 28%|██████████▉                            | 2253/8000 [16:05<1:26:29,  1.11it/s]

  ➤ Predicted: 0.3924, Target: 1.0, Loss: 0.6276


 28%|███████████                            | 2265/8000 [16:16<1:05:20,  1.46it/s]

  ➤ Predicted: 0.4353, Target: 0.0, Loss: 0.6285


 28%|███████████▋                             | 2273/8000 [16:21<58:19,  1.64it/s]

  ➤ Predicted: 0.7550, Target: 1.0, Loss: 0.6274


 29%|███████████▏                           | 2283/8000 [16:29<1:24:34,  1.13it/s]

  ➤ Predicted: 0.3900, Target: 0.0, Loss: 0.6293


 29%|███████████▏                           | 2293/8000 [16:37<1:27:25,  1.09it/s]

  ➤ Predicted: 0.6116, Target: 1.0, Loss: 0.6320


 29%|███████████▏                           | 2303/8000 [16:45<1:27:39,  1.08it/s]

  ➤ Predicted: 0.1088, Target: 0.0, Loss: 0.6310


 29%|███████████▎                           | 2313/8000 [16:54<1:32:18,  1.03it/s]

  ➤ Predicted: 0.5810, Target: 1.0, Loss: 0.6325


 30%|███████████▌                           | 2373/8000 [17:34<1:05:52,  1.42it/s]

  ➤ Predicted: 0.6938, Target: 1.0, Loss: 0.6357


 30%|████████████▏                            | 2385/8000 [17:39<34:11,  2.74it/s]

  ➤ Predicted: 0.3924, Target: 1.0, Loss: 0.6353


 30%|███████████▋                           | 2403/8000 [17:50<1:09:13,  1.35it/s]

  ➤ Predicted: 0.5017, Target: 1.0, Loss: 0.6370


 30%|████████████▍                            | 2415/8000 [17:57<39:05,  2.38it/s]

  ➤ Predicted: 0.5077, Target: 1.0, Loss: 0.6354


 30%|████████████▍                            | 2433/8000 [18:05<39:17,  2.36it/s]

  ➤ Predicted: 0.4754, Target: 1.0, Loss: 0.6325


 31%|████████████▋                            | 2476/8000 [18:15<17:43,  5.20it/s]

  ➤ Predicted: 0.6596, Target: 0.0, Loss: 0.6290


 31%|████████████▊                            | 2495/8000 [18:24<38:25,  2.39it/s]

  ➤ Predicted: 0.2835, Target: 0.0, Loss: 0.6290


 31%|████████████▊                            | 2503/8000 [18:25<27:21,  3.35it/s]

  ➤ Predicted: 0.4735, Target: 1.0, Loss: 0.6285


 32%|████████████▉                            | 2523/8000 [18:33<46:07,  1.98it/s]

  ➤ Predicted: 0.6089, Target: 1.0, Loss: 0.6295


 32%|████████████▎                          | 2533/8000 [18:42<1:32:48,  1.02s/it]

  ➤ Predicted: 0.7595, Target: 1.0, Loss: 0.6295


 32%|█████████████                            | 2545/8000 [18:50<46:32,  1.95it/s]

  ➤ Predicted: 0.6951, Target: 1.0, Loss: 0.6316


 32%|█████████████▏                           | 2563/8000 [19:01<57:54,  1.56it/s]

  ➤ Predicted: 0.6997, Target: 1.0, Loss: 0.6314


 32%|████████████▌                          | 2583/8000 [19:22<1:53:44,  1.26s/it]

  ➤ Predicted: 0.4485, Target: 1.0, Loss: 0.6321


 33%|█████████████▎                           | 2603/8000 [19:32<43:11,  2.08it/s]

  ➤ Predicted: 0.5842, Target: 1.0, Loss: 0.6342


 33%|████████████▋                          | 2613/8000 [19:41<1:20:54,  1.11it/s]

  ➤ Predicted: 0.7490, Target: 1.0, Loss: 0.6341


 33%|█████████████▋                           | 2665/8000 [19:57<33:24,  2.66it/s]

  ➤ Predicted: 0.5077, Target: 1.0, Loss: 0.6318


 34%|█████████████▊                           | 2683/8000 [20:06<37:31,  2.36it/s]

  ➤ Predicted: 0.3936, Target: 0.0, Loss: 0.6326


 34%|█████████████▊                           | 2693/8000 [20:12<57:52,  1.53it/s]

  ➤ Predicted: 0.7824, Target: 1.0, Loss: 0.6317


 34%|█████████████▏                         | 2703/8000 [20:19<1:07:36,  1.31it/s]

  ➤ Predicted: 0.4141, Target: 0.0, Loss: 0.6313


 34%|█████████████▏                         | 2713/8000 [20:26<1:02:05,  1.42it/s]

  ➤ Predicted: 0.2963, Target: 0.0, Loss: 0.6323


 34%|█████████████▎                         | 2723/8000 [20:35<1:29:22,  1.02s/it]

  ➤ Predicted: 0.8017, Target: 1.0, Loss: 0.6333


 34%|██████████████                           | 2744/8000 [20:46<33:57,  2.58it/s]

  ➤ Predicted: 0.5599, Target: 0.0, Loss: 0.6327


 35%|██████████████▎                          | 2785/8000 [21:02<37:47,  2.30it/s]

  ➤ Predicted: 0.6229, Target: 1.0, Loss: 0.6309


 35%|██████████████▎                          | 2803/8000 [21:11<52:14,  1.66it/s]

  ➤ Predicted: 0.6019, Target: 0.0, Loss: 0.6301


 36%|██████████████▋                          | 2854/8000 [21:38<27:21,  3.14it/s]

  ➤ Predicted: 0.6815, Target: 0.0, Loss: 0.6327


 36%|█████████████▉                         | 2863/8000 [21:45<1:12:07,  1.19it/s]

  ➤ Predicted: 0.6948, Target: 1.0, Loss: 0.6338


 36%|██████████████                         | 2873/8000 [21:54<1:20:42,  1.06it/s]

  ➤ Predicted: 0.3396, Target: 1.0, Loss: 0.6340


 36%|██████████████▊                          | 2883/8000 [21:59<48:45,  1.75it/s]

  ➤ Predicted: 0.5702, Target: 0.0, Loss: 0.6332


 36%|██████████████                         | 2893/8000 [22:05<1:00:21,  1.41it/s]

  ➤ Predicted: 0.4711, Target: 0.0, Loss: 0.6331


 36%|██████████████▏                        | 2913/8000 [22:21<1:04:17,  1.32it/s]

  ➤ Predicted: 0.6112, Target: 1.0, Loss: 0.6341


 37%|██████████████▏                        | 2923/8000 [22:29<1:05:27,  1.29it/s]

  ➤ Predicted: 0.6387, Target: 0.0, Loss: 0.6348


 38%|███████████████▌                         | 3032/8000 [22:42<11:38,  7.12it/s]

  ➤ Predicted: 0.4438, Target: 1.0, Loss: 0.6343


 38%|███████████████▋                         | 3053/8000 [22:52<50:19,  1.64it/s]

  ➤ Predicted: 0.6608, Target: 0.0, Loss: 0.6350


 38%|███████████████▋                         | 3066/8000 [22:56<22:24,  3.67it/s]

  ➤ Predicted: 0.4411, Target: 0.0, Loss: 0.6350


 38%|███████████████▊                         | 3078/8000 [23:04<26:02,  3.15it/s]

Error on sample 15655: NaN in input


 39%|███████████████▊                         | 3093/8000 [23:09<33:00,  2.48it/s]

  ➤ Predicted: 0.5556, Target: 0.0, Loss: 0.6326


 39%|███████████████▉                         | 3103/8000 [23:14<45:44,  1.78it/s]

  ➤ Predicted: 0.7490, Target: 1.0, Loss: 0.6339


 39%|███████████████▏                       | 3113/8000 [23:23<1:25:39,  1.05s/it]

  ➤ Predicted: 0.3945, Target: 0.0, Loss: 0.6352


 39%|███████████████▎                       | 3133/8000 [23:38<1:10:44,  1.15it/s]

  ➤ Predicted: 0.3950, Target: 0.0, Loss: 0.6344


 39%|███████████████▎                       | 3143/8000 [23:46<1:06:58,  1.21it/s]

  ➤ Predicted: 0.1132, Target: 0.0, Loss: 0.6350


 39%|███████████████▎                       | 3153/8000 [23:54<1:09:51,  1.16it/s]

  ➤ Predicted: 0.3732, Target: 1.0, Loss: 0.6356


 40%|████████████████▎                        | 3183/8000 [24:09<38:31,  2.08it/s]

  ➤ Predicted: 0.6425, Target: 0.0, Loss: 0.6380


 40%|████████████████▍                        | 3215/8000 [24:31<42:54,  1.86it/s]

  ➤ Predicted: 0.7989, Target: 1.0, Loss: 0.6364


 40%|████████████████▌                        | 3233/8000 [24:42<55:23,  1.43it/s]

  ➤ Predicted: 0.2970, Target: 0.0, Loss: 0.6377


 41%|████████████████▌                        | 3243/8000 [24:51<57:24,  1.38it/s]

  ➤ Predicted: 0.7966, Target: 1.0, Loss: 0.6383


 41%|███████████████▊                       | 3253/8000 [25:01<1:21:36,  1.03s/it]

  ➤ Predicted: 0.6536, Target: 1.0, Loss: 0.6397


 41%|████████████████▋                        | 3263/8000 [25:06<55:28,  1.42it/s]

  ➤ Predicted: 0.4234, Target: 0.0, Loss: 0.6390


 41%|████████████████▊                        | 3283/8000 [25:16<53:07,  1.48it/s]

  ➤ Predicted: 0.6019, Target: 0.0, Loss: 0.6378


 41%|████████████████▊                        | 3292/8000 [25:19<18:09,  4.32it/s]

  ➤ Predicted: 0.6959, Target: 1.0, Loss: 0.6379


 41%|████████████████▉                        | 3298/8000 [25:22<32:57,  2.38it/s]

Error on sample 15875: NaN in input


 41%|████████████████                       | 3303/8000 [25:27<1:13:00,  1.07it/s]

  ➤ Predicted: 0.7500, Target: 0.0, Loss: 0.6369


 41%|████████████████▉                        | 3312/8000 [25:31<35:58,  2.17it/s]

Error on sample 15889: NaN in input


 42%|█████████████████                        | 3335/8000 [25:43<36:09,  2.15it/s]

  ➤ Predicted: 0.6001, Target: 0.0, Loss: 0.6376
Error on sample 15911: NaN in input


 42%|█████████████████▏                       | 3343/8000 [25:49<52:24,  1.48it/s]

  ➤ Predicted: 0.4895, Target: 0.0, Loss: 0.6381


 42%|████████████████▎                      | 3353/8000 [25:58<1:17:28,  1.00s/it]

  ➤ Predicted: 0.5630, Target: 1.0, Loss: 0.6389


 42%|█████████████████▎                       | 3373/8000 [26:09<27:07,  2.84it/s]

  ➤ Predicted: 0.3715, Target: 0.0, Loss: 0.6387


 42%|█████████████████▎                       | 3383/8000 [26:16<45:14,  1.70it/s]

  ➤ Predicted: 0.3606, Target: 1.0, Loss: 0.6394


 42%|█████████████████▍                       | 3393/8000 [26:24<56:19,  1.36it/s]

  ➤ Predicted: 0.8240, Target: 1.0, Loss: 0.6394


 43%|████████████████▋                      | 3433/8000 [26:50<1:01:42,  1.23it/s]

  ➤ Predicted: 0.5390, Target: 0.0, Loss: 0.6375


 43%|█████████████████▋                       | 3453/8000 [26:58<32:54,  2.30it/s]

  ➤ Predicted: 0.6340, Target: 1.0, Loss: 0.6372


 44%|█████████████████▉                       | 3505/8000 [27:27<25:56,  2.89it/s]

  ➤ Predicted: 0.5822, Target: 1.0, Loss: 0.6386


 44%|█████████████████▎                     | 3553/8000 [28:03<1:05:34,  1.13it/s]

  ➤ Predicted: 0.4353, Target: 0.0, Loss: 0.6377


 45%|██████████████████▎                      | 3575/8000 [28:22<39:56,  1.85it/s]

  ➤ Predicted: 0.4421, Target: 1.0, Loss: 0.6396


 45%|██████████████████▍                      | 3594/8000 [28:31<28:54,  2.54it/s]

  ➤ Predicted: 0.3916, Target: 0.0, Loss: 0.6395


 45%|█████████████████▌                     | 3603/8000 [28:38<1:01:55,  1.18it/s]

  ➤ Predicted: 0.5072, Target: 1.0, Loss: 0.6402


 45%|█████████████████▌                     | 3613/8000 [28:45<1:06:23,  1.10it/s]

  ➤ Predicted: 0.6500, Target: 1.0, Loss: 0.6410


 45%|█████████████████▋                     | 3623/8000 [28:56<1:19:26,  1.09s/it]

  ➤ Predicted: 0.1448, Target: 0.0, Loss: 0.6414


 45%|██████████████████▌                      | 3633/8000 [29:02<33:57,  2.14it/s]

  ➤ Predicted: 0.4600, Target: 0.0, Loss: 0.6418


 46%|██████████████████▋                      | 3643/8000 [29:07<46:22,  1.57it/s]

  ➤ Predicted: 0.3335, Target: 0.0, Loss: 0.6429


 46%|██████████████████▊                      | 3665/8000 [29:16<31:15,  2.31it/s]

  ➤ Predicted: 0.5134, Target: 0.0, Loss: 0.6422


 46%|██████████████████▊                      | 3673/8000 [29:20<36:37,  1.97it/s]

  ➤ Predicted: 0.7684, Target: 1.0, Loss: 0.6428


 46%|██████████████████▉                      | 3703/8000 [29:34<22:25,  3.19it/s]

  ➤ Predicted: 0.4312, Target: 1.0, Loss: 0.6424


 46%|███████████████████                      | 3715/8000 [29:40<28:27,  2.51it/s]

  ➤ Predicted: 0.1780, Target: 0.0, Loss: 0.6423


 47%|███████████████████                      | 3723/8000 [29:45<44:59,  1.58it/s]

  ➤ Predicted: 0.3661, Target: 0.0, Loss: 0.6420


 47%|███████████████████▏                     | 3733/8000 [29:50<43:39,  1.63it/s]

  ➤ Predicted: 0.4071, Target: 0.0, Loss: 0.6422


 47%|███████████████████▏                     | 3743/8000 [29:55<36:37,  1.94it/s]

  ➤ Predicted: 0.7113, Target: 1.0, Loss: 0.6425


 47%|███████████████████▏                     | 3755/8000 [30:03<32:28,  2.18it/s]

  ➤ Predicted: 0.1374, Target: 0.0, Loss: 0.6433


 47%|███████████████████▎                     | 3763/8000 [30:09<55:27,  1.27it/s]

  ➤ Predicted: 0.6790, Target: 0.0, Loss: 0.6423


 47%|███████████████████▍                     | 3783/8000 [30:21<39:18,  1.79it/s]

  ➤ Predicted: 0.6161, Target: 1.0, Loss: 0.6420


 47%|███████████████████▍                     | 3793/8000 [30:29<46:04,  1.52it/s]

  ➤ Predicted: 0.0538, Target: 0.0, Loss: 0.6429


 48%|██████████████████▌                    | 3803/8000 [30:39<1:14:48,  1.07s/it]

  ➤ Predicted: 0.7454, Target: 0.0, Loss: 0.6435


 48%|███████████████████▌                     | 3815/8000 [30:46<28:56,  2.41it/s]

  ➤ Predicted: 0.3366, Target: 0.0, Loss: 0.6431


 48%|███████████████████▋                     | 3833/8000 [30:54<30:03,  2.31it/s]

  ➤ Predicted: 0.7322, Target: 1.0, Loss: 0.6426


 48%|███████████████████▋                     | 3848/8000 [31:00<22:59,  3.01it/s]

Error on sample 16425: NaN in input


 48%|██████████████████▊                    | 3853/8000 [31:07<1:23:52,  1.21s/it]

  ➤ Predicted: 0.8055, Target: 0.0, Loss: 0.6425


 48%|███████████████████▊                     | 3865/8000 [31:16<42:40,  1.62it/s]

  ➤ Predicted: 0.5527, Target: 1.0, Loss: 0.6430


 49%|███████████████████▉                     | 3883/8000 [31:22<38:49,  1.77it/s]

  ➤ Predicted: 0.7169, Target: 1.0, Loss: 0.6423


 49%|██████████████████▉                    | 3893/8000 [31:32<1:06:22,  1.03it/s]

  ➤ Predicted: 0.9503, Target: 1.0, Loss: 0.6422


 49%|████████████████████                     | 3903/8000 [31:37<38:05,  1.79it/s]

  ➤ Predicted: 0.5133, Target: 0.0, Loss: 0.6427


 49%|████████████████████                     | 3925/8000 [31:56<42:18,  1.61it/s]

  ➤ Predicted: 0.9379, Target: 1.0, Loss: 0.6432


 49%|████████████████████▏                    | 3936/8000 [32:00<24:03,  2.82it/s]

  ➤ Predicted: 0.7149, Target: 0.0, Loss: 0.6433


 49%|████████████████████▎                    | 3954/8000 [32:07<30:46,  2.19it/s]

  ➤ Predicted: 0.5840, Target: 0.0, Loss: 0.6418


 50%|████████████████████▎                    | 3963/8000 [32:13<44:08,  1.52it/s]

  ➤ Predicted: 0.8251, Target: 1.0, Loss: 0.6417


 50%|████████████████████▋                    | 4025/8000 [32:45<36:01,  1.84it/s]

  ➤ Predicted: 0.5144, Target: 0.0, Loss: 0.6417


 50%|████████████████████▋                    | 4033/8000 [32:49<27:33,  2.40it/s]

  ➤ Predicted: 0.4487, Target: 0.0, Loss: 0.6417


 51%|████████████████████▊                    | 4073/8000 [33:13<57:22,  1.14it/s]

  ➤ Predicted: 0.5937, Target: 1.0, Loss: 0.6425


 51%|████████████████████▉                    | 4083/8000 [33:20<53:31,  1.22it/s]

  ➤ Predicted: 0.1797, Target: 0.0, Loss: 0.6426


 51%|███████████████████▉                   | 4093/8000 [33:29<1:08:36,  1.05s/it]

  ➤ Predicted: 0.8542, Target: 1.0, Loss: 0.6435


 53%|█████████████████████▋                   | 4233/8000 [33:56<18:25,  3.41it/s]

  ➤ Predicted: 0.2124, Target: 0.0, Loss: 0.6437


 53%|█████████████████████▊                   | 4253/8000 [34:06<31:01,  2.01it/s]

  ➤ Predicted: 0.5124, Target: 0.0, Loss: 0.6426


 55%|██████████████████████▎                  | 4363/8000 [34:52<41:26,  1.46it/s]

  ➤ Predicted: 0.2443, Target: 0.0, Loss: 0.6421


 55%|██████████████████████▌                  | 4393/8000 [35:17<58:38,  1.03it/s]

  ➤ Predicted: 0.6423, Target: 0.0, Loss: 0.6421


 55%|██████████████████████▋                  | 4415/8000 [35:32<25:24,  2.35it/s]

  ➤ Predicted: 0.8697, Target: 1.0, Loss: 0.6438


 55%|██████████████████████▋                  | 4425/8000 [35:40<34:51,  1.71it/s]

  ➤ Predicted: 0.3023, Target: 1.0, Loss: 0.6426


 55%|██████████████████████▋                  | 4433/8000 [35:44<25:41,  2.31it/s]

  ➤ Predicted: 0.4350, Target: 0.0, Loss: 0.6427


 56%|██████████████████████▊                  | 4455/8000 [35:54<24:09,  2.45it/s]

  ➤ Predicted: 0.5162, Target: 0.0, Loss: 0.6424


 56%|██████████████████████▊                  | 4463/8000 [36:00<38:52,  1.52it/s]

  ➤ Predicted: 0.5694, Target: 0.0, Loss: 0.6423


 56%|██████████████████████▉                  | 4473/8000 [36:08<49:23,  1.19it/s]

  ➤ Predicted: 0.5982, Target: 1.0, Loss: 0.6426


 56%|███████████████████████                  | 4503/8000 [36:28<42:48,  1.36it/s]

  ➤ Predicted: 0.6679, Target: 0.0, Loss: 0.6421


 57%|██████████████████████                 | 4523/8000 [36:46<1:03:58,  1.10s/it]

  ➤ Predicted: 0.9125, Target: 1.0, Loss: 0.6418


 57%|██████████████████████                 | 4533/8000 [36:57<1:07:54,  1.18s/it]

  ➤ Predicted: 0.4191, Target: 0.0, Loss: 0.6421


 57%|███████████████████████▎                 | 4553/8000 [37:16<50:29,  1.14it/s]

  ➤ Predicted: 0.3723, Target: 0.0, Loss: 0.6415


 57%|██████████████████████▎                | 4573/8000 [37:34<1:04:25,  1.13s/it]

  ➤ Predicted: 0.2497, Target: 0.0, Loss: 0.6426


 57%|██████████████████████▎                | 4583/8000 [37:45<1:07:46,  1.19s/it]

  ➤ Predicted: 0.3894, Target: 0.0, Loss: 0.6433


 57%|███████████████████████▌                 | 4595/8000 [37:51<27:57,  2.03it/s]

  ➤ Predicted: 0.6148, Target: 0.0, Loss: 0.6421


 58%|███████████████████████▌                 | 4603/8000 [37:57<37:08,  1.52it/s]

  ➤ Predicted: 0.5562, Target: 0.0, Loss: 0.6416


 58%|███████████████████████▋                 | 4632/8000 [38:09<20:03,  2.80it/s]

  ➤ Predicted: 0.3404, Target: 1.0, Loss: 0.6416


 58%|███████████████████████▉                 | 4673/8000 [38:32<28:23,  1.95it/s]

  ➤ Predicted: 0.5566, Target: 0.0, Loss: 0.6404


 59%|████████████████████████                 | 4685/8000 [38:38<22:43,  2.43it/s]

  ➤ Predicted: 0.6075, Target: 1.0, Loss: 0.6412


 59%|████████████████████████                 | 4703/8000 [38:54<58:53,  1.07s/it]

  ➤ Predicted: 0.7286, Target: 1.0, Loss: 0.6426


 59%|████████████████████████▏                | 4713/8000 [39:03<48:33,  1.13it/s]

  ➤ Predicted: 0.2862, Target: 1.0, Loss: 0.6425


 59%|████████████████████████▏                | 4723/8000 [39:11<45:31,  1.20it/s]

  ➤ Predicted: 0.3545, Target: 0.0, Loss: 0.6426


 59%|███████████████████████                | 4733/8000 [39:22<1:03:02,  1.16s/it]

  ➤ Predicted: 0.2804, Target: 0.0, Loss: 0.6418


 59%|████████████████████████▎                | 4743/8000 [39:33<51:52,  1.05it/s]

  ➤ Predicted: 0.8223, Target: 1.0, Loss: 0.6422


 59%|███████████████████████▏               | 4753/8000 [39:43<1:03:43,  1.18s/it]

  ➤ Predicted: 0.3377, Target: 0.0, Loss: 0.6427


 60%|████████████████████████▋                | 4813/8000 [39:58<19:40,  2.70it/s]

  ➤ Predicted: 0.1435, Target: 0.0, Loss: 0.6443


 60%|████████████████████████▊                | 4833/8000 [40:04<18:17,  2.89it/s]

  ➤ Predicted: 0.2509, Target: 0.0, Loss: 0.6449


 61%|████████████████████████▊                | 4845/8000 [40:10<22:46,  2.31it/s]

  ➤ Predicted: 0.4876, Target: 1.0, Loss: 0.6448


 61%|████████████████████████▊                | 4853/8000 [40:18<53:02,  1.01s/it]

  ➤ Predicted: 0.1413, Target: 0.0, Loss: 0.6459


 61%|████████████████████████▉                | 4875/8000 [40:32<31:19,  1.66it/s]

  ➤ Predicted: 0.6173, Target: 1.0, Loss: 0.6458


 61%|█████████████████████████▏               | 4913/8000 [40:56<31:02,  1.66it/s]

  ➤ Predicted: 0.5695, Target: 1.0, Loss: 0.6474


 62%|█████████████████████████▎               | 4933/8000 [41:08<45:24,  1.13it/s]

  ➤ Predicted: 0.1237, Target: 0.0, Loss: 0.6482


 62%|█████████████████████████▎               | 4945/8000 [41:16<28:52,  1.76it/s]

  ➤ Predicted: 0.1516, Target: 1.0, Loss: 0.6475


 62%|█████████████████████████▍               | 4963/8000 [41:30<54:03,  1.07s/it]

  ➤ Predicted: 0.4429, Target: 0.0, Loss: 0.6482


 62%|█████████████████████████▍               | 4973/8000 [41:39<39:46,  1.27it/s]

  ➤ Predicted: 0.3625, Target: 0.0, Loss: 0.6488


 62%|█████████████████████████▌               | 4985/8000 [41:50<31:48,  1.58it/s]

  ➤ Predicted: 0.8714, Target: 1.0, Loss: 0.6484


 62%|█████████████████████████▌               | 4993/8000 [41:56<40:33,  1.24it/s]

  ➤ Predicted: 0.6159, Target: 1.0, Loss: 0.6483


 63%|█████████████████████████▋               | 5003/8000 [42:06<52:20,  1.05s/it]

  ➤ Predicted: 0.6454, Target: 1.0, Loss: 0.6487


 63%|█████████████████████████▋               | 5013/8000 [42:13<34:30,  1.44it/s]

  ➤ Predicted: 0.7773, Target: 1.0, Loss: 0.6490


 63%|█████████████████████████▋               | 5023/8000 [42:21<46:21,  1.07it/s]

  ➤ Predicted: 0.4563, Target: 1.0, Loss: 0.6492


 63%|█████████████████████████▊               | 5033/8000 [42:29<37:14,  1.33it/s]

  ➤ Predicted: 0.1724, Target: 0.0, Loss: 0.6496


 63%|█████████████████████████▊               | 5043/8000 [42:39<57:09,  1.16s/it]

  ➤ Predicted: 0.7814, Target: 1.0, Loss: 0.6500


 63%|█████████████████████████▉               | 5063/8000 [42:57<44:58,  1.09it/s]

  ➤ Predicted: 0.3903, Target: 1.0, Loss: 0.6491


 63%|█████████████████████████▉               | 5073/8000 [43:03<28:13,  1.73it/s]

  ➤ Predicted: 0.8637, Target: 1.0, Loss: 0.6488


 64%|██████████████████████████               | 5095/8000 [43:17<21:01,  2.30it/s]

  ➤ Predicted: 0.4623, Target: 0.0, Loss: 0.6500


 64%|██████████████████████████▏              | 5103/8000 [43:24<46:43,  1.03it/s]

  ➤ Predicted: 0.2271, Target: 0.0, Loss: 0.6505


 64%|██████████████████████████▎              | 5123/8000 [43:48<49:05,  1.02s/it]

  ➤ Predicted: 0.1176, Target: 0.0, Loss: 0.6515


 64%|██████████████████████████▎              | 5133/8000 [43:59<54:40,  1.14s/it]

  ➤ Predicted: 0.5573, Target: 0.0, Loss: 0.6514


 64%|██████████████████████████▍              | 5153/8000 [44:20<55:08,  1.16s/it]

  ➤ Predicted: 0.8234, Target: 0.0, Loss: 0.6520


 65%|██████████████████████████▌              | 5193/8000 [44:47<36:18,  1.29it/s]

  ➤ Predicted: 0.4526, Target: 1.0, Loss: 0.6526


 65%|██████████████████████████▋              | 5203/8000 [44:54<32:05,  1.45it/s]

  ➤ Predicted: 0.2334, Target: 1.0, Loss: 0.6528


 65%|██████████████████████████▋              | 5215/8000 [45:04<30:30,  1.52it/s]

  ➤ Predicted: 0.4599, Target: 1.0, Loss: 0.6520


 65%|██████████████████████████▊              | 5223/8000 [45:13<49:53,  1.08s/it]

  ➤ Predicted: 0.7288, Target: 1.0, Loss: 0.6525


 65%|██████████████████████████▊              | 5233/8000 [45:22<46:34,  1.01s/it]

  ➤ Predicted: 0.6889, Target: 1.0, Loss: 0.6531


 66%|█████████████████████████▌             | 5243/8000 [45:35<1:01:57,  1.35s/it]

  ➤ Predicted: 0.8170, Target: 1.0, Loss: 0.6541


 66%|█████████████████████████▌             | 5253/8000 [45:50<1:14:48,  1.63s/it]

  ➤ Predicted: 0.7705, Target: 0.0, Loss: 0.6539


 66%|███████████████████████████              | 5283/8000 [46:13<37:50,  1.20it/s]

  ➤ Predicted: 0.3706, Target: 0.0, Loss: 0.6547


 66%|███████████████████████████▏             | 5313/8000 [46:29<17:25,  2.57it/s]

  ➤ Predicted: 0.3284, Target: 1.0, Loss: 0.6545


 67%|███████████████████████████▍             | 5345/8000 [46:52<28:35,  1.55it/s]

  ➤ Predicted: 0.8715, Target: 1.0, Loss: 0.6552


 67%|███████████████████████████▍             | 5355/8000 [46:59<26:24,  1.67it/s]

  ➤ Predicted: 0.5156, Target: 0.0, Loss: 0.6556


 67%|███████████████████████████▍             | 5363/8000 [47:07<38:27,  1.14it/s]

  ➤ Predicted: 0.2745, Target: 0.0, Loss: 0.6553


 67%|███████████████████████████▌             | 5373/8000 [47:15<42:07,  1.04it/s]

  ➤ Predicted: 0.9062, Target: 1.0, Loss: 0.6549


 67%|███████████████████████████▌             | 5383/8000 [47:22<35:24,  1.23it/s]

  ➤ Predicted: 0.6802, Target: 1.0, Loss: 0.6546


 67%|███████████████████████████▋             | 5393/8000 [47:29<28:06,  1.55it/s]

  ➤ Predicted: 0.3528, Target: 1.0, Loss: 0.6541


 68%|███████████████████████████▋             | 5405/8000 [47:39<27:50,  1.55it/s]

  ➤ Predicted: 0.4861, Target: 1.0, Loss: 0.6540


 68%|███████████████████████████▉             | 5453/8000 [48:01<11:52,  3.57it/s]

Error on sample 18028: NaN in input


 69%|████████████████████████████▏            | 5503/8000 [48:27<32:07,  1.30it/s]

  ➤ Predicted: 0.7504, Target: 1.0, Loss: 0.6554


 69%|████████████████████████████▎            | 5513/8000 [48:39<50:54,  1.23s/it]

  ➤ Predicted: 0.4256, Target: 0.0, Loss: 0.6558


 69%|████████████████████████████▎            | 5523/8000 [48:50<52:45,  1.28s/it]

  ➤ Predicted: 0.1175, Target: 0.0, Loss: 0.6566


 69%|████████████████████████████▎            | 5535/8000 [49:01<22:55,  1.79it/s]

  ➤ Predicted: 0.5215, Target: 0.0, Loss: 0.6573


 69%|████████████████████████████▍            | 5543/8000 [49:07<30:54,  1.32it/s]

  ➤ Predicted: 0.8109, Target: 0.0, Loss: 0.6568


 70%|████████████████████████████▌            | 5583/8000 [49:24<26:33,  1.52it/s]

  ➤ Predicted: 0.3773, Target: 0.0, Loss: 0.6568


 70%|████████████████████████████▋            | 5603/8000 [49:34<23:09,  1.73it/s]

  ➤ Predicted: 0.4160, Target: 0.0, Loss: 0.6571


 70%|████████████████████████████▊            | 5633/8000 [50:02<44:18,  1.12s/it]

  ➤ Predicted: 0.1706, Target: 1.0, Loss: 0.6573


 71%|█████████████████████████████            | 5673/8000 [50:24<27:46,  1.40it/s]

  ➤ Predicted: 0.2357, Target: 0.0, Loss: 0.6571


 71%|█████████████████████████████▏           | 5685/8000 [50:33<18:38,  2.07it/s]

  ➤ Predicted: 0.5605, Target: 1.0, Loss: 0.6569


 71%|█████████████████████████████▏           | 5693/8000 [50:40<41:57,  1.09s/it]

  ➤ Predicted: 0.6618, Target: 0.0, Loss: 0.6562


 72%|█████████████████████████████▎           | 5722/8000 [50:53<06:15,  6.07it/s]

  ➤ Predicted: 0.6749, Target: 0.0, Loss: 0.6560


 72%|█████████████████████████████▍           | 5736/8000 [51:02<14:47,  2.55it/s]

  ➤ Predicted: 0.3811, Target: 1.0, Loss: 0.6560


 72%|█████████████████████████████▍           | 5753/8000 [51:18<44:23,  1.19s/it]

  ➤ Predicted: 0.4506, Target: 0.0, Loss: 0.6560


 72%|█████████████████████████████▋           | 5793/8000 [51:42<24:13,  1.52it/s]

  ➤ Predicted: 0.3472, Target: 0.0, Loss: 0.6561


 73%|█████████████████████████████▊           | 5813/8000 [51:57<22:49,  1.60it/s]

  ➤ Predicted: 0.3741, Target: 0.0, Loss: 0.6557


 73%|█████████████████████████████▊           | 5823/8000 [52:03<22:03,  1.65it/s]

  ➤ Predicted: 0.4363, Target: 1.0, Loss: 0.6558


 75%|██████████████████████████████▋          | 5976/8000 [52:37<05:14,  6.43it/s]

  ➤ Predicted: 0.5340, Target: 1.0, Loss: 0.6557


 75%|██████████████████████████████▋          | 5983/8000 [52:43<23:31,  1.43it/s]

  ➤ Predicted: 0.5089, Target: 1.0, Loss: 0.6562


 75%|██████████████████████████████▉          | 6033/8000 [53:13<32:13,  1.02it/s]

  ➤ Predicted: 0.6295, Target: 1.0, Loss: 0.6557


 76%|██████████████████████████████▉          | 6043/8000 [53:21<30:56,  1.05it/s]

  ➤ Predicted: 0.6026, Target: 0.0, Loss: 0.6557


 76%|███████████████████████████████          | 6053/8000 [53:33<39:33,  1.22s/it]

  ➤ Predicted: 0.7190, Target: 1.0, Loss: 0.6564


 76%|███████████████████████████████          | 6073/8000 [53:50<41:16,  1.29s/it]

  ➤ Predicted: 0.9486, Target: 1.0, Loss: 0.6573


 76%|███████████████████████████████▏         | 6095/8000 [54:15<23:53,  1.33it/s]

  ➤ Predicted: 0.5371, Target: 0.0, Loss: 0.6580


 76%|███████████████████████████████▎         | 6105/8000 [54:19<12:08,  2.60it/s]

  ➤ Predicted: 0.2879, Target: 0.0, Loss: 0.6580


 77%|███████████████████████████████▍         | 6133/8000 [54:38<28:15,  1.10it/s]

  ➤ Predicted: 0.1419, Target: 1.0, Loss: 0.6590


 77%|███████████████████████████████▍         | 6142/8000 [54:39<06:42,  4.62it/s]

  ➤ Predicted: 0.3649, Target: 0.0, Loss: 0.6592


 77%|███████████████████████████████▌         | 6165/8000 [54:57<19:00,  1.61it/s]

  ➤ Predicted: 0.2544, Target: 0.0, Loss: 0.6585


 77%|███████████████████████████████▋         | 6173/8000 [55:06<31:49,  1.05s/it]

  ➤ Predicted: 0.3113, Target: 0.0, Loss: 0.6587


 77%|██████████████████████████████▏        | 6183/8000 [55:24<1:08:54,  2.28s/it]

  ➤ Predicted: 0.5090, Target: 0.0, Loss: 0.6589


 77%|███████████████████████████████▋         | 6193/8000 [55:34<28:56,  1.04it/s]

  ➤ Predicted: 0.6230, Target: 0.0, Loss: 0.6594


 78%|███████████████████████████████▊         | 6215/8000 [55:48<15:09,  1.96it/s]

  ➤ Predicted: 0.7868, Target: 0.0, Loss: 0.6587


 78%|████████████████████████████████         | 6253/8000 [56:11<23:38,  1.23it/s]

  ➤ Predicted: 0.5045, Target: 0.0, Loss: 0.6586


 78%|████████████████████████████████         | 6263/8000 [56:18<22:53,  1.26it/s]

  ➤ Predicted: 0.5131, Target: 1.0, Loss: 0.6581


 79%|████████████████████████████████▏        | 6283/8000 [56:35<22:23,  1.28it/s]

  ➤ Predicted: 0.6870, Target: 1.0, Loss: 0.6582


 79%|████████████████████████████████▎        | 6293/8000 [56:47<35:05,  1.23s/it]

  ➤ Predicted: 0.7814, Target: 1.0, Loss: 0.6582


 79%|████████████████████████████████▎        | 6315/8000 [56:59<14:09,  1.98it/s]

  ➤ Predicted: 0.8290, Target: 1.0, Loss: 0.6575


 79%|████████████████████████████████▍        | 6333/8000 [57:13<23:45,  1.17it/s]

  ➤ Predicted: 0.5385, Target: 1.0, Loss: 0.6583


 79%|████████████████████████████████▌        | 6343/8000 [57:23<33:11,  1.20s/it]

  ➤ Predicted: 0.8103, Target: 1.0, Loss: 0.6589


 80%|████████████████████████████████▊        | 6395/8000 [57:42<11:31,  2.32it/s]

  ➤ Predicted: 0.3971, Target: 0.0, Loss: 0.6577


 80%|████████████████████████████████▉        | 6415/8000 [57:53<12:29,  2.11it/s]

  ➤ Predicted: 0.5115, Target: 1.0, Loss: 0.6582


 80%|████████████████████████████████▉        | 6423/8000 [58:00<22:59,  1.14it/s]

  ➤ Predicted: 0.7550, Target: 0.0, Loss: 0.6584


 81%|█████████████████████████████████        | 6445/8000 [58:14<15:50,  1.64it/s]

  ➤ Predicted: 0.5625, Target: 1.0, Loss: 0.6580


 81%|█████████████████████████████████        | 6453/8000 [58:21<23:54,  1.08it/s]

  ➤ Predicted: 0.6517, Target: 0.0, Loss: 0.6579


 81%|█████████████████████████████████        | 6463/8000 [58:29<18:09,  1.41it/s]

  ➤ Predicted: 0.4078, Target: 1.0, Loss: 0.6576


 81%|█████████████████████████████████▏       | 6473/8000 [58:40<28:54,  1.14s/it]

  ➤ Predicted: 0.6500, Target: 1.0, Loss: 0.6579


 81%|█████████████████████████████████▏       | 6483/8000 [58:49<22:05,  1.14it/s]

  ➤ Predicted: 0.1879, Target: 0.0, Loss: 0.6582


 81%|█████████████████████████████████▎       | 6495/8000 [59:01<17:05,  1.47it/s]

  ➤ Predicted: 0.5408, Target: 1.0, Loss: 0.6586


 81%|█████████████████████████████████▎       | 6503/8000 [59:12<32:09,  1.29s/it]

  ➤ Predicted: 0.0711, Target: 0.0, Loss: 0.6591


 81%|█████████████████████████████████▍       | 6513/8000 [59:22<29:34,  1.19s/it]

  ➤ Predicted: 0.3884, Target: 0.0, Loss: 0.6600


 82%|█████████████████████████████████▍       | 6533/8000 [59:42<27:17,  1.12s/it]

  ➤ Predicted: 0.5933, Target: 1.0, Loss: 0.6606


 82%|█████████████████████████████████▌       | 6543/8000 [59:54<25:29,  1.05s/it]

  ➤ Predicted: 0.7641, Target: 1.0, Loss: 0.6612


 82%|███████████████████████████████▉       | 6553/8000 [1:00:06<31:49,  1.32s/it]

  ➤ Predicted: 0.6613, Target: 1.0, Loss: 0.6603


 82%|███████████████████████████████▉       | 6563/8000 [1:00:18<29:06,  1.22s/it]

  ➤ Predicted: 0.5737, Target: 1.0, Loss: 0.6609


 82%|████████████████████████████████       | 6583/8000 [1:00:34<26:11,  1.11s/it]

  ➤ Predicted: 0.2875, Target: 0.0, Loss: 0.6619


 83%|████████████████████████████████▏      | 6605/8000 [1:00:47<12:10,  1.91it/s]

  ➤ Predicted: 0.2951, Target: 0.0, Loss: 0.6619


 83%|████████████████████████████████▎      | 6623/8000 [1:01:02<21:48,  1.05it/s]

  ➤ Predicted: 0.5008, Target: 1.0, Loss: 0.6615


 83%|████████████████████████████████▍      | 6643/8000 [1:01:13<12:31,  1.81it/s]

  ➤ Predicted: 0.4696, Target: 0.0, Loss: 0.6617


 83%|████████████████████████████████▍      | 6662/8000 [1:01:25<10:14,  2.18it/s]

  ➤ Predicted: 0.1468, Target: 0.0, Loss: 0.6623


 84%|████████████████████████████████▌      | 6683/8000 [1:01:40<17:37,  1.25it/s]

  ➤ Predicted: 0.3813, Target: 0.0, Loss: 0.6628


 84%|████████████████████████████████▋      | 6693/8000 [1:01:46<11:43,  1.86it/s]

  ➤ Predicted: 0.6262, Target: 1.0, Loss: 0.6629


 84%|████████████████████████████████▋      | 6713/8000 [1:01:56<14:52,  1.44it/s]

  ➤ Predicted: 0.6693, Target: 1.0, Loss: 0.6633


 84%|████████████████████████████████▊      | 6723/8000 [1:02:07<21:28,  1.01s/it]

  ➤ Predicted: 0.1729, Target: 0.0, Loss: 0.6632


 84%|████████████████████████████████▊      | 6733/8000 [1:02:19<22:10,  1.05s/it]

  ➤ Predicted: 0.1687, Target: 0.0, Loss: 0.6638


 84%|████████████████████████████████▉      | 6755/8000 [1:02:36<08:09,  2.55it/s]

  ➤ Predicted: 0.5619, Target: 0.0, Loss: 0.6638


 85%|████████████████████████████████▉      | 6765/8000 [1:02:42<10:16,  2.00it/s]

  ➤ Predicted: 0.3874, Target: 1.0, Loss: 0.6639


 85%|█████████████████████████████████      | 6785/8000 [1:02:55<09:12,  2.20it/s]

  ➤ Predicted: 0.7941, Target: 1.0, Loss: 0.6633


 85%|█████████████████████████████████      | 6793/8000 [1:03:02<16:07,  1.25it/s]

  ➤ Predicted: 0.9636, Target: 1.0, Loss: 0.6629


 85%|█████████████████████████████████▏     | 6815/8000 [1:03:19<11:40,  1.69it/s]

  ➤ Predicted: 0.1702, Target: 1.0, Loss: 0.6620


 85%|█████████████████████████████████▎     | 6823/8000 [1:03:30<25:38,  1.31s/it]

  ➤ Predicted: 0.4470, Target: 1.0, Loss: 0.6622


 85%|█████████████████████████████████▎     | 6833/8000 [1:03:40<21:37,  1.11s/it]

  ➤ Predicted: 0.6386, Target: 0.0, Loss: 0.6621


 86%|█████████████████████████████████▎     | 6843/8000 [1:03:53<23:39,  1.23s/it]

  ➤ Predicted: 0.8695, Target: 1.0, Loss: 0.6612


 86%|█████████████████████████████████▍     | 6853/8000 [1:04:04<20:03,  1.05s/it]

  ➤ Predicted: 0.6780, Target: 1.0, Loss: 0.6614


 86%|█████████████████████████████████▍     | 6866/8000 [1:04:11<07:15,  2.60it/s]

  ➤ Predicted: 0.5882, Target: 1.0, Loss: 0.6619


 86%|█████████████████████████████████▌     | 6885/8000 [1:04:20<10:01,  1.86it/s]

  ➤ Predicted: 0.7080, Target: 1.0, Loss: 0.6619


 86%|█████████████████████████████████▌     | 6893/8000 [1:04:24<11:50,  1.56it/s]

  ➤ Predicted: 0.3632, Target: 0.0, Loss: 0.6622


 86%|█████████████████████████████████▋     | 6916/8000 [1:04:37<06:14,  2.89it/s]

  ➤ Predicted: 0.4440, Target: 0.0, Loss: 0.6616


 87%|█████████████████████████████████▊     | 6933/8000 [1:04:47<11:54,  1.49it/s]

  ➤ Predicted: 0.6914, Target: 1.0, Loss: 0.6614


 87%|█████████████████████████████████▉     | 6974/8000 [1:05:08<06:40,  2.56it/s]

  ➤ Predicted: 0.6345, Target: 0.0, Loss: 0.6619


 87%|██████████████████████████████████     | 6983/8000 [1:05:15<13:36,  1.25it/s]

  ➤ Predicted: 0.3327, Target: 1.0, Loss: 0.6618


 87%|██████████████████████████████████     | 6993/8000 [1:05:22<11:01,  1.52it/s]

  ➤ Predicted: 0.6808, Target: 0.0, Loss: 0.6618


 88%|██████████████████████████████████▍    | 7055/8000 [1:06:05<04:26,  3.55it/s]

  ➤ Predicted: 0.2961, Target: 1.0, Loss: 0.6626


 88%|██████████████████████████████████▍    | 7063/8000 [1:06:10<08:58,  1.74it/s]

  ➤ Predicted: 0.6932, Target: 1.0, Loss: 0.6627


 89%|██████████████████████████████████▋    | 7105/8000 [1:06:29<04:06,  3.63it/s]

  ➤ Predicted: 0.4651, Target: 0.0, Loss: 0.6634


 90%|███████████████████████████████████    | 7183/8000 [1:06:51<06:35,  2.06it/s]

  ➤ Predicted: 0.6236, Target: 0.0, Loss: 0.6629


 91%|███████████████████████████████████▌   | 7294/8000 [1:07:26<05:40,  2.07it/s]

  ➤ Predicted: 0.2827, Target: 1.0, Loss: 0.6622


 91%|███████████████████████████████████▌   | 7303/8000 [1:07:35<10:58,  1.06it/s]

  ➤ Predicted: 0.2592, Target: 0.0, Loss: 0.6623


 91%|███████████████████████████████████▋   | 7313/8000 [1:07:44<09:39,  1.19it/s]

  ➤ Predicted: 0.3922, Target: 1.0, Loss: 0.6623


 92%|███████████████████████████████████▋   | 7325/8000 [1:07:57<07:43,  1.46it/s]

  ➤ Predicted: 0.3954, Target: 0.0, Loss: 0.6625


 92%|███████████████████████████████████▊   | 7335/8000 [1:08:00<04:50,  2.29it/s]

  ➤ Predicted: 0.3523, Target: 1.0, Loss: 0.6624


 92%|███████████████████████████████████▊   | 7343/8000 [1:08:08<11:03,  1.01s/it]

  ➤ Predicted: 0.3129, Target: 0.0, Loss: 0.6626


 92%|███████████████████████████████████▊   | 7353/8000 [1:08:16<09:46,  1.10it/s]

  ➤ Predicted: 0.2181, Target: 0.0, Loss: 0.6625


 92%|███████████████████████████████████▉   | 7373/8000 [1:08:37<14:21,  1.37s/it]

  ➤ Predicted: 0.5344, Target: 0.0, Loss: 0.6622


 92%|████████████████████████████████████   | 7395/8000 [1:08:48<05:28,  1.84it/s]

  ➤ Predicted: 0.6495, Target: 1.0, Loss: 0.6626


 93%|████████████████████████████████████   | 7403/8000 [1:08:51<05:04,  1.96it/s]

  ➤ Predicted: 0.6275, Target: 0.0, Loss: 0.6622


 93%|████████████████████████████████████▏  | 7415/8000 [1:08:57<03:26,  2.83it/s]

  ➤ Predicted: 0.3943, Target: 0.0, Loss: 0.6623


 93%|████████████████████████████████████▏  | 7423/8000 [1:09:03<06:35,  1.46it/s]

  ➤ Predicted: 0.7897, Target: 1.0, Loss: 0.6621


 93%|████████████████████████████████████▎  | 7445/8000 [1:09:22<04:43,  1.96it/s]

  ➤ Predicted: 0.6430, Target: 1.0, Loss: 0.6616


 93%|████████████████████████████████████▍  | 7463/8000 [1:09:35<08:40,  1.03it/s]

  ➤ Predicted: 0.2378, Target: 0.0, Loss: 0.6616


 93%|████████████████████████████████████▍  | 7473/8000 [1:09:40<05:34,  1.58it/s]

  ➤ Predicted: 0.3598, Target: 1.0, Loss: 0.6611


 94%|████████████████████████████████████▍  | 7483/8000 [1:09:49<07:18,  1.18it/s]

  ➤ Predicted: 0.5561, Target: 0.0, Loss: 0.6611


 94%|████████████████████████████████████▌  | 7493/8000 [1:10:02<10:17,  1.22s/it]

  ➤ Predicted: 0.5578, Target: 1.0, Loss: 0.6611


 94%|████████████████████████████████████▋  | 7513/8000 [1:10:18<07:39,  1.06it/s]

  ➤ Predicted: 0.3862, Target: 0.0, Loss: 0.6610


 94%|████████████████████████████████████▋  | 7533/8000 [1:10:39<10:51,  1.40s/it]

  ➤ Predicted: 0.2609, Target: 1.0, Loss: 0.6607


 94%|████████████████████████████████████▊  | 7543/8000 [1:10:52<09:22,  1.23s/it]

  ➤ Predicted: 0.5587, Target: 1.0, Loss: 0.6607


 94%|████████████████████████████████████▊  | 7555/8000 [1:11:00<03:44,  1.98it/s]

  ➤ Predicted: 0.4313, Target: 0.0, Loss: 0.6609


 95%|████████████████████████████████████▉  | 7565/8000 [1:11:03<03:02,  2.38it/s]

  ➤ Predicted: 0.5223, Target: 1.0, Loss: 0.6610


 95%|████████████████████████████████████▉  | 7583/8000 [1:11:14<06:30,  1.07it/s]

  ➤ Predicted: 0.6090, Target: 0.0, Loss: 0.6609


 95%|█████████████████████████████████████▏ | 7623/8000 [1:11:40<05:04,  1.24it/s]

  ➤ Predicted: 0.4903, Target: 0.0, Loss: 0.6609


 96%|█████████████████████████████████████▍ | 7683/8000 [1:12:23<03:52,  1.36it/s]

  ➤ Predicted: 0.4734, Target: 1.0, Loss: 0.6605


 96%|█████████████████████████████████████▌ | 7693/8000 [1:12:36<07:14,  1.42s/it]

  ➤ Predicted: 0.2645, Target: 0.0, Loss: 0.6602


 96%|█████████████████████████████████████▌ | 7703/8000 [1:12:45<03:38,  1.36it/s]

  ➤ Predicted: 0.5501, Target: 1.0, Loss: 0.6602


 96%|█████████████████████████████████████▌ | 7713/8000 [1:12:59<07:11,  1.50s/it]

  ➤ Predicted: 0.0914, Target: 0.0, Loss: 0.6605


 97%|█████████████████████████████████████▉ | 7775/8000 [1:13:28<01:42,  2.19it/s]

  ➤ Predicted: 0.5071, Target: 0.0, Loss: 0.6606


 97%|█████████████████████████████████████▉ | 7793/8000 [1:13:39<03:15,  1.06it/s]

  ➤ Predicted: 0.2445, Target: 0.0, Loss: 0.6610


 98%|██████████████████████████████████████▏| 7825/8000 [1:14:09<02:20,  1.25it/s]

  ➤ Predicted: 0.3626, Target: 1.0, Loss: 0.6610


 99%|██████████████████████████████████████▍| 7893/8000 [1:14:57<02:06,  1.18s/it]

  ➤ Predicted: 0.4937, Target: 0.0, Loss: 0.6617


 99%|██████████████████████████████████████▌| 7903/8000 [1:15:04<01:11,  1.36it/s]

  ➤ Predicted: 0.1192, Target: 0.0, Loss: 0.6618


 99%|██████████████████████████████████████▋| 7943/8000 [1:15:43<00:55,  1.02it/s]

  ➤ Predicted: 0.4964, Target: 0.0, Loss: 0.6626


 99%|██████████████████████████████████████▊| 7953/8000 [1:15:53<00:50,  1.07s/it]

  ➤ Predicted: 0.4556, Target: 1.0, Loss: 0.6624


100%|██████████████████████████████████████▊| 7963/8000 [1:16:01<00:27,  1.34it/s]

  ➤ Predicted: 0.3416, Target: 1.0, Loss: 0.6620


100%|██████████████████████████████████████▉| 7983/8000 [1:16:20<00:20,  1.22s/it]

  ➤ Predicted: 0.2690, Target: 1.0, Loss: 0.6613


100%|███████████████████████████████████████| 8000/8000 [1:16:40<00:00,  1.74it/s]

✅ Epoch 1 — Trained on 3837 samples — Avg Loss: 0.6407





In [44]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
from tqdm import tqdm

def evaluate_model_samplewise(model, data):
    total_loss = 0.0
    y_true = []
    y_pred = []
    sample_count = 0

    for i, sample in tqdm(data.iterrows(), total=len(data)):
        try:
            dense_input, ts1, ts2, tab1, tab2, tab3 = single_data_point(sample)

            if tab1.shape[0] == 0 or tab2.shape[0] == 0 or tab3.shape[0] == 0:
                continue

            # Format inputs
            ts1 = np.expand_dims(np.array(ts1, dtype=np.float32), axis=0)
            ts2 = np.expand_dims(np.array(ts2, dtype=np.float32), axis=0)
            tab1 = np.expand_dims(np.array(tab1, dtype=np.float32), axis=0)
            tab2 = np.expand_dims(np.array(tab2, dtype=np.float32), axis=0)
            tab3 = np.expand_dims(np.array(tab3, dtype=np.float32), axis=0)

            dense_input_array = np.array(dense_input, dtype=np.float32)
            target_value = float(dense_input_array[-1])  # scalar
            dense_input_trimmed = np.expand_dims(dense_input_array[:-1], axis=0)
            target = np.array([[target_value]], dtype=np.float32)

            inputs = [ts1, ts2, tab1, tab2, tab3, dense_input_trimmed]

            # Prediction
            pred = model.predict(inputs, verbose=0)[0][0]  # scalar

            # Record actual and predicted labels
            y_true.append(int(target_value))
            y_pred.append(int(pred >= 0.5))  # threshold at 0.5

            # Evaluate sample loss
            loss = model.evaluate(inputs, target, verbose=0)
            loss_value = loss[0] if isinstance(loss, (list, tuple)) else loss
            total_loss += loss_value
            sample_count += 1

        except Exception as e:
            print(f"Error on sample {i}: {e}")
            continue

    # Compute metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    avg_loss = total_loss / max(sample_count, 1)

    print("\n📊 Evaluation Metrics:")
    print(f"Samples evaluated: {sample_count}")
    print(f"Average Loss      : {avg_loss:.4f}")
    print(f"Accuracy          : {accuracy:.4f}")
    print(f"Precision         : {precision:.4f}")
    print(f"Recall            : {recall:.4f}")
    print(f"F1 Score          : {f1:.4f}")

    # Optionally return metrics
    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }


In [45]:
evaluate_model_samplewise(model, df_current_match_test)

100%|█████████████████████████████████████████| 4410/4410 [07:41<00:00,  9.56it/s]


📊 Evaluation Metrics:
Samples evaluated: 1197
Average Loss      : nan
Accuracy          : 0.6358
Precision         : 0.6655
Recall            : 0.6190
F1 Score          : 0.6414





{'loss': nan,
 'accuracy': 0.6357560568086884,
 'precision': 0.6655290102389079,
 'recall': 0.6190476190476191,
 'f1': 0.6414473684210527}