In [1]:
# imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import defaultdict
from collections.abc import Callable
import tqdm
from sklearn.model_selection import train_test_split
from torch.nn import CrossEntropyLoss

from utilities import (
    Embedding,
    LayerNorm,
    Dropout,
    Linear,
    DictReLU,
    pbt_init,
    pbt_update,
    Optimizer,
    get_dataloader_random_reshuffle,
    evaluate_model,
    update_model,
    get_array_minibatch,
    get_accuracy,
    AdamW
)

In [2]:
config = {
    "device": "cpu",
    "ensemble_shape": (1,),
    "float_dtype": torch.float32,
    "hyperparameter_raw_init_distributions": {
        "dropout_p": torch.distributions.Uniform(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(.01, device="cpu", dtype=torch.float32)
        ),
        "epsilon": torch.distributions.Uniform(
            torch.tensor(-10, device="cpu", dtype=torch.float32),
            torch.tensor(-5, device="cpu", dtype=torch.float32)
        ),
        "first_moment_decay": torch.distributions.Uniform(
            torch.tensor(-3, device="cpu", dtype=torch.float32),
            torch.tensor(0, device="cpu", dtype=torch.float32)
        ),
        "learning_rate": torch.distributions.Uniform(
            torch.tensor(-5, device="cpu", dtype=torch.float32),
            torch.tensor(-1, device="cpu", dtype=torch.float32)
        ),
        "second_moment_decay": torch.distributions.Uniform(
            torch.tensor(-5, device="cpu", dtype=torch.float32),
            torch.tensor(-1, device="cpu", dtype=torch.float32)
        ),
        "weight_decay": torch.distributions.Uniform(
            torch.tensor(-5, device="cpu", dtype=torch.float32),
            torch.tensor(-1, device="cpu", dtype=torch.float32)
        )
    },
    "hyperparameter_raw_perturb": {
        "dropout_p": torch.distributions.Normal(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(.01, device="cpu", dtype=torch.float32)
        ),
        "epsilon": torch.distributions.Normal(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(1, device="cpu", dtype=torch.float32)
        ),
        "first_moment_decay": torch.distributions.Normal(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(1, device="cpu", dtype=torch.float32)
        ),
        "learning_rate": torch.distributions.Normal(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(1, device="cpu", dtype=torch.float32)
        ),
        "second_moment_decay": torch.distributions.Normal(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(1, device="cpu", dtype=torch.float32)
        ),
        "weight_decay": torch.distributions.Normal(
            torch.tensor(0, device="cpu", dtype=torch.float32),
            torch.tensor(1, device="cpu", dtype=torch.float32)
        ),
    },
    "hyperparameter_transforms": {
        "dropout_p": lambda p: p.clip(0,1),
        "epsilon": lambda log10: 10 ** log10,
        "first_moment_decay": lambda x: (1 - 10 ** x).clamp(0, 1),
        "learning_rate": lambda log10: 10 ** log10,
        "second_moment_decay": lambda x: (1 - 10 ** x).clamp(0, 1),
        "weight_decay": lambda log10: 10 ** log10,
    },
    "improvement_threshold": 1e-4,
    "minibatch_size": 100,
    "pbt": True,
    "seed": 1,
    "sequence_size": 64,
    "steps_num": 10000,
    "steps_without_improvement": 1_000,
    "valid_interval": 100,
    "welch_confidence_level": .95,
    "welch_sample_size": 1024,
    "embedding_dim": 36,
    "dropout_p": torch.tensor([1], device='mps'),
    "n_heads": 3,
    "block_num": 2,
    "minibatch_size_eval": 100
}

In [3]:
torch.manual_seed(config["seed"])

<torch._C.Generator at 0x130352ad0>

In [4]:
def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d, device=config["device"])
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

pos_embed = get_positional_embeddings(64, config["embedding_dim"])[None, :]

In [5]:
class EmbeddingBlock(nn.Module):
    def __init__(
        self,
        config,
        embedding_dim
    ):
        super().__init__()
        
        self.piece_embed = Embedding(config, embedding_dim, vocabulary_size=7)
        self.player_embed = Embedding(config, embedding_dim, vocabulary_size=2)
        self.dropout = Dropout(config)
        self.pos_embed = pos_embed
        
    def forward(self, input):
        input = input[0]
        embedded_pieces = self.dropout({'features': self.piece_embed(input[...,0])})
        embedded_players = self.dropout({'features': self.player_embed(input[...,1])})
        embedded_pos = self.pos_embed
        
        output = embedded_pieces['features'] + embedded_players['features'] + embedded_pos
        
        return {'features': output.to(config["device"]), 'mask': torch.ones((config['minibatch_size'], config['sequence_size']), device=config['device']).to(torch.bool)}
        

In [6]:
class MultiHeadSelfAttentionBlock(nn.Module):
    """
    Pre-LN multi-head self-attention block.

    Parameters
    ----------
    attention_head_num : `int`
        The number of attention heads.
    config : `int`
        Configuration dictionary. Required key-value pairs:
        `"device"` : `str`
            The device to store parameters on.
        `"dropout_p"` : `torch.Tensor`
            Dropout probability tensor, of shape `ensemble_shape`.
        `"ensemble_shape"` : `tuple[int]`
            Ensemble shape.      
        `"float_dtype"` : `torch.dtype`
            The floating point datatype to use for the parameters.
    embedding_dim : `int`
        The feature dimension of internal representations.

    Calling
    -------
    Instance calls require one positional argument:
    batch : `dict`
        The input data dictionary. Required keys:
        `"features"` : `torch.Tensor`
            Tensor of element-level features, of shape
            `batch_shape + (sequence_dim, embedding_dim)` or
            `ensemble_shape + batch_shape + (sequence_dim, embedding_dim)`
        `"mask"` : `torch.Tensor`
            Mask showing which entries are not padding, of shape
            `batch_shape + (sequence_dim,)` or
            `ensemble_shape + batch_shape + (sequence_dim,)`
    """
    def __init__(
        self,
        attention_head_num: int,
        config: dict,
        embedding_dim: int
    ):
        super().__init__()

        self.attention_head_num = attention_head_num
        self.config = config
        self.dropout = Dropout(config)
        self.layer_norm = LayerNorm(
            config,
            embedding_dim
        )

        (
            self.key_weights,
            self.output_weights,
            self.query_weights,
            self.value_weights
        ) = (
            Linear(
                config,
                embedding_dim,
                embedding_dim,
                bias=False
            )
            for _ in range(4)
        )


    def forward(
        self,
        batch: dict
    ) -> dict:
        skip = batch["features"]
        batch = self.layer_norm(batch)
        residual, mask = (batch[key] for key in ("features", "mask"))

        sequence_dim, embedding_dim = residual.shape[-2:]
        key_dim = embedding_dim // self.attention_head_num

        key, query, value = (
            (
                linear(residual)
            ).reshape(
                residual.shape[:-1] + (self.attention_head_num, key_dim)
            ).transpose(-3, -2)
            for linear in (
                self.key_weights,
                self.query_weights,
                self.value_weights
            )
        )
        
        
        arange = torch.arange(sequence_dim, device=mask.device)
        attention_mask = mask[..., None, :] & mask[..., None]
        attention_mask |= (arange == arange[:, None])

        pooled_values = F.scaled_dot_product_attention(
            query,
            key,
            value,
            attention_mask[..., None, :, :]
        )

        residual = pooled_values.transpose(-3, -2).reshape(residual.shape)
        residual = self.output_weights(residual)
        residual = self.dropout({"features": residual})["features"]

        features = skip + residual

        return batch | {"features": features}

In [7]:
class FeedForwardBlock(torch.nn.Module):
    """
    Pre-LN feedforward block.

    Parameters
    ----------
    config : `int`
        Configuration dictionary. Required key-value pairs:
        `"device"` : `str`
            The device to store parameters on.
        `"dropout_p"` : `torch.Tensor`
            Dropout probability tensor, of shape `ensemble_shape`.
        `"ensemble_shape"` : `tuple[int]`
            Ensemble shape.      
        `"float_dtype"` : `torch.dtype`
            The floating point datatype to use for the parameters.
    embedding_dim : `int`
        The feature dimension of internal representations.

    Calling
    -------
    Instance calls require one positional argument:
    batch : `dict`
        The input data dictionary. Required key:
        `"features"` : `torch.Tensor`
            Tensor of element-level features, of shape
            `batch_shape + (sequence_dim, embedding_dim)` or
            `ensemble_shape + batch_shape + (sequence_dim, embedding_dim)`
    """
    def __init__(
        self,
        config: dict,
        embedding_dim: int
    ):
        super().__init__()

        self.residual_f = torch.nn.Sequential(
            LayerNorm(
                config,
                embedding_dim
            ),
            Linear(
                config,
                embedding_dim,
                embedding_dim,
                init_multiplier=2 ** .5
            ),
            DictReLU(),
            Linear(
                config,
                embedding_dim,
                embedding_dim
            ),
            Dropout(config)
        )


    def forward(self, batch: dict) -> dict:
        skip = batch

        residual = self.residual_f(batch)

        features = skip["features"] + residual["features"]

        return batch | {"features": features}

In [8]:
def train_supervised(
    config: dict,
    dataset_train: dict,
    dataset_valid: dict,
    get_loss: Callable[[dict, torch.Tensor], torch.Tensor],
    get_metric: Callable[[dict, torch.Tensor], torch.Tensor],
    model: torch.nn.Module,
    optimizer: Optimizer,
    target_key="target"
) -> dict:
    """
    Population-based training on a supervised learning task.
    Tuned hyperparameters are given by raw values and transformations.
    This way, the hyperparameters are perturbed by
    additive noise on raw values.

    Parameters
    ----------
    config : `dict`
        Configuration dictionary. Required key-value pairs:
        `"ensemble_shape"` : tuple[int]
            Ensemble shape. We assume this is a 1-dimensional tuple
            with dimensions the population size.
        `"hyperparameter_raw_init_distributions"` : `dict`
            Dictionary that maps tuned hyperparameter names
            to `torch.distributions.Distribution` of raw hyperparameter values.
            Required keys:
            `"learning_rate"`:
                The learning rate of stochastic gradient descent.
        `"hyperparameter_raw_perturbs"` : `dict`
            Dictionary that maps tuned hyperparameter names
            to `torch.distributions.Distribution` of additive noise.
        `"hyperparameter_transforms"` : `dict`
            Dictionary that maps tuned hyperparameter names
            to transformations of raw hyperparameter values.
        `"improvement_threshold"` : `float`
            A new metric score has to be this much better
            than the previous best to count as an improvement.
        `"minibatch_size"` : `int`
            Minibatch size to use in a training step.
        `"minibatch_size_eval"` : `int`
            Minibatch size to use in evaluation.
            On CPU, should be about the same as `minibatch_size`.
            On GPU, should be as big as possible without
            incurring an Out of Memory error.
        `"pbt"` : `bool`
            Whether to use PBT updates in validations.
            If `False`, the algorithm just samples hyperparameters at start,
            then keeps them constant.
        `"steps_num"` : `int`
            Maximum number of training steps.
        `"steps_without_improvement`" : `int`
            If the number of training steps without improvement
            exceeds this value, then training is stopped.
        `"valid_interval"` : `int`
            Frequency of evaluations, measured in number of training steps.
        `"welch_confidence_level"` : `float`
            The confidence level in Welch's t-test
            that is used in determining if a population member
            is to be replaced by another member with perturbed hyperparameters.
        `"welch_sample_size"` : `int`
            The last this many validation metrics are used
            in Welch's t-test.
    dataset_train : `dict`
        The dataset to train the model on.
    dataset_valid : `dict`
        The dataset to evaluate the model on.
    `get_loss` : `Callable[[dict, torch.Tensor], torch.Tensor]`
        A function that maps a pair of model output and target value tensor
        to a tensor of losses per ensemble member.
    `get_metric` : `Callable[[dict, torch.Tensor], torch.Tensor]`
        A function that maps a pair of model output and target value tensor
        to a tensor of metrics per ensemble member.
        We assume a greater metric is better.
    `model` : `torch.nn.Module`
        The model ensemble to tune.
    `optimizer` : `Optimizer`
        An optimizer that tracks the parameters of `model`.
    indptr_key : `str`, optional
        If the dataset has sequential entries,
        then this is the key of the index pointer tensor.
        Default: `"indptr"`.
    target_key : `str`, optional
        The key mapped to the target value tensor in the dataset.
        Default: `"target"`
        
    Returns
    -------
    An output dictionary with the following key-value pairs:
        `"source mask"` : `torch.Tensor`
            The source masks of population members
            that were replace by other members in a PBT update
        `"target indices"` : `torch.Tensor`
            The indices of population members
            that the member where the source mask is to were replaced with.
        `"validation metric"` : `torch.Tensor`
            The validation metrics at evaluation steps.

        In addition, for each tuned hyperparameter name,
        we include a `torch.Tensor` of values per update.
    """
    ensemble_shape = config["ensemble_shape"]
    if len(ensemble_shape) != 1:
        raise ValueError(f"The number of dimensions in the ensemble shape should be 1 for the  population size, but it is {len(ensemble_shape)}")

    config_local = dict(config)
    log = defaultdict(list)

    pbt_init(config_local, log)

    update_model(config_local, model)
    optimizer.update_config(config_local)

    best_valid_metric = -torch.inf
    progress_bar = tqdm.trange(config["steps_num"])
    steps_without_improvement = 0
    train_dataloader = get_dataloader_random_reshuffle(
        config,
        dataset_train["features"],
        dataset_train["labels"]
    )

    for step_id in progress_bar:
        model.train()
        minibatch = next(train_dataloader)
        optimizer.zero_grad()
        
        predict = model(minibatch[0])[0][...,0]
        target = minibatch[1][0]

        loss = get_loss(predict, target).sum()
        
        loss.backward()
        optimizer.step()
        
        if step_id % config["valid_interval"] == 0:
            model.eval()
            with torch.no_grad():
                split_name = "validation"
                minibatch = next(train_dataloader)
                optimizer.zero_grad()

                predict = model(minibatch[0])[0][...,0]
                target = minibatch[1][0]
                
                print(predict.argmax(dim=-1))
                print(target)
                
                metric = (predict.argmax(dim=-1) == target).to(config["float_dtype"]).mean()
                
                # log[f"{split_name} loss"].append(loss)
                log[f"{split_name} metric"].append(metric)
                # print(
                #     f"{split_name} loss {loss.min().cpu().item():.4f}"
                # )
                print(
                    f"{split_name} metric {metric.max().cpu().item():.4f}"
                )

                best_last_metric = log["validation metric"][-1].max()
                print(
                    f"Best last metric {best_last_metric.cpu().item():.2f}",
                    flush=True
                )
                if (
                    best_valid_metric + config["improvement_threshold"]
                ) < best_last_metric:
                    print(
                        f"New best metric",
                        flush=True
                    )
                    best_valid_metric = best_last_metric
                    steps_without_improvement = 0
                else:
                    print(
                        f"Best metric {best_valid_metric.cpu().item():.2f}",
                        flush=True
                    )
                    steps_without_improvement += config["valid_interval"]
                    if steps_without_improvement > config[
                        "steps_without_improvement"
                    ]:
                        break

                if config["pbt"] and (len(log["validation metric"]) >= config[
                    "welch_sample_size"
                ]):
                    evaluations = torch.stack(
                        log["validation metric"][-config["welch_sample_size"]:]
                    )
                    pbt_update(
                        config_local, evaluations, log, optimizer.get_parameters()
                    )

                    update_model(config_local, model)
                    optimizer.update_config(config_local)


    progress_bar.close()
    for key, value in log.items():
        if isinstance(value, list):
            log[key] = torch.stack(value)

    return log

In [9]:
class Transformer(nn.Module):
    def __init__(self, config):
        super(Transformer, self).__init__()
        
        embedding_dim = config["embedding_dim"]
        n_heads = config["n_heads"]
        block_num = config["block_num"]
        
        # create embedding block
        self.EmbeddingBlock = EmbeddingBlock(config, embedding_dim)
        
        # make repeated transformer block
        blocks = []
        for _ in range(block_num):
            blocks.extend([
                MultiHeadSelfAttentionBlock(n_heads, config, embedding_dim),
                FeedForwardBlock(config, embedding_dim)
            ])
        self.MHA_FF_Block = nn.Sequential(*blocks)
        
        # use Linear to create logits
        self.Linear = Linear(config, embedding_dim, 1)
        
        # use softmax
        self.softmax = nn.Softmax()
        

    def forward(self, input):
        embedded_inputs = self.EmbeddingBlock(input)
        
        out = self.MHA_FF_Block(embedded_inputs)
        
        logits = self.Linear(out)["features"]
        
        logits = self.softmax(logits)
        
        return logits

In [10]:
dataset = torch.load('data/dataset_v2.pt', map_location=config["device"])
dataset_train, dataset_valid = {}, {}
dataset_train["features"], dataset_valid["features"], dataset_train["labels"], dataset_valid["labels"] = train_test_split(
    dataset["features"], dataset["labels"][:,0], test_size=0.2, random_state=config['seed'])


In [11]:
dataset_train["features"][0]

tensor([[4, 1],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [4, 1],
        [6, 1],
        [0, 0],
        [1, 1],
        [1, 1],
        [1, 1],
        [5, 1],
        [0, 0],
        [1, 1],
        [1, 1],
        [1, 1],
        [0, 0],
        [0, 0],
        [2, 1],
        [0, 0],
        [0, 0],
        [2, 1],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [3, 1],
        [1, 1],
        [1, 1],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [3, 1],
        [0, 0],
        [0, 0],
        [1, 0],
        [0, 0],
        [1, 0],
        [1, 0],
        [0, 0],
        [1, 0],
        [0, 0],
        [1, 0],
        [3, 0],
        [1, 0],
        [5, 0],
        [2, 0],
        [1, 0],
        [3, 0],
        [1, 0],
        [4, 0],
        [2, 0],
        [0, 0],
        [0, 0],
        [6, 0],
        [0, 0],
        

In [12]:
model = Transformer(config).to(config["device"])
optimizer = AdamW(model.parameters())

log = train_supervised(
    config,
    dataset_train,
    dataset_valid,
    CrossEntropyLoss(),
    get_accuracy,
    model,
    optimizer,
    target_key="labels"
)

  0%|          | 0/10000 [00:00<?, ?it/s]

tensor([59, 38, 59, 13,  4, 11, 59, 44, 16, 18, 36,  1, 25, 52, 52, 59, 16, 52,
        55,  6, 28, 15, 36, 16, 36, 56, 30, 60, 19, 45, 11, 49, 59, 43, 36, 52,
        26, 46, 58, 51, 13, 47, 11,  0, 36, 13, 51, 50, 49, 59, 47,  5, 33, 12,
        28, 59, 20, 62, 61, 59, 31, 61, 61, 46, 41,  5,  6, 34, 27,  2, 59, 55,
        51, 54, 39, 51,  7, 51, 46, 55, 58, 59,  6, 18, 59,  4, 57, 49, 11, 18,
        24, 36,  4,  6,  6, 36, 31, 12, 29, 63])
tensor([57, 62, 54, 10,  4, 50, 59, 38, 34, 51, 13, 18, 26,  5, 11, 52, 52, 27,
        21, 51, 30, 38, 46, 50,  3, 12, 31, 29, 52, 26, 60, 14, 61, 55,  3,  2,
        32, 14, 26,  3, 60, 21, 54, 36,  8, 26,  3, 11,  3, 30, 52, 38, 27, 49,
         3, 59, 16, 19, 17, 62, 17, 40, 24, 26, 35, 38,  2, 36, 45, 42, 34,  2,
         5, 23, 60, 29, 38,  3, 58, 27, 51, 50, 14, 36, 52, 55, 38, 12, 52, 41,
        32, 28, 52, 10, 28,  5, 21, 47, 43, 26])
validation metric 0.0300
Best last metric 0.03
New best metric


  1%|          | 99/10000 [00:05<08:31, 19.34it/s]

tensor([59, 26, 41, 19, 12, 12, 11, 51, 56, 59, 59,  3,  0, 51,  7,  3, 59, 27,
        51, 52, 18, 57,  6, 44,  1,  2,  8, 45, 25, 15, 48,  3, 47, 28, 28, 56,
        59, 59, 28, 22, 10, 42,  3, 28, 59, 33, 23, 52, 33,  0,  3,  3,  3, 30,
         0, 46, 59,  3, 53, 60, 61, 18, 59,  4, 59, 11, 43, 39,  3, 13,  6, 45,
        52, 21, 51, 32,  2,  3, 16,  3, 36, 49,  0, 38,  3, 59, 51,  3,  3, 14,
        28, 12, 29,  3, 34,  3, 59, 54, 59,  2])
tensor([30, 26, 29, 12,  5, 12,  8, 56, 56, 51, 46, 36,  0, 50,  9,  1, 59, 29,
        61, 58, 18, 41,  4, 44,  1, 21, 44, 41, 48,  0, 44,  9,  5, 56, 28, 56,
        25, 52, 56, 22, 58, 26, 21, 14, 58,  9, 23, 60, 18, 10, 11, 27,  6, 23,
        15, 55, 52, 38,  5, 51, 29, 14, 55,  4, 39, 11, 43, 35, 26, 38,  4, 56,
        29, 36, 53, 49, 38, 11, 23,  3, 52, 63, 27, 56, 13, 55, 42,  2,  3, 20,
        56, 19, 29,  1, 55, 12, 57, 63, 35,  4])
validation metric 0.1800
Best last metric 0.18
New best metric


  2%|▏         | 198/10000 [00:09<07:51, 20.81it/s]

tensor([ 0, 56, 59, 32, 56, 24, 59,  0,  0,  0, 63, 23, 60, 14,  0,  0,  8, 18,
        33, 62, 40, 22, 11, 11, 56, 61, 59, 59, 62,  0,  0, 57, 26, 54, 59,  0,
         0, 53,  8, 59, 10, 60, 38, 59,  0, 62, 59, 62,  4,  0, 49, 62,  3,  0,
        12, 31,  0, 59, 59,  1, 21, 43, 15, 56, 59, 59, 44,  7, 61, 22,  0, 42,
        59, 59,  0,  0,  6, 59,  0,  0, 43, 32, 59, 19, 59,  0, 59, 58, 60, 43,
         9, 33, 60, 59, 29,  0, 59,  0,  0,  5])
tensor([26, 28, 62, 32, 41,  6, 56,  3,  1, 21, 47, 27, 57, 14, 21, 11,  8, 18,
        11, 57, 54,  4, 11,  2, 34, 52, 49, 41, 48, 11,  2, 20, 26, 39, 46,  5,
         5, 53, 28, 61, 10, 51, 16, 61,  4, 41, 52, 34,  4,  4, 39, 39, 13, 15,
        12, 27,  0, 44,  9,  9, 21, 43, 18, 45, 58, 62,  5, 19, 47, 27, 20, 42,
        58, 45, 29,  2,  4, 59,  1, 22, 41, 26, 60, 31, 60,  8, 62,  6, 58, 55,
        25, 19, 45, 51, 48,  6, 48, 14,  1,  2])
validation metric 0.1500
Best last metric 0.15
Best metric 0.18


  3%|▎         | 300/10000 [00:14<07:46, 20.79it/s]

tensor([54, 56, 28,  7, 18, 16, 25,  7, 44, 56, 50,  3, 56, 56, 19,  1, 62, 12,
        29, 11,  7, 17, 62, 26, 56,  7, 63, 14, 13, 58, 29, 54, 22, 55,  7,  8,
        56, 25, 20, 56, 22, 41, 22, 29,  7, 25, 27, 49, 51, 58, 59, 14, 38, 56,
         5, 14, 56,  9, 10, 53, 39, 63,  5, 24, 51, 15, 47, 62,  5, 52, 16, 32,
         0, 42, 59, 58, 40, 16,  7, 33, 56,  6,  7, 40,  5, 14, 52, 56, 14, 56,
         7, 28, 56, 30,  7, 56, 23, 52, 29, 16])
tensor([54, 61, 61, 13, 26,  4, 25, 14, 27, 45, 46, 34, 58, 60, 55,  2, 43, 36,
        15, 27, 12,  4, 62, 23, 41, 10, 43,  6,  4, 45,  3,  6, 27, 37, 36,  8,
        57, 18, 57, 60,  5, 13, 20, 16, 10, 25, 37,  8, 54, 61, 52, 26,  1, 52,
        26,  3, 61,  0, 31, 48, 43, 62,  9, 24, 40, 35, 37, 60, 14, 46, 43, 32,
         2, 44, 60, 42, 40, 11, 12, 59, 51, 18, 15, 61,  5,  4, 48, 61,  2, 51,
        22, 11, 44, 30, 12, 60, 31, 39, 12, 18])
validation metric 0.1000
Best last metric 0.10
Best metric 0.18


  4%|▍         | 399/10000 [00:19<07:41, 20.82it/s]

tensor([ 0, 55, 48, 25, 39, 57,  3,  0, 43, 23, 57, 20, 58,  5,  0,  6, 21, 63,
        17, 57, 35, 10, 56,  9, 32, 18,  6,  1,  7, 62, 58,  5,  5, 55, 60, 57,
        63,  0, 35, 28, 46, 55, 16, 26,  0, 41, 13,  6, 11, 40,  2, 57, 49, 24,
        55, 57, 18, 55, 57, 55, 20, 50,  6, 63, 43, 63,  0,  0,  0,  0,  8, 63,
         4,  5,  0, 56, 63,  0,  0, 27,  0,  0, 38, 55, 57, 55, 57, 18, 57,  6,
        31,  0, 61, 57,  5,  0, 52, 55, 42, 28])
tensor([ 1, 44, 51, 25, 52, 62,  3,  6, 54, 29, 60, 20, 37, 24,  5,  5, 21, 53,
        25, 50, 19, 44, 34, 62, 32, 18, 50, 19,  4, 45, 59, 19, 26, 38, 43, 62,
        51, 29, 53, 49, 58, 52, 11, 35,  1, 46, 33, 21, 22, 10, 37, 55, 49, 25,
        51, 55,  0, 62, 49, 51, 36, 49, 10, 62, 27, 48,  8,  6, 12, 11, 12, 52,
        19, 22, 14, 56, 46, 18, 21, 27,  9, 52, 47, 51, 61, 42, 49,  4, 25, 35,
        31,  5, 62, 50, 30,  6, 14, 53, 17, 55])
validation metric 0.1000
Best last metric 0.10
Best metric 0.18


  5%|▍         | 499/10000 [00:24<08:06, 19.53it/s]

tensor([48, 23,  3, 44, 47, 49,  3,  1, 44, 62, 56, 13, 11,  3, 19,  3, 33,  3,
         1,  3,  1,  2,  1,  1, 63, 18, 63,  4, 56, 17, 63, 57, 24, 38, 16, 63,
        36, 20,  0,  1, 62, 40, 60, 63, 58, 16, 16, 15,  3, 63, 16,  8, 32, 42,
        44, 18,  7, 30,  3, 63, 40,  3, 63, 59, 28,  3, 37, 16, 63, 63, 31, 41,
        30, 52,  1,  3, 53,  3, 46,  3, 36, 35, 63, 63,  3, 42, 63, 41, 62, 57,
         3,  6, 16, 44, 52, 39, 34, 14,  5, 41])
tensor([12, 19,  5, 48, 47, 39,  2,  2, 34, 51, 17, 27, 11, 26, 53,  3, 33,  6,
        10,  2, 10, 22, 12, 52, 52, 27, 55, 24, 48, 17, 59, 60, 47, 21, 17, 60,
        56, 37,  9,  1, 62, 37, 51, 52, 60, 13, 33, 36,  9, 28,  1,  8, 52, 15,
         2, 21, 28, 59,  5, 50, 51, 14, 58, 61, 19,  8, 49, 11, 44, 30,  0, 41,
        18, 13, 26,  6, 39, 29, 58, 26, 50, 45, 62, 51, 11, 43, 28, 52, 59, 50,
        12, 21, 61, 41, 45,  0, 39, 14, 13, 50])
validation metric 0.1000
Best last metric 0.10
Best metric 0.18


  6%|▌         | 599/10000 [00:29<07:30, 20.88it/s]

tensor([ 4, 50, 23, 59, 41, 21,  0, 28, 57, 17,  0, 58, 10, 59, 51, 36,  3,  0,
        63,  8, 45, 38, 39, 51, 57, 18,  0, 11,  4, 34, 29, 60, 44, 49,  0, 18,
        26, 60, 53, 59, 40, 48, 57, 11,  7, 25,  3, 31, 12, 59, 57,  0,  6, 11,
        46, 43, 52, 41, 59, 57, 42,  0,  9, 57, 47, 23, 56, 54, 59, 32, 14, 52,
        31,  0,  0, 52, 19,  0, 10,  0,  2, 24, 57, 22, 57, 27, 57,  0, 17, 30,
        59, 29, 42,  1, 24, 62, 44, 52, 35, 52])
tensor([ 4, 60, 49,  5, 45, 21, 10, 28, 53,  6,  0, 26,  5, 58, 45,  5, 16,  5,
         3, 45, 45, 13, 54, 51, 57, 41, 21, 11, 36, 40, 55, 60, 55, 49, 13,  5,
        27, 42, 36, 60, 54, 60, 51, 11,  7, 25, 23, 13, 12, 45, 62,  5, 13, 33,
        46, 43, 13, 10, 61, 45, 45, 20, 56, 52, 31, 54, 15, 55, 58, 32, 14, 55,
        31,  4, 26, 44, 19, 35,  0,  9,  3, 38, 57, 22, 61,  0, 52, 11, 17, 13,
        36,  2, 52, 55, 24, 14, 54, 45, 35, 63])
validation metric 0.2500
Best last metric 0.25
New best metric


  7%|▋         | 698/10000 [00:34<07:27, 20.76it/s]

tensor([38, 34, 58, 35, 56, 59, 56, 22, 21, 59,  0, 56, 56, 46, 16, 45,  4, 56,
        21, 56, 45,  2, 31, 18, 52, 56, 21, 52,  6, 35, 21, 21, 54, 11, 56, 49,
        61,  5, 20, 29,  6,  6, 61,  6, 46,  1, 17,  0, 21, 33, 41, 54, 35,  9,
        23, 50,  6, 52, 30, 21, 35, 56, 41, 56, 24, 43, 25, 56, 56, 21, 19, 56,
        10, 21,  7, 19, 60,  6, 35,  3,  0, 12, 30, 57,  4, 56, 21, 41, 58, 44,
        11, 27, 41,  8, 21, 21, 40, 21,  6,  0])
tensor([59, 34, 51, 35, 61, 45, 44,  9, 21, 50, 20, 35, 58, 51, 31, 44, 24, 35,
         4, 62, 56, 24, 15, 24, 61, 63, 12, 38, 10, 35, 38,  1, 54, 35, 36, 56,
        37,  5, 20, 10,  3, 10, 29, 10, 55, 26, 42,  5,  5, 62, 61, 20, 30, 44,
        33, 57,  6, 37, 30,  5, 35, 44, 40, 51, 24, 42, 25, 58, 51, 14, 62, 60,
        10, 12, 37,  2,  0,  6, 60, 28,  2, 12, 37, 62, 36, 59, 18, 46, 26, 44,
        22, 13, 52, 43,  4, 35, 40,  2, 25,  9])
validation metric 0.1700
Best last metric 0.17
Best metric 0.25


  7%|▋         | 736/10000 [00:36<07:38, 20.21it/s]


KeyboardInterrupt: 

In [None]:
dataset_train['features'].shape

torch.Size([667274, 64, 2])