In [1]:
!pip install python-chess

Collecting python-chess
  Downloading python_chess-1.999-py3-none-any.whl.metadata (776 bytes)
Collecting chess<2,>=1 (from python-chess)
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m6.1/6.1 MB[0m [31m54.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading python_chess-1.999-py3-none-any.whl (1.4 kB)
Building wheels for collected packages: chess
  Building wheel for chess (setup.py) ... [?25l[?25hdone
  Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147775 sha256=f004a3b000e6afef78592c2500f65017ebe139f883d98c9e0b1f83b33e9d67e0
  Stored in directory: /root/.cache/pip/wheels/83/1f/4e/8f4300f7dd554eb8de70ddfed96e94d3d030ace10c5b53d447
Successfully built chess
Installing collected packages: chess, python-chess
Successfully installed chess-1.11.2 python-chess-1.99

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
import numpy as np
import glob
import os
import time
import re
import chess

# ============================================================================
# 1. CONFIGURATION (ResNet-20)
# ============================================================================
CONFIG = {
    # --- ARCHITECTURE (ResNet-20) ---
    "num_features": 18,
    "num_moves": 4096,
    "num_res_blocks": 20,       
    "num_channels": 256,        # Wide layers

    # --- TRAINING PARAMS ---
    "batch_size": 2048,         
    "num_epochs": 3,          
    "lr": 0.001,                # Initial Learning Rate

    # --- TIME LIMIT SAFETY ---
    "max_train_hours": 12.0,    

    # --- PATHS ---
    "data_dir": "/kaggle/input/chess-stockfish-data", 
    "save_dir": "./",
    "resume_from": "/kaggle/input/teacher-model-weights/resnet20_epoch2.pth"         
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_GPUS = torch.cuda.device_count()

print(f"üöÄ Configuration:")
print(f"   ‚Ä¢ Device: {DEVICE} ({NUM_GPUS} GPUs)")
print(f"   ‚Ä¢ Dataset Root: {CONFIG['data_dir']}")

# ============================================================================
# 2. DATA ENCODING
# ============================================================================
PIECE_MAP = {'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, 'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11}

def encode_move(move_uci):
    try:
        fr = chess.parse_square(move_uci[:2])
        to = chess.parse_square(move_uci[2:4])
        return (fr * 64) + to
    except: return 0

def fen_to_tensor_18ch(fen):
    parts = fen.split(' ')
    board_str = parts[0]
    turn = parts[1]
    castling = parts[2]
    try: halfmove = parts[4]
    except: halfmove = 0
        
    matrix = np.zeros((18, 8, 8), dtype=np.float32)
    rows = board_str.split('/')
    for row_idx, row_data in enumerate(rows):
        col_idx = 0
        for char in row_data:
            if char.isdigit(): col_idx += int(char)
            else: matrix[PIECE_MAP[char], row_idx, col_idx] = 1.0; col_idx += 1
    if 'K' in castling: matrix[12, :, :] = 1.0
    if 'Q' in castling: matrix[13, :, :] = 1.0
    if 'k' in castling: matrix[14, :, :] = 1.0
    if 'q' in castling: matrix[15, :, :] = 1.0
    if turn == 'w': matrix[16, :, :] = 1.0
    try: matrix[17, :, :] = float(halfmove) / 100.0
    except: pass
    return matrix

# ============================================================================
# 3. DATASET CLASS 
# ============================================================================
class ChessDataset18(IterableDataset):
    def __init__(self, root_dir):
        self.files = []
        data_folders = glob.glob(os.path.join(root_dir, "data_*"))
        
        print(f"üìÇ Found {len(data_folders)} data folders. Scanning for files...")
        
        # 2. Recursively find all files inside these folders
        for folder in data_folders:
            for root, dirs, files in os.walk(folder):
                for file in files:
                    # Ignore hidden system files
                    if not file.startswith('.'):
                        self.files.append(os.path.join(root, file))
                        
        print(f"‚úÖ Total Training Files Found: {len(self.files)}")
        if len(self.files) == 0:
            print("‚ùå WARNING: No files found! Check your dataset path.")

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        # Shuffle files for better training distribution
        np.random.shuffle(self.files)
        
        # Split work among CPU workers
        if worker_info:
            per_worker = int(np.ceil(len(self.files) / float(worker_info.num_workers)))
            start = worker_info.id * per_worker
            end = min(start + per_worker, len(self.files))
            my_files = self.files[start:end]
        else:
            my_files = self.files
            
        for f_path in my_files:
            try:
                with open(f_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split('|')
                        if len(parts) < 4: continue
                        fen, move_uci, score_str = parts[0], parts[1], parts[3]
                        
                        tensor = fen_to_tensor_18ch(fen)
                        move_id = encode_move(move_uci)
                        try: score_val = max(0.0, min(1.0, float(score_str)))
                        except: score_val = 0.5
                        
                        yield tensor, move_id, score_val
            except Exception as e:
                # Silently skip bad files to prevent training crash
                pass

# ============================================================================
# 4. MODEL: ResNet-20
# ============================================================================
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
    def forward(self, x):
        return F.relu(x + self.bn2(self.conv2(F.relu(self.bn1(self.conv1(x))))))

class ChessResNet(nn.Module):
    def __init__(self, num_features, num_moves, num_res_blocks, num_channels):
        super().__init__()
        self.conv_input = nn.Conv2d(num_features, num_channels, 3, padding=1, bias=False)
        self.bn_input = nn.BatchNorm2d(num_channels)
        self.res_tower = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_res_blocks)])
        self.p_conv = nn.Conv2d(num_channels, 32, 1); self.p_bn = nn.BatchNorm2d(32)
        self.p_fc = nn.Linear(32 * 8 * 8, num_moves) 
        self.v_conv = nn.Conv2d(num_channels, 32, 1); self.v_bn = nn.BatchNorm2d(32)
        self.v_fc1 = nn.Linear(32 * 8 * 8, 128); self.v_fc2 = nn.Linear(128, 1)
    def forward(self, x):
        x = F.relu(self.bn_input(self.conv_input(x)))
        x = self.res_tower(x)
        p = self.p_fc(F.relu(self.p_bn(self.p_conv(x))).view(x.size(0), -1))
        v = torch.sigmoid(self.v_fc2(F.relu(self.v_fc1(F.relu(self.v_bn(self.v_conv(x))).view(x.size(0), -1)))))
        return p, v

# ============================================================================
# 5. TRAINING LOOP
# ============================================================================
def train_model():
    session_start_time = time.time()
    
    # 1. Init Model
    model = ChessResNet(CONFIG["num_features"], CONFIG["num_moves"], 
                        CONFIG["num_res_blocks"], CONFIG["num_channels"])
    model = model.to(DEVICE)
    
    if NUM_GPUS > 1:
        print(f"‚ö° Enabling DataParallel on {NUM_GPUS} GPUs")
        model = nn.DataParallel(model)

    # 2. Resume Logic
    start_epoch = 0
    if CONFIG["resume_from"] and os.path.exists(CONFIG["resume_from"]):
        print(f"üîÑ Resuming from {CONFIG['resume_from']}")
        checkpoint = torch.load(CONFIG["resume_from"], map_location=DEVICE)
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(checkpoint)
        else:
            model.load_state_dict(checkpoint)
        match = re.search(r"epoch(\d+)", CONFIG["resume_from"])
        if match: start_epoch = int(match.group(1))

    # 3. Optimizer & Data
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    criterion_p = nn.CrossEntropyLoss()
    criterion_v = nn.MSELoss()
    
    # StepLR: Reduce LR every 5 epochs
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    # Initialize Dataset with Root Directory
    dataset = ChessDataset18(CONFIG["data_dir"])
    
    # Num_workers=4 helps feed the 2 GPUs fast enough
    dataloader = DataLoader(dataset, batch_size=CONFIG["batch_size"], num_workers=4, pin_memory=True)

    model.train()
    
    # 4. The Loop
    for epoch in range(start_epoch, CONFIG["num_epochs"]):
        total_p_loss, total_v_loss = 0, 0
        batch_count = 0
        epoch_start = time.time()
        
        print(f"\n--- Starting Epoch {epoch+1}/{CONFIG['num_epochs']} ---")
        
        for boards, move_ids, scores in dataloader:
            boards = boards.to(DEVICE)
            move_ids = move_ids.long().to(DEVICE)
            scores = scores.float().to(DEVICE).view(-1, 1)

            optimizer.zero_grad()
            p_out, v_out = model(boards)
            
            loss_p = criterion_p(p_out, move_ids)
            loss_v = criterion_v(v_out, scores)
            loss = loss_p + loss_v
            
            loss.backward()
            optimizer.step()
            
            total_p_loss += loss_p.item()
            total_v_loss += loss_v.item()
            batch_count += 1
            
            if batch_count % 100 == 0:
                print(f"E{epoch+1}|B{batch_count} >> P:{loss_p.item():.3f} | V:{loss_v.item():.3f}")

        # Update Scheduler
        scheduler.step()

        # Stats
        avg_p = total_p_loss / max(1, batch_count)
        avg_v = total_v_loss / max(1, batch_count)
        duration = (time.time() - epoch_start) / 60
        print(f"‚úÖ Epoch {epoch+1} Done ({duration:.1f}m). Avg Loss P:{avg_p:.4f} V:{avg_v:.4f}")
        
        # Save Checkpoint
        save_name = f"resnet20_epoch{epoch+1}.pth"
        save_path = os.path.join(CONFIG["save_dir"], save_name)
        
        if isinstance(model, nn.DataParallel):
            torch.save(model.module.state_dict(), save_path)
        else:
            torch.save(model.state_dict(), save_path)
        print(f"üíæ Checkpoint Saved: {save_name}")

        # Time Check
        elapsed_hours = (time.time() - session_start_time) / 3600
        if elapsed_hours > CONFIG["max_train_hours"]:
            print(f"‚ö†Ô∏è Limit Reached ({elapsed_hours:.2f}h). Stopping.")
            break

if __name__ == "__main__":
    train_model()

üöÄ Configuration:
   ‚Ä¢ Device: cuda (2 GPUs)
   ‚Ä¢ Dataset Root: /kaggle/input/chess-stockfish-data
‚ö° Enabling DataParallel on 2 GPUs
üîÑ Resuming from /kaggle/input/teacher-model-weights/resnet20_epoch2.pth
üìÇ Found 5 data folders. Scanning for files...
‚úÖ Total Training Files Found: 36

--- Starting Epoch 3/3 ---
E3|B100 >> P:1.370 | V:0.032
E3|B200 >> P:1.388 | V:0.032
E3|B300 >> P:1.321 | V:0.039
E3|B400 >> P:1.248 | V:0.036
E3|B500 >> P:1.364 | V:0.031
E3|B600 >> P:1.272 | V:0.035
E3|B700 >> P:1.289 | V:0.022
E3|B800 >> P:1.278 | V:0.034
E3|B900 >> P:1.338 | V:0.030
E3|B1000 >> P:1.293 | V:0.032
E3|B1100 >> P:1.306 | V:0.027
E3|B1200 >> P:1.361 | V:0.030
E3|B1300 >> P:1.414 | V:0.030
E3|B1400 >> P:1.393 | V:0.031
E3|B1500 >> P:1.351 | V:0.038
E3|B1600 >> P:1.238 | V:0.029
E3|B1700 >> P:1.314 | V:0.029
E3|B1800 >> P:1.283 | V:0.026
E3|B1900 >> P:1.332 | V:0.027
E3|B2000 >> P:1.276 | V:0.034
E3|B2100 >> P:1.281 | V:0.032
E3|B2200 >> P:1.395 | V:0.033
E3|B2300 >> P:1.236 |