In [39]:

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import os
from transformers import get_cosine_schedule_with_warmup
import itertools
from collections import defaultdict

os.makedirs('checkpoints', exist_ok=True)

In [40]:

def argmax_accuracy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Calculates the count of correctly predicted samples for a batch based on
    matching the output argmax with *any* of the target argmax indices,
    returning the count as a tensor suitable for XLA accumulation.

    Args:
        output (torch.Tensor): Model output tensor of shape (n, w), where n is
                               batch size and w is the vector dimension.
        target (torch.Tensor): Target tensor of the same shape (n, w).

    Returns:
        torch.Tensor: A scalar tensor containing the count of correct predictions
                      for the batch, located on the same device as inputs.
                      Returns a zero tensor if batch size is 0.
    """
    if output.shape != target.shape:
        raise ValueError(f"Output and target tensors must have the same shape. Got {output.shape} and {target.shape}")
    if output.dim() != 2:
         raise ValueError(f"Output and target tensors must be 2D (n, w). Got {output.dim()} dimensions.")

    n, w = output.shape

    if n == 0:
        # Return a scalar tensor on the correct device
        return torch.tensor(0, dtype=torch.long, device=output.device) # Use long for counts

    # 0. Find the index of the maximum value in the output (common for both calculations)
    output_argmax = torch.argmax(output, dim=1)
    
    # --- Calculation for the original target ---

    # 1. Find the maximum value in the original target
    target_max_values = torch.max(target, dim=1, keepdim=True)[0]

    # 2. Create a boolean mask for the original target
    target_max_mask = (target == target_max_values)

    # 3. Check if the output_argmax index is True in the original target_max_mask
    # Ensure arange is on the same device as the tensors
    correct_predictions_bool_original = target_max_mask[torch.arange(n, device=output.device), output_argmax]

    # 4. Sum the boolean tensor to get the count of correct predictions for original target.
    correct_count_original = correct_predictions_bool_original.sum()

    # --- Calculation for the clipped target ---

    # 1. Clip the target values to (-1, 1)
    target_clipped = torch.clip(target, -1, 1)

    # 2. Find the maximum value in the clipped target
    target_clipped_max_values = torch.max(target_clipped, dim=1, keepdim=True)[0]

    # 3. Create a boolean mask for the clipped target
    target_clipped_max_mask = (target_clipped == target_clipped_max_values)

    # 4. Check if the output_argmax index is True in the clipped target_max_mask
    correct_predictions_bool_clipped = target_clipped_max_mask[torch.arange(n, device=output.device), output_argmax]

    # 5. Sum the boolean tensor to get the count of correct predictions for clipped target.
    correct_count_clipped = correct_predictions_bool_clipped.sum()

    return correct_count_original, correct_count_clipped

In [41]:
class LayerNorm2d(nn.LayerNorm):
    """ LayerNorm for channels of '2D' spatial NCHW tensors """
    def __init__(self, num_channels, eps=1e-6, affine=True):
        super().__init__(num_channels, eps=eps, elementwise_affine=affine)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.permute(0, 2, 3, 1)
        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        x = x.permute(0, 3, 1, 2)
        return x

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        reduced_channel = max(channel // reduction, 1)
        self.fc_scale = nn.Sequential(
            nn.Linear(channel, reduced_channel, bias=True),
            nn.SiLU(inplace=True),
            nn.Linear(reduced_channel, channel, bias=True),
            nn.Sigmoid()
        )
        self.fc_offset = nn.Sequential(
            nn.Linear(channel, reduced_channel, bias=True),
            nn.SiLU(inplace=True),
            nn.Linear(reduced_channel, channel, bias=True),
        )


    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y_scale = self.fc_scale(y).view(b, c, 1, 1)
        y_offset = self.fc_offset(y).view(b, c, 1, 1)
        return x * y_scale.expand_as(x) + y_offset.expand_as(x)

class InitialExtractor(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(InitialExtractor, self).__init__()
        layers = []
        layers.append(nn.BatchNorm2d(in_channels))
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, stride=1, bias=False))
        self.convs = nn.Sequential(*layers)

    def forward(self, x):
        out = self.convs(x)
        return out

class Grouped1x1SumConv(nn.Module):
    def __init__(self, in_channels, out_channels, num_groups):
        super().__init__()
        assert in_channels % num_groups == 0, "in_channels must be divisible by num_groups"

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_groups = num_groups

        self.conv = nn.Conv2d(
            in_channels,
            out_channels * num_groups,
            kernel_size=1,
            groups=num_groups,
            bias=True
        )

    def forward(self, x):
        B, _, H, W = x.shape
        G = self.num_groups
        Cout = self.out_channels

        out = self.conv(x)

        out = out.view(B, G, Cout, H, W)
        out = out.sum(dim=1)
        return out

class ConnectFourBlock(nn.Module):
    def __init__(self, in_channels, kernel_size = (5, 5), mlp_factor=4, group_factor=2):
        super().__init__()

        ks_h, ks_w = kernel_size
        hidden_dim = in_channels * mlp_factor
        layers_0 = []

        layers_0.append(nn.Conv2d(in_channels, in_channels, kernel_size=(ks_h, ks_w),  padding=(ks_h // 2, ks_w // 2), groups=in_channels, bias=False))
        layers_0.append(LayerNorm2d(in_channels))

        layers_0.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=False))
        layers_0.append(nn.SiLU(inplace=True))

        layers_0.append(nn.Conv2d(hidden_dim, in_channels, kernel_size=1, bias=False))

        self.conv_next = nn.Sequential(*layers_0)

        hidden_dim = in_channels * group_factor
        n_groups = 2 * group_factor
        ks_h, ks_w = 3, 3
        
        layers_1 = []
        layers_1.append(LayerNorm2d(in_channels))
        layers_1.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=False))
        self.proj_up = nn.Sequential(*layers_1)

        layers_2 = []
        layers_2.append(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=(ks_h, ks_w), padding=(ks_h // 2, ks_w // 2), groups=n_groups, bias=False))
        layers_2.append(Grouped1x1SumConv(hidden_dim, hidden_dim, num_groups=n_groups))
        layers_2.append(nn.SiLU(inplace=True))
        self.grouped_0 = nn.Sequential(*layers_2)

        layers_3 = []
        layers_3.append(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=(ks_h, ks_w), padding=(ks_h // 2, ks_w // 2), groups=n_groups, bias=False))
        layers_3.append(Grouped1x1SumConv(hidden_dim, hidden_dim, num_groups=n_groups))
        layers_3.append(nn.SiLU(inplace=True))
        self.grouped_1 = nn.Sequential(*layers_3)

        layers_4 = []
        layers_4.append(SEBlock(hidden_dim))
        layers_4.append(nn.Conv2d(hidden_dim, in_channels, kernel_size=1, bias=False))
        self.proj_down = nn.Sequential(*layers_4)

    def forward(self, x):
        out = x
        summand = self.conv_next(x)
        out = out + summand

        up_proj = self.proj_up(out)
        summand = self.grouped_0(up_proj)
        up_proj = up_proj + summand
        summand = self.grouped_1(up_proj)
        up_proj = up_proj + summand

        summand = self.proj_down(up_proj)
        out = out + summand
        
        return out 

In [42]:
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    return nn.init.trunc_normal_(tensor, mean=mean, std=std, a=a, b=b)

class VisionTransformerWithTasks(nn.Module):
    def __init__(
        self,
        input_shape,
        num_layers,
        embed_dim,
        num_heads,
        mlp_dim,
        num_tasks=7 + 7 + 1,
    ):
        super().__init__()
        C, H, W = input_shape
        self.num_patches = H * W
        self.embed_dim = embed_dim
        self.num_tasks = num_tasks

        self.proj = nn.Linear(C, embed_dim)
        self.task_tokens = nn.Parameter(torch.randn(1, num_tasks, embed_dim))
        self.pos_embed = nn.Embedding(self.num_patches + num_tasks, embed_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=mlp_dim,
            batch_first=True,
            activation="gelu"
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.output_head = nn.Conv1d(
            in_channels=embed_dim * num_tasks,
            out_channels=num_tasks,
            kernel_size=1,
            groups=num_tasks
        )

        self._init_weights()

    def _init_weights(self):
        trunc_normal_(self.task_tokens, std=0.02)
        trunc_normal_(self.pos_embed.weight, std=0.02)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv1d):
                trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)

        task_tokens = self.task_tokens.expand(B, -1, -1)
        x = torch.cat((task_tokens, x), dim=1)

        pos_ids = torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(B, -1)
        x = x + self.pos_embed(pos_ids)

        x = self.transformer(x)
        task_outputs = x[:, :self.num_tasks]

        task_outputs = task_outputs.reshape(B, self.num_tasks * self.embed_dim, 1)
        logits = self.output_head(task_outputs).squeeze(-1)

        return logits


In [43]:
class PolicyHead(nn.Module):
    def __init__(self, in_channels, input_height, input_width, expansion_factor_c=4, expansion_factor_w=4):
        super(PolicyHead, self).__init__()
        intermediate_channels = expansion_factor_c * in_channels
        self.conv_reduce_height_pw_0 = nn.Conv2d(in_channels, intermediate_channels, kernel_size=1, stride=1, bias=False)
        self.conv_reduce_height_pw_1 = nn.Conv2d(intermediate_channels, intermediate_channels, kernel_size=1, stride=1, bias=False)
        self.conv_reduce_height_dw = nn.Conv2d(intermediate_channels, intermediate_channels, kernel_size=(input_height, 1), stride=(input_height, 1), groups=intermediate_channels, bias=False)
        self.conv_reduce_height_bn = nn.BatchNorm2d(intermediate_channels)
        self.silu_reduce_height_1 = nn.SiLU(inplace=True)
        self.silu_reduce_height_2 = nn.SiLU(inplace=True)
        self.conv_expand_pw = nn.Conv2d(intermediate_channels, in_channels, kernel_size=1, stride=1, bias=False)

        self.expansion_factor_w = expansion_factor_w
        intermediate_channels = expansion_factor_w * input_width
        self.conv_permuted_input_width_as_channels = nn.Conv2d(
            in_channels=input_width, 
            out_channels=intermediate_channels, 
            kernel_size=1, 
            stride=1, 
            bias=False
        )
        self.silu_permuted_input_width_as_channels = nn.SiLU(inplace=True)

        intermediate_channels = expansion_factor_w * in_channels
        self.final_logits_conv = nn.Conv2d(
            in_channels=intermediate_channels, 
            out_channels=1,
            kernel_size=1, 
            stride=1, 
            bias=True
        )
    
    def forward(self, x):
        out = self.conv_reduce_height_pw_0(x)
        out = self.conv_reduce_height_pw_1(out)
        out = self.silu_reduce_height_1(out)
        out = self.conv_reduce_height_dw(out)
        out = self.conv_reduce_height_bn(out)
        out = self.silu_reduce_height_2(out)
        out = self.conv_expand_pw(out)
        
        B, C2, H_one, W_orig = out.shape 
        
        permuted_for_conv = out.permute(0, 3, 2, 1).contiguous()
        
         # Shape: (B, W_orig, 1, C2)
        convolved_output = self.conv_permuted_input_width_as_channels(permuted_for_conv)
        convolved_output = self.silu_permuted_input_width_as_channels(convolved_output)
        
        # Shape: (B, k_expansion_factor * W_orig, 1, C2)
        current_tensor = convolved_output.permute(0, 3, 2, 1).contiguous()
        # Shape: (B, C2, 1, self.k_expansion_factor * W_orig)
        
        # --- Step 2: Rearrange to (B, k*C2, 1, W_orig) ---
        k = self.expansion_factor_w
        
        #(B, C2, 1, k * W_orig)
        temp_rearrange = current_tensor.squeeze(2) 
        #(B, C2, k * W_orig)
        
        temp_rearrange = temp_rearrange.view(B, C2, k, W_orig)
        #(B, C2, k, W_orig)
        
        temp_rearrange = temp_rearrange.permute(0, 2, 1, 3) 
        #(B, k, C2, W_orig)
        
        temp_rearrange = temp_rearrange.contiguous().view(B, k * C2, W_orig)
        #(B, k*C2, W_orig)
        
        tensor_for_final_conv = temp_rearrange.unsqueeze(2)
        #(B, k*C2, 1, W_orig)

        logits_intermediate = self.final_logits_conv(tensor_for_final_conv)
        # (B, 1, 1, W_orig)

        policy_logits = logits_intermediate.view(B, W_orig)
        # Shape: (B, W_orig)

        return policy_logits

class ValueHead(nn.Module):
    def __init__(self, in_channels, input_height, input_width, expansion_factor_c=4, expansion_factor_w=2):
        super(ValueHead, self).__init__()
        intermediate_channels = expansion_factor_c * in_channels
        self.conv_reduce_height_pw_0 = nn.Conv2d(in_channels, intermediate_channels, kernel_size=1, stride=1, bias=False)
        self.conv_reduce_height_pw_1 = nn.Conv2d(intermediate_channels, intermediate_channels, kernel_size=1, stride=1, bias=False)
        self.conv_reduce_height_dw = nn.Conv2d(intermediate_channels, intermediate_channels, kernel_size=(input_height, 1), stride=(input_height, 1), groups=intermediate_channels, bias=False)
        self.conv_reduce_height_bn = nn.BatchNorm2d(intermediate_channels)
        self.silu_reduce_height_1 = nn.SiLU(inplace=True)
        self.silu_reduce_height_2 = nn.SiLU(inplace=True)
        self.conv_expand_pw = nn.Conv2d(intermediate_channels, in_channels, kernel_size=1, stride=1, bias=False)

        self.expansion_factor_w = expansion_factor_w
        intermediate_channels = expansion_factor_w * input_width
        self.conv_permuted_input_width_as_channels = nn.Conv2d(
            in_channels=input_width, 
            out_channels=intermediate_channels, 
            kernel_size=1, 
            stride=1, 
            bias=False
        )
        self.silu_flattened_output = nn.SiLU(inplace=True)
        
        final_flattened_features = intermediate_channels * 1 * in_channels
        self.final_linear = nn.Linear(final_flattened_features, 1)

    
    def forward(self, x):
        out = self.conv_reduce_height_pw_0(x)
        out = self.conv_reduce_height_pw_1(out)
        out = self.silu_reduce_height_1(out)
        out = self.conv_reduce_height_dw(out)
        out = self.conv_reduce_height_bn(out)
        out = self.silu_reduce_height_2(out)
        out = self.conv_expand_pw(out)
        
        B, C2, H_one, W_orig = out.shape 
        
        permuted_for_conv = out.permute(0, 3, 2, 1).contiguous()
        
        # Shape: (B, W_orig, 1, C2)
        convolved_output = self.conv_permuted_input_width_as_channels(permuted_for_conv)
        flattened_output = convolved_output.view(B, -1) 
        silu_output = self.silu_flattened_output(flattened_output)
        scalar_value = self.final_linear(silu_output)
        
        return scalar_value

In [44]:
input_channels = 2
input_height = 6
input_width = 7
output_width = 7

class Connect4ModelVit(nn.Module):
    def __init__(self):
        super().__init__()

        extractor_layers = []
        stage_0_width = 64
        num_conv_blocks = 5
        num_transformer_layers = 5
        
        transformer_args = {
            "input_shape" : (stage_0_width, input_height, input_width),
            "num_layers" : num_transformer_layers,
            "embed_dim" : 256,
            "num_heads" : 4,
            "mlp_dim" : 1024,
        }
        
        extractor_layers.append(
            InitialExtractor(in_channels=input_channels, out_channels=stage_0_width),
        )
        
        extractor_layers.extend([
            ConnectFourBlock(in_channels=stage_0_width,) for _ in range(num_conv_blocks)
        ])
        
        self.conv_extractor = nn.Sequential(*extractor_layers)

        self.transformer = VisionTransformerWithTasks(
            **transformer_args
        )

    def forward(self, x):
        convnet_features = self.conv_extractor(x)
        concatenated_output = self.transformer(convnet_features)

        policy_outputs, value_outputs, next_value_outputs \
            = torch.split(concatenated_output, [7, 1, 7], dim=-1)
        return policy_outputs, value_outputs, next_value_outputs
    
class Connect4ModelCvn(nn.Module):
    def __init__(self):
        super().__init__()

        extractor_layers = []
        stage_0_width = 64
        num_conv_blocks = 7
        
        extractor_layers.append(
            InitialExtractor(in_channels=input_channels, out_channels=stage_0_width),
        )
        
        extractor_layers.extend([
            ConnectFourBlock(in_channels=stage_0_width,) for _ in range(num_conv_blocks)
        ])
        
        self.conv_extractor = nn.Sequential(*extractor_layers)

        self.policy_logits_head = nn.Sequential(
                ConnectFourBlock(in_channels=stage_0_width,),
                PolicyHead(
                    in_channels=64, 
                    input_height=input_height,
                    input_width=input_width)
            )

        self.policy_regression_head = nn.Sequential(
                ConnectFourBlock(in_channels=stage_0_width,),
                PolicyHead(
                    in_channels=64, 
                    input_height=input_height,
                    input_width=input_width)
            )

        self.value_head = nn.Sequential(
                ConnectFourBlock(in_channels=stage_0_width,),
                ValueHead(
                    in_channels=64, 
                    input_height=input_height,
                    input_width=input_width)
            )

    def forward(self, x):
        convnet_features = self.conv_extractor(x)

        policy_outputs = self.policy_logits_head(convnet_features)
        value_outputs = self.value_head(convnet_features)
        next_value_outputs = self.policy_regression_head(convnet_features)
        return policy_outputs, value_outputs, next_value_outputs
    
def get_custom_model(model_type="vit_medium"):
    if model_type == "vit_medium":
        return Connect4ModelVit()
    if model_type == "cvn_tiny":
        return Connect4ModelCvn()
    raise Exception(f"unknown model_type: {model_type}")

def count_model_flops(model=None):
    from torch.utils.flop_counter import FlopCounterMode
    if model is None:
        model = get_custom_model()
    print("Model Architecture:")
    print(model)
    
    dummy_input = torch.randn(1, input_channels, input_height, input_width)
    
    with FlopCounterMode(model) as count:
        dummy_output = model(dummy_input)
        total_flops = count.get_total_flops()
    
    print(f"Total FLOPS: {total_flops}")
    print("Dummy Input Shape:", dummy_input.shape)
    print("Dummy Output:", dummy_output)

In [45]:
from tqdm import tqdm

def move_to_device(batch, device):
        obs, targets = batch
        obs = obs.to(device, torch.float32)
        targets = targets.to(device, torch.float32)
        return obs, targets

mse_loss = nn.MSELoss(reduction='none')
kl_criterion = nn.KLDivLoss(reduction='none', log_target=True)

def calculate_loss(policy_outputs, value_outputs, next_value_outputs, labels):
    target_policy_log_probs = F.log_softmax(labels / 10, dim=-1)
    policy_log_probs = F.log_softmax(policy_outputs / 4, dim=-1)

    _, best_choice_indices = torch.max(torch.clamp(labels, min=-1, max=1), dim=-1)
    bad_choice_mask = torch.ones_like(policy_log_probs, dtype=torch.float)
    bad_choice_mask.scatter_(1, best_choice_indices.unsqueeze(1), 0)
    bad_choice_loss = - (policy_log_probs * bad_choice_mask).mean(dim=-1)

    policy_loss = kl_criterion(policy_log_probs, target_policy_log_probs).sum(dim=1)
    policy_loss = 0.8 * policy_loss + 0.002 * bad_choice_loss
    
    value_loss = mse_loss(value_outputs, torch.max(labels, dim=-1).values.unsqueeze(1)).mean(dim=1)
    next_value_loss = mse_loss(next_value_outputs, labels).mean(dim=1)

    combined_loss = 10000 * policy_loss + value_loss + 0.1 * next_value_loss

    loss_parts = {
        'combined': combined_loss,
        'policy': policy_loss,
        'value': value_loss,
        'next_value': next_value_loss
    }

    return combined_loss, loss_parts

def validate_model(model, val_loader, device, max_batches=None, use_tqdm=False, mine_hard=False):
    model.to(device)
    
    if mine_hard:
        hard_examples = []
        
    ema_loss = 0.0
    ema_alpha = 0.05
    
    # Validation loop
    model.eval()
    num_samples_in_val = 0
    val_loss = kl_loss_total = value_loss_total = next_value_loss_total = 0.0
    val_correct_count_strong = val_correct_count_weak = 0.0
    val_correct_count_strong_nv = val_correct_count_weak_nv = 0.0
    loader_slice =  itertools.islice(val_loader, max_batches) if max_batches else val_loader
    if use_tqdm:
        loader_slice = tqdm(loader_slice)
    with torch.no_grad():
        for batch in loader_slice:
            inputs, labels = move_to_device(batch, device)
            policy_outputs, value_outputs, next_value_outputs = model(inputs)

            loss, loss_parts = calculate_loss(policy_outputs, value_outputs, next_value_outputs, labels)
            
            batch_correct_count_strong, batch_correct_count_weak = argmax_accuracy(policy_outputs, labels)
            val_correct_count_strong += batch_correct_count_strong
            val_correct_count_weak += batch_correct_count_weak
            
            batch_correct_count_strong_nv, batch_correct_count_weak_nv = argmax_accuracy(next_value_outputs, labels)
            val_correct_count_strong_nv += batch_correct_count_strong_nv
            val_correct_count_weak_nv += batch_correct_count_weak_nv
            
            num_samples_in_batch = inputs.size(0)
            batch_loss_mean = loss.mean(dim=0).item()
            val_loss += batch_loss_mean * num_samples_in_batch
            kl_loss_total += loss_parts["policy"].sum(dim=0).item()
            value_loss_total += loss_parts["value"].sum(dim=0).item()
            next_value_loss_total += loss_parts["next_value"].sum(dim=0).item()
            num_samples_in_val += num_samples_in_batch
            
            if mine_hard:
                # Update the exponential moving average (EMA) of the loss
                if ema_loss == 0.0:
                    ema_loss = 0.8 * batch_loss_mean
                else:
                    ema_loss = ema_alpha * batch_loss_mean + (1 - ema_alpha) * ema_loss
                
                threshold = ema_loss * 10
                hard_example_indices = torch.where(loss > threshold)[0]
                
                if len(hard_example_indices) > 0:
                    hard_inputs = inputs[hard_example_indices]
                    hard_labels = labels[hard_example_indices]
                    hard_losses = loss[hard_example_indices]
                    hard_examples.append((hard_inputs, hard_labels, hard_losses))
            
    epoch_val_loss = {
            "combined": val_loss / num_samples_in_val, "policy": kl_loss_total / num_samples_in_val,
            "value": value_loss_total / num_samples_in_val, "next_value": next_value_loss_total / num_samples_in_val,
        }
    epoch_val_policy_acc = {
            "strong": val_correct_count_strong / num_samples_in_val, "weak": val_correct_count_weak / num_samples_in_val,
            "strong_nv": val_correct_count_strong_nv / num_samples_in_val, "weak_nv": val_correct_count_weak_nv / num_samples_in_val,
        }

    print(f"Validation Loss: {epoch_val_loss}")
    print(f"Validation Policy Acc: {epoch_val_policy_acc}")
    if mine_hard:
        if len(hard_examples) > 0:
            hard_inputs_tensor = torch.cat([ex[0] for ex in hard_examples], dim=0)
            hard_labels_tensor = torch.cat([ex[1] for ex in hard_examples], dim=0)
            hard_losses_tensor = torch.cat([ex[2] for ex in hard_examples], dim=0)
            return hard_inputs_tensor, hard_labels_tensor, hard_losses_tensor
        else:
            return None
            
    return epoch_val_policy_acc

def train_model(
    train_dataset,
    val_dataset,
    model: nn.Module,
    device,
    checkpoint_path: str = "checkpoint.pth",
    best_model_path: str = "best_model.pth",
    batch_size: int = 32,
    learning_rate: float = 0.001,
    epochs: int = 10,
    num_batches_per_epoch = 10,
    warmup_fraction: float = 0.1,
    weight_decay: float = 0.01,
    use_tqdm=False,
):  
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    
    total_steps = num_batches_per_epoch * epochs
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    scheduler = get_cosine_schedule_with_warmup(optimizer, math.ceil(total_steps * warmup_fraction), total_steps)
        
    start_epoch = 0
    best_val_acc = 0
    model.to(device)
    
    epoch = start_epoch - 1
    step = -1
    model.train()
    continue_training = True
    
    # Training loop
    while continue_training:
        if use_tqdm:
            loader_subset = tqdm(train_loader)
        for batch in loader_subset:
            step += 1
            if step >= num_batches_per_epoch:
                step = 0
                running_loss_avg = {loss_name: loss_value / num_samples_in_epoch for loss_name, loss_value in running_loss.items()}
                print(f"Epoch [{epoch}/{epochs}] done, Training Loss: {running_loss_avg}")
                
                epoch_val_policy_acc = validate_model(model, val_loader, device, num_batches_per_epoch, use_tqdm=use_tqdm)
                # Save checkpoint and best model based on validation accuracy
                if epoch_val_policy_acc["strong"] > best_val_acc:
                    best_val_acc = epoch_val_policy_acc["strong"]
                    torch.save(model.state_dict(), best_model_path)
                    print(f"Validation accuracy improved. Saved best model to {best_model_path}")
                torch.save(model.state_dict(), checkpoint_path)
                
            if step == 0:
                epoch += 1
                running_loss = defaultdict(float)
                num_samples_in_epoch = 0
                if epoch >= epochs:
                    continue_training = False
                    break
                print(f"Begin epoch {epoch} with LR: {optimizer.param_groups[0]['lr']:.6f}")
                
            inputs, labels = move_to_device(batch, device)

            optimizer.zero_grad()

            policy_outputs, value_outputs, next_value_outputs = model(inputs)

            combined_loss, loss_parts = calculate_loss(policy_outputs, value_outputs, next_value_outputs, labels)

            combined_loss.mean(dim=0).backward()
            optimizer.step()
            scheduler.step()

            num_samples_in_epoch += inputs.size(0)
            for loss_name, loss_value in loss_parts.items():
                running_loss[loss_name] += loss_value.sum(dim=0).item()
            if (step + 1) % math.ceil(num_batches_per_epoch / 10) == 0:
                running_loss_avg = {loss_name: loss_value / num_samples_in_epoch for loss_name, loss_value in running_loss.items()}
                print(f"Epoch [{epoch}/{epochs}], Step [{step} / {num_batches_per_epoch}], Running Loss: {running_loss_avg}, LR: {optimizer.param_groups[0]['lr']:.6f}")
        

    print("Finished Training")
    print(f"Best validation acc achieved: {best_val_acc:.4f}")

In [46]:
    
def load_dataset(dataset_path):
    import numpy as np
    import gc
    import torch

    try:
        print("Loading dataset with mmap_mode='r'...")

        dataset = np.load(dataset_path, mmap_mode='r')
        train_obs_np = dataset["x_train"]
        train_targets_np = dataset["y_train"]
        val_obs_np = dataset["x_val"]
        val_targets_np = dataset["y_val"]

        del dataset
        gc.collect() # Force garbage collection

        print("Converting NumPy arrays to PyTorch tensors with float16 precision (1/4)...")
        train_obs = torch.from_numpy(train_obs_np).to(dtype=torch.float16)
        del train_obs_np
        gc.collect()

        print("Converting NumPy arrays to PyTorch tensors with float16 precision (2/4)...")
        val_obs = torch.from_numpy(val_obs_np).to(dtype=torch.float16)
        del val_obs_np
        gc.collect()

        print("Converting NumPy arrays to PyTorch tensors with float16 precision (3/4)...")
        train_targets = torch.from_numpy(train_targets_np).to(dtype=torch.float16)
        del train_targets_np
        gc.collect()

        print("Converting NumPy arrays to PyTorch tensors with float16 precision (4/4)...")
        val_targets = torch.from_numpy(val_targets_np).to(dtype=torch.float16)
        del val_targets_np
        gc.collect()

        # Create TensorDatasets for use with PyTorch DataLoaders
        train_dataset = TensorDataset(train_obs, train_targets)
        val_dataset = TensorDataset(val_obs, val_targets)
        
        print("Dataset preparation complete.")
        return train_dataset, val_dataset
    except Exception as e:
        print(f"could not load data: {e}")
        return None


In [47]:
import torch

def load_weights_only(filepath):
    """
    Loads only the weights (state_dict) from a .pth file.

    Args:
        filepath (str): The path to the .pth file.

    Returns:
        dict: The state dictionary containing the model weights, or None if an error occurs.
    """
    try:
        # Load the entire checkpoint
        checkpoint = torch.load(filepath, map_location="cpu")

        # Check if the loaded object is a state_dict directly
        if isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()):
            # Assume it's a state_dict if all keys are strings (common for weights)
            print(f"Successfully loaded state_dict from {filepath}")
            state_dict = checkpoint
        elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            # If it's a dictionary containing a 'state_dict' key (common for full checkpoints)
            print(f"Successfully loaded 'state_dict' from checkpoint in {filepath}")
            state_dict = checkpoint['state_dict']
        else:
            print(f"Warning: The .pth file at {filepath} does not seem to contain a standard state_dict or a checkpoint with 'state_dict'.")
            print("Attempting to return the loaded object directly. Please inspect its content.")
            state_dict = checkpoint

        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith('_orig_mod.'):
                # Remove the prefix and add the key-value pair to the new dict
                new_key = key[len('_orig_mod.'):]
                new_state_dict[new_key] = value
            else:
                # Keep keys that don't have the prefix as they are
                new_state_dict[key] = value
        return new_state_dict
    except FileNotFoundError:
        print(f"Error: File not found at {filepath}")
        return None
    except Exception as e:
        print(f"An error occurred while loading the file: {e}")
        return None

In [48]:
arch = "vit_medium"
num_run = 0

model = get_custom_model(arch)
count_model_flops(model)

Model Architecture:
Connect4ModelVit(
  (conv_extractor): Sequential(
    (0): InitialExtractor(
      (convs): Sequential(
        (0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (1): ConnectFourBlock(
      (conv_next): Sequential(
        (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=64, bias=False)
        (1): LayerNorm2d((64,), eps=1e-06, elementwise_affine=True)
        (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): SiLU(inplace=True)
        (4): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (proj_up): Sequential(
        (0): LayerNorm2d((64,), eps=1e-06, elementwise_affine=True)
        (1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (grouped_0): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3)

  with FlopCounterMode(model) as count:


In [49]:
dataset_path = "./data/c4_data_enriched.npz"

In [50]:
dataset_splits = load_dataset(dataset_path)
if dataset_splits is None:
    os._exit(1)

train_dataset, val_dataset = dataset_splits

Loading dataset with mmap_mode='r'...
Converting NumPy arrays to PyTorch tensors with float16 precision (1/4)...
Converting NumPy arrays to PyTorch tensors with float16 precision (2/4)...
Converting NumPy arrays to PyTorch tensors with float16 precision (3/4)...
Converting NumPy arrays to PyTorch tensors with float16 precision (4/4)...
Dataset preparation complete.


In [51]:
save_model_name = f"checkpoints/{arch}_r{num_run}"
checkpoint_file = f"{save_model_name}_checkpoint.pth"
best_model_file = f"{save_model_name}_best_model.pth"

load_file = f"{save_model_name}_starting_checkpoint.pth"

weights = load_weights_only(load_file)
if weights:
    print("Keys in loaded weights:", weights.keys())
    try:
        model.load_state_dict(weights)
        print("Successfully loaded weights into a new model.")
    except Exception as e:
        print(f"Could not load weights into a new model: {e}")


Successfully loaded state_dict from checkpoints/vit_medium_r0_starting_checkpoint.pth
Keys in loaded weights: dict_keys(['conv_extractor.0.convs.0.weight', 'conv_extractor.0.convs.0.bias', 'conv_extractor.0.convs.0.running_mean', 'conv_extractor.0.convs.0.running_var', 'conv_extractor.0.convs.0.num_batches_tracked', 'conv_extractor.0.convs.1.weight', 'conv_extractor.1.conv_next.0.weight', 'conv_extractor.1.conv_next.1.weight', 'conv_extractor.1.conv_next.1.bias', 'conv_extractor.1.conv_next.2.weight', 'conv_extractor.1.conv_next.4.weight', 'conv_extractor.1.proj_up.0.weight', 'conv_extractor.1.proj_up.0.bias', 'conv_extractor.1.proj_up.1.weight', 'conv_extractor.1.grouped_0.0.weight', 'conv_extractor.1.grouped_0.1.conv.weight', 'conv_extractor.1.grouped_0.1.conv.bias', 'conv_extractor.1.grouped_1.0.weight', 'conv_extractor.1.grouped_1.1.conv.weight', 'conv_extractor.1.grouped_1.1.conv.bias', 'conv_extractor.1.proj_down.0.fc_scale.0.weight', 'conv_extractor.1.proj_down.0.fc_scale.0.bias

In [52]:
device = torch.device("mps")
# INFO: epochs and batches decreased for presentation
args = dict(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    model=model,
    batch_size=2 ** 11,
    learning_rate=5e-5,
    epochs=4,
    num_batches_per_epoch=20,
    warmup_fraction=0.2,
    weight_decay=4e-5,
    checkpoint_path=checkpoint_file,
    best_model_path=best_model_file,
    device=device,
    use_tqdm=True
)


In [53]:
train_model(**args)



Begin epoch 0 with LR: 0.000000


  0%|          | 2/57415 [00:11<76:23:22,  4.79s/it] 

Epoch [0/4], Step [1 / 20], Running Loss: {'combined': 6455.566162109375, 'policy': 0.4191662669181824, 'value': 1520.9827880859375, 'next_value': 7429.206298828125}, LR: 0.000006


  0%|          | 4/57415 [00:13<40:13:39,  2.52s/it]

Epoch [0/4], Step [3 / 20], Running Loss: {'combined': 6245.2410888671875, 'policy': 0.39965417236089706, 'value': 1481.2819519042969, 'next_value': 7674.1729736328125}, LR: 0.000013


  0%|          | 6/57415 [00:16<29:20:34,  1.84s/it]

Epoch [0/4], Step [5 / 20], Running Loss: {'combined': 6095.862060546875, 'policy': 0.382262443502744, 'value': 1484.161844889323, 'next_value': 7890.757405598958}, LR: 0.000019


  0%|          | 8/57415 [00:19<25:04:39,  1.57s/it]

Epoch [0/4], Step [7 / 20], Running Loss: {'combined': 5811.6326904296875, 'policy': 0.3578588180243969, 'value': 1436.7593078613281, 'next_value': 7962.851989746094}, LR: 0.000025


  0%|          | 10/57415 [00:21<23:12:10,  1.46s/it]

Epoch [0/4], Step [9 / 20], Running Loss: {'combined': 5620.75244140625, 'policy': 0.34198285937309264, 'value': 1392.0317138671876, 'next_value': 8088.920849609375}, LR: 0.000031


  0%|          | 12/57415 [00:24<22:22:17,  1.40s/it]

Epoch [0/4], Step [11 / 20], Running Loss: {'combined': 5458.125406901042, 'policy': 0.3294528548916181, 'value': 1349.6328430175781, 'next_value': 8139.6396891276045}, LR: 0.000038


  0%|          | 14/57415 [00:27<22:10:26,  1.39s/it]

Epoch [0/4], Step [13 / 20], Running Loss: {'combined': 5347.727992466518, 'policy': 0.3197121428591864, 'value': 1343.777596609933, 'next_value': 8068.289202008928}, LR: 0.000044


  0%|          | 16/57415 [00:29<21:32:40,  1.35s/it]

Epoch [0/4], Step [15 / 20], Running Loss: {'combined': 5206.093719482422, 'policy': 0.3088137209415436, 'value': 1309.4091300964355, 'next_value': 8085.4735107421875}, LR: 0.000050


  0%|          | 18/57415 [00:32<21:17:38,  1.34s/it]

Epoch [0/4], Step [17 / 20], Running Loss: {'combined': 5094.994656032986, 'policy': 0.3000914032260577, 'value': 1290.1588711208767, 'next_value': 8039.217122395833}, LR: 0.000050


  0%|          | 20/57415 [00:35<21:13:47,  1.33s/it]

Epoch [0/4], Step [19 / 20], Running Loss: {'combined': 5011.550329589843, 'policy': 0.29277342185378075, 'value': 1280.660922241211, 'next_value': 8031.551171875}, LR: 0.000050
Epoch [0/4] done, Training Loss: {'combined': 5011.550329589843, 'policy': 0.29277342185378075, 'value': 1280.660922241211, 'next_value': 8031.551171875}


20it [00:07,  2.57it/s]


Validation Loss: {'combined': 1218.2769470214844, 'policy': 0.08711761981248856, 'value': 271.2638626098633, 'next_value': 758.3688659667969}
Validation Policy Acc: {'strong': tensor(0.9098, device='mps:0'), 'weak': tensor(0.9816, device='mps:0'), 'strong_nv': tensor(0.8337, device='mps:0'), 'weak_nv': tensor(0.9885, device='mps:0')}
Validation accuracy improved. Saved best model to checkpoints/vit_medium_r0_best_model.pth
Begin epoch 1 with LR: 0.000050


  0%|          | 22/57415 [00:46<48:25:13,  3.04s/it]

Epoch [1/4], Step [1 / 20], Running Loss: {'combined': 1886.5298461914062, 'policy': 0.1216006949543953, 'value': 575.4322509765625, 'next_value': 950.9069519042969}, LR: 0.000049


  0%|          | 24/57415 [00:48<33:19:44,  2.09s/it]

Epoch [1/4], Step [3 / 20], Running Loss: {'combined': 1934.5039367675781, 'policy': 0.12195104360580444, 'value': 618.8231048583984, 'next_value': 961.7043762207031}, LR: 0.000048


  0%|          | 26/57415 [00:50<25:56:58,  1.63s/it]

Epoch [1/4], Step [5 / 20], Running Loss: {'combined': 1871.9553629557292, 'policy': 0.11906309922536214, 'value': 585.7340342203776, 'next_value': 955.9037373860677}, LR: 0.000047


  0%|          | 28/57415 [00:53<22:28:31,  1.41s/it]

Epoch [1/4], Step [7 / 20], Running Loss: {'combined': 1824.9727478027344, 'policy': 0.11698203813284636, 'value': 560.9809913635254, 'next_value': 941.7138595581055}, LR: 0.000046


  0%|          | 30/57415 [00:55<20:46:15,  1.30s/it]

Epoch [1/4], Step [9 / 20], Running Loss: {'combined': 1821.641748046875, 'policy': 0.11582648903131484, 'value': 568.7159576416016, 'next_value': 946.6091430664062}, LR: 0.000044


  0%|          | 32/57415 [00:58<19:47:44,  1.24s/it]

Epoch [1/4], Step [11 / 20], Running Loss: {'combined': 1790.8223266601562, 'policy': 0.11394132611652215, 'value': 556.6498235066732, 'next_value': 947.5925750732422}, LR: 0.000043


  0%|          | 34/57415 [01:00<19:23:23,  1.22s/it]

Epoch [1/4], Step [13 / 20], Running Loss: {'combined': 1765.6448887416295, 'policy': 0.11199862137436867, 'value': 550.7519792829241, 'next_value': 949.0671648297991}, LR: 0.000041


  0%|          | 36/57415 [01:02<19:11:04,  1.20s/it]

Epoch [1/4], Step [15 / 20], Running Loss: {'combined': 1750.298484802246, 'policy': 0.11065924493595958, 'value': 548.6295166015625, 'next_value': 950.7653884887695}, LR: 0.000039


  0%|          | 38/57415 [01:05<19:02:36,  1.19s/it]

Epoch [1/4], Step [17 / 20], Running Loss: {'combined': 1746.952880859375, 'policy': 0.1100816097524431, 'value': 550.1558295355903, 'next_value': 959.8096516927084}, LR: 0.000037


  0%|          | 40/57415 [01:07<18:57:57,  1.19s/it]

Epoch [1/4], Step [19 / 20], Running Loss: {'combined': 1724.4182739257812, 'policy': 0.10846277512609959, 'value': 543.7958801269531, 'next_value': 959.946499633789}, LR: 0.000035
Epoch [1/4] done, Training Loss: {'combined': 1724.4182739257812, 'policy': 0.10846277512609959, 'value': 543.7958801269531, 'next_value': 959.946499633789}


20it [00:07,  2.82it/s]


Validation Loss: {'combined': 834.0098815917969, 'policy': 0.05389855224639177, 'value': 226.0476448059082, 'next_value': 689.7672393798828}
Validation Policy Acc: {'strong': tensor(0.9270, device='mps:0'), 'weak': tensor(0.9950, device='mps:0'), 'strong_nv': tensor(0.8430, device='mps:0'), 'weak_nv': tensor(0.9941, device='mps:0')}
Validation accuracy improved. Saved best model to checkpoints/vit_medium_r0_best_model.pth
Begin epoch 2 with LR: 0.000035


  0%|          | 42/57415 [01:17<43:07:16,  2.71s/it]

Epoch [2/4], Step [1 / 20], Running Loss: {'combined': 1648.7371215820312, 'policy': 0.09573324397206306, 'value': 600.3723449707031, 'next_value': 910.3231506347656}, LR: 0.000032


  0%|          | 44/57415 [01:19<30:48:15,  1.93s/it]

Epoch [2/4], Step [3 / 20], Running Loss: {'combined': 1584.0894470214844, 'policy': 0.09331063367426395, 'value': 560.5438232421875, 'next_value': 904.3923950195312}, LR: 0.000030


  0%|          | 46/57415 [01:21<24:45:25,  1.55s/it]

Epoch [2/4], Step [5 / 20], Running Loss: {'combined': 1493.5968017578125, 'policy': 0.08988862484693527, 'value': 505.5275624593099, 'next_value': 891.8297729492188}, LR: 0.000027


  0%|          | 48/57415 [01:24<22:11:43,  1.39s/it]

Epoch [2/4], Step [7 / 20], Running Loss: {'combined': 1497.7239990234375, 'policy': 0.0896861394867301, 'value': 511.93869400024414, 'next_value': 889.2390899658203}, LR: 0.000025


  0%|          | 50/57415 [01:26<20:48:13,  1.31s/it]

Epoch [2/4], Step [9 / 20], Running Loss: {'combined': 1500.903271484375, 'policy': 0.08989177271723747, 'value': 513.3483184814453, 'next_value': 886.37216796875}, LR: 0.000023


  0%|          | 52/57415 [01:29<19:54:39,  1.25s/it]

Epoch [2/4], Step [11 / 20], Running Loss: {'combined': 1509.890645345052, 'policy': 0.090319716061155, 'value': 518.8029708862305, 'next_value': 878.9049580891927}, LR: 0.000020


  0%|          | 54/57415 [01:31<19:32:37,  1.23s/it]

Epoch [2/4], Step [13 / 20], Running Loss: {'combined': 1501.8181326729912, 'policy': 0.09045373914497239, 'value': 509.89683750697543, 'next_value': 873.8389238630023}, LR: 0.000018


  0%|          | 56/57415 [01:34<19:23:47,  1.22s/it]

Epoch [2/4], Step [15 / 20], Running Loss: {'combined': 1499.010871887207, 'policy': 0.09047550149261951, 'value': 507.54509353637695, 'next_value': 867.1075439453125}, LR: 0.000015


  0%|          | 58/57415 [01:36<19:20:50,  1.21s/it]

Epoch [2/4], Step [17 / 20], Running Loss: {'combined': 1494.731913248698, 'policy': 0.09003447575701608, 'value': 508.4146033393012, 'next_value': 859.7254876030815}, LR: 0.000013


  0%|          | 60/57415 [01:38<19:14:44,  1.21s/it]

Epoch [2/4], Step [19 / 20], Running Loss: {'combined': 1484.5274963378906, 'policy': 0.08927823938429355, 'value': 507.032698059082, 'next_value': 847.1240142822265}, LR: 0.000011
Epoch [2/4] done, Training Loss: {'combined': 1484.5274963378906, 'policy': 0.08927823938429355, 'value': 507.032698059082, 'next_value': 847.1240142822265}


20it [00:07,  2.68it/s]


Validation Loss: {'combined': 752.4904907226562, 'policy': 0.04819699060171843, 'value': 213.56237182617187, 'next_value': 569.5821685791016}
Validation Policy Acc: {'strong': tensor(0.9345, device='mps:0'), 'weak': tensor(0.9957, device='mps:0'), 'strong_nv': tensor(0.8254, device='mps:0'), 'weak_nv': tensor(0.9946, device='mps:0')}
Validation accuracy improved. Saved best model to checkpoints/vit_medium_r0_best_model.pth
Begin epoch 3 with LR: 0.000011


  0%|          | 62/57415 [01:48<45:01:50,  2.83s/it]

Epoch [3/4], Step [1 / 20], Running Loss: {'combined': 1452.849853515625, 'policy': 0.08292215317487717, 'value': 544.6606140136719, 'next_value': 789.6776733398438}, LR: 0.000009


  0%|          | 64/57415 [01:51<32:03:06,  2.01s/it]

Epoch [3/4], Step [3 / 20], Running Loss: {'combined': 1437.3251953125, 'policy': 0.08382716774940491, 'value': 518.9060440063477, 'next_value': 801.4747924804688}, LR: 0.000007


  0%|          | 66/57415 [01:53<25:39:14,  1.61s/it]

Epoch [3/4], Step [5 / 20], Running Loss: {'combined': 1424.0284220377605, 'policy': 0.08342431982358296, 'value': 510.47133382161456, 'next_value': 793.1387023925781}, LR: 0.000006


  0%|          | 68/57415 [01:56<22:30:13,  1.41s/it]

Epoch [3/4], Step [7 / 20], Running Loss: {'combined': 1429.1421813964844, 'policy': 0.08429753594100475, 'value': 506.36793518066406, 'next_value': 797.9889068603516}, LR: 0.000004


  0%|          | 70/57415 [01:58<20:57:54,  1.32s/it]

Epoch [3/4], Step [9 / 20], Running Loss: {'combined': 1420.0154052734374, 'policy': 0.08433568179607391, 'value': 496.26243591308594, 'next_value': 803.9615234375}, LR: 0.000003


  0%|          | 72/57415 [02:01<20:12:47,  1.27s/it]

Epoch [3/4], Step [11 / 20], Running Loss: {'combined': 1414.1031799316406, 'policy': 0.08402304723858833, 'value': 494.00375111897785, 'next_value': 798.6896362304688}, LR: 0.000002


  0%|          | 74/57415 [02:03<19:51:52,  1.25s/it]

Epoch [3/4], Step [13 / 20], Running Loss: {'combined': 1393.6715262276787, 'policy': 0.08333131085549082, 'value': 480.73399353027344, 'next_value': 796.2442932128906}, LR: 0.000001


  0%|          | 76/57415 [02:06<19:44:59,  1.24s/it]

Epoch [3/4], Step [15 / 20], Running Loss: {'combined': 1391.1943893432617, 'policy': 0.08350273361429572, 'value': 476.3062381744385, 'next_value': 798.6081924438477}, LR: 0.000000


  0%|          | 78/57415 [02:08<19:55:59,  1.25s/it]

Epoch [3/4], Step [17 / 20], Running Loss: {'combined': 1398.4647759331597, 'policy': 0.0836206976738241, 'value': 482.51224941677515, 'next_value': 797.4556206597222}, LR: 0.000000


  0%|          | 80/57415 [02:11<19:44:16,  1.24s/it]

Epoch [3/4], Step [19 / 20], Running Loss: {'combined': 1389.3052795410156, 'policy': 0.08318107016384602, 'value': 477.8706939697266, 'next_value': 796.2389831542969}, LR: 0.000000
Epoch [3/4] done, Training Loss: {'combined': 1389.3052795410156, 'policy': 0.08318107016384602, 'value': 477.8706939697266, 'next_value': 796.2389831542969}


20it [00:07,  2.64it/s]


Validation Loss: {'combined': 746.3404388427734, 'policy': 0.047507242672145365, 'value': 215.40302581787108, 'next_value': 558.6499664306641}
Validation Policy Acc: {'strong': tensor(0.9353, device='mps:0'), 'weak': tensor(0.9959, device='mps:0'), 'strong_nv': tensor(0.8243, device='mps:0'), 'weak_nv': tensor(0.9947, device='mps:0')}
Validation accuracy improved. Saved best model to checkpoints/vit_medium_r0_best_model.pth


  0%|          | 80/57415 [02:19<27:48:13,  1.75s/it]

Finished Training
Best validation acc achieved: 0.9353





In [54]:
# Mine hard examples
if False:
    model.to(device)
    train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=False, num_workers=4, pin_memory=True)
    hard_examples = validate_model(model, train_loader, device, max_batches=100, use_tqdm=True, mine_hard=True)
    if hard_examples:
        print(len(hard_examples[0]))
        hard_inputs, hard_targets, hard_losses = hard_examples
        hard_examples_tensor = {
            "obs": hard_inputs.cpu(),
            "targets": hard_targets.cpu(),
            "loss": hard_losses.cpu()
        }

        torch.save(hard_examples_tensor, "hard_examples.pt")

In [55]:
del train_dataset
del val_dataset

In [56]:

dataset = np.load(dataset_path, mmap_mode='r')
test_obs = torch.from_numpy(dataset["x_test"])
test_targets = torch.from_numpy(dataset["y_test"])

random_seed = 534984
indices = np.arange(len(test_obs))
np.random.seed(random_seed)
np.random.shuffle(indices)

test_obs = test_obs[indices]
test_targets = test_targets[indices]
test_dataset = TensorDataset(test_obs, test_targets)

In [57]:
model.to(device)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=4, pin_memory=True)
validate_model(model, test_loader, device, max_batches=500, use_tqdm=True) # INFO: batches decreased for presentation

500it [01:37,  5.12it/s]

Validation Loss: {'combined': 741.1239390869141, 'policy': 0.046999649949371815, 'value': 215.28080293273925, 'next_value': 558.4663616333008}
Validation Policy Acc: {'strong': tensor(0.9348, device='mps:0'), 'weak': tensor(0.9962, device='mps:0'), 'strong_nv': tensor(0.8247, device='mps:0'), 'weak_nv': tensor(0.9951, device='mps:0')}





{'strong': tensor(0.9348, device='mps:0'),
 'weak': tensor(0.9962, device='mps:0'),
 'strong_nv': tensor(0.8247, device='mps:0'),
 'weak_nv': tensor(0.9951, device='mps:0')}

In [58]:
save_model = model
save_model.to(torch.device("cpu"))
save_model.eval()
save_model(torch.randn((1, 2, 6, 7)))

example_inputs = (torch.randn(1, 2, 6, 7),)
onnx_program = torch.onnx.export(save_model, example_inputs, dynamo=True)
onnx_program.save(f"{save_model_name}.onnx")
print(f"Model saved to {save_model_name}.onnx")

[torch.onnx] Obtain model graph for `Connect4ModelVit([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `Connect4ModelVit([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...


Skipping constant folding for op SequenceEmpty with multiple outputs.
Skipping constant folding for op SequenceEmpty with multiple outputs.
Skipping constant folding for op SequenceEmpty with multiple outputs.
Skipping constant folding for op SequenceEmpty with multiple outputs.
Skipping constant folding for op SequenceEmpty with multiple outputs.


[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 1 of general pattern rewrite rules.


Skipping constant folding for op SequenceEmpty with multiple outputs.
Skipping constant folding for op SequenceEmpty with multiple outputs.
Skipping constant folding for op SequenceEmpty with multiple outputs.
Skipping constant folding for op SequenceEmpty with multiple outputs.
Skipping constant folding for op SequenceEmpty with multiple outputs.


Model saved to checkpoints/vit_medium_r0.onnx


In [59]:
import onnx

onnx_model = onnx.load(f"{save_model_name}.onnx")
onnx.checker.check_model(onnx_model)
import onnxruntime

example_inputs = (torch.randn(1, 2, 6, 7),torch.randn(1, 2, 6, 7),torch.randn(1, 2, 6, 7),)
onnx_inputs = [tensor.numpy(force=True) for tensor in example_inputs]
print(f"Input length: {len(onnx_inputs)}")
print(f"Sample input: {onnx_inputs}")

ort_session = onnxruntime.InferenceSession(
    f"./{save_model_name}.onnx", providers=["CPUExecutionProvider"]
)

onnxruntime_input = {input_arg.name: input_value for input_arg, input_value in zip(ort_session.get_inputs(), onnx_inputs)}
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
onnxruntime_outputs

Input length: 3
Sample input: [array([[[[-5.25287628e-01,  6.11568429e-02,  5.17404735e-01,
           5.96397400e-01,  3.84037286e-01, -1.59243613e-01,
           5.42762220e-01],
         [ 2.14245528e-01,  7.87356138e-01, -1.58834112e+00,
           1.75697207e-01,  8.48437369e-01, -4.95422870e-01,
           2.99675852e-01],
         [-3.49775493e-01,  3.61759448e+00, -3.38464499e-01,
           4.38362896e-01, -4.41578239e-01,  8.47287774e-01,
           6.50815010e-01],
         [-8.28610063e-01,  5.07761165e-02,  1.85775951e-01,
          -5.23240447e-01, -1.61857158e-01,  9.11434650e-01,
          -7.95467079e-01],
         [-1.02821314e+00, -1.25863028e+00,  4.69664156e-01,
           1.06807184e+00, -7.85515010e-01, -2.44781435e-01,
          -2.26164132e-01],
         [ 3.58530849e-01,  5.18474758e-01,  1.41793907e-01,
           6.62241220e-01, -7.16749191e-01,  8.76955807e-01,
          -2.35519099e+00]],

        [[-8.96165967e-01,  5.93281925e-01,  1.39522731e+00,
      

[array([[-27.231363 , -20.608765 ,  -6.5820737,  23.44077  ,  -7.0962567,
         -18.519863 , -25.27468  ]], dtype=float32),
 array([[102.15377]], dtype=float32),
 array([[-98.50069, -93.21299, -88.82956, 103.82227, -88.43998, -96.47022,
         -98.75507]], dtype=float32)]

In [None]:
save_model(example_inputs[0])

(tensor([[-27.2314, -20.6088,  -6.5821,  23.4408,  -7.0963, -18.5199, -25.2747]],
        grad_fn=<SplitWithSizesBackward0>),
 tensor([[102.1539]], grad_fn=<SplitWithSizesBackward0>),
 tensor([[-98.5007, -93.2131, -88.8296, 103.8223, -88.4400, -96.4703, -98.7551]],
        grad_fn=<SplitWithSizesBackward0>))