In [1]:
%%writefile utils.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import math

def argmax_accuracy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """
    Calculates the count of correctly predicted samples for a batch based on
    matching the output argmax with *any* of the target argmax indices,
    returning the count as a tensor suitable for XLA accumulation.

    Args:
        output (torch.Tensor): Model output tensor of shape (n, w), where n is
                               batch size and w is the vector dimension.
        target (torch.Tensor): Target tensor of the same shape (n, w).

    Returns:
        torch.Tensor: A scalar tensor containing the count of correct predictions
                      for the batch, located on the same device as inputs.
                      Returns a zero tensor if batch size is 0.
    """
    if output.shape != target.shape:
        raise ValueError(f"Output and target tensors must have the same shape. Got {output.shape} and {target.shape}")
    if output.dim() != 2:
         raise ValueError(f"Output and target tensors must be 2D (n, w). Got {output.dim()} dimensions.")

    n, w = output.shape

    if n == 0:
        # Return a scalar tensor on the correct device
        return torch.tensor(0, dtype=torch.long, device=output.device) # Use long for counts

    # 0. Find the index of the maximum value in the output (common for both calculations)
    output_argmax = torch.argmax(output, dim=1)
    
    # --- Calculation for the original target ---

    # 1. Find the maximum value in the original target
    target_max_values = torch.max(target, dim=1, keepdim=True)[0]

    # 2. Create a boolean mask for the original target
    target_max_mask = (target == target_max_values)

    # 3. Check if the output_argmax index is True in the original target_max_mask
    # Ensure arange is on the same device as the tensors
    correct_predictions_bool_original = target_max_mask[torch.arange(n, device=output.device), output_argmax]

    # 4. Sum the boolean tensor to get the count of correct predictions for original target.
    correct_count_original = correct_predictions_bool_original.sum()

    # --- Calculation for the clipped target ---

    # 1. Clip the target values to (-1, 1)
    target_clipped = torch.clip(target, -1, 1)

    # 2. Find the maximum value in the clipped target
    target_clipped_max_values = torch.max(target_clipped, dim=1, keepdim=True)[0]

    # 3. Create a boolean mask for the clipped target
    target_clipped_max_mask = (target_clipped == target_clipped_max_values)

    # 4. Check if the output_argmax index is True in the clipped target_max_mask
    correct_predictions_bool_clipped = target_clipped_max_mask[torch.arange(n, device=output.device), output_argmax]

    # 5. Sum the boolean tensor to get the count of correct predictions for clipped target.
    correct_count_clipped = correct_predictions_bool_clipped.sum()

    return correct_count_original, correct_count_clipped

Overwriting utils.py


In [2]:
%%writefile layers.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import math

class LayerNorm2d(nn.LayerNorm):
    """ LayerNorm for channels of '2D' spatial NCHW tensors """
    def __init__(self, num_channels, eps=1e-6, affine=True):
        super().__init__(num_channels, eps=eps, elementwise_affine=affine)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.permute(0, 2, 3, 1)
        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        x = x.permute(0, 3, 1, 2)
        return x

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        reduced_channel = max(channel // reduction, 1)
        self.fc_scale = nn.Sequential(
            nn.Linear(channel, reduced_channel, bias=True),
            nn.SiLU(inplace=True),
            nn.Linear(reduced_channel, channel, bias=True),
            nn.Sigmoid()
        )
        self.fc_offset = nn.Sequential(
            nn.Linear(channel, reduced_channel, bias=True),
            nn.SiLU(inplace=True),
            nn.Linear(reduced_channel, channel, bias=True),
        )


    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y_scale = self.fc_scale(y).view(b, c, 1, 1)
        y_offset = self.fc_offset(y).view(b, c, 1, 1)
        return x * y_scale.expand_as(x) + y_offset.expand_as(x)

class InitialExtractor(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(InitialExtractor, self).__init__()
        layers = []
        layers.append(nn.BatchNorm2d(in_channels))
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, stride=1, bias=False))
        self.convs = nn.Sequential(*layers)

    def forward(self, x):
        out = self.convs(x)
        return out

class Grouped1x1SumConv(nn.Module):
    def __init__(self, in_channels, out_channels, num_groups):
        super().__init__()
        assert in_channels % num_groups == 0, "in_channels must be divisible by num_groups"

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_groups = num_groups

        self.conv = nn.Conv2d(
            in_channels,
            out_channels * num_groups,
            kernel_size=1,
            groups=num_groups,
            bias=True
        )

    def forward(self, x):
        B, _, H, W = x.shape
        G = self.num_groups
        Cout = self.out_channels

        out = self.conv(x)

        out = out.view(B, G, Cout, H, W)
        out = out.sum(dim=1)
        return out

class ConnectFourBlock(nn.Module):
    def __init__(self, in_channels, kernel_size = (5, 5), mlp_factor=4, group_factor=2):
        super().__init__()

        ks_h, ks_w = kernel_size
        hidden_dim = in_channels * mlp_factor
        layers_0 = []

        layers_0.append(nn.Conv2d(in_channels, in_channels, kernel_size=(ks_h, ks_w),  padding=(ks_h // 2, ks_w // 2), groups=in_channels, bias=False))
        layers_0.append(LayerNorm2d(in_channels))

        layers_0.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=False))
        layers_0.append(nn.SiLU(inplace=True))

        layers_0.append(nn.Conv2d(hidden_dim, in_channels, kernel_size=1, bias=False))

        self.conv_next = nn.Sequential(*layers_0)

        hidden_dim = in_channels * group_factor
        n_groups = 2 * group_factor
        ks_h, ks_w = 3, 3
        
        layers_1 = []
        layers_1.append(LayerNorm2d(in_channels))
        layers_1.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=False))
        self.proj_up = nn.Sequential(*layers_1)

        layers_2 = []
        layers_2.append(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=(ks_h, ks_w), padding=(ks_h // 2, ks_w // 2), groups=n_groups, bias=False))
        layers_2.append(Grouped1x1SumConv(hidden_dim, hidden_dim, num_groups=n_groups))
        layers_2.append(nn.SiLU(inplace=True))
        self.grouped_0 = nn.Sequential(*layers_2)

        layers_3 = []
        layers_3.append(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=(ks_h, ks_w), padding=(ks_h // 2, ks_w // 2), groups=n_groups, bias=False))
        layers_3.append(Grouped1x1SumConv(hidden_dim, hidden_dim, num_groups=n_groups))
        layers_3.append(nn.SiLU(inplace=True))
        self.grouped_1 = nn.Sequential(*layers_3)

        layers_4 = []
        layers_4.append(SEBlock(hidden_dim))
        layers_4.append(nn.Conv2d(hidden_dim, in_channels, kernel_size=1, bias=False))
        self.proj_down = nn.Sequential(*layers_4)

    def forward(self, x):
        out = x
        summand = self.conv_next(x)
        out = out + summand

        up_proj = self.proj_up(out)
        summand = self.grouped_0(up_proj)
        up_proj = up_proj + summand
        summand = self.grouped_1(up_proj)
        up_proj = up_proj + summand

        summand = self.proj_down(up_proj)
        out = out + summand
        
        return out 

Overwriting layers.py


In [3]:
%%writefile vit.py
import torch
import torch.nn as nn
import math

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    return nn.init.trunc_normal_(tensor, mean=mean, std=std, a=a, b=b)

class VisionTransformerWithTasks(nn.Module):
    def __init__(
        self,
        input_shape,
        num_layers,
        embed_dim=160,
        num_heads=5,
        mlp_dim=320,
        num_tasks=7 + 7 + 1,
    ):
        super().__init__()
        C, H, W = input_shape
        self.num_patches = H * W
        self.embed_dim = embed_dim
        self.num_tasks = num_tasks

        self.proj = nn.Linear(C, embed_dim)
        self.task_tokens = nn.Parameter(torch.randn(1, num_tasks, embed_dim))
        self.pos_embed = nn.Embedding(self.num_patches + num_tasks, embed_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=mlp_dim,
            batch_first=True,
            activation="gelu"
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.output_head = nn.Conv1d(
            in_channels=embed_dim * num_tasks,
            out_channels=num_tasks,
            kernel_size=1,
            groups=num_tasks
        )

        self._init_weights()

    def _init_weights(self):
        trunc_normal_(self.task_tokens, std=0.02)
        trunc_normal_(self.pos_embed.weight, std=0.02)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv1d):
                trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)

        task_tokens = self.task_tokens.expand(B, -1, -1)
        x = torch.cat((task_tokens, x), dim=1)

        pos_ids = torch.arange(x.size(1), device=x.device).unsqueeze(0).expand(B, -1)
        x = x + self.pos_embed(pos_ids)

        x = self.transformer(x)
        task_outputs = x[:, :self.num_tasks]

        task_outputs = task_outputs.reshape(B, self.num_tasks * self.embed_dim, 1)
        logits = self.output_head(task_outputs).squeeze(-1)

        return logits


Overwriting vit.py


In [4]:
%%writefile c4model.py
import layers
from vit import VisionTransformerWithTasks
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import math

import numpy as np

input_channels = 2
input_height = 6
input_width = 7
output_width = 7

class Connect4Model(nn.Module):
    def __init__(self):
        super(Connect4Model, self).__init__()

        extractor_layers = []
        stage_0_width = 64
        num_conv_blocks = 2
        num_transformer_layers = 5
        
        extractor_layers.append(
            layers.InitialExtractor(in_channels=input_channels, out_channels=stage_0_width),
        )
        
        extractor_layers.extend([
            layers.ConnectFourBlock(in_channels=stage_0_width,) for _ in range(num_conv_blocks)
        ])
        
        self.conv_extractor = nn.Sequential(*extractor_layers)

        self.transformer = VisionTransformerWithTasks(
            input_shape=(stage_0_width, input_height, input_width),
            num_layers=num_transformer_layers,
        )

    def forward(self, x):
        convnet_features = self.conv_extractor(x)
        concatenated_output = self.transformer(convnet_features)

        policy_outputs, value_outputs, next_value_outputs \
            = torch.split(concatenated_output, [7, 1, 7], dim=-1)
        return policy_outputs, value_outputs, next_value_outputs
    
def get_custom_model():
    return Connect4Model()

def count_model_flops(model=None):
    from torch.utils.flop_counter import FlopCounterMode
    if model is None:
        model = get_custom_model()
    print("Model Architecture:")
    print(model)
    
    dummy_input = torch.randn(1, input_channels, input_height, input_width)
    
    with FlopCounterMode(model) as count:
        dummy_output = model(dummy_input)
        total_flops = count.get_total_flops()
    
    print(f"Total FLOPS: {total_flops}")
    print("Dummy Input Shape:", dummy_input.shape)
    print("Dummy Output:", dummy_output)

Overwriting c4model.py


In [5]:
%%writefile train.py
from c4model import *
from utils import *
from transformers import get_cosine_schedule_with_warmup

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
from typing import Optional
import os
import itertools

def train_model(
    train_dataset,
    val_dataset,
    model: nn.Module,
    checkpoint_path: str = "checkpoint.pth",
    best_model_path: str = "best_model.pth",
    batch_size: int = 32,
    learning_rate: float = 0.001,
    epochs: int = 10,
    subsample_size: Optional[int] = None,
    warmup_fraction: float = 0.1,
    weight_decay: float = 0.01,
):
    
    device = torch.device("mps")

    def move_to_device(batch):
        obs, targets = batch
        obs = obs.to(device, torch.float32)
        targets = targets.to(device, torch.float32)
        return obs, targets
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    
    subsample_size = min(subsample_size, len(train_dataset))
    num_batches_per_epoch = math.floor(max(0, subsample_size) / batch_size) if batch_size > 0 else 0

    total_steps = num_batches_per_epoch * epochs
    
    mse_loss = nn.MSELoss()
    kl_criterion = nn.KLDivLoss(reduction='batchmean', log_target=True)
    
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    scheduler = get_cosine_schedule_with_warmup(optimizer, math.ceil(total_steps * warmup_fraction), total_steps)
        
    start_epoch = 0
    best_val_acc = 0
    model.to(device)
    
    # Training loop
    for epoch in range(start_epoch, epochs):
        print(f"Begin epoch {epoch} with LR: {optimizer.param_groups[0]['lr']:.6f}")
        model.train()
        running_loss = 0.0

        num_samples_in_epoch = 0
        loader_subset = itertools.islice(train_loader, num_batches_per_epoch)
        for step, batch in enumerate(loader_subset):
            inputs, labels = move_to_device(batch)
            target_policy_log_probs = F.log_softmax(labels / 10, dim=-1)
            
            optimizer.zero_grad()
            
            policy_outputs, value_outputs, next_value_outputs = model(inputs)
            policy_outputs = F.log_softmax(policy_outputs / 4, dim=-1)

            kl_loss = 10000 * kl_criterion(policy_outputs, target_policy_log_probs)
            value_loss = mse_loss(value_outputs, torch.max(labels, dim=-1).values.unsqueeze(1))
            next_value_loss = 0.1 * mse_loss(next_value_outputs, labels)
            
            loss = kl_loss + value_loss + next_value_loss

            loss.backward()
            optimizer.step()
            scheduler.step()
            batch_train_loss = loss.item()
            num_samples_in_epoch += inputs.size(0)
            running_loss += batch_train_loss * inputs.size(0)
            if (step + 1) % math.ceil(num_batches_per_epoch / 10) == 0:
                running_loss_avg = running_loss / num_samples_in_epoch
                print(f"Epoch [{epoch}/{epochs}], Step [{step} / {num_batches_per_epoch}], Running Loss: {running_loss_avg} Batch Loss: {kl_loss.item():.4f} + {value_loss.item():.4f} + {next_value_loss.item():.4f} = {batch_train_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Validation loop
        model.eval()
        num_samples_in_val = 0
        val_loss = 0.0
        val_correct_count_strong = val_correct_count_weak = 0.0
        with torch.no_grad():
            for batch in itertools.islice(val_loader, num_batches_per_epoch):
                inputs, labels = move_to_device(batch)
                target_policy_log_probs = F.log_softmax(labels / 10, dim=-1)
                
                policy_outputs, value_outputs, next_value_outputs = model(inputs)
                policy_outputs = F.log_softmax(policy_outputs / 4, dim=-1)
    
                kl_loss = 10000 * kl_criterion(policy_outputs, target_policy_log_probs)
                value_loss = mse_loss(value_outputs, torch.max(labels, dim=-1).values.unsqueeze(1))
                next_value_loss = 0.1 * mse_loss(next_value_outputs, labels)
                    
                loss = kl_loss + value_loss + next_value_loss
                
                batch_correct_count_strong, batch_correct_count_weak = argmax_accuracy(policy_outputs, labels)
                val_correct_count_strong += batch_correct_count_strong
                val_correct_count_weak += batch_correct_count_weak
                
                val_loss += loss.item() * inputs.size(0)
                num_samples_in_val += inputs.size(0)

        epoch_train_loss = running_loss / num_samples_in_epoch
        epoch_val_loss = val_loss / num_samples_in_val
        epoch_val_policy_acc = {"strong": val_correct_count_strong / num_samples_in_val, "weak": val_correct_count_weak / num_samples_in_val}

        print(f"Epoch [{epoch}/{epochs}], Training Loss: {epoch_train_loss:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation Policy Acc: {epoch_val_policy_acc}")
        
        # Save checkpoint and best model based on validation accuracy
        if epoch_val_policy_acc["strong"] > best_val_acc:
            best_val_acc = epoch_val_policy_acc["strong"]
            torch.save(model.state_dict(), best_model_path)
            print(f"Validation loss improved. Saved best model to {best_model_path}")
        torch.save(model.state_dict(), checkpoint_path)

    print("Finished Training")
    print(f"Best validation acc achieved: {best_val_acc:.4f}")

Overwriting train.py


In [6]:
import train
import c4model
import layers
import os

import sys
del sys.modules["layers"]
del sys.modules["c4model"]
del sys.modules["train"]
import c4model
import train

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
def load_dataset(dataset_path):
    import numpy as np
    from sklearn.model_selection import train_test_split
    from torch.utils.data import TensorDataset
    import gc
    import torch

    def adjust_targets(targets_np):
        mask_pos = targets_np > 0
        targets_np[mask_pos] += 100
        mask_neg = targets_np < 0
        targets_np[mask_neg] -= 100

    try:
        print("Loading dataset with mmap_mode='r'...")

        dataset = np.load(dataset_path, mmap_mode='r')
        full_train_obs = dataset["x_train"]
        full_train_targets = dataset["y_train"]

        print("Creating train/validation splits...")
        train_obs_np, val_obs_np, train_targets_np, val_targets_np = train_test_split(
            full_train_obs, full_train_targets, test_size=0.1, random_state=12505
        )

        adjust_targets(train_targets_np)
        adjust_targets(val_targets_np)

        del full_train_obs
        del full_train_targets
        del dataset
        gc.collect() # Force garbage collection

        print("Converting NumPy arrays to PyTorch tensors with float16 precision (1/4)...")
        train_obs = torch.from_numpy(train_obs_np).to(dtype=torch.float16)
        del train_obs_np
        gc.collect()

        print("Converting NumPy arrays to PyTorch tensors with float16 precision (2/4)...")
        val_obs = torch.from_numpy(val_obs_np).to(dtype=torch.float16)
        del val_obs_np
        gc.collect()

        print("Converting NumPy arrays to PyTorch tensors with float16 precision (3/4)...")
        train_targets = torch.from_numpy(train_targets_np).to(dtype=torch.float16)
        del train_targets_np
        gc.collect()

        print("Converting NumPy arrays to PyTorch tensors with float16 precision (4/4)...")
        val_targets = torch.from_numpy(val_targets_np).to(dtype=torch.float16)
        del val_targets_np
        gc.collect()

        # Create TensorDatasets for use with PyTorch DataLoaders
        train_dataset = TensorDataset(train_obs, train_targets)
        val_dataset = TensorDataset(val_obs, val_targets)

        print("Dataset preparation complete.")
        return train_dataset, val_dataset
    except Exception as e:
        print(f"could not load data: {e}")
        return None


In [8]:
import torch

def load_weights_only(filepath):
    """
    Loads only the weights (state_dict) from a .pth file.

    Args:
        filepath (str): The path to the .pth file.

    Returns:
        dict: The state dictionary containing the model weights, or None if an error occurs.
    """
    try:
        # Load the entire checkpoint
        checkpoint = torch.load(filepath)

        # Check if the loaded object is a state_dict directly
        if isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()):
            # Assume it's a state_dict if all keys are strings (common for weights)
            print(f"Successfully loaded state_dict from {filepath}")
            return checkpoint
        elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            # If it's a dictionary containing a 'state_dict' key (common for full checkpoints)
            print(f"Successfully loaded 'state_dict' from checkpoint in {filepath}")
            return checkpoint['state_dict']
        else:
            print(f"Warning: The .pth file at {filepath} does not seem to contain a standard state_dict or a checkpoint with 'state_dict'.")
            print("Attempting to return the loaded object directly. Please inspect its content.")
            return checkpoint # Return whatever was loaded, might be a raw tensor or other data

    except FileNotFoundError:
        print(f"Error: File not found at {filepath}")
        return None
    except Exception as e:
        print(f"An error occurred while loading the file: {e}")
        return None

In [9]:
model = c4model.get_custom_model()
c4model.count_model_flops(model)

Model Architecture:
Connect4Model(
  (conv_extractor): Sequential(
    (0): InitialExtractor(
      (convs): Sequential(
        (0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (1): ConnectFourBlock(
      (conv_next): Sequential(
        (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=64, bias=False)
        (1): LayerNorm2d((64,), eps=1e-06, elementwise_affine=True)
        (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): SiLU(inplace=True)
        (4): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (proj_up): Sequential(
        (0): LayerNorm2d((64,), eps=1e-06, elementwise_affine=True)
        (1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (grouped_0): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), s

  with FlopCounterMode(model) as count:


In [10]:
dataset_splits = load_dataset("./c4_data.npz")
if dataset_splits is None:
    os._exit(1)

train_dataset, val_dataset = dataset_splits

Loading dataset with mmap_mode='r'...
Creating train/validation splits...
Converting NumPy arrays to PyTorch tensors with float16 precision (1/4)...
Converting NumPy arrays to PyTorch tensors with float16 precision (2/4)...
Converting NumPy arrays to PyTorch tensors with float16 precision (3/4)...
Converting NumPy arrays to PyTorch tensors with float16 precision (4/4)...
Dataset preparation complete.


In [11]:
arch = "v1_0"

checkpoint_file = f"{arch}_checkpoint.pth"
best_model_file = f"{arch}_best_model.pth"

load_file = f"{arch}_starting_checkpoint.pth"

weights = load_weights_only(load_file)
if weights:
    print("Keys in loaded weights:", weights.keys())
    try:
        model.load_state_dict(weights)
        print("Successfully loaded weights into a new model.")
    except Exception as e:
        print(f"Could not load weights into a new model: {e}")

args = dict(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    model=model,
    batch_size=2 ** 11,
    learning_rate=1e-3,
    epochs=12,
    subsample_size=2 ** 25,
    warmup_fraction=0.2,
    weight_decay=0.06,
    checkpoint_path=checkpoint_file,
    best_model_path=best_model_file,
)

train.train_model(**args)

Error: File not found at v1_0_starting_checkpoint.pth
Begin epoch 0 with LR: 0.000000




Epoch [0/12], Step [1638 / 16384], Running Loss: 21076.174521669083 Batch Loss: 4800.4590 + 8690.2803 + 5717.6406 = 19208.3789, LR: 0.000042
Epoch [0/12], Step [3277 / 16384], Running Loss: 18928.92789107783 Batch Loss: 3331.3420 + 6303.9111 + 5063.1714 = 14698.4238, LR: 0.000083
Epoch [0/12], Step [4916 / 16384], Running Loss: 16838.458346044335 Batch Loss: 2760.4419 + 3197.9858 + 4498.6162 = 10457.0439, LR: 0.000125
Epoch [0/12], Step [6555 / 16384], Running Loss: 15010.384118332086 Batch Loss: 2618.9136 + 1916.7490 + 4140.8403 = 8676.5029, LR: 0.000167
Epoch [0/12], Step [8194 / 16384], Running Loss: 13606.58859340442 Batch Loss: 2541.7820 + 1915.8936 + 3340.4583 = 7798.1338, LR: 0.000208
Epoch [0/12], Step [9833 / 16384], Running Loss: 12872.303769477625 Batch Loss: 2832.5063 + 2080.2229 + 2361.7163 = 7274.4458, LR: 0.000250
Epoch [0/12], Step [11472 / 16384], Running Loss: 11930.67132232674 Batch Loss: 2307.3345 + 1817.3192 + 1131.4760 = 5256.1299, LR: 0.000292
Epoch [0/12], Step 

In [12]:
import os
#os._exit(00)