In [1]:
!pip install -q pytorch_tabular --find-links=/kaggle/input/pytorch-tabular --no-index

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 0.22.0 requires google-cloud-bigquery[bqstorage,pandas]>=3.10.0, but you have google-cloud-bigquery 2.34.4 which is incompatible.
bigframes 0.22.0 requires google-cloud-storage>=2.0.0, but you have google-cloud-storage 1.44.0 which is incompatible.
bigframes 0.22.0 requires pandas<2.1.4,>=1.5.0, but you have pandas 2.2.2 which is incompatible.
cesium 0.12.3 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.
dataproc-jupyter-plugin 0.1.79 requires pydantic~=1.10.0, but you have pydantic 2.9.2 which is incompatible.[0m[31m
[0m

# About Pytorch Tabular
PyTorch Tabular is a high-level library built on top of PyTorch and PyTorch Lightning that aims to make deep learning with tabular data accessible and efficient. It provides various tabular NN models like Category Embedding Model, TabNet and etc. This notebook demonstrate how to use pytorch tabular in UM comp.


docs: https://pytorch-tabular.readthedocs.io/en/latest/ \
github: https://github.com/manujosephv/pytorch_tabular/tree/main

# Import

In [2]:
import glob
import os

import numpy as np
import pandas as pd
import polars as pl
import torch
from pytorch_tabular import TabularModel
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models import CategoryEmbeddingModelConfig, TabNetModelConfig, TabTransformerConfig, \
    GatedAdditiveTreeEnsembleConfig, GANDALFConfig, FTTransformerConfig, AutoIntConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from sklearn.model_selection import GroupKFold, StratifiedGroupKFold
from torch import nn

from sklearn.metrics import mean_squared_error
import random
import warnings; warnings.filterwarnings('ignore')

# Fix Seed

In [3]:
def seed_everything(seed):
    np.random.seed(seed)
    random.seed(seed)
seed_everything(seed=2024)


# Codes for Training

In [4]:

class CFG:
    train_path = '/kaggle/input/um-game-playing-strength-of-mcts-variants/train.csv'
    split_agent_features = True
    folds = 5
    epochs = 30
    batch_size = 128
    model_name = "default"
    base_save_folder = "save_models"
    

target_col = 'utility_agent1'
game_col = 'GameRulesetName'
agent_cols = ['agent1', 'agent2']

#%%
nan_columns = ['Behaviour', 'StateRepetition', 'Duration', 'Complexity', 'BoardCoverage', 'GameOutcome', 'StateEvaluation', 'Clarity', 'Decisiveness', 'Drama', 'MoveEvaluation', 'StateEvaluationDifference', 'BoardSitesOccupied', 'BranchingFactor', 'DecisionFactor', 'MoveDistance', 'PieceNumber', 'ScoreDifference']
zero_columns = ['Realtime', 'Simultaneous', 'HiddenInformation', 'Match', 'AsymmetricRules', 'AsymmetricPlayRules', 'AsymmetricEndRules', 'AsymmetricSetup', 'Simulation', 'Solitaire', 'Multiplayer', 'Coalition', 'Puzzle', 'DeductionPuzzle', 'PlanningPuzzle', 'PrismShape', 'ParallelogramShape', 'RectanglePyramidalShape', 'TargetShape', 'BrickTiling', 'CelticTiling', 'QuadHexTiling', 'Hints', 'DiceD3', 'BiasedDice', 'Card', 'Domino', 'SituationalTurnKo', 'SituationalSuperko', 'InitialAmount', 'InitialPot', 'BetDecision', 'BetDecisionFrequency', 'VoteDecisionFrequency', 'ChooseTrumpSuitDecision', 'ChooseTrumpSuitDecisionFrequency', 'LeapDecisionToFriend', 'LeapDecisionToFriendFrequency', 'HopDecisionEnemyToFriend', 'HopDecisionEnemyToFriendFrequency', 'HopDecisionFriendToFriend', 'FromToDecisionWithinBoard', 'FromToDecisionBetweenContainers', 'BetEffect', 'BetEffectFrequency', 'VoteEffectFrequency', 'SwapPlayersEffectFrequency', 'TakeControl', 'TakeControlFrequency', 'PassEffectFrequency', 'SetCost', 'SetCostFrequency', 'SetPhase', 'SetPhaseFrequency', 'SetTrumpSuit', 'SetTrumpSuitFrequency', 'StepEffectFrequency', 'SlideEffectFrequency', 'LeapEffectFrequency', 'HopEffectFrequency', 'FromToEffectFrequency', 'SwapPiecesEffect', 'SwapPiecesEffectFrequency', 'ShootEffect', 'ShootEffectFrequency', 'MaxCapture', 'OffDiagonalDirection', 'Information', 'HidePieceType', 'HidePieceOwner', 'HidePieceCount', 'HidePieceRotation', 'HidePieceValue', 'HidePieceState', 'InvisiblePiece', 'LineDrawFrequency', 'ConnectionDraw', 'ConnectionDrawFrequency', 'GroupLossFrequency', 'GroupDrawFrequency', 'LoopLossFrequency', 'LoopDraw', 'LoopDrawFrequency', 'PatternLoss', 'PatternLossFrequency', 'PatternDraw', 'PatternDrawFrequency', 'PathExtentEndFrequency', 'PathExtentWinFrequency', 'PathExtentLossFrequency', 'PathExtentDraw', 'PathExtentDrawFrequency', 'TerritoryLoss', 'TerritoryLossFrequency', 'TerritoryDraw', 'TerritoryDrawFrequency', 'CheckmateLoss', 'CheckmateLossFrequency', 'CheckmateDraw', 'CheckmateDrawFrequency', 'NoTargetPieceLoss', 'NoTargetPieceLossFrequency', 'NoTargetPieceDraw', 'NoTargetPieceDrawFrequency', 'NoOwnPiecesDraw', 'NoOwnPiecesDrawFrequency', 'FillLoss', 'FillLossFrequency', 'FillDraw', 'FillDrawFrequency', 'ScoringDrawFrequency', 'NoProgressWin', 'NoProgressWinFrequency', 'NoProgressLoss', 'NoProgressLossFrequency', 'SolvedEnd', 'PositionalRepetition', 'SituationalRepetition', 'Narrowness', 'Variance', 'DecisivenessMoves', 'DecisivenessThreshold', 'LeadChange', 'Stability', 'DramaAverage', 'DramaMedian', 'DramaMaximum', 'DramaMinimum', 'DramaVariance', 'DramaChangeAverage', 'DramaChangeSign', 'DramaChangeLineBestFit', 'DramaChangeNumTimes', 'DramaMaxIncrease', 'DramaMaxDecrease', 'MoveEvaluationAverage', 'MoveEvaluationMedian', 'MoveEvaluationMaximum', 'MoveEvaluationMinimum', 'MoveEvaluationVariance', 'MoveEvaluationChangeAverage', 'MoveEvaluationChangeSign', 'MoveEvaluationChangeLineBestFit', 'MoveEvaluationChangeNumTimes', 'MoveEvaluationMaxIncrease', 'MoveEvaluationMaxDecrease', 'StateEvaluationDifferenceAverage', 'StateEvaluationDifferenceMedian', 'StateEvaluationDifferenceMaximum', 'StateEvaluationDifferenceMinimum', 'StateEvaluationDifferenceVariance', 'StateEvaluationDifferenceChangeAverage', 'StateEvaluationDifferenceChangeSign', 'StateEvaluationDifferenceChangeLineBestFit', 'StateEvaluationDifferenceChangeNumTimes', 'StateEvaluationDifferenceMaxIncrease', 'StateEvaluationDifferenceMaxDecrease', 'BoardSitesOccupiedMinimum', 'BranchingFactorMinimum', 'DecisionFactorMinimum', 'MoveDistanceMinimum', 'PieceNumberMinimum', 'ScoreDifferenceMinimum', 'ScoreDifferenceChangeNumTimes', 'Roots', 'Cosine', 'Sine', 'Tangent', 'Exponential', 'Logarithm', 'ExclusiveDisjunction', 'Float', 'HandComponent', 'SetHidden', 'SetInvisible', 'SetHiddenCount', 'SetHiddenRotation', 'SetHiddenState', 'SetHiddenValue', 'SetHiddenWhat', 'SetHiddenWho']
one_columns = ['Id', 'NumPlayers', 'Properties', 'Format', 'Time', 'Discrete', 'Turns', 'Alternating', 'Players', 'TwoPlayer', 'Equipment', 'Container', 'Board', 'PlayableSites', 'Component', 'Rules', 'Play', 'End']
frequency_columns = ['BetDecisionFrequency', 'VoteDecisionFrequency', 'SwapPlayersDecisionFrequency', 'ChooseTrumpSuitDecisionFrequency', 'PassDecisionFrequency', 'ProposeDecisionFrequency', 'AddDecisionFrequency', 'PromotionDecisionFrequency', 'RemoveDecisionFrequency', 'RotationDecisionFrequency', 'StepDecisionFrequency', 'StepDecisionToEmptyFrequency', 'StepDecisionToFriendFrequency', 'StepDecisionToEnemyFrequency', 'SlideDecisionFrequency', 'SlideDecisionToEmptyFrequency', 'SlideDecisionToEnemyFrequency', 'SlideDecisionToFriendFrequency', 'LeapDecisionFrequency', 'LeapDecisionToEmptyFrequency', 'LeapDecisionToFriendFrequency', 'LeapDecisionToEnemyFrequency', 'HopDecisionFrequency', 'HopDecisionMoreThanOneFrequency', 'HopDecisionEnemyToEmptyFrequency', 'HopDecisionFriendToEmptyFrequency', 'HopDecisionEnemyToFriendFrequency', 'HopDecisionFriendToFriendFrequency', 'HopDecisionEnemyToEnemyFrequency', 'HopDecisionFriendToEnemyFrequency', 'FromToDecisionFrequency', 'FromToDecisionWithinBoardFrequency', 'FromToDecisionBetweenContainersFrequency', 'FromToDecisionEmptyFrequency', 'FromToDecisionEnemyFrequency', 'FromToDecisionFriendFrequency', 'SwapPiecesDecisionFrequency', 'ShootDecisionFrequency', 'BetEffectFrequency', 'VoteEffectFrequency', 'SwapPlayersEffectFrequency', 'TakeControlFrequency', 'PassEffectFrequency', 'RollFrequency', 'ProposeEffectFrequency', 'AddEffectFrequency', 'SowFrequency', 'SowCaptureFrequency', 'SowRemoveFrequency', 'SowBacktrackingFrequency', 'PromotionEffectFrequency', 'RemoveEffectFrequency', 'PushEffectFrequency', 'FlipFrequency', 'SetNextPlayerFrequency', 'MoveAgainFrequency', 'SetValueFrequency', 'SetCountFrequency', 'SetCostFrequency', 'SetPhaseFrequency', 'SetTrumpSuitFrequency', 'SetRotationFrequency', 'StepEffectFrequency', 'SlideEffectFrequency', 'LeapEffectFrequency', 'HopEffectFrequency', 'FromToEffectFrequency', 'SwapPiecesEffectFrequency', 'ShootEffectFrequency', 'ReplacementCaptureFrequency', 'HopCaptureFrequency', 'HopCaptureMoreThanOneFrequency', 'DirectionCaptureFrequency', 'EncloseCaptureFrequency', 'CustodialCaptureFrequency', 'InterveneCaptureFrequency', 'SurroundCaptureFrequency', 'CaptureSequenceFrequency', 'LineEndFrequency', 'LineWinFrequency', 'LineLossFrequency', 'LineDrawFrequency', 'ConnectionEndFrequency', 'ConnectionWinFrequency', 'ConnectionLossFrequency', 'ConnectionDrawFrequency', 'GroupEndFrequency', 'GroupWinFrequency', 'GroupLossFrequency', 'GroupDrawFrequency', 'LoopEndFrequency', 'LoopWinFrequency', 'LoopLossFrequency', 'LoopDrawFrequency', 'PatternEndFrequency', 'PatternWinFrequency', 'PatternLossFrequency', 'PatternDrawFrequency', 'PathExtentEndFrequency', 'PathExtentWinFrequency', 'PathExtentLossFrequency', 'PathExtentDrawFrequency', 'TerritoryEndFrequency', 'TerritoryWinFrequency', 'TerritoryLossFrequency', 'TerritoryDrawFrequency', 'CheckmateFrequency', 'CheckmateWinFrequency', 'CheckmateLossFrequency', 'CheckmateDrawFrequency', 'NoTargetPieceEndFrequency', 'NoTargetPieceWinFrequency', 'NoTargetPieceLossFrequency', 'NoTargetPieceDrawFrequency', 'EliminatePiecesEndFrequency', 'EliminatePiecesWinFrequency', 'EliminatePiecesLossFrequency', 'EliminatePiecesDrawFrequency', 'NoOwnPiecesEndFrequency', 'NoOwnPiecesWinFrequency', 'NoOwnPiecesLossFrequency', 'NoOwnPiecesDrawFrequency', 'FillEndFrequency', 'FillWinFrequency', 'FillLossFrequency', 'FillDrawFrequency', 'ReachEndFrequency', 'ReachWinFrequency', 'ReachLossFrequency', 'ReachDrawFrequency', 'ScoringEndFrequency', 'ScoringWinFrequency', 'ScoringLossFrequency', 'ScoringDrawFrequency', 'NoMovesEndFrequency', 'NoMovesWinFrequency', 'NoMovesLossFrequency', 'NoMovesDrawFrequency', 'NoProgressEndFrequency', 'NoProgressWinFrequency', 'NoProgressLossFrequency', 'NoProgressDrawFrequency', 'DrawFrequency']
component_columns = ['ComponentStyle', 'AnimalComponent', 'ChessComponent', 'KingComponent', 'QueenComponent', 'KnightComponent', 'RookComponent', 'BishopComponent', 'PawnComponent', 'FairyChessComponent', 'PloyComponent', 'ShogiComponent', 'XiangqiComponent', 'StrategoComponent', 'JanggiComponent', 'HandComponent', 'CheckersComponent', 'BallComponent', 'TaflComponent', 'DiscComponent', 'MarkerComponent']
rules_columns =  ['EnglishRules', 'LudRules', 'GameRulesetName']
# Correlation columns with correlation threshold 0.9
corr_columns = ['Asymmetric', 'AsymmetricForces', 'Cooperation', 'Shape', 'RegularShape', 'PolygonShape', 'CircleShape', 'SpiralShape', 'MancalaBoard', 'NumPlayableSitesOnBoard', 'SquarePyramidalShape', 'NumInnerSites', 'NumEdges', 'NumCells', 'NumOuterSites', 'NumTopSites', 'NumRightSites', 'Hand', 'NumVertices', 'PlayersWithDirections', 'Stochastic', 'NumComponentsType', 'OpeningContract', 'Repetition', 'NumStartComponentsBoard', 'NumStartComponentsHand', 'NumStartComponents', 'SwapOption', 'VoteDecision', 'StepDecision', 'LeapDecision', 'LeapDecisionToEmpty', 'MovesNonDecision', 'Dice', 'TrackLoop', 'Sow', 'SowWithEffect', 'SowProperties', 'SowOriginFirst', 'SetMove', 'PieceValue', 'PieceRotation', 'HopDecisionEnemyToEmpty', 'AutoMove', 'CanNotMove', 'SlideDecision', 'Track', 'Directions', 'RightwardDirection', 'RightwardsDirection', 'ForwardLeftDirection', 'BackwardLeftDirection', 'LineEnd', 'Connection', 'ConnectionEnd', 'Loop', 'PathExtent', 'PatternEnd', 'LoopLoss', 'PathExtentEnd', 'PathExtentWin', 'Territory', 'TerritoryEnd', 'Threat', 'Checkmate', 'NoPieceMover', 'NoOwnPiecesEnd', 'FillEnd', 'Scoring', 'ScoringEnd', 'NoMoves', 'ProgressCheck', 'NoProgressEnd', 'Completion', 'DurationTurns', 'BoardSitesOccupiedAverage', 'BoardSitesOccupiedChangeAverage', 'BranchingFactorAverage', 'BranchingFactorMaximum', 'BranchingFactorVariance', 'BranchingFactorMedian', 'DecisionFactorAverage', 'BranchingFactorChangeMaxDecrease', 'DecisionFactorMaximum', 'BranchingFactorChangeAverage', 'BranchingFactorChangeLineBestFit', 'BranchingFactorChangeSign', 'DecisionFactorChangeAverage', 'BranchingFactorChangeMaxIncrease', 'DecisionFactorVariance', 'MoveDistanceAverage', 'MoveDistanceChangeAverage', 'MoveDistanceMaximum', 'MoveDistanceMaxIncrease', 'PieceNumberAverage', 'PieceNumberMedian', 'BoardSitesOccupiedChangeSign', 'PieceNumberChangeAverage', 'BoardSitesOccupiedChangeNumTimes', 'ScoreDifferenceAverage', 'ScoreDifferenceMedian', 'ScoreDifferenceMaximum', 'ScoreDifferenceVariance', 'ScoreDifferenceChangeAverage', 'ScoreDifferenceChangeLineBestFit', 'ScoreDifferenceMaxIncrease', 'Arithmetic', 'Visual', 'Vertex', 'BoardStyle', 'SowCCW', 'NumLayers', 'LeapDecisionToEnemy', 'KingComponent', 'KnightComponent', 'ChessComponent', 'StackType', 'StateType', 'PieceState', 'RememberValues', 'Implementation', 'ComplexityBalanceInteraction', 'LeftwardsDirection']
#corr_columns = [AsymmetricForces, AsymmetricPiecesType, Team, RegularShape, PolygonShape, Tiling, CircleTiling, SpiralTiling, TrackLoop, NumInnerSites, NumLayers, NumEdges, NumCells, NumVertices, NumPerimeterSites, NumBottomSites, NumLeftSites, NumContainers, NumPlayableSites, PieceDirection, Dice, NumComponentsTypePerPlayer, SwapOption, PositionalSuperko, NumStartComponentsBoardPerPlayer, NumStartComponentsHandPerPlayer, NumStartComponentsPerPlayer, SwapPlayersDecision, ProposeDecision, StepDecisionToEmpty, LeapDecisionToEmpty, LeapDecisionToEnemy, MovesEffects, Roll, Sow, SowWithEffect, SowProperties, SowOriginFirst, SowCCW, MoveAgain, SetValue, SetRotation, HopCapture, PathExtent, Threat, LineOfSight, Directions, AbsoluteDirections, LeftwardDirection, LeftwardsDirection, ForwardRightDirection, BackwardRightDirection, LineWin, ConnectionEnd, ConnectionWin, LoopEnd, LoopLoss, PatternWin, PathExtentEnd, PathExtentWin, PathExtentLoss, TerritoryEnd, TerritoryWin, Checkmate, CheckmateWin, NoOwnPiecesEnd, NoOwnPiecesWin, FillWin, ScoringEnd, ScoringWin, NoMovesEnd, NoProgressEnd, NoProgressDraw, Drawishness, Timeouts, BoardSitesOccupiedMedian, BoardSitesOccupiedChangeLineBestFit, BranchingFactorMedian, BranchingFactorVariance, BranchingFactorChangeMaxDecrease, DecisionFactorAverage, DecisionFactorMedian, DecisionFactorMaximum, DecisionFactorVariance, DecisionFactorChangeAverage, DecisionFactorChangeSign, DecisionFactorChangeLineBestFit, DecisionFactorMaxIncrease, DecisionFactorMaxDecrease, MoveDistanceMedian, MoveDistanceChangeLineBestFit, MoveDistanceMaxIncrease, MoveDistanceMaxDecrease, PieceNumberMedian, PieceNumberMaximum, PieceNumberChangeSign, PieceNumberChangeLineBestFit, PieceNumberChangeNumTimes, ScoreDifferenceMedian, ScoreDifferenceMaximum, ScoreDifferenceVariance, ScoreDifferenceChangeAverage, ScoreDifferenceChangeLineBestFit, ScoreDifferenceMaxIncrease, ScoreDifferenceMaxDecrease, Comparison, Style, BoardStyle, GraphStyle, MancalaStyle, ShibumiStyle, KnightComponent, RookComponent, PawnComponent, StateType, StackState, SiteState, ForgetValues, SetInternalCounter, Efficiency, EfficiencyPerPlayout]
output_cols = ['num_wins_agent1', 'num_draws_agent1', 'num_losses_agent1']
dropped_cols = output_cols + nan_columns + zero_columns + one_columns + frequency_columns + component_columns + rules_columns + corr_columns

dropped_cols +=['Cooperation', 'Team', 'TriangleShape', 'DiamondShape', 'SpiralShape', 'StarShape', 'SquarePyramidalShape', 'SemiRegularTiling', 'CircleTiling', 'SpiralTiling', 'MancalaThreeRows', 'MancalaSixRows', 'MancalaCircular', 'AlquerqueBoardWithOneTriangle', 'AlquerqueBoardWithTwoTriangles', 'AlquerqueBoardWithFourTriangles', 'AlquerqueBoardWithEightTriangles', 'ThreeMensMorrisBoard', 'ThreeMensMorrisBoardWithTwoTriangles', 'NineMensMorrisBoard', 'StarBoard', 'PachisiBoard', 'Boardless', 'NumColumns', 'NumCorners', 'NumOffDiagonalDirections', 'NumLayers', 'NumCentreSites', 'NumConvexCorners', 'NumPhasesBoard', 'NumContainers', 'Piece', 'PieceValue', 'PieceRotation', 'PieceDirection', 'LargePiece', 'Tile', 'NumComponentsType', 'NumDice', 'OpeningContract', 'SwapOption', 'Repetition', 'TurnKo', 'PositionalSuperko', 'AutoMove', 'InitialRandomPlacement', 'InitialScore', 'InitialCost', 'Moves', 'VoteDecision', 'SwapPlayersDecision', 'SwapPlayersDecisionFrequency', 'ProposeDecision', 'ProposeDecisionFrequency', 'PromotionDecisionFrequency', 'RotationDecision', 'RotationDecisionFrequency', 'StepDecisionToFriend', 'StepDecisionToFriendFrequency', 'StepDecisionToEnemy', 'SlideDecisionToEnemy', 'SlideDecisionToEnemyFrequency', 'SlideDecisionToFriend', 'SlideDecisionToFriendFrequency', 'LeapDecision', 'LeapDecisionFrequency', 'LeapDecisionToEmpty', 'LeapDecisionToEmptyFrequency', 'LeapDecisionToEnemy', 'LeapDecisionToEnemyFrequency', 'HopDecisionFriendToEmpty', 'HopDecisionFriendToEmptyFrequency', 'HopDecisionFriendToFriendFrequency', 'HopDecisionEnemyToEnemy', 'HopDecisionEnemyToEnemyFrequency', 'HopDecisionFriendToEnemy', 'HopDecisionFriendToEnemyFrequency', 'FromToDecisionFrequency', 'FromToDecisionEnemy', 'FromToDecisionEnemyFrequency', 'FromToDecisionFriend', 'SwapPiecesDecision', 'SwapPiecesDecisionFrequency', 'ShootDecision', 'ShootDecisionFrequency', 'VoteEffect', 'SwapPlayersEffect', 'PassEffect', 'ProposeEffect', 'ProposeEffectFrequency', 'AddEffectFrequency', 'SowFrequency', 'SowCapture', 'SowCaptureFrequency', 'SowRemove', 'SowBacktracking', 'SowBacktrackingFrequency', 'SowProperties', 'SowOriginFirst', 'SowCCW', 'PromotionEffectFrequency', 'PushEffect', 'PushEffectFrequency', 'Flip', 'FlipFrequency', 'SetNextPlayer', 'SetValue', 'SetValueFrequency', 'SetCount', 'SetCountFrequency', 'SetRotation', 'SetRotationFrequency', 'StepEffect', 'SlideEffect', 'LeapEffect', 'ByDieMove', 'MaxDistance', 'ReplacementCaptureFrequency', 'HopCaptureMoreThanOne', 'DirectionCapture', 'DirectionCaptureFrequency', 'EncloseCaptureFrequency', 'CustodialCapture', 'CustodialCaptureFrequency', 'InterveneCapture', 'InterveneCaptureFrequency', 'SurroundCapture', 'SurroundCaptureFrequency', 'CaptureSequence', 'CaptureSequenceFrequency', 'Group', 'Loop', 'Pattern', 'PathExtent', 'Territory', 'Fill', 'CanNotMove', 'Threat', 'CountPiecesMoverComparison', 'ProgressCheck', 'RotationalDirection', 'SameLayerDirection', 'ForwardDirection', 'BackwardDirection', 'BackwardsDirection', 'LeftwardDirection', 'RightwardsDirection', 'LeftwardsDirection', 'ForwardLeftDirection', 'ForwardRightDirection', 'BackwardLeftDirection', 'BackwardRightDirection', 'SameDirection', 'OppositeDirection', 'NumPlayPhase', 'LineLoss', 'LineLossFrequency', 'LineDraw', 'ConnectionEnd', 'ConnectionEndFrequency', 'ConnectionWinFrequency', 'ConnectionLoss', 'ConnectionLossFrequency', 'GroupEnd', 'GroupEndFrequency', 'GroupWin', 'GroupWinFrequency', 'GroupLoss', 'GroupDraw', 'LoopEnd', 'LoopEndFrequency', 'LoopWin', 'LoopWinFrequency', 'LoopLoss', 'PatternEnd', 'PatternEndFrequency', 'PatternWin', 'PatternWinFrequency', 'PathExtentEnd', 'PathExtentWin', 'PathExtentLoss', 'TerritoryEnd', 'TerritoryWin', 'TerritoryWinFrequency', 'Checkmate', 'CheckmateWin', 'NoTargetPieceEndFrequency', 'NoTargetPieceWin', 'NoTargetPieceWinFrequency', 'EliminatePiecesLoss', 'EliminatePiecesLossFrequency', 'EliminatePiecesDraw', 'EliminatePiecesDrawFrequency', 'NoOwnPiecesEnd', 'NoOwnPiecesWin', 'NoOwnPiecesLoss', 'NoOwnPiecesLossFrequency', 'FillEnd', 'FillEndFrequency', 'FillWin', 'FillWinFrequency', 'ReachWin', 'ReachLoss', 'ReachLossFrequency', 'ReachDraw', 'ReachDrawFrequency', 'ScoringLoss', 'ScoringLossFrequency', 'ScoringDraw', 'NoMovesLoss', 'NoMovesDrawFrequency', 'NoProgressEnd', 'NoProgressEndFrequency', 'NoProgressDraw', 'NoProgressDrawFrequency', 'BoardCoverageFull', 'BoardSitesOccupiedChangeNumTimes', 'BranchingFactorChangeLineBestFit', 'BranchingFactorChangeNumTimesn', 'DecisionFactorChangeNumTimes', 'MoveDistanceChangeSign', 'MoveDistanceChangeLineBestFit', 'PieceNumberChangeNumTimes', 'PieceNumberMaxIncrease', 'ScoreDifferenceMedian', 'ScoreDifferenceVariance', 'ScoreDifferenceChangeAverage', 'ScoreDifferenceChangeSign', 'ScoreDifferenceChangeLineBestFit', 'Math', 'Division', 'Modulo', 'Absolute', 'Exponentiation', 'Minimum', 'Maximum', 'Even', 'Odd', 'Visual', 'GraphStyle', 'MancalaStyle', 'PenAndPaperStyle', 'ShibumiStyle', 'BackgammonStyle', 'JanggiStyle', 'XiangqiStyle', 'ShogiStyle', 'TableStyle', 'SurakartaStyle', 'NoBoard', 'ChessComponent', 'KingComponent', 'QueenComponent', 'KnightComponent', 'RookComponent', 'BishopComponent', 'PawnComponent', 'FairyChessComponent', 'PloyComponent', 'ShogiComponent', 'XiangqiComponent', 'StrategoComponent', 'JanggiComponent', 'TaflComponent', 'StackType', 'Stack', 'ShowPieceValue', 'ShowPieceState', 'Implementation', 'StateType', 'StackState', 'VisitedSites', 'InternalCounter', 'SetInternalCounter', 'Efficiency', 'NumOffDiagonalDirections_0.0', 'NumOffDiagonalDirections_4.82', 'NumOffDiagonalDirections_2.0', 'NumOffDiagonalDirections_5.18', 'NumOffDiagonalDirections_3.08', 'NumOffDiagonalDirections_0.06', 'NumLayers_1', 'NumLayers_0', 'NumLayers_4', 'NumLayers_5', 'NumPhasesBoard_1', 'NumPhasesBoard_5', 'NumDice_0', 'NumDice_2', 'NumDice_6', 'NumDice_3', 'NumDice_5', 'NumDice_7', 'ProposeDecisionFrequency_0.0', 'ProposeDecisionFrequency_0.05', 'ProposeDecisionFrequency_0.01', 'PromotionDecisionFrequency_0.0', 'PromotionDecisionFrequency_0.01', 'PromotionDecisionFrequency_0.03', 'PromotionDecisionFrequency_0.02', 'PromotionDecisionFrequency_0.11', 'PromotionDecisionFrequency_0.05', 'PromotionDecisionFrequency_0.04', 'SlideDecisionToFriendFrequency_0.0', 'SlideDecisionToFriendFrequency_0.19', 'SlideDecisionToFriendFrequency_0.06', 'LeapDecisionToEnemyFrequency_0.0', 'LeapDecisionToEnemyFrequency_0.04', 'LeapDecisionToEnemyFrequency_0.01', 'LeapDecisionToEnemyFrequency_0.02', 'LeapDecisionToEnemyFrequency_0.07', 'LeapDecisionToEnemyFrequency_0.03', 'LeapDecisionToEnemyFrequency_0.14', 'LeapDecisionToEnemyFrequency_0.08', 'HopDecisionFriendToFriendFrequency_0.0', 'HopDecisionFriendToFriendFrequency_0.13', 'HopDecisionFriendToFriendFrequency_0.09', 'HopDecisionEnemyToEnemyFrequency_0.0', 'HopDecisionEnemyToEnemyFrequency_0.01', 'HopDecisionEnemyToEnemyFrequency_0.2', 'HopDecisionEnemyToEnemyFrequency_0.03', 'HopDecisionFriendToEnemyFrequency_0.0', 'HopDecisionFriendToEnemyFrequency_0.01', 'HopDecisionFriendToEnemyFrequency_0.09', 'HopDecisionFriendToEnemyFrequency_0.25', 'HopDecisionFriendToEnemyFrequency_0.02', 'FromToDecisionFrequency_0.0', 'FromToDecisionFrequency_0.38', 'FromToDecisionFrequency_1.0', 'FromToDecisionFrequency_0.31', 'FromToDecisionFrequency_0.94', 'FromToDecisionFrequency_0.67', 'ProposeEffectFrequency_0.0', 'ProposeEffectFrequency_0.01', 'ProposeEffectFrequency_0.03', 'PushEffectFrequency_0.0', 'PushEffectFrequency_0.5', 'PushEffectFrequency_0.96', 'PushEffectFrequency_0.25', 'FlipFrequency_0.0', 'FlipFrequency_0.87', 'FlipFrequency_1.0', 'FlipFrequency_0.96', 'SetCountFrequency_0.0', 'SetCountFrequency_0.62', 'SetCountFrequency_0.54', 'SetCountFrequency_0.02', 'DirectionCaptureFrequency_0.0', 'DirectionCaptureFrequency_0.55', 'DirectionCaptureFrequency_0.54', 'EncloseCaptureFrequency_0.0', 'EncloseCaptureFrequency_0.08', 'EncloseCaptureFrequency_0.1', 'EncloseCaptureFrequency_0.07', 'EncloseCaptureFrequency_0.12', 'EncloseCaptureFrequency_0.02', 'EncloseCaptureFrequency_0.09', 'InterveneCaptureFrequency_0.0', 'InterveneCaptureFrequency_0.01', 'InterveneCaptureFrequency_0.14', 'InterveneCaptureFrequency_0.04', 'SurroundCaptureFrequency_0.0', 'SurroundCaptureFrequency_0.01', 'SurroundCaptureFrequency_0.03', 'SurroundCaptureFrequency_0.02', 'NumPlayPhase_3', 'NumPlayPhase_4', 'NumPlayPhase_5', 'NumPlayPhase_6', 'NumPlayPhase_7', 'NumPlayPhase_8', 'LineLossFrequency_0.0', 'LineLossFrequency_0.96', 'LineLossFrequency_0.87', 'LineLossFrequency_0.46', 'LineLossFrequency_0.26', 'LineLossFrequency_0.88', 'LineLossFrequency_0.94', 'ConnectionEndFrequency_0.0', 'ConnectionEndFrequency_0.19', 'ConnectionEndFrequency_1.0', 'ConnectionEndFrequency_0.23', 'ConnectionEndFrequency_0.94', 'ConnectionEndFrequency_0.35', 'ConnectionEndFrequency_0.97', 'ConnectionLossFrequency_0.0', 'ConnectionLossFrequency_0.54', 'ConnectionLossFrequency_0.78', 'GroupEndFrequency_0.0', 'GroupEndFrequency_1.0', 'GroupEndFrequency_0.11', 'GroupEndFrequency_0.79', 'GroupWinFrequency_0.0', 'GroupWinFrequency_0.11', 'GroupWinFrequency_1.0', 'LoopEndFrequency_0.0', 'LoopEndFrequency_0.14', 'LoopEndFrequency_0.66', 'LoopWinFrequency_0.0', 'LoopWinFrequency_0.14', 'LoopWinFrequency_0.66', 'PatternEndFrequency_0.0', 'PatternEndFrequency_0.63', 'PatternEndFrequency_0.35', 'PatternWinFrequency_0.0', 'PatternWinFrequency_0.63', 'PatternWinFrequency_0.35', 'NoTargetPieceWinFrequency_0.0', 'NoTargetPieceWinFrequency_0.72', 'NoTargetPieceWinFrequency_0.77', 'NoTargetPieceWinFrequency_0.95', 'NoTargetPieceWinFrequency_0.32', 'NoTargetPieceWinFrequency_1.0', 'EliminatePiecesLossFrequency_0.0', 'EliminatePiecesLossFrequency_0.85', 'EliminatePiecesLossFrequency_0.96', 'EliminatePiecesLossFrequency_0.68', 'EliminatePiecesDrawFrequency_0.0', 'EliminatePiecesDrawFrequency_0.03', 'EliminatePiecesDrawFrequency_0.91', 'EliminatePiecesDrawFrequency_1.0', 'EliminatePiecesDrawFrequency_0.36', 'EliminatePiecesDrawFrequency_0.86', 'NoOwnPiecesLossFrequency_0.0', 'NoOwnPiecesLossFrequency_1.0', 'NoOwnPiecesLossFrequency_0.68', 'FillEndFrequency_0.0', 'FillEndFrequency_1.0', 'FillEndFrequency_0.04', 'FillEndFrequency_0.01', 'FillEndFrequency_0.99', 'FillEndFrequency_0.72', 'FillWinFrequency_0.0', 'FillWinFrequency_1.0', 'FillWinFrequency_0.04', 'FillWinFrequency_0.01', 'FillWinFrequency_0.99', 'ReachDrawFrequency_0.0', 'ReachDrawFrequency_0.9', 'ReachDrawFrequency_0.98', 'ScoringLossFrequency_0.0', 'ScoringLossFrequency_0.6', 'ScoringLossFrequency_0.62', 'NoMovesLossFrequency_0.0', 'NoMovesLossFrequency_1.0', 'NoMovesLossFrequency_0.13', 'NoMovesLossFrequency_0.06', 'NoMovesDrawFrequency_0.0', 'NoMovesDrawFrequency_0.01', 'NoMovesDrawFrequency_0.04', 'NoMovesDrawFrequency_0.03', 'NoMovesDrawFrequency_0.22', 'BoardSitesOccupiedChangeNumTimes_0.0', 'BoardSitesOccupiedChangeNumTimes_0.06', 'BoardSitesOccupiedChangeNumTimes_0.42', 'BoardSitesOccupiedChangeNumTimes_0.12', 'BoardSitesOccupiedChangeNumTimes_0.14', 'BoardSitesOccupiedChangeNumTimes_0.94', 'BranchingFactorChangeNumTimesn_0.0', 'BranchingFactorChangeNumTimesn_0.3', 'BranchingFactorChangeNumTimesn_0.02', 'BranchingFactorChangeNumTimesn_0.07', 'BranchingFactorChangeNumTimesn_0.04', 'BranchingFactorChangeNumTimesn_0.13', 'BranchingFactorChangeNumTimesn_0.01', 'BranchingFactorChangeNumTimesn_0.21', 'BranchingFactorChangeNumTimesn_0.03', 'PieceNumberChangeNumTimes_0.0', 'PieceNumberChangeNumTimes_0.06', 'PieceNumberChangeNumTimes_0.42', 'PieceNumberChangeNumTimes_0.12', 'PieceNumberChangeNumTimes_0.14', 'PieceNumberChangeNumTimes_1.0', 'KintsBoard', 'FortyStonesWithFourGapsBoard', 'Roll', 'SumDice', 'CheckmateFrequency', 'NumDice_4']


In [5]:

def preprocess_data(df):
    df = df.drop(filter(lambda x: x in df.columns, dropped_cols))
    if CFG.split_agent_features:
        for col in agent_cols:
            df = df.with_columns(
                pl.col(col).str.split(by="-").list.to_struct(fields=lambda idx: f"{col}_{idx}")
            ).unnest(col).drop(f"{col}_0")
    df = df.with_columns([pl.col(col).cast(pl.Categorical) for col in df.columns if col[:6] in agent_cols])
    df = df.with_columns([pl.col(col).cast(pl.Float32) for col in df.columns if col[:6] not in agent_cols and col != game_col])
    df = df.to_pandas()
    
    df['Playouts/Moves'] = df['PlayoutsPerSecond'] / (df['MovesPerSecond'] + 1e-15)
    df['EfficiencyPerPlayout'] = df['MovesPerSecond'] / (df['PlayoutsPerSecond'] + 1e-15)
    df['TurnsDurationEfficiency'] = df['DurationActions'] / (df['DurationTurnsStdDev'] + 1e-15)
    df['AdvantageBalanceRatio'] = df['AdvantageP1'] / (df['Balance'] + 1e-15)
    df['ActionTimeEfficiency'] = df['DurationActions'] / (df['MovesPerSecond'] + 1e-15)
    df['StandardizedTurnsEfficiency'] = df['DurationTurnsStdDev'] / (df['DurationActions'] + 1e-15)
    df['AdvantageTimeImpact'] = df['AdvantageP1'] / (df['DurationActions'] + 1e-15)
    df['DurationToComplexityRatio'] = df['DurationActions'] / (df['StateTreeComplexity'] + 1e-15)
    df['NormalizedGameTreeComplexity'] =  df['GameTreeComplexity'] /  (df['StateTreeComplexity'] + 1e-15)
    df['ComplexityBalanceInteraction'] =  df['Balance'] *  df['GameTreeComplexity']
    df['OverallComplexity'] =  df['StateTreeComplexity'] +  df['GameTreeComplexity']
    df['ComplexityPerPlayout'] =  df['GameTreeComplexity'] /  (df['PlayoutsPerSecond'] + 1e-15)
    df['TurnsNotTimeouts/Moves'] = df['DurationTurnsNotTimeouts'] / (df['MovesPerSecond'] + 1e-15)
    df['Timeouts/DurationActions'] = df['Timeouts'] / (df['DurationActions'] + 1e-15)
    df['OutcomeUniformity/AdvantageP1'] = df['OutcomeUniformity'] / (df['AdvantageP1'] + 1e-15)
    #df['ComplexDecisionRatio'] = df['StepDecisionToEnemy'] + df['SlideDecisionToEnemy'] + df['HopDecisionMoreThanOne']
    #df['AggressiveActionsRatio'] = df['StepDecisionToEnemy'] + df['HopDecisionEnemyToEnemy'] + df['HopDecisionFriendToEnemy'] + df['SlideDecisionToEnemy']

    print(f'Data shape: {df.shape}\n')
    num_cols = df.select_dtypes(exclude=['category']).columns.tolist()
    num_cols = [num for num in num_cols if num not in [target_col, game_col]]
    cat_cols = [
        col for col in df.columns if col not in num_cols and col != target_col
    ]
    
    for col in cat_cols:
        df[col] = df[col].cat.add_categories('NA')
        df[col] = df[col].fillna('NA')
        df[col] = df[col].cat.add_categories(0)

    return df, num_cols, cat_cols


In [6]:


class RMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss(reduction='mean')

    def forward(self, y_hat, y):
        loss = self.mse(y_hat, y)
        loss = torch.sqrt(loss)
        return loss

In [7]:

def get_model_config(model_name, head_config):
    if model_name == "AutoInt":
        return AutoIntConfig(
            task="regression",
            learning_rate=1e-5,
            head="LinearHead",  # Linear Head
            head_config=head_config,  # Linear Head Config
        )
    elif model_name == "TabTransformer":
        return TabTransformerConfig(
            task="regression",
            learning_rate=1e-3,
            head="LinearHead",  # Linear Head
            head_config=head_config,  # Linear Head Config
            share_embedding=False,
        )
    elif model_name == "GatedAdditiveTreeEnsemble":
        return GatedAdditiveTreeEnsembleConfig(
            task="regression",
            learning_rate=1e-3,
            head="LinearHead",  # Linear Head
            head_config=head_config,  # Linear Head Config
            gflu_stages=4,
            num_trees=30,
            tree_depth=5,
            chain_trees=False,
        )
    elif model_name == "TabNet":
        return TabNetModelConfig(
            task="regression",
            learning_rate=1e-5,
            n_d=16,
            n_a=16,
            n_steps=4,
            head="LinearHead",  # Linear Head
            head_config=head_config,  # Linear Head Config
        )
    elif model_name == "FTTransformer":
        return FTTransformerConfig(
            task="regression",
            learning_rate=1e-3,
            head="LinearHead",  # Linear Head
            head_config=head_config,  # Linear Head Config
            share_embedding_strategy="add",
        )
    else:
        model_config = CategoryEmbeddingModelConfig(
            task="regression",
            layers="1024-512-256-128",  # Number of nodes in each layer
            activation="ReLU",  # Activation between each layers
            learning_rate=1e-3,
            head="LinearHead",  # Linear Head
            head_config=head_config,  # Linear Head Config
            embedding_dropout=0.2,
            dropout=0.25,
            use_batch_norm=True,
        )
        return model_config


In [8]:
def run_train():
    base_save_folder = CFG.base_save_folder
    model_name = CFG.model_name


    data = pl.read_csv(CFG.train_path)
    df, num_cols, cat_cols = preprocess_data(data)

    y = data.to_pandas()[target_col]
    y_int=round(y*15)
    splitter = StratifiedGroupKFold(n_splits=CFG.folds,random_state=2024,shuffle=True)
    oof_df = pd.DataFrame()
    rmses = []
    
    for fold, (train_idx, valid_idx) in enumerate(splitter.split(df, y_int, groups=data.select('GameRulesetName').to_numpy())):
        df_train = df.loc[train_idx]
        df_valid = df.loc[valid_idx]


        head_config = LinearHeadConfig(
            layers="",
            dropout=0.2,
            use_batch_norm=True,
        ).__dict__
        data_config = DataConfig(
            target=[target_col],
            continuous_cols=num_cols,
            categorical_cols=cat_cols,
            normalize_continuous_features=True,
            continuous_feature_transform="quantile_normal",
        )


        model_config = get_model_config(model_name, head_config)
        optimizer_config = OptimizerConfig(
            optimizer="AdamW",
            optimizer_params={
                "weight_decay": 1e-5
            },
            lr_scheduler="ReduceLROnPlateau",
        )
        trainer_config = TrainerConfig(
            batch_size=CFG.batch_size,
            max_epochs=CFG.epochs,
            early_stopping_patience=CFG.epochs//2,
            load_best=True,
            early_stopping_mode="min",
            early_stopping="valid_loss",
            checkpoints="valid_loss",
            checkpoints_mode="min",
            progress_bar="simple",
            gradient_clip_val=70
        )

        model = TabularModel(
            data_config=data_config,
            model_config=model_config,
            optimizer_config=optimizer_config,
            trainer_config=trainer_config,
            suppress_lightning_logger=True
        )
        # fit model
        model.fit(train=df_train, validation=df_valid, loss=RMSELoss())

        # Get predictions to compute RMSE score
        result = model.trainer.predict(
            model=model.model,
            dataloaders=model.datamodule.prepare_inference_dataloader(df_valid),
        )

        predictions = np.concatenate(
            [res["logits"].detach().cpu().numpy() for res in result], axis=0
        ).flatten()
        rmse = mean_squared_error(df_valid[target_col], predictions, squared=False)
        print(f'Fold {fold} RMSE: {rmse:.4f}')
        rmses.append(rmse)

        #save model
        save_folder = f'{base_save_folder}/fold_{fold}_rmse_{rmse:.4f}/'
        os.makedirs(save_folder, exist_ok=True)
        model.save_model(save_folder)
    print(f"Average RMSE over folds: {np.mean(rmses):.4f}")

# Train & Inference for Submitting

In [9]:

def run_inference(test, submission):
    test_data, num_cols, cat_cols = preprocess_data(test)

    model_folder_list = glob.glob(f'{CFG.base_save_folder}/fold_*/')
    y_pred = 0.0
    for model_path in model_folder_list:
        model = TabularModel.load_model(
            model_path
        )

        result = model.trainer.predict(
            model=model.model,
            dataloaders=model.datamodule.prepare_inference_dataloader(test_data),
        )
        predictions = np.concatenate(
            [res["logits"].detach().cpu().numpy() for res in result], axis=0
        ).flatten()
        y_pred += predictions

    y_pred = y_pred / len(model_folder_list)
    y_pred = np.clip(y_pred, -1, 1)
    print(y_pred)
    return submission.with_columns(pl.Series('utility_agent1', y_pred))


In [10]:
import os
import kaggle_evaluation
import kaggle_evaluation.mcts_inference_server
counter = 0

def predict(test, submission):
    global counter
    if counter == 0:
        run_train()
    counter += 1
    
    return run_inference(test, submission)


In [11]:
inference_server = kaggle_evaluation.mcts_inference_server.MCTSInferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway(
        (
            '/kaggle/input/um-game-playing-strength-of-mcts-variants/test.csv',
            '/kaggle/input/um-game-playing-strength-of-mcts-variants/sample_submission.csv'
        )
    )
    print('Done')

Data shape: (233234, 217)



Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Predicting: |          | 0/? [00:00<?, ?it/s]

Fold 0 RMSE: 0.4511


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Predicting: |          | 0/? [00:00<?, ?it/s]

Fold 1 RMSE: 0.4543


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Predicting: |          | 0/? [00:00<?, ?it/s]

Fold 2 RMSE: 0.4347


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Predicting: |          | 0/? [00:00<?, ?it/s]

Fold 3 RMSE: 0.4462


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Predicting: |          | 0/? [00:00<?, ?it/s]

Fold 4 RMSE: 0.4496
Average RMSE over folds: 0.4472
Data shape: (3, 216)



Predicting: |          | 0/? [00:00<?, ?it/s]

Predicting: |          | 0/? [00:00<?, ?it/s]

Predicting: |          | 0/? [00:00<?, ?it/s]

Predicting: |          | 0/? [00:00<?, ?it/s]

Predicting: |          | 0/? [00:00<?, ?it/s]

[ 0.05644159 -0.10247562  0.00136217]
Done
