In [1]:
# models
from typing import List
import numpy as np
from torch import nn
from torch.nn import functional as F
import torch
from transformers import ViTConfig, ViTModel

# schedulers
from enum import auto, Enum
import math

# normalizer
import torch

# dataset
from typing import NamedTuple, Optional
import torch
import numpy as np

from tqdm.auto import tqdm

In [2]:
class WallSample(NamedTuple):
    states: torch.Tensor
    locations: torch.Tensor
    actions: torch.Tensor


class WallDataset:
    def __init__(
        self,
        data_path,
        probing=False,
        device="cuda",
    ):
        self.device = device
        self.states = np.load(f"{data_path}/states.npy", mmap_mode="r")
        self.actions = np.load(f"{data_path}/actions.npy")

        if probing:
            self.locations = np.load(f"{data_path}/locations.npy")
        else:
            self.locations = None

    def __len__(self):
        return len(self.states)

    def __getitem__(self, i):
        states = torch.from_numpy(self.states[i].copy()).float().to(self.device)
        actions = torch.from_numpy(self.actions[i].copy()).float().to(self.device)

        if self.locations is not None:
            locations = torch.from_numpy(self.locations[i].copy()).float().to(self.device)
        else:
            locations = torch.empty(0).to(self.device)

        return WallSample(states=states, locations=locations, actions=actions)


def create_wall_dataloader(
    data_path,
    probing=False,
    device="cuda",
    batch_size=64,
    train=True,
):
    ds = WallDataset(
        data_path=data_path,
        probing=probing,
        device=device,
    )

    loader = torch.utils.data.DataLoader(
        ds,
        batch_size,
        shuffle=train,
        drop_last=True,
        pin_memory=False,
    )

    return loader

In [16]:
def build_mlp(layers_dims: List[int]):
    layers = []
    for i in range(len(layers_dims) - 2):
        layers.append(nn.Linear(layers_dims[i], layers_dims[i + 1]))
        layers.append(nn.BatchNorm1d(layers_dims[i + 1]))
        layers.append(nn.ReLU(True))
    layers.append(nn.Linear(layers_dims[-2], layers_dims[-1]))
    return nn.Sequential(*layers)


class MockModel(torch.nn.Module):
    """
    Does nothing. Just for testing.
    """

    def __init__(self, device="cuda", bs=64, n_steps=17, output_dim=256):
        super().__init__()
        self.device = device
        self.bs = bs
        self.n_steps = n_steps
        self.repr_dim = 256

    def forward(self, states, actions):
        """
        Args:
            During training:
                states: [B, T, Ch, H, W]
            During inference:
                states: [B, 1, Ch, H, W]
            actions: [B, T-1, 2]

        Output:
            predictions: [B, T, D]
        """
        return torch.randn((self.bs, self.n_steps, self.repr_dim)).to(self.device)


class Encoder_ViT(nn.Module):
    def __init__(self, reprst_H, reprst_W):
        super().__init__()
        self.reprst_H = reprst_H
        self.reprst_W = reprst_W
         # Initializing a ViT vit-base-patch16-224 style configuration
        self.config_ViT = ViTConfig(
            hidden_size=self.reprst_H*self.reprst_W, # todo
            num_hidden_layers=2, # 4
            num_attention_heads=1, 
            intermediate_size=self.reprst_H*self.reprst_W*4, 
            image_size=65, 
            patch_size=13, 
            num_channels=2,
            return_dict=True
        )
        # Initializing a model (with random weights) from the vit-base-patch16-224 style configuration
        self.bareViT = ViTModel(self.config_ViT) 
        # The bare ViT Model transformer outputting raw hidden-states without any specific head on top. 
        """
        last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size))
        — Sequence of hidden-states at the output of the last layer of the model.
        
        pooler_output (torch.FloatTensor of shape (batch_size, hidden_size))
        — Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns the classification token after processing through a linear layer and a tanh activation function. The linear layer weights are trained from the next sentence prediction (classification) objective during pretraining.
        """

    def forward(self, observs):
        """
        Args:
            During training:
                observs: [B, Ch, H, W]
            During inference: will unroll the JEPA world model recurrently into the future, conditioned on "initial" observation and action sequence 
                observs: [B, Ch, H, W] ??
        Output:
            target_states: [B, reprst_H*reprst_W] or [B, reprst_H, reprst_W]
        """
        target_states = self.bareViT(observs)
        return target_states.pooler_output # [B, hidden_size] = [B, reprst_H*reprst_W]


class Predictor_1dCNN(nn.Module):
    def __init__(self, reprst_H, reprst_W):
        super().__init__()
        self.reprst_D = reprst_H*reprst_W
        self.action_projector = nn.Sequential(
            nn.Linear(2, self.reprst_D),
            nn.ReLU()
        )
        # input: [B, 2, reprst_H*reprst_W]
        self.cnn = nn.Sequential(
            nn.Conv1d(in_channels=2, out_channels=16, kernel_size=3, padding=1), # ch1: prev_states, ch2: actions_proj
            nn.ReLU(),
            nn.BatchNorm1d(16),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Conv1d(in_channels=32, out_channels=1, kernel_size=3, padding=1),
            # nn.ReLU(),
            # nn.BatchNorm2d(64),
            # nn.Conv2d(in_channels=64, out_channels=2, kernel_size=3, padding=1)
        )
        # input: [B, 1, reprst_H*reprst_W]

    def forward(self, prev_states, actions):
        """
        Args:
            During training:
                prev_states: [B, reprst_H*reprst_W]
            During inference: will unroll the JEPA world model recurrently into the future, conditioned on "initial" observation and action sequence 
                prev_states: [B, reprst_H*reprst_W]
            actions: [B, 2]
        Output:
            curr_states: [B, reprst_H*reprst_W]
        """
        actions_proj = self.action_projector(actions) # [B, reprst_H*reprst_W]
        input = torch.stack((prev_states, actions_proj), dim=1) # input: [B, 2, reprst_H*eprst_W]
        curr_states = self.cnn(input) # [B, 1, reprst_H*reprst_W]
        curr_states = curr_states.view(-1, self.reprst_D) # [B, reprst_H*reprst_W]
        return curr_states
    


class Predictor_2dCNN(nn.Module):
    def __init__(self, reprst_H, reprst_W):
        super().__init__()
        self.reprst_H = reprst_H
        self.reprst_W = reprst_W
        self.action_projector = nn.Sequential(
            nn.Linear(2, self.reprst_H * self.reprst_W),
            nn.ReLU()
        )
        # input: [B, 2, reprst_H, reprst_W]
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=16, kernel_size=3, padding=1), # ch1: prev_states, ch2: actions_proj
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, padding=1),
            # nn.ReLU(),
            # nn.BatchNorm2d(64),
            # nn.Conv2d(in_channels=64, out_channels=2, kernel_size=3, padding=1)
        )
        # input: [B, 1, reprst_H, reprst_W]

    def forward(self, prev_states, actions):
        """
        Args:
            During training:
                prev_states: [B, reprst_H, reprst_W]
            During inference: will unroll the JEPA world model recurrently into the future, conditioned on "initial" observation and action sequence 
                prev_states: [B, reprst_H, reprst_W]
            actions: [B, 2]
        Output:
            curr_states: [B, reprst_H, reprst_W]
        """
        actions_proj = self.action_projector(actions).view(-1, self.reprst_H, self.reprst_W) # [B, reprst_H, reprst_W]
        input = torch.stack((prev_states, actions_proj), dim=1) # input: [B, 2, reprst_H, reprst_W]
        curr_states = self.cnn(input) # [B, 1, reprst_H, reprst_W]
        curr_states = curr_states.view(-1, self.reprst_H, self.reprst_W) # [B, reprst_H, reprst_W]
        return curr_states


class JEPAWorldModel(nn.Module):
    def __init__(self, encoder, encoder_target, predictor, device="cuda"):
        super().__init__()
        self.encoder = encoder # todo: same or not
        self.encoder_target = encoder_target # todo: same or not
        self.predictor = predictor
        # self.funct_distance = funct_distance
        self.device = device

    def forward(self, observs, actions):
        """
        Args:
            During training:
                observs: [B(batch size), T, Ch, H, W]
            During inference: will unroll the JEPA world model recurrently into the future, conditioned on "initial" observation and action sequence 
                observs: [B, 1, Ch, H, W]
            actions: [B, T-1, 2]
        Output:
            predictions: [B, T, D ("flattened" repr_dim)]
            targets: 
        """
        Bsize, T, _, _, _ = observs.shape
        pred_states = [] 
        target_states = []
        
        states_0 = self.encoder(observs[:, 0]) # states_0: [B, D], observs[:, 0]: [B, Ch, H, W]
        pred_states_1 = self.predictor(states_0, actions[:, 0]) # pred_states_1: [B, D]
        pred_states.append(pred_states_1) # [s1]
        target_states_1 = self.encoder_target(observs[:, 1]) # target_states_1: [B, D]
        target_states.append(target_states_1) # [s1']
        
        for t in range(1, T-1):
            pred_states_t = self.predictor(pred_states[t-1], actions[:, t])
            pred_states.append(pred_states_t) # [s1, s2]
            target_states_t = self.encoder_target(observs[:, t+1])
            target_states.append(target_states_t) # [s1', s2']

        return torch.stack(pred_states, dim=1), torch.stack(target_states, dim=1) # concatenate states of different timesteps => [B, T-1, D]


class JEPAWorldModel_enc1(nn.Module):
    def __init__(self, encoder, predictor, device="cuda"):
        super().__init__()
        self.encoder = encoder # todo: same or not
        self.predictor = predictor
        # self.funct_distance = funct_distance
        self.device = device

    def forward(self, observs, actions):
        """
        Args:
            During training:
                observs: [B(batch size), T, Ch, H, W]
            During inference: will unroll the JEPA world model recurrently into the future, conditioned on "initial" observation and action sequence 
                observs: [B, 1, Ch, H, W]
            actions: [B, T-1, 2]
        Output:
            predictions: [B, T, D ("flattened" repr_dim)]
            targets: 
        """
        Bsize, T, _, _, _ = observs.shape
        pred_states = [] 
        target_states = []
        
        states_0 = self.encoder(observs[:, 0]) # states_0: [B, D], observs[:, 0]: [B, Ch, H, W]
        pred_states_1 = self.predictor(states_0, actions[:, 0]) # pred_states_1: [B, D]
        pred_states.append(pred_states_1) # [s1]
        target_states_1 = self.encoder(observs[:, 1]) # target_states_1: [B, D]
        target_states.append(target_states_1) # [s1']
        
        for t in range(1, T-1):
            pred_states_t = self.predictor(pred_states[t-1], actions[:, t])
            pred_states.append(pred_states_t) # [s1, s2]
            target_states_t = self.encoder(observs[:, t+1])
            target_states.append(target_states_t) # [s1', s2']

        return torch.stack(pred_states, dim=1), torch.stack(target_states, dim=1) # concatenate states of different timesteps => [B, T-1, D]



class Prober(torch.nn.Module):
    def __init__(
        self,
        embedding: int,
        arch: str,
        output_shape: List[int],
    ):
        super().__init__()
        self.output_dim = np.prod(output_shape)
        self.output_shape = output_shape
        self.arch = arch

        arch_list = list(map(int, arch.split("-"))) if arch != "" else []
        f = [embedding] + arch_list + [self.output_dim]
        layers = []
        for i in range(len(f) - 2):
            layers.append(torch.nn.Linear(f[i], f[i + 1]))
            layers.append(torch.nn.ReLU(True))
        layers.append(torch.nn.Linear(f[-2], f[-1]))
        self.prober = torch.nn.Sequential(*layers)

    def forward(self, e):
        output = self.prober(e)
        return output


In [17]:
class BarlowTwinsLoss(nn.Module):
    def __init__(self, lambda_=5e-3):
        """
        Barlow Twins Loss Module.

        Args:
            lambda_ (float): Scaling factor for the redundancy reduction term.
        """
        super(BarlowTwinsLoss, self).__init__()
        self.lambda_ = lambda_

    def forward(self, preds, targets):
        """
        Computes the Barlow Twins loss.

        Args:
            preds (torch.Tensor): Embeddings from the first view. Shape: (batch_size, T-1, embedding_dim).
            targets (torch.Tensor): Embeddings from the second view. Shape: (batch_size, T-1, embedding_dim).

        Returns:
            torch.tensor(np.mean(lt_loss))
        """
        batch_size, traj_length, embedding_dim = preds.shape
        total_loss = 0.0
        # lt_loss = []
        for t in range(traj_length):
            z1 = preds[:, t] # [batch_size, embedding_dim]
            z2 = preds[:, t]
        
            # Normalize embeddings
            z1 = F.normalize(z1, dim=1)
            z2 = F.normalize(z2, dim=1)
            print("z1", z1)
            
            # Cross-correlation matrix
            # batch_size = z1.size(0)
            c = (z1.T @ z2) / batch_size
            print(c)

            # Diagonal loss (invariance loss)
            identity_loss = torch.mean((torch.diag(c) - 1) ** 2)
            print("identity_loss: ", identity_loss)
    
            # Off-diagonal loss (redundancy reduction)
            off_diag = c - torch.eye(embedding_dim, device=c.device)
            off_diag_loss = torch.mean(off_diag ** 2)
            print("off_diag_loss: ", off_diag_loss)

            # Combined loss for this timestep
            timestep_loss = identity_loss + self.lambda_ * off_diag_loss
            print("timestep_loss: ", timestep_loss)
            total_loss += timestep_loss
    
            # # Identity matrix
            # on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()  # (Cii - 1)^2
            # off_diag = self.off_diagonal(c).pow_(2).sum()       # Cij^2 for i != j
    
            # # Total loss
            # loss = on_diag + self.lambda_ * off_diag
            # lt_loss.append(loss.item())
        return total_loss/traj_length

    # @staticmethod
    # def off_diagonal(x):
    #     """
    #     Extracts the off-diagonal elements of a square matrix.

    #     Args:
    #         x (torch.Tensor): Input square matrix. Shape: (embedding_dim, embedding_dim).

    #     Returns:
    #         torch.Tensor: Flattened off-diagonal elements.
    #     """
    #     n, _ = x.size()
    #     return x.flatten()[1:].view(n - 1, n + 1)[:, :-1].flatten()

class VICRegLoss(nn.Module):
    def __init__(self, lambda_=1e-2):
        super().__init__()
        self.lambda_ = lambda_

    def forward(self, predicted_states, target_states):
        predicted_states, target_states = predicted_states.to(device), target_states.to(device)
        # 1. Prediction Loss: Minimize distance between predicted and target states
        # pred_loss = F.mse_loss(torch.stack(predicted_states), torch.stack(target_states))
        pred_loss = F.mse_loss(predicted_states, target_states)
        
        # 2. Variance Loss: Encourage representations to have non-zero variance
        std_loss = self.variance_loss(predicted_states)
        
        # 3. Covariance Loss: Decorrelate representation dimensions
        cov_loss = self.covariance_loss(predicted_states)
        
        # Weighted combination of losses
        total_loss = pred_loss + 1e-2 * (std_loss + cov_loss)
        # print(f'total_loss type {total_loss.device.type}')
        return total_loss

    def variance_loss(self, representations, min_std=0.1):
        """Encourage each feature to have non-zero variance"""
        # repr_tensor = torch.stack(representations)
        representations = representations.to(device)
        std_loss = torch.relu(min_std - representations.std(dim=0)).mean()
        # print(f'std_loss type {std_loss.device.type}')
        return std_loss
    
    def covariance_loss(self, representations):
        """Decorrelate representation dimensions"""
        # repr_tensor = torch.stack(representations)
        representations = representations.to(device)
        repr_tensor = representations
        repr_tensor = repr_tensor.to(device)
        
        # Center the representations
        repr_tensor = repr_tensor - repr_tensor.mean(dim=0)
        
        # Flatten tensor (keep batch dimension intact)
        repr_tensor = repr_tensor.view(repr_tensor.shape[0], -1)
        
        # Compute covariance matrix
        cov_matrix = (repr_tensor.T @ repr_tensor) / (repr_tensor.shape[0] - 1)
        
        # Decorrelate dimensions (set diagonal to zero)
        cov_matrix.fill_diagonal_(0)
        
        # Compute loss
        cov_loss = (cov_matrix ** 2).sum()
        # print(f'cov_loss type {cov_loss.device.type}')
        return cov_loss


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

pth_name = "best_ViT_1dCNN_enc2_rpr128_ep30_lrsched10002.pth"

reprst_H=8
reprst_W=16
Enc = Encoder_ViT(reprst_H, reprst_W).to(device)
Enc_t = Encoder_ViT(reprst_H, reprst_W).to(device)
Pred = Predictor_1dCNN(reprst_H, reprst_W).to(device)
model = JEPAWorldModel(encoder=Enc, encoder_target=Enc_t, predictor=Pred).to(device)
# model = JEPAWorldModel_enc1(encoder=Enc, predictor=Pred).to(device)
# model.load_state_dict(torch.load(pth_name, weights_only=True))

# criterion = BarlowTwinsLoss()
criterion = VICRegLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=1000, T_mult=2, eta_min=1e-6)

dataset = create_wall_dataloader(data_path='./DL24FA/train', device=device, batch_size=64)

num_epochs = 30
min_loss = float('inf')
# step = 0
for epoch in range(num_epochs):
    print("Epoch ", epoch+1)
    model.train()
    total_loss = 0
    # for batch in dataset:
    for batch in tqdm(dataset, desc=""):
        observs = batch.states.to(device)
        actions = batch.actions.to(device)
        
        optimizer.zero_grad()
        
        pred_states, target_states = model(observs, actions)
        loss = criterion(pred_states, target_states) # mean of losses (across timesteps)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # if step % 100 == 0:
        #     print(f"training loss {loss.item()}")

        # step += 1
        # print(loss.item())
    mean_loss = total_loss/len(dataset)
    print(f"Training Loss: {mean_loss: .4f}")

    if mean_loss < min_loss:
        min_loss = mean_loss
        torch.save(model.state_dict(), pth_name)

Using device: cuda
Epoch  1


100%|██████████| 2297/2297 [1:22:40<00:00,  2.16s/it]


Training Loss:  333.4233
Epoch  2


100%|██████████| 2297/2297 [04:15<00:00,  9.00it/s]


Training Loss:  93.4474
Epoch  3


100%|██████████| 2297/2297 [04:15<00:00,  9.00it/s]


Training Loss:  86.4680
Epoch  4


100%|██████████| 2297/2297 [04:15<00:00,  9.01it/s]


Training Loss:  49.9601
Epoch  5


100%|██████████| 2297/2297 [04:16<00:00,  8.97it/s]


Training Loss:  43.6193
Epoch  6


100%|██████████| 2297/2297 [04:15<00:00,  8.99it/s]


Training Loss:  30.6412
Epoch  7


100%|██████████| 2297/2297 [04:15<00:00,  8.99it/s]


Training Loss:  48.9906
Epoch  8


100%|██████████| 2297/2297 [04:15<00:00,  9.00it/s]


Training Loss:  47.9562
Epoch  9


100%|██████████| 2297/2297 [04:15<00:00,  8.99it/s]


Training Loss:  41.5153
Epoch  10


100%|██████████| 2297/2297 [04:18<00:00,  8.89it/s]

Training Loss:  26.2358





In [14]:
import argparse
import dataclasses
from dataclasses import dataclass
from enum import Enum
from typing import Any, Iterable, Tuple, Union, cast, List

from omegaconf import OmegaConf

DataClass = Any
DataClassType = Any


@dataclass
class ConfigBase:
    """Base class that should handle parsing from command line,
    json, dicts.
    """

    @classmethod
    def parse_from_command_line(cls):
        return omegaconf_parse(cls)

    @classmethod
    def parse_from_file(cls, path: str):
        oc = OmegaConf.load(path)
        return cls.parse_from_dict(OmegaConf.to_container(oc))

    @classmethod
    def parse_from_command_line_deprecated(cls):
        result = DataclassArgParser(
            cls, fromfile_prefix_chars="@"
        ).parse_args_into_dataclasses()
        if len(result) > 1:
            raise RuntimeError(
                f"The following arguments were not recognized: {result[1:]}"
            )
        return result[0]

    @classmethod
    def parse_from_dict(cls, inputs):
        return DataclassArgParser._populate_dataclass_from_dict(cls, inputs.copy())

    @classmethod
    def parse_from_flat_dict(cls, inputs):
        return DataclassArgParser._populate_dataclass_from_flat_dict(cls, inputs.copy())

    def save(self, path: str):
        with open(path, "w") as f:
            OmegaConf.save(config=self, f=f)

In [15]:
from typing import List
import numpy as np
from torch import nn
from torch.nn import functional as F
import torch


def build_mlp(layers_dims: List[int]):
    layers = []
    for i in range(len(layers_dims) - 2):
        layers.append(nn.Linear(layers_dims[i], layers_dims[i + 1]))
        layers.append(nn.BatchNorm1d(layers_dims[i + 1]))
        layers.append(nn.ReLU(True))
    layers.append(nn.Linear(layers_dims[-2], layers_dims[-1]))
    return nn.Sequential(*layers)


class MockModel(torch.nn.Module):
    """
    Does nothing. Just for testing.
    """

    def __init__(self, device="cuda", bs=64, n_steps=17, output_dim=256):
        super().__init__()
        self.device = device
        self.bs = bs
        self.n_steps = n_steps
        self.repr_dim = 256

    def forward(self, states, actions):
        """
        Args:
            During training:
                states: [B, T, Ch, H, W]
            During inference:
                states: [B, 1, Ch, H, W]
            actions: [B, T-1, 2]

        Output:
            predictions: [B, T, D]
        """
        return torch.randn((self.bs, self.n_steps, self.repr_dim)).to(self.device)


class Prober(torch.nn.Module):
    def __init__(
        self,
        embedding: int,
        arch: str,
        output_shape: List[int],
    ):
        super().__init__()
        self.output_dim = np.prod(output_shape)
        self.output_shape = output_shape
        self.arch = arch

        arch_list = list(map(int, arch.split("-"))) if arch != "" else []
        f = [embedding] + arch_list + [self.output_dim]
        layers = []
        for i in range(len(f) - 2):
            layers.append(torch.nn.Linear(f[i], f[i + 1]))
            layers.append(torch.nn.ReLU(True))
        layers.append(torch.nn.Linear(f[-2], f[-1]))
        self.prober = torch.nn.Sequential(*layers)

    def forward(self, e):
        output = self.prober(e)
        return output


In [16]:
import torch


class Normalizer:
    def __init__(self):
        self.location_mean = torch.tensor([31.5863, 32.0618])
        self.location_std = torch.tensor([16.1025, 16.1353])

    def normalize_location(self, location: torch.Tensor) -> torch.Tensor:
        return (location - self.location_mean.to(location.device)) / (
            self.location_std.to(location.device) + 1e-6
        )

    def unnormalize_location(self, location: torch.Tensor) -> torch.Tensor:
        return location * self.location_std.to(location.device) + self.location_mean.to(
            location.device
        )

    def unnormalize_mse(self, mse):
        return mse * (self.location_std.to(mse.device) ** 2)

In [17]:
from enum import auto, Enum
import math


class LRSchedule(Enum):
    Constant = auto()
    Cosine = auto()


class Scheduler:
    def __init__(
        self,
        schedule: str,
        base_lr: float,
        data_loader,
        epochs: int,
        optimizer,
        batch_steps=None,
        batch_size=None,
    ):
        self.schedule = schedule
        self.base_lr = base_lr
        self.data_loader = data_loader
        self.epochs = epochs
        self.optimizer = optimizer

        if batch_size is None:
            self.batch_size = data_loader.config.batch_size
        else:
            self.batch_size = batch_size

        if batch_steps is None:
            self.batch_steps = len(data_loader)
        else:
            self.batch_steps = batch_steps

    def adjust_learning_rate(self, step: int):
        if self.schedule == LRSchedule.Constant:
            return self.base_lr
        else:
            max_steps = self.epochs * self.batch_steps
            warmup_steps = int(0.10 * max_steps)
            for param_group in self.optimizer.param_groups:
                base_lr = (
                    param_group["base_lr"] if "base_lr" in param_group else self.base_lr
                )
                base_lr = base_lr * self.batch_size / 256
                if step < warmup_steps:
                    lr = base_lr * step / warmup_steps
                else:
                    step -= warmup_steps
                    max_steps -= warmup_steps
                    q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
                    end_lr = base_lr * 0.001
                    lr = base_lr * q + end_lr * (1 - q)
                param_group["lr"] = lr
            return lr

In [18]:
from typing import NamedTuple, List, Any, Optional, Dict
from itertools import chain
from dataclasses import dataclass
import itertools
import os
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import numpy as np
from matplotlib import pyplot as plt

#from schedulers import Scheduler, LRSchedule
#from models import Prober, build_mlp
#from configs import ConfigBase

#from dataset import WallDataset
#from normalizer import Normalizer


@dataclass
class ProbingConfig(ConfigBase):
    probe_targets: str = "locations"
    lr: float = 0.0002
    epochs: int = 20
    schedule: LRSchedule = LRSchedule.Cosine
    sample_timesteps: int = 30
    prober_arch: str = "256"


class ProbeResult(NamedTuple):
    model: torch.nn.Module
    average_eval_loss: float
    eval_losses_per_step: List[float]
    plots: List[Any]


default_config = ProbingConfig()


def location_losses(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    assert pred.shape == target.shape
    mse = (pred - target).pow(2).mean(dim=0)
    return mse


class ProbingEvaluator:
    def __init__(
        self,
        device: "cuda",
        model: torch.nn.Module,
        probe_train_ds,
        probe_val_ds: dict,
        config: ProbingConfig = default_config,
        quick_debug: bool = False,
    ):
        self.device = device
        self.config = config

        self.model = model
        self.model.eval()

        self.quick_debug = quick_debug

        self.ds = probe_train_ds
        self.val_ds = probe_val_ds

        self.normalizer = Normalizer()

    def train_pred_prober(self):
        """
        Probes whether the predicted embeddings capture the future locations
        """
        repr_dim = self.model.repr_dim
        dataset = self.ds
        model = self.model

        config = self.config
        epochs = config.epochs

        if self.quick_debug:
            epochs = 1
        test_batch = next(iter(dataset))

        prober_output_shape = getattr(test_batch, "locations")[0, 0].shape
        prober = Prober(
            repr_dim,
            config.prober_arch,
            output_shape=prober_output_shape,
        ).to(self.device)

        all_parameters = []
        all_parameters += list(prober.parameters())

        optimizer_pred_prober = torch.optim.Adam(all_parameters, config.lr)

        step = 0

        batch_size = dataset.batch_size
        batch_steps = None

        scheduler = Scheduler(
            schedule=self.config.schedule,
            base_lr=config.lr,
            data_loader=dataset,
            epochs=epochs,
            optimizer=optimizer_pred_prober,
            batch_steps=batch_steps,
            batch_size=batch_size,
        )

        for epoch in tqdm(range(epochs), desc=f"Probe prediction epochs"):
            for batch in tqdm(dataset, desc="Probe prediction step"):
                ################################################################################
                # TODO: Forward pass through your model
                init_states = batch.states[:, 0:1]  # BS, 1, C, H, W
                #pred_encs = model(states=init_states, actions=batch.actions)
                pred_encs, target_states = model(batch.states, batch.actions)
                pred_encs = pred_encs.transpose(0, 1)  # # BS, T, D --> T, BS, D

                # Make sure pred_encs has shape (T, BS, D) at this point
                ################################################################################

                pred_encs = pred_encs.detach()

                n_steps = pred_encs.shape[0]
                bs = pred_encs.shape[1]

                losses_list = []

                target = getattr(batch, "locations").cuda()
                target = self.normalizer.normalize_location(target)

                if (
                    config.sample_timesteps is not None
                    and config.sample_timesteps < n_steps
                ):
                    sample_shape = (config.sample_timesteps,) + pred_encs.shape[1:]
                    # we only randomly sample n timesteps to train prober.
                    # we most likely do this to avoid OOM
                    sampled_pred_encs = torch.empty(
                        sample_shape,
                        dtype=pred_encs.dtype,
                        device=pred_encs.device,
                    )

                    sampled_target_locs = torch.empty(bs, config.sample_timesteps, 2)

                    for i in range(bs):
                        indices = torch.randperm(n_steps)[: config.sample_timesteps]
                        sampled_pred_encs[:, i, :] = pred_encs[indices, i, :]
                        sampled_target_locs[i, :] = target[i, indices]

                    pred_encs = sampled_pred_encs
                    target = sampled_target_locs.cuda()

                pred_locs = torch.stack([prober(x) for x in pred_encs], dim=1)
                target = target[:, 1:17, :]
                # print(f'target shape = {target.shape}')
                # print(f'pred_locs shape = {pred_locs.shape}')
                losses = location_losses(pred_locs, target)
                per_probe_loss = losses.mean()

                if step % 100 == 0:
                    print(f"normalized pred locations loss {per_probe_loss.item()}")

                losses_list.append(per_probe_loss)
                optimizer_pred_prober.zero_grad()
                loss = sum(losses_list)
                loss.backward()
                optimizer_pred_prober.step()

                lr = scheduler.adjust_learning_rate(step)

                step += 1

                if self.quick_debug and step > 2:
                    break

        return prober

    @torch.no_grad()
    def evaluate_all(
        self,
        prober,
    ):
        """
        Evaluates on all the different validation datasets
        """
        avg_losses = {}

        for prefix, val_ds in self.val_ds.items():
            avg_losses[prefix] = self.evaluate_pred_prober(
                prober=prober,
                val_ds=val_ds,
                prefix=prefix,
            )

        return avg_losses

    @torch.no_grad()
    def evaluate_pred_prober(
        self,
        prober,
        val_ds,
        prefix="",
    ):
        quick_debug = self.quick_debug
        config = self.config

        model = self.model
        probing_losses = []
        prober.eval()

        for idx, batch in enumerate(tqdm(val_ds, desc="Eval probe pred")):
            ################################################################################
            # TODO: Forward pass through your model
            init_states = batch.states[:, 0:1]  # BS, 1 C, H, W
            #pred_encs = model(states=init_states, actions=batch.actions)
            pred_encs, target_states = model(batch.states, batch.actions)
            # # BS, T, D --> T, BS, D
            pred_encs = pred_encs.transpose(0, 1)

            # Make sure pred_encs has shape (T, BS, D) at this point
            ################################################################################

            target = getattr(batch, "locations").cuda()
            target = self.normalizer.normalize_location(target)

            pred_locs = torch.stack([prober(x) for x in pred_encs], dim=1)
            target = target[:, 1:17, :]
            # print(f'target shape = {target.shape}')
            # print(f'pred_locs shape = {pred_locs.shape}')
            
            losses = location_losses(pred_locs, target)
            probing_losses.append(losses.cpu())

        losses_t = torch.stack(probing_losses, dim=0).mean(dim=0)
        losses_t = self.normalizer.unnormalize_mse(losses_t)

        losses_t = losses_t.mean(dim=-1)
        average_eval_loss = losses_t.mean().item()

        return average_eval_loss

In [21]:
#from dataset import create_wall_dataloader
#from evaluator import ProbingEvaluator
#import torch
#from models import MockModel
#import glob


def get_device():
    """Check for GPU availability."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    return device


def load_data(device):
    #data_path = "/scratch/DL24FA"
    data_path = "./DL24FA"

    probe_train_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/train",
        probing=True,
        device=device,
        train=True,
    )

    probe_val_normal_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_normal/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_wall_ds = create_wall_dataloader(
        data_path=f"{data_path}/probe_wall/val",
        probing=True,
        device=device,
        train=False,
    )

    probe_val_ds = {"normal": probe_val_normal_ds, "wall": probe_val_wall_ds}

    return probe_train_ds, probe_val_ds


def load_model():
    """Load or initialize the model."""
    # TODO: Replace MockModel with your trained model
    #model = MockModel()
    #return model
    
    # Load the model
    # jepa_model = JEPAWorldModel(
    # representation_dim=representation_dim, 
    # action_dim=action_dim)
    
    device = get_device()
    Enc = Encoder_ViT(reprst_H, reprst_W).to(device)
    Enc_t = Encoder_ViT(reprst_H, reprst_W).to(device)
    Pred = Predictor_1dCNN(reprst_H, reprst_W).to(device)
    jepa_model = JEPAWorldModel(encoder=Enc, encoder_target=Enc_t, predictor=Pred).to(device)
    # jepa_model = JEPAWorldModel_enc1(encoder=Enc, predictor=Pred).to(device)
    jepa_model.load_state_dict(torch.load(pth_name))
    
    # Put the model in evaluation mode (if needed)
    jepa_model.eval()
    return jepa_model


def evaluate_model(device, model, probe_train_ds, probe_val_ds):
    evaluator = ProbingEvaluator(
        device=device,
        model=model,
        probe_train_ds=probe_train_ds,
        probe_val_ds=probe_val_ds,
        quick_debug=False,
    )

    prober = evaluator.train_pred_prober()

    avg_losses = evaluator.evaluate_all(prober=prober)

    for probe_attr, loss in avg_losses.items():
        print(f"{probe_attr} loss: {loss}")


#if __name__ == "__main__":
device = get_device()
probe_train_ds, probe_val_ds = load_data(device)
model = load_model()
evaluate_model(device, model, probe_train_ds, probe_val_ds)

Using device: cuda
Using device: cuda


  jepa_model.load_state_dict(torch.load("best_ViT_1dCNN_.pth"))
Probe prediction epochs:   0%|          | 0/20 [00:00<?, ?it/s]
Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s][A
Probe prediction step:   1%|          | 1/156 [00:02<06:44,  2.61s/it][A

normalized pred locations loss 1.0078556537628174



Probe prediction step:   1%|▏         | 2/156 [00:04<06:09,  2.40s/it][A
Probe prediction step:   2%|▏         | 3/156 [00:07<05:59,  2.35s/it][A
Probe prediction step:   3%|▎         | 4/156 [00:09<05:44,  2.26s/it][A
Probe prediction step:   3%|▎         | 5/156 [00:11<05:26,  2.16s/it][A
Probe prediction step:   4%|▍         | 6/156 [00:13<05:41,  2.28s/it][A
Probe prediction step:   4%|▍         | 7/156 [00:15<05:33,  2.24s/it][A
Probe prediction step:   5%|▌         | 8/156 [00:18<05:24,  2.19s/it][A
Probe prediction step:   6%|▌         | 9/156 [00:19<04:58,  2.03s/it][A
Probe prediction step:   6%|▋         | 10/156 [00:21<04:51,  2.00s/it][A
Probe prediction step:   7%|▋         | 11/156 [00:23<04:40,  1.93s/it][A
Probe prediction step:   8%|▊         | 12/156 [00:25<04:37,  1.93s/it][A
Probe prediction step:   8%|▊         | 13/156 [00:27<04:44,  1.99s/it][A
Probe prediction step:   9%|▉         | 14/156 [00:29<05:03,  2.14s/it][A
Probe prediction step:  10%|▉   

normalized pred locations loss 1.0741963386535645



Probe prediction step:  65%|██████▌   | 102/156 [01:15<00:14,  3.64it/s][A
Probe prediction step:  66%|██████▌   | 103/156 [01:15<00:15,  3.38it/s][A
Probe prediction step:  67%|██████▋   | 104/156 [01:15<00:13,  3.82it/s][A
Probe prediction step:  67%|██████▋   | 105/156 [01:15<00:12,  3.98it/s][A
Probe prediction step:  68%|██████▊   | 106/156 [01:16<00:15,  3.19it/s][A
Probe prediction step:  69%|██████▊   | 107/156 [01:16<00:12,  3.90it/s][A
Probe prediction step:  69%|██████▉   | 108/156 [01:16<00:13,  3.52it/s][A
Probe prediction step:  70%|██████▉   | 109/156 [01:17<00:13,  3.55it/s][A
Probe prediction step:  71%|███████   | 110/156 [01:17<00:17,  2.59it/s][A
Probe prediction step:  71%|███████   | 111/156 [01:18<00:17,  2.51it/s][A
Probe prediction step:  72%|███████▏  | 112/156 [01:18<00:16,  2.60it/s][A
Probe prediction step:  72%|███████▏  | 113/156 [01:18<00:15,  2.83it/s][A
Probe prediction step:  73%|███████▎  | 114/156 [01:18<00:11,  3.57it/s][A
Probe predi

normalized pred locations loss 1.0286595821380615


Probe prediction step:  31%|███       | 48/156 [00:03<00:07, 14.94it/s][A
Probe prediction step:  32%|███▏      | 50/156 [00:03<00:07, 14.98it/s][A
Probe prediction step:  33%|███▎      | 52/156 [00:03<00:07, 13.96it/s][A
Probe prediction step:  35%|███▍      | 54/156 [00:03<00:07, 13.66it/s][A
Probe prediction step:  36%|███▌      | 56/156 [00:03<00:07, 14.06it/s][A
Probe prediction step:  37%|███▋      | 58/156 [00:04<00:06, 14.40it/s][A
Probe prediction step:  38%|███▊      | 60/156 [00:04<00:06, 14.63it/s][A
Probe prediction step:  40%|███▉      | 62/156 [00:04<00:06, 14.83it/s][A
Probe prediction step:  41%|████      | 64/156 [00:04<00:06, 14.92it/s][A
Probe prediction step:  42%|████▏     | 66/156 [00:04<00:06, 14.99it/s][A
Probe prediction step:  44%|████▎     | 68/156 [00:04<00:05, 15.10it/s][A
Probe prediction step:  45%|████▍     | 70/156 [00:04<00:06, 13.41it/s][A
Probe prediction step:  46%|████▌     | 72/156 [00:05<00:06, 13.88it/s][A
Probe prediction step:  4

normalized pred locations loss 1.1558361053466797



Probe prediction step:  95%|█████████▍| 148/156 [00:10<00:00, 14.93it/s][A
Probe prediction step:  96%|█████████▌| 150/156 [00:10<00:00, 15.01it/s][A
Probe prediction step:  97%|█████████▋| 152/156 [00:10<00:00, 13.56it/s][A
Probe prediction step:  99%|█████████▊| 154/156 [00:10<00:00, 14.00it/s][A
Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.45it/s][A
Probe prediction epochs:  10%|█         | 2/20 [01:42<13:14, 44.13s/it]
Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s][A
Probe prediction step:   1%|▏         | 2/156 [00:00<00:09, 15.67it/s][A
Probe prediction step:   3%|▎         | 4/156 [00:00<00:09, 15.34it/s][A
Probe prediction step:   4%|▍         | 6/156 [00:00<00:09, 15.31it/s][A
Probe prediction step:   5%|▌         | 8/156 [00:00<00:09, 15.34it/s][A
Probe prediction step:   6%|▋         | 10/156 [00:00<00:09, 15.29it/s][A
Probe prediction step:   8%|▊         | 12/156 [00:00<00:10, 13.41it/s][A
Probe prediction step:   9%|▉      

normalized pred locations loss 0.9558184146881104


[A
Probe prediction step:  60%|██████    | 94/156 [00:06<00:04, 13.24it/s][A
Probe prediction step:  62%|██████▏   | 96/156 [00:06<00:04, 13.75it/s][A
Probe prediction step:  63%|██████▎   | 98/156 [00:06<00:04, 14.18it/s][A
Probe prediction step:  64%|██████▍   | 100/156 [00:06<00:03, 14.49it/s][A
Probe prediction step:  65%|██████▌   | 102/156 [00:07<00:03, 14.69it/s][A
Probe prediction step:  67%|██████▋   | 104/156 [00:07<00:03, 14.89it/s][A
Probe prediction step:  68%|██████▊   | 106/156 [00:07<00:03, 14.96it/s][A
Probe prediction step:  69%|██████▉   | 108/156 [00:07<00:03, 14.88it/s][A
Probe prediction step:  71%|███████   | 110/156 [00:07<00:03, 13.44it/s][A
Probe prediction step:  72%|███████▏  | 112/156 [00:07<00:03, 13.91it/s][A
Probe prediction step:  73%|███████▎  | 114/156 [00:07<00:02, 14.25it/s][A
Probe prediction step:  74%|███████▍  | 116/156 [00:07<00:02, 14.54it/s][A
Probe prediction step:  76%|███████▌  | 118/156 [00:08<00:02, 14.72it/s][A
Probe predi

normalized pred locations loss 1.011620044708252



Probe prediction step:  23%|██▎       | 36/156 [00:02<00:08, 13.48it/s][A
Probe prediction step:  24%|██▍       | 38/156 [00:02<00:08, 13.98it/s][A
Probe prediction step:  26%|██▌       | 40/156 [00:02<00:08, 14.30it/s][A
Probe prediction step:  27%|██▋       | 42/156 [00:02<00:07, 14.55it/s][A
Probe prediction step:  28%|██▊       | 44/156 [00:03<00:07, 14.79it/s][A
Probe prediction step:  29%|██▉       | 46/156 [00:03<00:07, 14.88it/s][A
Probe prediction step:  31%|███       | 48/156 [00:03<00:07, 15.01it/s][A
Probe prediction step:  32%|███▏      | 50/156 [00:03<00:07, 15.08it/s][A
Probe prediction step:  33%|███▎      | 52/156 [00:03<00:07, 13.44it/s][A
Probe prediction step:  35%|███▍      | 54/156 [00:03<00:07, 13.89it/s][A
Probe prediction step:  36%|███▌      | 56/156 [00:03<00:07, 14.21it/s][A
Probe prediction step:  37%|███▋      | 58/156 [00:04<00:06, 14.55it/s][A
Probe prediction step:  38%|███▊      | 60/156 [00:04<00:06, 14.75it/s][A
Probe prediction step:  

normalized pred locations loss 0.8777765035629272



Probe prediction step:  87%|████████▋ | 136/156 [00:09<00:01, 14.00it/s][A
Probe prediction step:  88%|████████▊ | 138/156 [00:09<00:01, 14.31it/s][A
Probe prediction step:  90%|████████▉ | 140/156 [00:09<00:01, 14.58it/s][A
Probe prediction step:  91%|█████████ | 142/156 [00:09<00:00, 14.79it/s][A
Probe prediction step:  92%|█████████▏| 144/156 [00:09<00:00, 14.89it/s][A
Probe prediction step:  94%|█████████▎| 146/156 [00:10<00:00, 14.97it/s][A
Probe prediction step:  95%|█████████▍| 148/156 [00:10<00:00, 15.06it/s][A
Probe prediction step:  96%|█████████▌| 150/156 [00:10<00:00, 13.40it/s][A
Probe prediction step:  97%|█████████▋| 152/156 [00:10<00:00, 13.88it/s][A
Probe prediction step:  99%|█████████▊| 154/156 [00:10<00:00, 14.24it/s][A
Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.45it/s][A
Probe prediction epochs:  20%|██        | 4/20 [02:04<05:47, 21.74s/it]
Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s][A
Probe prediction step:   

normalized pred locations loss 0.8836227655410767


Probe prediction step:  51%|█████▏    | 80/156 [00:05<00:05, 14.12it/s][A
Probe prediction step:  53%|█████▎    | 82/156 [00:05<00:05, 14.41it/s][A
Probe prediction step:  54%|█████▍    | 84/156 [00:05<00:04, 14.66it/s][A
Probe prediction step:  55%|█████▌    | 86/156 [00:05<00:04, 14.69it/s][A
Probe prediction step:  56%|█████▋    | 88/156 [00:06<00:04, 14.86it/s][A
Probe prediction step:  58%|█████▊    | 90/156 [00:06<00:04, 15.00it/s][A
Probe prediction step:  59%|█████▉    | 92/156 [00:06<00:04, 13.46it/s][A
Probe prediction step:  60%|██████    | 94/156 [00:06<00:04, 13.92it/s][A
Probe prediction step:  62%|██████▏   | 96/156 [00:06<00:04, 14.31it/s][A
Probe prediction step:  63%|██████▎   | 98/156 [00:06<00:03, 14.54it/s][A
Probe prediction step:  64%|██████▍   | 100/156 [00:06<00:03, 14.71it/s][A
Probe prediction step:  65%|██████▌   | 102/156 [00:07<00:03, 14.86it/s][A
Probe prediction step:  67%|██████▋   | 104/156 [00:07<00:03, 14.99it/s][A
Probe prediction step:

normalized pred locations loss 0.9607294797897339



Probe prediction step:  16%|█▌        | 25/156 [00:01<00:08, 14.62it/s][A
Probe prediction step:  17%|█▋        | 27/156 [00:01<00:08, 14.78it/s][A
Probe prediction step:  19%|█▊        | 29/156 [00:02<00:08, 14.95it/s][A
Probe prediction step:  20%|█▉        | 31/156 [00:02<00:08, 15.02it/s][A
Probe prediction step:  21%|██        | 33/156 [00:02<00:08, 14.95it/s][A
Probe prediction step:  22%|██▏       | 35/156 [00:02<00:08, 13.45it/s][A
Probe prediction step:  24%|██▎       | 37/156 [00:02<00:08, 13.91it/s][A
Probe prediction step:  25%|██▌       | 39/156 [00:02<00:08, 14.28it/s][A
Probe prediction step:  26%|██▋       | 41/156 [00:02<00:07, 14.59it/s][A
Probe prediction step:  28%|██▊       | 43/156 [00:02<00:07, 14.74it/s][A
Probe prediction step:  29%|██▉       | 45/156 [00:03<00:07, 14.85it/s][A
Probe prediction step:  30%|███       | 47/156 [00:03<00:07, 15.00it/s][A
Probe prediction step:  31%|███▏      | 49/156 [00:03<00:07, 15.05it/s][A
Probe prediction step:  

normalized pred locations loss 0.7957510948181152



Probe prediction step:  80%|████████  | 125/156 [00:08<00:02, 14.88it/s][A
Probe prediction step:  81%|████████▏ | 127/156 [00:08<00:01, 14.98it/s][A
Probe prediction step:  83%|████████▎ | 129/156 [00:08<00:01, 15.06it/s][A
Probe prediction step:  84%|████████▍ | 131/156 [00:09<00:01, 14.82it/s][A
Probe prediction step:  85%|████████▌ | 133/156 [00:09<00:01, 13.66it/s][A
Probe prediction step:  87%|████████▋ | 135/156 [00:09<00:01, 14.08it/s][A
Probe prediction step:  88%|████████▊ | 137/156 [00:09<00:01, 14.44it/s][A
Probe prediction step:  89%|████████▉ | 139/156 [00:09<00:01, 14.65it/s][A
Probe prediction step:  90%|█████████ | 141/156 [00:09<00:01, 14.79it/s][A
Probe prediction step:  92%|█████████▏| 143/156 [00:09<00:00, 14.91it/s][A
Probe prediction step:  93%|█████████▎| 145/156 [00:10<00:00, 15.01it/s][A
Probe prediction step:  94%|█████████▍| 147/156 [00:10<00:00, 15.05it/s][A
Probe prediction step:  96%|█████████▌| 149/156 [00:10<00:00, 13.43it/s][A
Probe predi

normalized pred locations loss 1.1322768926620483


[A
Probe prediction step:  45%|████▍     | 70/156 [00:04<00:05, 14.90it/s][A
Probe prediction step:  46%|████▌     | 72/156 [00:04<00:05, 14.98it/s][A
Probe prediction step:  47%|████▋     | 74/156 [00:05<00:05, 13.67it/s][A
Probe prediction step:  49%|████▊     | 76/156 [00:05<00:05, 13.92it/s][A
Probe prediction step:  50%|█████     | 78/156 [00:05<00:05, 14.29it/s][A
Probe prediction step:  51%|█████▏    | 80/156 [00:05<00:05, 14.54it/s][A
Probe prediction step:  53%|█████▎    | 82/156 [00:05<00:05, 14.78it/s][A
Probe prediction step:  54%|█████▍    | 84/156 [00:05<00:04, 14.89it/s][A
Probe prediction step:  55%|█████▌    | 86/156 [00:05<00:04, 14.98it/s][A
Probe prediction step:  56%|█████▋    | 88/156 [00:06<00:04, 15.10it/s][A
Probe prediction step:  58%|█████▊    | 90/156 [00:06<00:04, 13.98it/s][A
Probe prediction step:  59%|█████▉    | 92/156 [00:06<00:04, 13.76it/s][A
Probe prediction step:  60%|██████    | 94/156 [00:06<00:04, 14.21it/s][A
Probe prediction step

normalized pred locations loss 0.8892500400543213


[A
Probe prediction step:   9%|▉         | 14/156 [00:00<00:09, 15.22it/s][A
Probe prediction step:  10%|█         | 16/156 [00:01<00:09, 14.06it/s][A
Probe prediction step:  12%|█▏        | 18/156 [00:01<00:09, 13.86it/s][A
Probe prediction step:  13%|█▎        | 20/156 [00:01<00:09, 14.24it/s][A
Probe prediction step:  14%|█▍        | 22/156 [00:01<00:09, 14.55it/s][A
Probe prediction step:  15%|█▌        | 24/156 [00:01<00:08, 14.72it/s][A
Probe prediction step:  17%|█▋        | 26/156 [00:01<00:08, 14.88it/s][A
Probe prediction step:  18%|█▊        | 28/156 [00:01<00:08, 15.02it/s][A
Probe prediction step:  19%|█▉        | 30/156 [00:02<00:08, 15.08it/s][A
Probe prediction step:  21%|██        | 32/156 [00:02<00:08, 15.01it/s][A
Probe prediction step:  22%|██▏       | 34/156 [00:02<00:09, 13.46it/s][A
Probe prediction step:  23%|██▎       | 36/156 [00:02<00:08, 13.98it/s][A
Probe prediction step:  24%|██▍       | 38/156 [00:02<00:08, 14.32it/s][A
Probe prediction step

normalized pred locations loss 1.0338062047958374


Probe prediction step:  72%|███████▏  | 112/156 [00:07<00:02, 15.07it/s][A
Probe prediction step:  73%|███████▎  | 114/156 [00:07<00:03, 13.88it/s][A
Probe prediction step:  74%|███████▍  | 116/156 [00:08<00:02, 13.82it/s][A
Probe prediction step:  76%|███████▌  | 118/156 [00:08<00:02, 14.23it/s][A
Probe prediction step:  77%|███████▋  | 120/156 [00:08<00:02, 14.55it/s][A
Probe prediction step:  78%|███████▊  | 122/156 [00:08<00:02, 14.73it/s][A
Probe prediction step:  79%|███████▉  | 124/156 [00:08<00:02, 14.87it/s][A
Probe prediction step:  81%|████████  | 126/156 [00:08<00:01, 15.01it/s][A
Probe prediction step:  82%|████████▏ | 128/156 [00:08<00:01, 15.05it/s][A
Probe prediction step:  83%|████████▎ | 130/156 [00:08<00:01, 14.79it/s][A
Probe prediction step:  85%|████████▍ | 132/156 [00:09<00:01, 13.64it/s][A
Probe prediction step:  86%|████████▌ | 134/156 [00:09<00:01, 14.06it/s][A
Probe prediction step:  87%|████████▋ | 136/156 [00:09<00:01, 14.37it/s][A
Probe predic

normalized pred locations loss 0.9716324806213379



Probe prediction step:  36%|███▌      | 56/156 [00:03<00:06, 14.83it/s][A
Probe prediction step:  37%|███▋      | 58/156 [00:04<00:07, 13.35it/s][A
Probe prediction step:  38%|███▊      | 60/156 [00:04<00:06, 13.88it/s][A
Probe prediction step:  40%|███▉      | 62/156 [00:04<00:06, 14.24it/s][A
Probe prediction step:  41%|████      | 64/156 [00:04<00:06, 14.48it/s][A
Probe prediction step:  42%|████▏     | 66/156 [00:04<00:06, 14.72it/s][A
Probe prediction step:  44%|████▎     | 68/156 [00:04<00:05, 14.85it/s][A
Probe prediction step:  45%|████▍     | 70/156 [00:04<00:05, 14.96it/s][A
Probe prediction step:  46%|████▌     | 72/156 [00:04<00:05, 15.09it/s][A
Probe prediction step:  47%|████▋     | 74/156 [00:05<00:06, 13.39it/s][A
Probe prediction step:  49%|████▊     | 76/156 [00:05<00:05, 13.82it/s][A
Probe prediction step:  50%|█████     | 78/156 [00:05<00:05, 14.24it/s][A
Probe prediction step:  51%|█████▏    | 80/156 [00:05<00:05, 14.55it/s][A
Probe prediction step:  

normalized pred locations loss 1.0036547183990479



Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.44it/s][A
Probe prediction epochs:  45%|████▌     | 9/20 [02:58<02:14, 12.25s/it]
Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s][A
Probe prediction step:   1%|▏         | 2/156 [00:00<00:10, 15.39it/s][A
Probe prediction step:   3%|▎         | 4/156 [00:00<00:09, 15.29it/s][A
Probe prediction step:   4%|▍         | 6/156 [00:00<00:09, 15.33it/s][A
Probe prediction step:   5%|▌         | 8/156 [00:00<00:09, 15.28it/s][A
Probe prediction step:   6%|▋         | 10/156 [00:00<00:09, 15.22it/s][A
Probe prediction step:   8%|▊         | 12/156 [00:00<00:09, 15.23it/s][A
Probe prediction step:   9%|▉         | 14/156 [00:00<00:09, 15.20it/s][A
Probe prediction step:  10%|█         | 16/156 [00:01<00:10, 13.49it/s][A
Probe prediction step:  12%|█▏        | 18/156 [00:01<00:09, 13.96it/s][A
Probe prediction step:  13%|█▎        | 20/156 [00:01<00:09, 14.37it/s][A
Probe prediction step:  14%|█▍        |

normalized pred locations loss 1.0258128643035889



Probe prediction step:  64%|██████▍   | 100/156 [00:06<00:04, 13.93it/s][A
Probe prediction step:  65%|██████▌   | 102/156 [00:07<00:03, 14.26it/s][A
Probe prediction step:  67%|██████▋   | 104/156 [00:07<00:03, 14.53it/s][A
Probe prediction step:  68%|██████▊   | 106/156 [00:07<00:03, 14.65it/s][A
Probe prediction step:  69%|██████▉   | 108/156 [00:07<00:03, 14.73it/s][A
Probe prediction step:  71%|███████   | 110/156 [00:07<00:03, 14.80it/s][A
Probe prediction step:  72%|███████▏  | 112/156 [00:07<00:03, 14.58it/s][A
Probe prediction step:  73%|███████▎  | 114/156 [00:07<00:03, 13.52it/s][A
Probe prediction step:  74%|███████▍  | 116/156 [00:08<00:02, 13.90it/s][A
Probe prediction step:  76%|███████▌  | 118/156 [00:08<00:02, 14.27it/s][A
Probe prediction step:  77%|███████▋  | 120/156 [00:08<00:02, 14.51it/s][A
Probe prediction step:  78%|███████▊  | 122/156 [00:08<00:02, 14.66it/s][A
Probe prediction step:  79%|███████▉  | 124/156 [00:08<00:02, 14.82it/s][A
Probe predi

normalized pred locations loss 1.011622667312622



Probe prediction step:  29%|██▉       | 46/156 [00:03<00:07, 14.73it/s][A
Probe prediction step:  31%|███       | 48/156 [00:03<00:07, 14.92it/s][A
Probe prediction step:  32%|███▏      | 50/156 [00:03<00:07, 15.00it/s][A
Probe prediction step:  33%|███▎      | 52/156 [00:03<00:06, 15.03it/s][A
Probe prediction step:  35%|███▍      | 54/156 [00:03<00:07, 13.58it/s][A
Probe prediction step:  36%|███▌      | 56/156 [00:03<00:07, 13.87it/s][A
Probe prediction step:  37%|███▋      | 58/156 [00:04<00:06, 14.24it/s][A
Probe prediction step:  38%|███▊      | 60/156 [00:04<00:06, 14.57it/s][A
Probe prediction step:  40%|███▉      | 62/156 [00:04<00:06, 14.73it/s][A
Probe prediction step:  41%|████      | 64/156 [00:04<00:06, 14.84it/s][A
Probe prediction step:  42%|████▏     | 66/156 [00:04<00:06, 14.98it/s][A
Probe prediction step:  44%|████▎     | 68/156 [00:04<00:05, 15.08it/s][A
Probe prediction step:  45%|████▍     | 70/156 [00:04<00:06, 14.07it/s][A
Probe prediction step:  

normalized pred locations loss 0.9969220757484436



Probe prediction step:  92%|█████████▏| 144/156 [00:10<00:01, 11.57it/s][A
Probe prediction step:  94%|█████████▎| 146/156 [00:10<00:00, 12.49it/s][A
Probe prediction step:  95%|█████████▍| 148/156 [00:10<00:00, 13.20it/s][A
Probe prediction step:  96%|█████████▌| 150/156 [00:10<00:00, 12.31it/s][A
Probe prediction step:  97%|█████████▋| 152/156 [00:10<00:00, 13.07it/s][A
Probe prediction step:  99%|█████████▊| 154/156 [00:10<00:00, 13.63it/s][A
Probe prediction step: 100%|██████████| 156/156 [00:11<00:00, 14.14it/s][A
Probe prediction epochs:  55%|█████▌    | 11/20 [03:19<01:44, 11.58s/it]
Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s][A
Probe prediction step:   1%|▏         | 2/156 [00:00<00:09, 15.73it/s][A
Probe prediction step:   3%|▎         | 4/156 [00:00<00:09, 15.41it/s][A
Probe prediction step:   4%|▍         | 6/156 [00:00<00:09, 15.31it/s][A
Probe prediction step:   5%|▌         | 8/156 [00:00<00:09, 15.32it/s][A
Probe prediction step:   6%|▋   

normalized pred locations loss 0.9631943106651306


[A
Probe prediction step:  58%|█████▊    | 90/156 [00:06<00:04, 15.04it/s][A
Probe prediction step:  59%|█████▉    | 92/156 [00:06<00:04, 13.49it/s][A
Probe prediction step:  60%|██████    | 94/156 [00:06<00:04, 13.92it/s][A
Probe prediction step:  62%|██████▏   | 96/156 [00:06<00:04, 14.29it/s][A
Probe prediction step:  63%|██████▎   | 98/156 [00:06<00:03, 14.57it/s][A
Probe prediction step:  64%|██████▍   | 100/156 [00:06<00:03, 14.73it/s][A
Probe prediction step:  65%|██████▌   | 102/156 [00:07<00:03, 14.86it/s][A
Probe prediction step:  67%|██████▋   | 104/156 [00:07<00:03, 14.93it/s][A
Probe prediction step:  68%|██████▊   | 106/156 [00:07<00:03, 15.05it/s][A
Probe prediction step:  69%|██████▉   | 108/156 [00:07<00:03, 13.39it/s][A
Probe prediction step:  71%|███████   | 110/156 [00:07<00:03, 13.85it/s][A
Probe prediction step:  72%|███████▏  | 112/156 [00:07<00:03, 14.27it/s][A
Probe prediction step:  73%|███████▎  | 114/156 [00:07<00:02, 14.53it/s][A
Probe predict

normalized pred locations loss 1.0017292499542236



Probe prediction step:  21%|██        | 33/156 [00:02<00:08, 15.06it/s][A
Probe prediction step:  22%|██▏       | 35/156 [00:02<00:09, 13.36it/s][A
Probe prediction step:  24%|██▎       | 37/156 [00:02<00:08, 13.88it/s][A
Probe prediction step:  25%|██▌       | 39/156 [00:02<00:08, 14.24it/s][A
Probe prediction step:  26%|██▋       | 41/156 [00:02<00:07, 14.49it/s][A
Probe prediction step:  28%|██▊       | 43/156 [00:02<00:07, 14.75it/s][A
Probe prediction step:  29%|██▉       | 45/156 [00:03<00:07, 14.88it/s][A
Probe prediction step:  30%|███       | 47/156 [00:03<00:07, 14.97it/s][A
Probe prediction step:  31%|███▏      | 49/156 [00:03<00:07, 15.05it/s][A
Probe prediction step:  33%|███▎      | 51/156 [00:03<00:07, 13.50it/s][A
Probe prediction step:  34%|███▍      | 53/156 [00:03<00:07, 13.94it/s][A
Probe prediction step:  35%|███▌      | 55/156 [00:03<00:07, 14.32it/s][A
Probe prediction step:  37%|███▋      | 57/156 [00:03<00:06, 14.62it/s][A
Probe prediction step:  

normalized pred locations loss 1.0851588249206543



Probe prediction step:  85%|████████▌ | 133/156 [00:09<00:01, 13.52it/s][A
Probe prediction step:  87%|████████▋ | 135/156 [00:09<00:01, 14.00it/s][A
Probe prediction step:  88%|████████▊ | 137/156 [00:09<00:01, 14.36it/s][A
Probe prediction step:  89%|████████▉ | 139/156 [00:09<00:01, 14.60it/s][A
Probe prediction step:  90%|█████████ | 141/156 [00:09<00:01, 14.81it/s][A
Probe prediction step:  92%|█████████▏| 143/156 [00:09<00:00, 14.89it/s][A
Probe prediction step:  93%|█████████▎| 145/156 [00:10<00:00, 14.95it/s][A
Probe prediction step:  94%|█████████▍| 147/156 [00:10<00:00, 15.07it/s][A
Probe prediction step:  96%|█████████▌| 149/156 [00:10<00:00, 13.50it/s][A
Probe prediction step:  97%|█████████▋| 151/156 [00:10<00:00, 13.90it/s][A
Probe prediction step:  98%|█████████▊| 153/156 [00:10<00:00, 14.30it/s][A
Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.44it/s][A
Probe prediction epochs:  65%|██████▌   | 13/20 [03:41<01:18, 11.17s/it]
Probe predicti

normalized pred locations loss 1.0340242385864258



Probe prediction step:  49%|████▊     | 76/156 [00:05<00:05, 13.60it/s][A
Probe prediction step:  50%|█████     | 78/156 [00:05<00:05, 14.08it/s][A
Probe prediction step:  51%|█████▏    | 80/156 [00:05<00:05, 14.36it/s][A
Probe prediction step:  53%|█████▎    | 82/156 [00:05<00:05, 14.57it/s][A
Probe prediction step:  54%|█████▍    | 84/156 [00:05<00:04, 14.80it/s][A
Probe prediction step:  55%|█████▌    | 86/156 [00:05<00:04, 14.96it/s][A
Probe prediction step:  56%|█████▋    | 88/156 [00:06<00:04, 15.04it/s][A
Probe prediction step:  58%|█████▊    | 90/156 [00:06<00:04, 15.09it/s][A
Probe prediction step:  59%|█████▉    | 92/156 [00:06<00:04, 13.57it/s][A
Probe prediction step:  60%|██████    | 94/156 [00:06<00:04, 13.99it/s][A
Probe prediction step:  62%|██████▏   | 96/156 [00:06<00:04, 14.32it/s][A
Probe prediction step:  63%|██████▎   | 98/156 [00:06<00:03, 14.61it/s][A
Probe prediction step:  64%|██████▍   | 100/156 [00:06<00:03, 14.77it/s][A
Probe prediction step: 

normalized pred locations loss 1.0574798583984375



Probe prediction step:  13%|█▎        | 21/156 [00:01<00:09, 14.18it/s][A
Probe prediction step:  15%|█▍        | 23/156 [00:01<00:09, 14.51it/s][A
Probe prediction step:  16%|█▌        | 25/156 [00:01<00:08, 14.71it/s][A
Probe prediction step:  17%|█▋        | 27/156 [00:01<00:08, 14.83it/s][A
Probe prediction step:  19%|█▊        | 29/156 [00:02<00:08, 14.97it/s][A
Probe prediction step:  20%|█▉        | 31/156 [00:02<00:08, 15.02it/s][A
Probe prediction step:  21%|██        | 33/156 [00:02<00:08, 14.45it/s][A
Probe prediction step:  22%|██▏       | 35/156 [00:02<00:08, 13.59it/s][A
Probe prediction step:  24%|██▎       | 37/156 [00:02<00:08, 14.01it/s][A
Probe prediction step:  25%|██▌       | 39/156 [00:02<00:08, 14.32it/s][A
Probe prediction step:  26%|██▋       | 41/156 [00:02<00:07, 14.56it/s][A
Probe prediction step:  28%|██▊       | 43/156 [00:02<00:07, 14.81it/s][A
Probe prediction step:  29%|██▉       | 45/156 [00:03<00:07, 14.92it/s][A
Probe prediction step:  

normalized pred locations loss 0.9633973836898804



Probe prediction step:  78%|███████▊  | 121/156 [00:08<00:02, 14.50it/s][A
Probe prediction step:  79%|███████▉  | 123/156 [00:08<00:02, 14.69it/s][A
Probe prediction step:  80%|████████  | 125/156 [00:08<00:02, 14.82it/s][A
Probe prediction step:  81%|████████▏ | 127/156 [00:08<00:01, 14.98it/s][A
Probe prediction step:  83%|████████▎ | 129/156 [00:08<00:01, 15.04it/s][A
Probe prediction step:  84%|████████▍ | 131/156 [00:09<00:01, 14.99it/s][A
Probe prediction step:  85%|████████▌ | 133/156 [00:09<00:01, 13.50it/s][A
Probe prediction step:  87%|████████▋ | 135/156 [00:09<00:01, 13.94it/s][A
Probe prediction step:  88%|████████▊ | 137/156 [00:09<00:01, 14.29it/s][A
Probe prediction step:  89%|████████▉ | 139/156 [00:09<00:01, 14.57it/s][A
Probe prediction step:  90%|█████████ | 141/156 [00:09<00:01, 14.74it/s][A
Probe prediction step:  92%|█████████▏| 143/156 [00:09<00:00, 14.86it/s][A
Probe prediction step:  93%|█████████▎| 145/156 [00:10<00:00, 15.01it/s][A
Probe predi

normalized pred locations loss 1.022871494293213


Probe prediction step:  41%|████      | 64/156 [00:04<00:06, 14.58it/s][A
Probe prediction step:  42%|████▏     | 66/156 [00:04<00:06, 14.74it/s][A
Probe prediction step:  44%|████▎     | 68/156 [00:04<00:05, 14.86it/s][A
Probe prediction step:  45%|████▍     | 70/156 [00:04<00:05, 14.94it/s][A
Probe prediction step:  46%|████▌     | 72/156 [00:04<00:05, 15.07it/s][A
Probe prediction step:  47%|████▋     | 74/156 [00:05<00:06, 13.65it/s][A
Probe prediction step:  49%|████▊     | 76/156 [00:05<00:05, 13.92it/s][A
Probe prediction step:  50%|█████     | 78/156 [00:05<00:05, 14.31it/s][A
Probe prediction step:  51%|█████▏    | 80/156 [00:05<00:05, 14.56it/s][A
Probe prediction step:  53%|█████▎    | 82/156 [00:05<00:05, 14.71it/s][A
Probe prediction step:  54%|█████▍    | 84/156 [00:05<00:04, 14.90it/s][A
Probe prediction step:  55%|█████▌    | 86/156 [00:05<00:04, 14.99it/s][A
Probe prediction step:  56%|█████▋    | 88/156 [00:06<00:04, 15.03it/s][A
Probe prediction step:  5

normalized pred locations loss 0.9700708389282227


Probe prediction step:   5%|▌         | 8/156 [00:00<00:09, 15.20it/s][A
Probe prediction step:   6%|▋         | 10/156 [00:00<00:09, 15.16it/s][A
Probe prediction step:   8%|▊         | 12/156 [00:00<00:09, 15.22it/s][A
Probe prediction step:   9%|▉         | 14/156 [00:00<00:09, 15.19it/s][A
Probe prediction step:  10%|█         | 16/156 [00:01<00:10, 13.50it/s][A
Probe prediction step:  12%|█▏        | 18/156 [00:01<00:09, 13.85it/s][A
Probe prediction step:  13%|█▎        | 20/156 [00:01<00:09, 14.20it/s][A
Probe prediction step:  14%|█▍        | 22/156 [00:01<00:09, 14.49it/s][A
Probe prediction step:  15%|█▌        | 24/156 [00:01<00:08, 14.74it/s][A
Probe prediction step:  17%|█▋        | 26/156 [00:01<00:08, 14.86it/s][A
Probe prediction step:  18%|█▊        | 28/156 [00:01<00:08, 14.95it/s][A
Probe prediction step:  19%|█▉        | 30/156 [00:02<00:08, 15.05it/s][A
Probe prediction step:  21%|██        | 32/156 [00:02<00:08, 13.85it/s][A
Probe prediction step:  22

normalized pred locations loss 0.9893603920936584



Probe prediction step:  69%|██████▉   | 108/156 [00:07<00:03, 14.34it/s][A
Probe prediction step:  71%|███████   | 110/156 [00:07<00:03, 14.56it/s][A
Probe prediction step:  72%|███████▏  | 112/156 [00:07<00:02, 14.71it/s][A
Probe prediction step:  73%|███████▎  | 114/156 [00:07<00:03, 13.30it/s][A
Probe prediction step:  74%|███████▍  | 116/156 [00:08<00:02, 13.80it/s][A
Probe prediction step:  76%|███████▌  | 118/156 [00:08<00:02, 14.18it/s][A
Probe prediction step:  77%|███████▋  | 120/156 [00:08<00:02, 14.47it/s][A
Probe prediction step:  78%|███████▊  | 122/156 [00:08<00:02, 14.70it/s][A
Probe prediction step:  79%|███████▉  | 124/156 [00:08<00:02, 14.81it/s][A
Probe prediction step:  81%|████████  | 126/156 [00:08<00:02, 14.93it/s][A
Probe prediction step:  82%|████████▏ | 128/156 [00:08<00:01, 15.04it/s][A
Probe prediction step:  83%|████████▎ | 130/156 [00:09<00:01, 13.67it/s][A
Probe prediction step:  85%|████████▍ | 132/156 [00:09<00:01, 13.83it/s][A
Probe predi

normalized pred locations loss 1.0972964763641357


Probe prediction step:  33%|███▎      | 52/156 [00:03<00:06, 14.94it/s][A
Probe prediction step:  35%|███▍      | 54/156 [00:03<00:06, 14.97it/s][A
Probe prediction step:  36%|███▌      | 56/156 [00:03<00:07, 13.56it/s][A
Probe prediction step:  37%|███▋      | 58/156 [00:04<00:07, 13.98it/s][A
Probe prediction step:  38%|███▊      | 60/156 [00:04<00:06, 14.25it/s][A
Probe prediction step:  40%|███▉      | 62/156 [00:04<00:06, 14.52it/s][A
Probe prediction step:  41%|████      | 64/156 [00:04<00:06, 14.76it/s][A
Probe prediction step:  42%|████▏     | 66/156 [00:04<00:06, 14.87it/s][A
Probe prediction step:  44%|████▎     | 68/156 [00:04<00:05, 14.96it/s][A
Probe prediction step:  45%|████▍     | 70/156 [00:04<00:05, 15.06it/s][A
Probe prediction step:  46%|████▌     | 72/156 [00:04<00:06, 13.66it/s][A
Probe prediction step:  47%|████▋     | 74/156 [00:05<00:05, 13.93it/s][A
Probe prediction step:  49%|████▊     | 76/156 [00:05<00:05, 14.32it/s][A
Probe prediction step:  5

normalized pred locations loss 0.9173845052719116



Probe prediction step:  97%|█████████▋| 152/156 [00:10<00:00, 14.86it/s][A
Probe prediction step:  99%|█████████▊| 154/156 [00:10<00:00, 13.35it/s][A
Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.41it/s][A
Probe prediction epochs:  90%|█████████ | 18/20 [04:35<00:21, 10.86s/it]
Probe prediction step:   0%|          | 0/156 [00:00<?, ?it/s][A
Probe prediction step:   1%|▏         | 2/156 [00:00<00:09, 15.42it/s][A
Probe prediction step:   3%|▎         | 4/156 [00:00<00:09, 15.23it/s][A
Probe prediction step:   4%|▍         | 6/156 [00:00<00:09, 15.27it/s][A
Probe prediction step:   5%|▌         | 8/156 [00:00<00:09, 15.20it/s][A
Probe prediction step:   6%|▋         | 10/156 [00:00<00:09, 15.20it/s][A
Probe prediction step:   8%|▊         | 12/156 [00:00<00:09, 15.25it/s][A
Probe prediction step:   9%|▉         | 14/156 [00:00<00:10, 13.54it/s][A
Probe prediction step:  10%|█         | 16/156 [00:01<00:10, 13.98it/s][A
Probe prediction step:  12%|█▏      

normalized pred locations loss 0.9471026659011841



Probe prediction step:  62%|██████▏   | 96/156 [00:06<00:04, 13.49it/s][A
Probe prediction step:  63%|██████▎   | 98/156 [00:06<00:04, 13.93it/s][A
Probe prediction step:  64%|██████▍   | 100/156 [00:06<00:03, 14.24it/s][A
Probe prediction step:  65%|██████▌   | 102/156 [00:07<00:03, 14.56it/s][A
Probe prediction step:  67%|██████▋   | 104/156 [00:07<00:03, 14.73it/s][A
Probe prediction step:  68%|██████▊   | 106/156 [00:07<00:03, 14.86it/s][A
Probe prediction step:  69%|██████▉   | 108/156 [00:07<00:03, 14.97it/s][A
Probe prediction step:  71%|███████   | 110/156 [00:07<00:03, 15.05it/s][A
Probe prediction step:  72%|███████▏  | 112/156 [00:07<00:03, 13.48it/s][A
Probe prediction step:  73%|███████▎  | 114/156 [00:07<00:03, 13.96it/s][A
Probe prediction step:  74%|███████▍  | 116/156 [00:08<00:02, 14.30it/s][A
Probe prediction step:  76%|███████▌  | 118/156 [00:08<00:02, 14.53it/s][A
Probe prediction step:  77%|███████▋  | 120/156 [00:08<00:02, 14.71it/s][A
Probe predict

normalized pred locations loss 1.072992205619812



Probe prediction step:  26%|██▌       | 40/156 [00:02<00:08, 13.80it/s][A
Probe prediction step:  27%|██▋       | 42/156 [00:02<00:08, 14.18it/s][A
Probe prediction step:  28%|██▊       | 44/156 [00:03<00:07, 14.51it/s][A
Probe prediction step:  29%|██▉       | 46/156 [00:03<00:07, 14.70it/s][A
Probe prediction step:  31%|███       | 48/156 [00:03<00:07, 14.81it/s][A
Probe prediction step:  32%|███▏      | 50/156 [00:03<00:07, 14.97it/s][A
Probe prediction step:  33%|███▎      | 52/156 [00:03<00:06, 15.03it/s][A
Probe prediction step:  35%|███▍      | 54/156 [00:03<00:07, 13.53it/s][A
Probe prediction step:  36%|███▌      | 56/156 [00:03<00:07, 14.01it/s][A
Probe prediction step:  37%|███▋      | 58/156 [00:04<00:06, 14.33it/s][A
Probe prediction step:  38%|███▊      | 60/156 [00:04<00:06, 14.55it/s][A
Probe prediction step:  40%|███▉      | 62/156 [00:04<00:06, 14.77it/s][A
Probe prediction step:  41%|████      | 64/156 [00:04<00:06, 14.93it/s][A
Probe prediction step:  

normalized pred locations loss 0.8841239213943481


Probe prediction step:  90%|████████▉ | 140/156 [00:09<00:01, 14.26it/s][A
Probe prediction step:  91%|█████████ | 142/156 [00:09<00:00, 14.55it/s][A
Probe prediction step:  92%|█████████▏| 144/156 [00:09<00:00, 14.76it/s][A
Probe prediction step:  94%|█████████▎| 146/156 [00:10<00:00, 14.88it/s][A
Probe prediction step:  95%|█████████▍| 148/156 [00:10<00:00, 15.02it/s][A
Probe prediction step:  96%|█████████▌| 150/156 [00:10<00:00, 15.06it/s][A
Probe prediction step:  97%|█████████▋| 152/156 [00:10<00:00, 13.48it/s][A
Probe prediction step:  99%|█████████▊| 154/156 [00:10<00:00, 13.98it/s][A
Probe prediction step: 100%|██████████| 156/156 [00:10<00:00, 14.39it/s][A
Probe prediction epochs: 100%|██████████| 20/20 [04:57<00:00, 14.85s/it]
Eval probe pred: 100%|██████████| 62/62 [00:06<00:00, 10.00it/s]
Eval probe pred: 100%|██████████| 62/62 [00:06<00:00,  9.92it/s]

normal loss: 253.65716552734375
wall loss: 207.3687744140625



