In [1]:
from deepset_attention import *
MODEL = DeepSet(in_features=5, feats=[80,120,70,50,8], n_class=2,pool="mean") #for full dataset Mean pooling, for small dataset Max pooling
model=MODEL
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=model.to(device) 


pool:  mean
pool:  mean
pool:  mean
pool:  mean
pool:  mean
pool:  mean


In [2]:
import torch
from torch.utils.data import Dataset
import pickle
import os
import numpy as np
import random
import pandas as pd

class StreamingHcaDataset(Dataset): 
    def __init__(self, proton_dir, pion_dir, features=["x", "y", "z", "total_energy","mean_time"]):
        super().__init__()
        
        self.proton_files = sorted([os.path.join(proton_dir, f) for f in os.listdir(proton_dir) if f.endswith(".pkl")])
        self.pion_files = sorted([os.path.join(pion_dir, f) for f in os.listdir(pion_dir) if f.endswith(".pkl")])

        self.features = features
        self.all_files = self.proton_files + self.pion_files  # Combine file lists
        self.labels = [0] * len(self.proton_files) + [1] * len(self.pion_files)  # 0 for proton, 1 for pion

    def __len__(self):
        return len(self.all_files)  # Total number of files

    def _load_file(self, file_path, label):
        """Loads a single pickle file (containing a single DataFrame) and returns point cloud data with labels."""
        # with open(file_path, "rb") as f:
        #     df = pickle.load(f)  # Load single DataFrame
    
        df=pd.read_pickle(file_path)
        # df=pd.read_parquet(file_path,engine='fastparquet')
        df=df[df["total_energy"]>5]
        
          
        part_feat = df[self.features].to_numpy()

        # Handle NaN and Inf values
        part_feat[np.isnan(part_feat)] = 0.0
        part_feat[np.isinf(part_feat)] = 0.0

        return {
            "part": torch.tensor(part_feat, dtype=torch.float32),
            "label": torch.tensor(label, dtype=torch.long),
            "seq_length": torch.tensor(part_feat.shape[0], dtype=torch.long),
        }

    def __getitem__(self, idx):
        random_idx = random.randint(0, len(self.all_files) - 1)  # Pick a random file
        file_path = self.all_files[random_idx]
        label = self.labels[random_idx]

        return self._load_file(file_path, label)  # Return data from the chosen file


In [3]:
import torch
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    """Custom collate function to handle variable-length point cloud data."""
    parts = [item["part"] for item in batch]  # List of tensors (each of shape [N, 5])
    labels = torch.tensor([item["label"] for item in batch], dtype=torch.long)  # Convert list to tensor
    seq_lengths = torch.tensor([item["seq_length"] for item in batch], dtype=torch.long)  # Convert list to tensor

    # Pad variable-length tensors to the longest sequence in the batch
    padded_parts = pad_sequence(parts, batch_first=True, padding_value=0.0)  # Shape [batch_size, max_seq_len, 5]

    return {"part": padded_parts, "label": labels, "seq_length": seq_lengths}


In [4]:

energy="25"
granularity="200"
pion_dir =f"/mnt/c/Users/hnayak/Documents/{energy}GeV/small_PKL_pion_{energy}GeV_{granularity}"
proton_dir =f"/mnt/c/Users/hnayak/Documents/{energy}GeV/small_PKL_proton_{energy}GeV_{granularity}"
name=proton_dir.replace(f"/mnt/c/Users/hnayak/Documents/{energy}GeV/small_PKL_proton_","")
print(name)
from torch.utils.data import random_split, DataLoader

# Define dataset
train_dataset = StreamingHcaDataset(proton_dir=proton_dir, pion_dir=pion_dir)

# Define split sizes
total_size = len(train_dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size  # Ensure the sum matches total_size

# Split dataset
train_set, val_set, test_set = random_split(train_dataset, [train_size, val_size, test_size])

# Define DataLoaders
train_loader = DataLoader(train_set, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=32)
val_loader = DataLoader(val_set, batch_size=1, shuffle=False, collate_fn=collate_fn,num_workers=32)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False, collate_fn=collate_fn,num_workers=32)


print(f"Train: {train_size}, Validation: {val_size}, Test: {test_size}")



25GeV_200
Train: 160000, Validation: 20000, Test: 20000


In [5]:
import tqdm
def test_model(model, test_loader, criterion=None, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    model.eval()  # Set model to evaluation mode

    total_loss = 0.0
    correct = 0
    total_samples = 0

    # Progress bar for testing
    test_loader_tqdm = tqdm(enumerate(test_loader), total=len(test_loader), desc="Testing")

    with torch.no_grad():
        for i,batch in test_loader_tqdm:
            parts = batch["part"].to(device)         # Input point cloud data
            labels = batch["label"].to(device)  # Labels
            batch_size,seq_len,feat_dim=parts.shape
            parts=parts.cpu().numpy().reshape(-1,feat_dim)
            qt = QuantileTransformer(output_distribution='normal', random_state=42)
            parts = qt.fit_transform(parts)
            parts=torch.tensor(parts).reshape(batch_size,seq_len,feat_dim).to(device)

            outputs = model(parts)  # Forward pass
            loss = criterion(outputs, labels) if criterion else 0  # Compute loss if provided
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)  # Get class prediction
            correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)

            test_loader_tqdm.set_postfix(loss=loss.item())  # Update progress bar

    avg_loss = total_loss / len(test_loader) if criterion else 0
    accuracy = correct / total_samples * 100

    print(f"Test Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
    return avg_loss , accuracy

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import pandas as pd
from sklearn.preprocessing import QuantileTransformer

def train_model(model, train_loader, val_loader, num_epochs=2, learning_rate=5e-4, device=None, save_path=f"./Models/Z_{name}.pth", log_path=f"./Logs/log_summary_Z_{name}.csv"):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_val_loss = float("inf")  # Initialize best loss

    log_data = []  # To store log info for CSV

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        train_loader_tqdm = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")

        for i, batch in train_loader_tqdm:
            parts = batch["part"].to(device)
            batch_size, seq_len, feat_dim = parts.shape
            parts = parts.cpu().numpy().reshape(-1, feat_dim)
            qt = QuantileTransformer(output_distribution='normal', random_state=42)
            parts = qt.fit_transform(parts)
            parts = torch.tensor(parts).reshape(batch_size, seq_len, feat_dim).float().to(device)

            labels = batch["label"].to(device)

            optimizer.zero_grad()
            outputs = model(parts)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            train_loader_tqdm.set_postfix(loss=loss.item())

        avg_train_loss = running_loss / len(train_loader)
        val_loss, Accuracy = test_model(model, val_loader, criterion, device)
         

        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")

        # Save log
        log_data.append({
            "Epoch": epoch + 1,
            "Train Loss": avg_train_loss,
            "Val Loss": val_loss,
            "Accuracy": Accuracy,  # Placeholder for accuracy
        })

        # Save model if validation improves
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            print(f"Model saved at epoch {epoch+1} with val loss {val_loss:.4f}")

    # Save the log to CSV
    # df_log = pd.DataFrame(log_data)
    # df_log.to_csv(log_path, index=False)
    print(f"Training log saved to {log_path}")
    print("Training complete!")


In [7]:

# if __name__ == "__main__":
#     #Train the model
#     train_model(model, train_loader,val_loader=val_loader, num_epochs=60, learning_rate=5e-4, device=device)

In [8]:
model_test = MODEL
model_name=f"./Models/Z_{name}.pth"
print(model_name)
model_test.load_state_dict(torch.load(f"./Models/Z_{name}.pth",weights_only=True))
model_test.to(device)

./Models/Z_25GeV_200.pth


DeepSet(
  (sequential): ModuleList(
    (0): DeepSetLayer(
      (Gamma): Linear(in_features=5, out_features=80, bias=True)
      (Lambda): Linear(in_features=5, out_features=80, bias=True)
      (bn): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    )
    (1): LeakyReLU(negative_slope=0.01)
    (2): DeepSetLayer(
      (Gamma): Linear(in_features=80, out_features=120, bias=True)
      (Lambda): Linear(in_features=80, out_features=120, bias=True)
      (bn): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    )
    (3): LeakyReLU(negative_slope=0.01)
    (4): DeepSetLayer(
      (Gamma): Linear(in_features=120, out_features=70, bias=True)
      (Lambda): Linear(in_features=120, out_features=70, bias=True)
      (bn): BatchNorm1d(70, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    )
    (5): LeakyReLU(negative_slope=0.01)
    (6): DeepSetLayer(
      (Gamma): Linear(in_features=70, out_features=5

In [9]:
total_params = sum(p.numel() for p in MODEL.parameters())
print(total_params)

45292


In [10]:
from sklearn.metrics import confusion_matrix, roc_curve, auc
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from scipy.stats import beta
from matplotlib.font_manager import FontProperties

bold_font = FontProperties(weight='bold', size=14)

def evaluate_model(model, data_loader, criterion, device, name1="model", return_accuracy=False):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    data_loader_tqdm = tqdm(enumerate(data_loader), desc="Testing", total=len(data_loader))
    
    true_labels = []
    pred_labels = []
    all_scores = []

    with torch.no_grad():
        for i, batch in data_loader_tqdm:
            parts = batch["part"].to(device)
            batch_size, seq_len, feat_dim = parts.shape
            parts_np = parts.cpu().numpy().reshape(-1, feat_dim)

            qt = QuantileTransformer(output_distribution='normal', random_state=42)
            parts_np = qt.fit_transform(parts_np)
            parts = torch.tensor(parts_np).reshape(batch_size, seq_len, feat_dim).to(device)

            labels = batch["label"].to(device)
            outputs = model(parts)
            preds = F.softmax(outputs, dim=1)

            # Collect soft scores for the positive class (class 1)
            all_scores.extend(preds[:, 1].detach().cpu().numpy())
            pred_labels.extend(torch.argmax(preds, dim=-1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

            loss = criterion(outputs, labels)
            total_loss += loss.item()

            if return_accuracy:
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

    avg_loss = total_loss / len(data_loader)
    accuracy = 100 * correct / total if return_accuracy else None

    # Confusion Matrix
    cm = confusion_matrix(true_labels, pred_labels)
    plt.figure(figsize=(10, 8))
    colors = ["#cce5ff", "#004c99"]
    cmap = LinearSegmentedColormap.from_list("Custom Blue", colors)
    sns.heatmap(cm, annot=True, fmt='d', cmap=cmap,
                xticklabels=["Proton", "Pion"], yticklabels=["Proton", "Pion"],
                annot_kws={"size": 18, "weight": "bold"})
    
    TP, FN = cm[0, 0], cm[0, 1]
    FP, TN = cm[1, 0], cm[1, 1]
    n_total = TP + TN + FP + FN
    n_correct = TP + TN
    alpha = 0.32

    lower_bound = beta.ppf(alpha / 2, n_correct, n_total - n_correct + 1)
    upper_bound = beta.ppf(1 - alpha / 2, n_correct + 1, n_total - n_correct)
    
    title = f"Accuracy: {accuracy:.2f}%, 68% CI: [{lower_bound*100:.2f}%, {upper_bound*100:.2f}%]"
    plt.title(title, fontsize=15, weight='bold')
    plt.xlabel('Predicted', size=14, weight='bold')
    plt.ylabel('True', size=14, weight='bold')
    plt.xticks(size=14, weight='bold')
    plt.yticks(size=14, weight='bold')
    print(f"./Plots/confusion_matrix_{name}.pdf")
    # plt.savefig(f"./Plots/confusion_matrix_{name}.pdf", dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Confusion Matrix:\n{cm}")
    print(f"68% Confidence Interval: [{lower_bound:.4f}, {upper_bound:.4f}]")

    # ROC Curve
    y_true = np.array(true_labels)
    y_scores = np.array(all_scores)

    # np.savez(f"./Scores/scores_{name}.npz", array1=y_true, array2=y_scores)
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)
    jscore = tpr - fpr
    j_index = np.argmax(jscore)
    threshold = thresholds[j_index]
    print(f"Optimal Threshold: {threshold:.4f} (J-Index: {j_index})")

    N_x = sum(y_true == 0)
    N_y = sum(y_true == 1)
    sigma_fpr = np.sqrt(fpr * (1 - fpr) / N_x)
    sigma_tpr = np.sqrt(tpr * (1 - tpr) / N_y)

    plt.figure()
    plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f} )')
    plt.fill_between(fpr, tpr - 5 * sigma_tpr, tpr + 5 * sigma_tpr,
                     color='blue', alpha=0.25, label='1-sigma region (PiPR) [5x]')
    plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--', label='Random Guess')
    plt.xlabel('False Positive Rate', fontsize=16,weight='bold')
    plt.ylabel('True Positive Rate', fontsize=16,weight='bold')
    plt.xticks(fontsize=14,weight='bold')
    plt.yticks(fontsize=14,weight='bold')
    plt.title('Receiver Operating Characteristic (ROC)', fontsize=17,weight='bold')
    # plt.legend(loc='lower right', fontsize=14, weight='bold')
    plt.legend(title="",prop={'weight': 'bold', 'size': 14},title_fontproperties=bold_font, loc='lower right')
    plt.grid()
    # plt.savefig(f"./Plots/roc_curve_{name}.pdf")

    return (avg_loss, f"{accuracy:.2f}%") if return_accuracy else avg_loss



In [11]:
evaluate_model(model_test, test_loader, nn.CrossEntropyLoss(), device, return_accuracy=True)



Set: tensor([[[ 0.0482, -0.0117, -5.1993, -0.3789, -5.1993],
         [ 0.0482,  0.2434, -3.0787, -0.4127, -3.0787],
         [ 0.0482,  0.2434, -2.8659, -0.4212, -2.8659],
         ...,
         [-0.3469, -1.4480, -0.4355, -1.4555, -0.4965],
         [-0.3469, -1.3413, -0.3677, -0.5716, -0.4212],
         [-2.4433,  1.0558,  0.2568, -1.6163,  0.2126]]], device='cuda:0') shape:  torch.Size([1, 963, 5])
Set: tensor([[[ 1.0544e+01, -2.2799e-02,  1.5721e-01,  ..., -1.2333e-01,
          -1.0967e-02,  8.8862e+00],
         [ 6.6944e+00, -5.5024e-03,  3.2012e-01,  ..., -7.9900e-02,
          -1.5694e-03,  4.7348e+00],
         [ 6.2440e+00, -4.0294e-03,  3.2688e-01,  ..., -7.5254e-02,
          -1.2708e-03,  4.3276e+00],
         ...,
         [-2.7352e-02,  1.1192e+00,  3.2934e-01,  ..., -9.7325e-04,
          -5.0089e-02, -2.3040e-03],
         [-2.7007e-02, -6.9602e-03, -6.2435e-03,  ...,  1.7644e-01,
          -4.9380e-02,  4.0005e-01],
         [ 3.3422e+00,  6.9518e+00,  1.3409e+00,  

Testing:   0%|          | 5/20000 [00:00<45:45,  7.28it/s]  

tensor([[[-2.0587e-01,  6.0141e+00,  5.7892e+00,  ..., -4.2032e-01,
          -3.8221e-01,  1.5230e+01],
         [-9.3404e-02,  4.6684e+00,  3.6298e+00,  ..., -2.4168e-01,
          -2.2773e-01,  8.0573e+00],
         [-8.4747e-02,  4.4920e+00,  3.3693e+00,  ..., -2.2225e-01,
          -2.0964e-01,  7.0910e+00],
         ...,
         [-7.9693e-02, -1.1407e-01, -5.5249e-02,  ...,  3.5826e+00,
          -2.3394e-02, -8.1025e-02],
         [-8.2516e-02, -2.3027e-02, -8.5398e-02,  ..., -2.3226e-02,
          -1.1218e-02, -8.5985e-02],
         [ 1.8096e+01, -3.1746e-01, -2.5403e-01,  ...,  9.2410e+00,
          -1.4218e-01,  4.7674e+00]]], device='cuda:0') shape:  torch.Size([1, 963, 120])
Set: tensor([[[-6.2695e-01,  3.3734e+01, -1.6909e+00,  ...,  7.8009e+01,
           2.4442e+01, -4.2321e+00],
         [-3.3906e-01,  1.7553e+00, -7.0530e-01,  ...,  2.5963e+01,
           2.1390e+01, -2.2977e+00],
         [-3.0613e-01, -1.3074e-02, -6.2189e-01,  ...,  1.9722e+01,
           2.0568e+0



Set: tensor([[[ 0.1115, -0.7004, -5.1993, -0.8128, -5.1993],
         [ 0.1115, -0.7004, -2.9615, -1.2874, -3.0842],
         [ 0.1115, -0.7004, -2.6901, -1.2254, -2.6452],
         ...,
         [-0.2533, -0.4242, -2.6901, -0.5391, -2.7412],
         [-0.8182, -0.7004, -2.5688, -1.7651, -2.5050],
         [-1.2200, -0.1554, -2.3793,  0.2402, -2.2827]]], device='cuda:0') shape:  torch.Size([1, 981, 5])
Set: tensor([[[ 8.8102e+00, -2.3500e-02,  2.6178e-01,  ..., -1.1716e-01,
          -2.5819e-02,  8.8729e+00],
         [ 4.3292e+00, -2.0604e-05,  7.7281e-01,  ..., -7.0766e-02,
          -2.2142e-02,  4.4635e+00],
         [ 3.4452e+00,  1.2937e-01,  7.0197e-01,  ..., -6.2705e-02,
          -2.1397e-02,  3.6915e+00],
         ...,
         [ 4.4726e+00, -8.6142e-03, -4.1691e-04,  ..., -6.1011e-02,
          -2.6353e-02,  4.2951e+00],
         [ 3.7557e+00,  2.0341e+00,  1.0207e+00,  ..., -5.0655e-02,
          -4.5528e-02,  2.7353e+00],
         [ 4.7328e+00, -1.3494e-02, -1.0901e-02,  

Testing:   0%|          | 18/20000 [00:01<14:23, 23.13it/s]

tensor([[[ 1.0010e+01, -1.8726e-02,  6.9983e-01,  ..., -1.2888e-01,
           1.1084e-01,  8.6811e+00],
         [ 4.2475e+00, -2.0413e-02,  1.4770e-02,  ..., -7.4741e-02,
          -9.6415e-03,  5.2925e+00],
         [ 3.7409e+00,  1.5561e+00,  1.9031e+00,  ..., -7.4837e-02,
           2.8715e-02,  3.4611e+00],
         ...,
         [-2.8877e-02, -2.4453e-02, -7.8042e-04,  ..., -3.0858e-02,
           1.0074e+00,  1.5841e+00],
         [ 1.4787e+00,  6.3064e+00,  3.8760e+00,  ..., -5.3705e-02,
           7.9586e+00, -3.1225e-02],
         [-2.5199e-02, -3.0875e-02, -4.1375e-03,  ..., -3.1839e-02,
           1.7250e+00,  1.7344e+00]]], device='cuda:0') shape:  torch.Size([1, 1011, 80])
Set: tensor([[[-2.5901e-01,  2.7199e+00,  6.2907e+00,  ..., -3.9484e-01,
          -3.7924e-01,  1.6663e+01],
         [-2.2194e-01,  3.7220e+00, -7.4490e-03,  ..., -2.2421e-01,
          -1.6981e-01,  4.1466e-01],
         [-2.1267e-01, -1.7521e-01,  5.9350e+00,  ..., -9.1356e-03,
          -1.6951e-0

Testing:   0%|          | 26/20000 [00:01<11:47, 28.23it/s]

tensor([[[ 1.0251e+01, -1.9667e-02,  3.5822e-01,  ..., -1.2306e-01,
          -1.2260e-02,  8.7475e+00],
         [ 5.7187e+00,  1.0706e+00,  1.2011e+00,  ..., -7.7549e-02,
          -7.5866e-03,  3.9647e+00],
         [ 5.2326e+00,  4.5597e-01,  8.2293e-01,  ..., -7.1919e-02,
          -8.0946e-03,  3.8661e+00],
         ...,
         [ 2.2130e+00,  1.0532e+00,  8.0001e-01,  ..., -3.9311e-02,
           4.0336e+00, -7.6290e-04],
         [ 1.8550e+00,  3.4884e+00,  2.1205e+00,  ..., -4.1205e-02,
           4.9136e+00, -1.3304e-02],
         [ 4.2213e+00, -4.7255e-03,  2.4327e-02,  ..., -5.5800e-02,
          -1.5980e-02,  3.4105e+00]]], device='cuda:0') shape:  torch.Size([1, 820, 80])
Set: tensor([[[-2.2116e-01,  4.8005e+00,  5.6293e+00,  ..., -4.0772e-01,
          -3.7740e-01,  1.5396e+01],
         [-1.3316e-01, -5.1739e-02,  7.1011e+00,  ..., -1.2982e-01,
          -2.0819e-01,  7.5997e+00],
         [-1.2393e-01, -9.0059e-03,  4.9463e+00,  ..., -1.5070e-01,
          -1.8083e-01

Testing:   0%|          | 35/20000 [00:01<09:55, 33.51it/s]

tensor([[[-0.1342, -0.5282, -5.1993, -1.2873, -5.1993],
         [-0.1342, -0.3029, -3.0585, -0.8250, -3.0585],
         [-0.1342, -0.3029, -2.8444, -0.4267, -2.8444],
         ...,
         [ 0.3175,  1.8043,  0.8369,  0.1979,  0.9360],
         [-0.1342, -0.5282, -1.8410, -0.0404, -2.1846],
         [-0.1342, -0.5282, -2.4456, -0.6912, -2.4743]]], device='cuda:0') shape:  torch.Size([1, 900, 5])
Set: tensor([[[ 9.4339e+00, -8.6103e-03,  8.3608e-01,  ..., -1.1780e-01,
          -2.7016e-02,  8.2840e+00],
         [ 5.4225e+00, -2.9670e-03,  4.0961e-01,  ..., -7.2212e-02,
          -1.9560e-02,  4.5775e+00],
         [ 4.9338e+00, -1.0783e-02, -5.2104e-04,  ..., -6.6422e-02,
          -2.0246e-02,  4.5505e+00],
         ...,
         [ 1.9363e+00,  2.6131e+00,  6.0047e-01,  ..., -1.2820e-02,
           4.9902e+00, -3.0345e-02],
         [ 2.8265e+00, -1.9177e-02, -6.5645e-03,  ..., -4.4651e-02,
          -2.6763e-02,  3.8141e+00],
         [ 3.5955e+00, -5.8756e-03,  1.0961e-01,  ..., 

Testing:   0%|          | 43/20000 [00:01<09:17, 35.82it/s]

tensor([[[ 0.3348, -0.1241, -5.1993, -0.9310, -5.1993],
         [ 0.3348, -0.1241, -3.0524, -0.5091, -3.0524],
         [ 0.3348, -0.1241, -2.8380, -1.0669, -2.8380],
         ...,
         [ 0.5953,  0.9114,  0.3696, -2.2418,  0.1817],
         [ 0.0441, -0.3318,  0.2340,  0.1759,  0.0441],
         [ 0.5953, -0.5750, -1.3498, -1.7810, -1.2381]]], device='cuda:0') shape:  torch.Size([1, 882, 5])
Set: tensor([[[ 1.0141e+01, -1.3660e-02,  8.1823e-01,  ..., -1.2682e-01,
          -4.5933e-03,  8.4853e+00],
         [ 5.5525e+00, -1.0402e-02,  3.0178e-01,  ..., -7.8548e-02,
          -2.8124e-03,  4.8542e+00],
         [ 5.1466e+00,  3.6285e-01,  9.4078e-01,  ..., -7.5371e-02,
          -1.1801e-03,  3.9274e+00],
         ...,
         [ 1.2432e+00,  6.2229e+00,  2.9537e+00,  ..., -2.5575e-02,
           3.9009e+00, -3.3560e-02],
         [-1.4813e-02, -8.8044e-03, -7.5881e-03,  ..., -9.6138e-04,
          -1.4146e-02, -2.4643e-03],
         [ 5.5984e-01,  2.1280e+00,  1.5560e+00,  ..., 



Set: tensor([[[ 0.1575,  0.3942, -5.1993, -2.7340, -5.1993],
         [-0.1880,  0.3942, -3.0778, -0.6902, -3.0778],
         [ 0.1575,  0.0628, -2.8649, -0.8372, -2.8649],
         ...,
         [-0.5130,  0.3942, -1.7026, -1.9420, -1.6545],
         [ 0.1575,  0.3942, -1.7430, -1.9965, -1.7312],
         [ 0.1575,  0.0628, -1.7737, -1.0313, -1.7674]]], device='cuda:0') shape:  torch.Size([1, 960, 5])
Set: tensor([[[ 1.1708e+01,  3.6338e+00,  3.1496e+00,  ..., -1.3584e-01,
           8.1725e-01,  6.4687e+00],
         [ 7.2379e+00,  4.5667e-01,  6.5427e-01,  ..., -7.9716e-02,
          -3.5474e-03,  4.3415e+00],
         [ 5.7580e+00,  2.1057e-01,  7.3181e-01,  ..., -7.5495e-02,
          -1.8819e-03,  4.0561e+00],
         ...,
         [ 4.5387e+00,  4.3793e+00,  1.9716e+00,  ..., -4.8692e-02,
          -7.5326e-03,  4.2777e-01],
         [ 4.2946e+00,  4.0327e+00,  2.2513e+00,  ..., -5.7827e-02,
           1.0936e+00,  6.5520e-01],
         [ 3.4499e+00,  1.3137e+00,  9.3982e-01,  



Set: tensor([[[ 0.5279,  0.3749, -5.1993, -0.4208, -5.1993],
         [ 0.5279,  0.3749, -3.0644, -0.0314, -3.0644],
         [ 0.5279,  0.6197, -2.8507, -0.4327, -2.8507],
         ...,
         [-0.2023,  0.7530, -1.8890, -0.2471, -1.8423],
         [ 0.5279,  0.6197, -2.2236, -0.9755, -2.1924],
         [ 0.1537,  0.3749, -2.6679, -0.7066, -2.6227]]], device='cuda:0') shape:  torch.Size([1, 918, 5])
Set: tensor([[[ 1.1236e+01, -1.9119e-02,  5.9515e-01,  ..., -1.3351e-01,
           1.2302e+00,  8.7295e+00],
         [ 6.6757e+00, -1.5189e-02,  1.1633e-01,  ..., -8.5589e-02,
           1.4146e+00,  5.0902e+00],
         [ 6.8748e+00, -1.1047e-03,  7.2342e-01,  ..., -8.4889e-02,
           2.1680e+00,  4.1766e+00],
         ...,
         [ 5.5067e+00,  7.3227e-01,  3.4118e-01,  ..., -5.5945e-02,
           5.9036e-01,  2.2383e+00],
         [ 5.5361e+00,  1.5267e+00,  1.3400e+00,  ..., -7.2255e-02,
           2.3916e+00,  2.4299e+00],
         [ 6.0355e+00,  5.2745e-01,  7.6778e-01,  



tensor([[[-3.5500e-01,  3.1999e-01,  5.0737e+00,  ..., -3.6957e-01,
          -3.6112e-01,  1.4629e+01],
         [-1.8470e-01,  3.0774e-01,  3.3215e+00,  ..., -1.7203e-01,
          -1.7306e-01,  7.4864e+00],
         [-1.7534e-01,  1.8181e+00,  2.6427e+00,  ..., -1.7747e-01,
          -1.6270e-01,  5.7307e+00],
         ...,
         [-1.2020e-01, -5.2923e-01,  4.9082e+00,  ...,  2.4580e+01,
          -2.3314e-01,  3.4162e+00],
         [-1.2369e-01, -3.8683e-02, -8.5761e-03,  ..., -1.2739e-01,
          -2.3816e-01, -1.1592e-02],
         [-1.1977e-01, -1.3569e-01, -6.9877e-03,  ..., -5.1414e-02,
          -2.5027e-01, -7.8683e-03]]], device='cuda:0') shape:  torch.Size([1, 817, 120])
Set: tensor([[[-8.7432e-01,  6.3653e+01, -1.5486e+00,  ...,  6.4272e+01,
           3.4016e+01, -4.1419e+00],
         [-5.6048e-01,  2.7754e+01, -7.3602e-01,  ...,  2.1231e+01,
           2.0765e+01, -2.0431e+00],
         [-5.0383e-01,  2.1366e+01, -6.7279e-01,  ...,  1.2780e+01,
           2.5063e+0

Testing:   0%|          | 72/20000 [00:02<09:40, 34.32it/s]

Set: tensor([[[-0.1300, -0.1010, -5.1993, -1.1889, -5.1993],
         [-0.1300, -0.1010, -3.0486, -0.6090, -3.0486],
         [-0.1300, -0.1010, -2.8339, -0.5815, -2.8339],
         ...,
         [-1.0488, -0.3745, -0.8623, -0.0663, -1.2621],
         [ 0.1796,  0.2237, -1.4749, -1.5191, -1.7095],
         [ 0.5047,  0.2237, -0.8623, -0.2653, -1.2370]]], device='cuda:0') shape:  torch.Size([1, 871, 5])
Set: tensor([[[ 1.0495e+01, -4.4896e-03,  9.7715e-01,  ..., -1.2244e-01,
          -1.6221e-02,  8.1386e+00],
         [ 5.8848e+00, -4.8193e-03,  2.7869e-01,  ..., -7.3650e-02,
          -1.4819e-02,  4.6490e+00],
         [ 5.4272e+00, -4.1584e-03,  2.4393e-01,  ..., -6.8863e-02,
          -1.4605e-02,  4.2721e+00],
         ...,
         [ 1.7848e+00, -4.9681e-03, -8.5115e-03,  ..., -1.5047e-02,
          -4.6947e-02,  1.7920e+00],
         [ 3.6860e+00,  2.6753e+00,  1.5946e+00,  ..., -5.1697e-02,
           5.5621e-01,  1.1722e+00],
         [ 2.3372e+00, -1.1832e-03,  2.4409e-01,  



Set: tensor([[[-0.1516, -0.2086, -5.1993, -0.4663, -5.1993],
         [-0.1516,  0.0498, -3.0252, -0.1832, -3.0252],
         [-0.1516,  0.0498, -2.8090,  0.3117, -2.8090],
         ...,
         [-0.1516,  0.3003, -1.9521, -0.5155, -1.9216],
         [-0.4594,  0.1974, -2.0436, -0.0545, -2.0828],
         [-0.4594,  0.3003, -2.0436, -0.5190, -2.0312]]], device='cuda:0') shape:  torch.Size([1, 806, 5])
Set: tensor([[[ 1.0176e+01, -2.2467e-02,  7.4749e-02,  ..., -1.1899e-01,
          -2.1313e-02,  8.8737e+00],
         [ 6.1929e+00, -1.2042e-02, -1.2600e-03,  ..., -7.3549e-02,
          -1.2530e-02,  4.9181e+00],
         [ 5.6913e+00, -2.2049e-02, -6.9879e-03,  ..., -6.7448e-02,
          -1.3448e-02,  4.9778e+00],
         ...,
         [ 4.5227e+00,  5.9128e-01,  3.9353e-01,  ..., -5.3513e-02,
          -3.7993e-03,  2.3954e+00],
         [ 4.7271e+00, -5.2097e-03, -2.9964e-03,  ..., -5.0323e-02,
          -1.6394e-02,  3.1297e+00],
         [ 4.9334e+00,  7.2677e-01,  2.9871e-01,  



Set: tensor([[[-7.2824e-01,  3.8812e+01, -1.5783e+00,  ...,  7.9932e+01,
           1.8230e+01, -3.9442e+00],
         [-4.4257e-01,  7.1268e+00, -7.1734e-01,  ...,  5.2117e+00,
          -2.3375e-01, -1.0798e+00],
         [-2.9075e-01, -6.6142e-02, -9.6341e-01,  ...,  1.9431e+01,
           1.6011e+01, -2.2529e+00],
         ...,
         [-1.9626e-01, -4.2562e-01, -3.3128e-01,  ..., -2.2735e-01,
          -2.6312e-02, -8.3841e-01],
         [-1.7208e-01, -4.0421e-01, -2.6141e-01,  ..., -2.2473e-01,
           1.9486e-01, -7.7129e-01],
         [-3.6037e-01,  4.5446e+00, -5.9785e-01,  ...,  6.7291e+00,
           2.2305e+01, -1.8529e+00]]], device='cuda:0') shape:  torch.Size([1, 915, 70])
Set: tensor([[[-2.3832e+00,  4.2523e+01,  1.3582e+02,  ...,  1.3204e+02,
          -2.0088e+00, -1.3882e+00],
         [-5.6608e-01, -3.5143e-01,  3.9625e+00,  ...,  6.3966e-01,
          -5.7002e-01, -3.9158e-01],
         [-1.3532e+00, -2.4987e-01, -3.4157e-01,  ...,  2.8565e+01,
          -8.525



Set: tensor([[[ 0.0636, -0.1836, -5.1993, -0.7045, -5.1993],
         [ 0.0636, -0.1836, -3.0114, -0.8547, -3.0114],
         [ 0.0636, -0.1836, -2.7943, -0.9477, -2.7943],
         ...,
         [-1.3795, -0.4159, -0.0065,  1.1448,  0.3788],
         [-1.4693, -0.4159, -0.0065, -0.3440,  0.4427],
         [-1.4693, -0.4159,  0.1111,  1.0000,  0.4936]]], device='cuda:0') shape:  torch.Size([1, 770, 5])
Set: tensor([[[ 1.0128e+01, -1.8027e-02,  4.3206e-01,  ..., -1.2236e-01,
          -1.4185e-02,  8.6778e+00],
         [ 5.5026e+00, -1.4522e-03,  5.7361e-01,  ..., -7.4756e-02,
          -1.0963e-02,  4.4317e+00],
         [ 5.0505e+00,  1.9756e-01,  6.7751e-01,  ..., -7.0247e-02,
          -1.0455e-02,  3.9371e+00],
         ...,
         [-1.4644e-02, -2.0899e-02, -2.3697e-02,  ...,  2.0388e+00,
          -5.6433e-02, -2.2243e-03],
         [-1.3962e-02,  1.4102e+00, -6.8313e-03,  ...,  1.8073e+00,
          -5.5097e-02, -1.7626e-02],
         [-1.6423e-02, -1.6336e-02, -2.2339e-02,  



tensor([[[-5.3442e-01,  2.3140e+01, -1.3416e+00,  ...,  6.8821e+01,
           1.4374e+01, -3.5086e+00],
         [-2.0850e-01, -1.2045e-01, -4.1011e-01,  ...,  1.8460e+01,
           5.1719e+00, -1.6965e+00],
         [-2.2063e-01, -1.4761e-01, -3.4491e-01,  ...,  1.1693e+01,
          -8.0971e-02, -1.1533e+00],
         ...,
         [-1.5104e-02, -4.6337e-02, -9.5703e-01,  ..., -3.6356e-02,
           7.5291e+00, -1.3359e+00],
         [-4.3292e-02, -2.5607e-01,  9.8233e+00,  ..., -8.2601e-01,
          -8.8432e-01, -9.8066e-02],
         [ 1.1104e+00, -2.5767e-01,  6.5616e+00,  ..., -2.8998e-01,
          -1.4800e-01, -4.5399e-01]]], device='cuda:0') shape:  torch.Size([1, 910, 70])
Set: tensor([[[-1.9723e+00,  3.4532e+01,  1.0574e+02,  ...,  1.2041e+02,
          -1.7109e+00, -1.1775e+00],
         [-8.7870e-01, -3.7457e-01,  9.5036e+00,  ...,  4.5862e+01,
          -1.2074e+00, -6.9641e-01],
         [-5.6641e-01, -3.9614e-01, -1.2126e-02,  ...,  2.8744e+01,
          -7.9916e-01



Set: tensor([[[ 1.2583e-03,  2.4012e-01, -5.1993e+00, -1.1827e+00, -5.1993e+00],
         [ 1.2583e-03,  2.4012e-01, -2.4823e+00, -4.1971e-01, -3.0890e+00],
         [ 1.2583e-03,  2.4012e-01, -2.1819e+00,  1.5200e+00, -2.6507e+00],
         ...,
         [ 1.4063e+00, -4.7253e-01,  1.5612e+00,  8.5678e-01,  1.3700e+00],
         [ 1.4063e+00, -2.5959e-01,  1.5612e+00,  1.1980e+00,  1.3830e+00],
         [ 1.4063e+00, -1.1983e-01,  1.5321e+00,  1.3700e+00,  1.3962e+00]]],
       device='cuda:0') shape:  torch.Size([1, 997, 5])
Set: tensor([[[ 1.1275e+01, -4.1214e-04,  1.2182e+00,  ..., -1.2794e-01,
          -3.8777e-03,  7.9813e+00],
         [ 6.5455e+00, -5.5019e-03,  2.8513e-01,  ..., -7.3224e-02,
          -4.1388e-03,  4.7638e+00],
         [ 5.4895e+00, -4.7108e-02, -1.9486e-02,  ..., -5.9721e-02,
          -8.0041e-03,  5.7598e+00],
         ...,
         [-5.5254e-02, -2.7091e-02, -1.2075e-02,  ...,  1.6122e+00,
           1.9883e+00, -1.7441e-02],
         [-5.0420e-02, -3.16

Testing:   1%|          | 116/20000 [00:03<08:59, 36.88it/s]

tensor([[[-0.1987,  5.7532,  7.2421,  ..., -0.4298, -0.4146, 18.4608],
         [-0.0893, -0.1583, 10.2966,  ..., -0.0683, -0.2672, 12.1989],
         [-0.0791,  3.4745,  3.5020,  ..., -0.2170, -0.2175,  7.5651],
         ...,
         [ 4.1962, -0.0822,  5.9529,  ..., -0.1221, -0.2630,  8.2590],
         [ 4.2682, -0.0408,  5.1933,  ..., -0.1633, -0.2534,  7.6785],
         [-0.0676,  5.6175,  2.6109,  ..., -0.2100, -0.1917,  5.5523]]],
       device='cuda:0') shape:  torch.Size([1, 955, 120])
Set: tensor([[[-7.2594e-01,  3.9063e+01, -1.7712e+00,  ...,  9.1457e+01,
           1.9312e+01, -4.4269e+00],
         [-3.6489e-01, -8.7653e-03, -6.8795e-01,  ..., -1.1962e-01,
          -5.9361e-01, -6.7816e-01],
         [-3.0958e-01, -5.3720e-02, -5.8833e-01,  ...,  2.1408e+01,
           1.4512e+01, -2.0073e+00],
         ...,
         [-4.1889e-01,  1.6916e+01, -2.8194e-01,  ..., -2.5197e-01,
           1.1573e+01, -1.7818e+00],
         [-3.0179e-01,  2.6547e+01, -3.6104e-01,  ..., -2.176



tensor([[[-0.1305, -0.3129, -5.1993, -0.9744, -5.1993],
         [-0.1305,  0.1203, -3.0476, -0.5684, -3.0476],
         [-0.1305,  0.1203, -2.8328, -0.4659, -2.8328],
         ...,
         [-0.1305, -0.3129, -1.5949, -2.1407, -1.6053],
         [-0.4419,  0.4530, -1.8481, -1.0317, -1.8481],
         [-0.4419, -0.3129, -2.3338, -0.5214, -2.3123]]], device='cuda:0') shape:  torch.Size([1, 868, 5])
Set: tensor([[[ 9.9419e+00, -1.2537e-02,  6.0285e-01,  ..., -1.1938e-01,
          -2.2181e-02,  8.4593e+00],
         [ 6.4363e+00, -2.4590e-03,  3.6678e-01,  ..., -7.6155e-02,
          -9.2398e-03,  4.5645e+00],
         [ 5.9721e+00, -3.5103e-03,  2.4572e-01,  ..., -7.1162e-02,
          -9.2078e-03,  4.2580e+00],
         ...,
         [ 2.4209e+00,  3.5693e+00,  1.8936e+00,  ..., -4.3596e-02,
          -1.4679e-02,  6.2040e-01],
         [ 4.9582e+00,  2.2235e+00,  9.8349e-01,  ..., -5.1462e-02,
          -6.6062e-03,  1.6323e+00],
         [ 3.9740e+00, -3.5670e-03, -6.0079e-04,  ..., 

Testing:   1%|          | 128/20000 [00:04<08:57, 36.95it/s]

tensor([[[-2.3983e+00,  2.5760e+01,  1.1215e+02,  ...,  1.1512e+02,
          -2.0793e+00, -1.3133e+00],
         [-1.0113e+00, -2.9132e-01, -8.3690e-01,  ...,  1.5186e+01,
          -2.7645e-01, -5.6537e-02],
         [-8.9019e-01, -4.4829e-01, -4.2439e-01,  ..., -1.4895e-01,
          -5.6434e-01, -2.9185e-01],
         ...,
         [-4.9959e-01, -4.4959e-01, -4.8174e-01,  ..., -4.9729e-01,
          -7.1652e-01, -3.8327e-01],
         [-3.6690e-01, -8.0802e-01, -9.7981e-01,  ..., -2.8561e-01,
          -2.8808e-01, -4.9142e-03],
         [-3.3975e-01, -6.8011e-01, -7.5823e-01,  ..., -3.3568e-01,
          -2.7378e-01,  2.0967e+00]]], device='cuda:0') shape:  torch.Size([1, 713, 50])
Set: tensor([[[ 140.6132,  -13.0611,   -7.3422,  ..., -136.8756,   37.7941,
           140.4220],
         [   7.4616,    3.1918,   -2.6725,  ...,   -7.6983,   -3.1810,
            11.2570],
         [  -8.4005,   -1.1930,    2.0703,  ...,   -3.3435,   -1.5327,
             9.9970],
         ...,
      

Testing:   1%|          | 132/20000 [00:04<12:27, 26.59it/s]

tensor([[[-0.7290, 39.8104, -1.5816,  ..., 69.6512, 28.8808, -4.2454],
         [-0.4823,  8.1092, -0.6400,  ..., 15.2980, 12.8241, -1.8222],
         [-0.4348,  4.4111, -0.5673,  ...,  7.9611, 16.0038, -1.7432],
         ...,
         [-0.2997, -0.0841, -0.6966,  ...,  4.6311, 21.5418, -1.9394],
         [-0.3406,  2.2677, -0.3757,  ..., -0.0760,  0.8489, -1.1393],
         [-0.2768, -0.1369, -0.3770,  ..., -0.4152, -0.4932, -0.0814]]],
       device='cuda:0') shape:  torch.Size([1, 870, 70])
Set: tensor([[[-2.6445e+00,  3.4867e+01,  1.1320e+02,  ...,  1.2379e+02,
          -1.9731e+00, -1.3040e+00],
         [-1.2054e+00, -4.1454e-01,  3.1277e+00,  ...,  2.9371e+01,
          -1.4262e+00, -6.8700e-01],
         [-1.1721e+00, -5.0285e-01, -6.1491e-02,  ...,  2.4867e+01,
          -1.4055e+00, -6.6627e-01],
         ...,
         [-1.3738e+00, -4.8131e-01, -5.2998e-01,  ...,  5.5480e+00,
          -1.0336e+00, -3.6048e-01],
         [-6.0197e-01, -3.6576e-01, -1.5155e-01,  ...,  1.0107



tensor([[[-7.7408e-01,  4.2812e+01, -1.6094e+00,  ...,  7.7037e+01,
           2.1971e+01, -4.1249e+00],
         [-4.8565e-01,  1.7280e+01, -7.7868e-01,  ...,  2.6021e+01,
           2.4852e+01, -2.3413e+00],
         [-3.8366e-01,  5.3397e+00, -9.2991e-01,  ...,  2.4959e+01,
           2.4161e+01, -2.5463e+00],
         ...,
         [-2.0782e-01, -8.7708e-02, -3.6791e-01,  ..., -4.9217e-02,
           1.9605e+01, -1.4767e+00],
         [-1.6070e-01, -1.7855e-01, -5.1578e-01,  ..., -8.1462e-02,
           1.7588e+01, -1.4776e+00],
         [-4.3905e-01, -9.7723e-02, -7.8922e-02,  ..., -1.1056e+00,
          -1.6229e+00, -1.4388e-01]]], device='cuda:0') shape:  torch.Size([1, 731, 70])
Set: tensor([[[ -2.5261,  37.0674, 125.0045,  ..., 126.3762,  -2.0625,  -1.3606],
         [ -1.5686,  -0.4515,   2.4631,  ...,  23.4934,  -1.6779,  -0.7324],
         [ -1.5946,  -0.2951,  -0.2650,  ...,  23.6240,  -1.3027,  -0.5938],
         ...,
         [ -0.8869,  -0.6067,  -0.1876,  ...,  30.8799

Testing:   1%|          | 136/20000 [00:04<18:16, 18.12it/s]

tensor([[[ 1.3398e+02,  7.9151e+00,  2.1844e+00,  ..., -1.4530e+02,
           3.0623e+01,  1.3688e+02],
         [ 1.1332e+02, -9.3737e+01, -2.2573e+01,  ..., -4.9517e+01,
           3.8249e+01,  1.2938e+02],
         [ 1.0039e+02, -7.5617e+01, -1.7072e+01,  ..., -3.0156e+01,
           2.7093e+01,  1.1216e+02],
         ...,
         [-4.7367e+00, -2.3490e+00, -3.4641e+00,  ..., -1.1877e+00,
          -5.4104e+00,  8.8848e+00],
         [-1.3917e+01, -8.5134e+00,  7.0424e+00,  ...,  8.3603e-02,
          -7.8485e+00, -1.0860e+01],
         [-1.2426e+01, -1.0293e+00, -1.3163e-01,  ..., -1.3673e+00,
          -5.9398e+00,  1.8639e+00]]], device='cuda:0') shape:  torch.Size([1, 921, 8])
Set: tensor([[[ 0.0365, -0.1696, -5.1993, -0.4464, -5.1993],
         [ 0.0365,  0.1944, -3.0323, -0.4905, -3.0323],
         [ 0.0365,  0.1944, -2.8165, -1.2437, -2.8165],
         ...,
         [ 1.3235, -2.1061,  1.3162,  1.1157,  1.2706],
         [ 1.3843, -2.1061,  1.3162,  1.2774,  1.2913],
      



tensor([[[-2.3803e+00,  2.9056e+01,  1.1433e+02,  ...,  1.1807e+02,
          -1.9322e+00, -1.3279e+00],
         [-5.7881e-01, -3.1807e-01, -1.2570e-02,  ..., -1.9818e-02,
          -5.1771e-01, -3.3952e-01],
         [-8.5968e-01, -4.4603e-01, -4.6461e-02,  ...,  3.3508e+01,
          -1.0853e+00, -6.4453e-01],
         ...,
         [-3.7885e-01, -4.6713e-01, -5.5156e-01,  ..., -5.9224e-01,
          -3.9702e-01, -4.3431e-02],
         [-3.1990e-01, -4.2048e-01, -4.6001e-01,  ..., -6.9740e-01,
          -4.4271e-01, -7.6942e-02],
         [-3.2913e-01, -4.5666e-01, -4.7058e-01,  ..., -7.3033e-01,
          -4.5300e-01, -1.3332e-01]]], device='cuda:0') shape:  torch.Size([1, 824, 50])
Set: tensor([[[ 1.3243e+02,  1.6587e-01, -3.1302e+00,  ..., -1.3562e+02,
           3.7423e+01,  1.3538e+02],
         [-9.6440e-01,  1.0306e+01, -1.2392e+00,  ...,  2.5964e+00,
           6.8804e+00,  1.0857e+01],
         [ 1.1472e+02, -5.2815e+01, -2.9261e+01,  ..., -3.2581e+01,
           2.9302e+01



tensor([[[-1.9966e+00,  3.0861e+01,  1.2094e+02,  ...,  1.1250e+02,
          -1.9095e+00, -1.3036e+00],
         [-1.2631e+00, -4.4497e-01, -1.1407e-01,  ...,  2.5424e+01,
          -1.3301e+00, -6.0042e-01],
         [-1.1663e+00, -5.2138e-01, -1.7531e-01,  ...,  2.1241e+01,
          -1.2536e+00, -5.9679e-01],
         ...,
         [-9.4336e-01, -4.0278e-01, -3.3403e-01,  ...,  1.3828e+01,
          -7.9325e-01, -2.8561e-01],
         [ 6.2992e+01,  3.0246e+01,  7.2945e+01,  ..., -1.4470e+00,
          -5.6047e-01, -1.2830e+00],
         [-9.8891e-01, -5.9246e-01, -8.3249e-02,  ...,  3.0922e+01,
          -1.3135e+00, -7.6321e-01]]], device='cuda:0') shape:  torch.Size([1, 895, 50])
Set: tensor([[[ 190.8667,   -2.8899,  -26.4989,  ..., -128.0964,   40.8037,
           196.6759],
         [ 113.2562,  -74.5487,  -34.0338,  ...,  -40.0781,   39.8794,
           123.2198],
         [ 121.0154,  -69.2052,  -32.5932,  ...,  -32.6269,   36.6484,
           127.6864],
         ...,
      



tensor([[[-2.8370e-01,  2.8954e+00,  3.3415e+00,  ..., -3.8132e-01,
          -3.4277e-01,  1.0481e+01],
         [-1.4660e-01,  1.4851e+00,  1.7133e+00,  ..., -1.8402e-01,
          -1.7137e-01, -2.7677e-03],
         [-1.8847e-01, -3.7880e-02,  3.5187e+00,  ..., -1.1954e-01,
          -1.6563e-01,  4.9584e-02],
         ...,
         [-1.0768e-01, -1.0654e-01,  4.9359e+00,  ..., -2.1600e-02,
          -1.3584e-01, -2.4562e-02],
         [-1.1116e-01,  3.0268e+00, -5.4158e-03,  ..., -1.3005e-01,
          -1.0743e-01, -4.9733e-02],
         [-1.4225e-01, -9.8930e-02,  4.1080e+00,  ..., -1.7126e-02,
          -1.2236e-01, -2.6740e-02]]], device='cuda:0') shape:  torch.Size([1, 775, 120])
Set: tensor([[[-6.0627e-01,  3.5406e+01, -1.3783e+00,  ...,  5.5039e+01,
           3.0413e+01, -3.7957e+00],
         [-2.3644e-01, -4.5255e-02, -4.0414e-01,  ..., -1.5543e-02,
           1.4466e+01, -1.4032e+00],
         [-3.0086e-01,  1.9229e+00, -4.1780e-01,  ..., -6.2306e-02,
          -7.7035e-0



tensor([[[-2.4188e-01, -1.1069e-02,  1.2106e+01,  ..., -3.8541e-01,
          -4.7570e-01,  2.8879e+01],
         [-1.4788e-01, -2.4351e-01,  1.4271e+01,  ..., -1.7991e-02,
          -3.3219e-01,  1.9260e+01],
         [-1.0119e-01,  6.1991e+00,  5.4259e+00,  ..., -3.1102e-01,
          -2.7535e-01,  9.6983e+00],
         ...,
         [-1.0136e-01,  6.1324e+00,  5.2649e+00,  ..., -3.0701e-01,
          -2.5595e-01,  1.0941e+01],
         [-9.0958e-02,  6.0150e+00,  5.7018e+00,  ..., -3.1499e-01,
          -2.6070e-01,  9.3497e+00],
         [-8.4708e-02,  5.4887e+00,  5.6988e+00,  ..., -3.0011e-01,
          -2.6957e-01,  1.0092e+01]]], device='cuda:0') shape:  torch.Size([1, 908, 120])
Set: tensor([[[ -1.1411,  94.5566,  -1.9188,  ..., 118.6973,  25.8465,  -4.9324],
         [ -0.6381,  48.9092,  -0.9809,  ...,  -0.1285,  -0.7694,  -1.2599],
         [ -0.4847,  45.6821,  -1.1845,  ...,  39.5890,  28.9658,  -3.2670],
         ...,
         [ -0.4952,  66.0502,  -1.2512,  ...,  36.640



tensor([[[ 9.8741e+00, -2.0844e-02,  4.5434e-01,  ..., -1.2462e-01,
          -8.4864e-03,  8.7764e+00],
         [ 5.4846e+00, -1.1762e-02,  6.7344e-02,  ..., -7.3163e-02,
          -1.3890e-02,  4.9707e+00],
         [ 5.0581e+00, -3.4763e-03,  4.1733e-01,  ..., -6.9256e-02,
          -1.2884e-02,  4.2835e+00],
         ...,
         [-1.2954e-03,  5.6717e+00,  9.8666e-01,  ...,  2.3863e+00,
          -3.8632e-02, -4.2469e-02],
         [ 3.6634e+00, -1.5500e-03,  3.0654e-01,  ..., -5.4063e-02,
           4.8349e-01,  2.5306e+00],
         [ 3.0270e+00, -5.6792e-03,  4.5468e-03,  ..., -4.8144e-02,
          -1.2335e-02,  2.8174e+00]]], device='cuda:0') shape:  torch.Size([1, 971, 80])
Set: tensor([[[-2.5224e-01,  4.0984e+00,  5.3548e+00,  ..., -4.0449e-01,
          -3.7598e-01,  1.5817e+01],
         [-1.3426e-01,  5.4710e+00,  1.6130e+00,  ..., -2.3104e-01,
          -1.8931e-01,  4.0660e+00],
         [-1.2516e-01,  3.2193e+00,  2.3848e+00,  ..., -1.8543e-01,
          -1.6995e-01



tensor([[[ 0.3704,  0.2096, -5.1993, -1.6741, -5.1993],
         [ 0.1629, -0.0426, -2.5603, -0.0935, -3.0098],
         [ 0.1629, -0.0426, -2.3095, -0.4488, -2.5603],
         ...,
         [-2.7196, -0.6269,  1.1511,  1.5321,  1.0129],
         [ 0.3704, -0.0426, -2.5603, -2.4161, -2.7926],
         [-0.1265, -0.0426, -2.5603, -0.7088, -2.6587]]], device='cuda:0') shape:  torch.Size([1, 766, 5])
Set: tensor([[[ 1.1021e+01,  8.0224e-01,  1.8837e+00,  ..., -1.3311e-01,
           6.6946e-01,  7.6064e+00],
         [ 5.5898e+00, -1.7617e-02, -2.0536e-03,  ..., -7.0704e-02,
          -7.3395e-03,  5.1009e+00],
         [ 4.7296e+00, -6.6913e-03,  2.0519e-01,  ..., -6.3881e-02,
          -5.5090e-03,  3.9158e+00],
         ...,
         [-2.7366e-02, -2.0905e-02, -3.4099e-02,  ...,  5.8709e+00,
          -9.9979e-02, -1.1801e-02],
         [ 5.2760e+00,  3.5562e+00,  2.5414e+00,  ..., -7.6941e-02,
           4.7191e-01,  2.5423e+00],
         [ 5.1810e+00,  5.6900e-02,  4.1835e-01,  ..., 

Testing:   1%|          | 149/20000 [00:06<34:50,  9.49it/s]

tensor([[[ 189.1541,   -6.4793,  -17.9910,  ..., -149.9098,   45.9284,
           187.6411],
         [   3.0106,   -1.5515,   -6.5509,  ...,    2.5746,    7.4496,
            14.6215],
         [  -1.3062,   -2.3478,   -2.8019,  ...,    1.0615,    3.2940,
            12.5202],
         ...,
         [ 103.9292,   40.2705,    1.1321,  ...,   51.2802,    2.6409,
            21.0508],
         [ -11.3780,   -1.3953,   -0.4319,  ...,    0.3194,   -3.4061,
             0.6624],
         [  84.1409,  -38.0128,  -10.4156,  ...,  -23.3151,   12.9861,
            77.6857]]], device='cuda:0') shape:  torch.Size([1, 807, 8])
Set: tensor([[[ 0.0457, -0.1124, -5.1993,  2.2185, -2.8465],
         [-0.1096, -0.0138, -5.1993,  1.9346, -5.1993],
         [ 0.0457, -0.0138, -5.1993,  2.4767, -2.7149],
         ...,
         [-0.1096, -0.6529, -0.4757,  0.7035, -0.6581],
         [-0.1096, -0.5694, -0.4757, -1.4081, -0.6823],
         [ 0.2777, -0.3756, -0.8941,  0.0874, -1.0838]]], device='cuda:0') sha



tensor([[[-6.7359e-01,  3.6496e+01, -1.6883e+00,  ...,  6.8544e+01,
           2.8864e+01, -4.2533e+00],
         [-2.9336e-01, -6.3042e-02, -8.2539e-01,  ...,  4.5537e+00,
           2.2018e+01, -1.9860e+00],
         [-3.2661e-01,  5.3572e+00, -1.0573e+00,  ...,  1.8893e+01,
           2.2269e+01, -2.4196e+00],
         ...,
         [-3.7138e-02, -4.6502e-01, -7.8388e-01,  ..., -1.4661e-01,
          -2.5169e-01, -1.0065e+00],
         [-1.1595e-01, -4.9707e-01, -9.1546e-01,  ..., -1.1013e-01,
          -3.0299e-01, -9.3603e-01],
         [-1.1909e-01, -5.0421e-01, -9.6438e-01,  ..., -7.5901e-02,
          -3.3205e-01, -9.2329e-01]]], device='cuda:0') shape:  torch.Size([1, 897, 70])
Set: tensor([[[ -2.5789,  38.4481,  98.9546,  ..., 118.3990,  -1.7009,  -1.2808],
         [ -1.2164,  -0.4081,  -0.6367,  ...,   5.7565,  -0.6997,  -0.3542],
         [ -1.2668,  -0.2445,  -0.7190,  ...,  15.2062,  -0.5267,  -0.4648],
         ...,
         [ -0.2822,  -0.4776,  -0.4177,  ...,  -0.3957

Testing:   1%|          | 153/20000 [00:06<35:13,  9.39it/s]

tensor([[[-1.8882e+00,  4.0178e+01,  1.3326e+02,  ...,  1.2896e+02,
          -1.9125e+00, -1.3215e+00],
         [-9.9394e-01, -4.6853e-01,  1.1375e+01,  ...,  4.7334e+01,
          -1.4366e+00, -8.4822e-01],
         [-9.1031e-01, -5.3470e-01,  3.1211e+00,  ...,  4.1205e+01,
          -1.3322e+00, -8.0583e-01],
         ...,
         [-7.1092e-01, -3.2763e-01, -5.9331e-01,  ..., -3.0901e-01,
          -6.3688e-01, -6.2358e-02],
         [-7.0387e-01, -3.3664e-01, -5.8428e-01,  ..., -3.3861e-01,
          -6.7442e-01, -6.2191e-02],
         [-5.5416e-01, -4.1841e-01, -1.4172e-01,  ...,  9.6780e+00,
          -6.5593e-01, -2.4412e-01]]], device='cuda:0') shape:  torch.Size([1, 853, 50])
Set: tensor([[[ 2.0373e+02,  3.4897e+00, -3.5122e+01,  ..., -1.4228e+02,
           4.8964e+01,  2.1359e+02],
         [ 1.8529e+02, -8.4542e+01, -4.5427e+01,  ..., -4.8884e+01,
           5.4665e+01,  1.9745e+02],
         [ 1.7792e+02, -7.4490e+01, -4.3618e+01,  ..., -4.0085e+01,
           5.0483e+01

Testing:   1%|          | 155/20000 [00:07<36:06,  9.16it/s]

tensor([[[ -2.4845,  39.8389, 129.0441,  ..., 134.3156,  -2.0203,  -1.3903],
         [ -1.4146,  -0.3081,  -0.1451,  ...,  37.9287,  -1.1888,  -0.5727],
         [ -1.0252,  -0.4279,   6.5722,  ...,  41.2519,  -1.4064,  -0.7459],
         ...,
         [ -0.7190,  -0.3647,  -0.5906,  ...,   3.3802,  -0.4524,  -0.1489],
         [ -0.7891,  -0.7778,  -0.9207,  ...,  -0.4690,  -0.6918,   8.0359],
         [ -0.8162,  -0.7840,  -0.9290,  ...,  -0.3367,  -0.5720,  18.4005]]],
       device='cuda:0') shape:  torch.Size([1, 951, 50])
Set: tensor([[[ 1.4362e+02,  8.2665e+00, -3.5782e+00,  ..., -1.5032e+02,
           3.7289e+01,  1.3972e+02],
         [ 7.3415e+01, -5.4974e+01, -2.6405e+01,  ..., -3.9192e+01,
           3.0709e+01,  8.5013e+01],
         [ 1.5846e+02, -7.1536e+01, -3.8458e+01,  ..., -4.6785e+01,
           3.9705e+01,  1.6602e+02],
         ...,
         [-8.9588e+00, -4.9243e-02, -6.7418e-01,  ..., -7.9959e-01,
          -2.4829e+00,  9.7554e-01],
         [-1.9793e+01,  3.

Testing:   1%|          | 157/20000 [00:07<36:57,  8.95it/s]

tensor([[[-7.8102e-01,  4.2210e+01, -1.6088e+00,  ...,  8.3144e+01,
           1.6954e+01, -3.9425e+00],
         [-4.4938e-01,  1.1321e+01, -7.3487e-01,  ...,  1.7129e+01,
           2.6071e+01, -2.1676e+00],
         [-4.1366e-01,  8.0246e+00, -6.6154e-01,  ...,  1.2699e+01,
           1.7133e+01, -1.7881e+00],
         ...,
         [-2.8075e-01, -2.8541e-01, -5.6245e-01,  ..., -9.1896e-01,
          -1.0033e+00, -6.5938e-01],
         [ 1.2094e+01, -2.4450e-02, -9.0559e-01,  ..., -1.2704e-01,
           8.0843e+00, -1.1262e+00],
         [-3.5761e-01,  5.4984e+00, -6.4558e-01,  ...,  1.2197e+01,
           2.7106e+01, -2.0604e+00]]], device='cuda:0') shape:  torch.Size([1, 776, 70])
Set: tensor([[[-2.4059e+00,  3.5436e+01,  1.2906e+02,  ...,  1.2704e+02,
          -2.0307e+00, -1.4365e+00],
         [-1.5048e+00, -5.1841e-01, -8.2649e-02,  ...,  1.7483e+01,
          -1.5401e+00, -6.6093e-01],
         [-1.2091e+00, -4.8924e-01, -9.1646e-02,  ...,  1.7243e+01,
          -1.2593e+00

Testing:   1%|          | 159/20000 [00:07<37:36,  8.79it/s]

tensor([[[-1.8493e-01,  4.3949e+00,  5.7568e+00,  ..., -3.9350e-01,
          -3.7588e-01,  1.4874e+01],
         [-9.2286e-02,  2.0845e+00,  3.7227e+00,  ..., -1.9528e-01,
          -1.9532e-01,  5.7453e+00],
         [-6.3145e-02,  2.0672e+00,  2.9213e+00,  ..., -1.9220e-01,
          -1.9864e-01,  6.1435e+00],
         ...,
         [-1.5915e-01, -3.4570e-02, -2.3676e-01,  ..., -1.6189e-02,
          -1.1409e-01, -3.9793e-03],
         [-1.7735e-01, -2.5017e-02, -2.1930e-01,  ..., -2.3257e-02,
          -8.6417e-02,  9.2872e-01],
         [ 1.8981e+00,  5.3777e+00, -2.2098e-02,  ..., -1.8670e-01,
          -1.3359e-01, -3.4487e-02]]], device='cuda:0') shape:  torch.Size([1, 863, 120])
Set: tensor([[[-5.7139e-01,  2.4427e+01, -1.4943e+00,  ...,  7.1310e+01,
           1.6137e+01, -3.7599e+00],
         [-2.9838e-01, -4.0987e-02, -5.7686e-01,  ...,  8.4348e+00,
           1.2472e+01, -1.6568e+00],
         [-2.3266e-01, -1.1105e-01, -4.7503e-01,  ...,  1.0895e+01,
           8.1612e+0



tensor([[[ 1.0464e+01, -6.2537e-03,  1.2466e+00,  ..., -1.3064e-01,
           3.0612e-01,  8.1837e+00],
         [ 5.9020e+00,  3.5857e-01,  1.0619e+00,  ..., -8.3171e-02,
           5.5400e-01,  4.2828e+00],
         [ 5.3984e+00, -7.9227e-03,  4.1406e-01,  ..., -7.6930e-02,
           4.4620e-01,  4.4073e+00],
         ...,
         [-2.6692e-03, -2.5591e-02, -1.3745e-02,  ..., -7.5457e-03,
           6.5899e+00, -1.3805e-02],
         [ 4.4157e+00, -7.6172e-03,  1.3709e-01,  ..., -6.2011e-02,
           1.2002e+00,  3.1197e+00],
         [ 3.8385e+00,  3.2036e-02,  5.9357e-01,  ..., -6.1179e-02,
           5.8585e-01,  2.8552e+00]]], device='cuda:0') shape:  torch.Size([1, 884, 80])
Set: tensor([[[-0.2454,  0.0225,  8.4471,  ..., -0.3596, -0.3870, 20.1139],
         [-0.1594, -0.0215,  6.4335,  ..., -0.1656, -0.2153,  8.6823],
         [-0.1463,  3.6983,  3.4405,  ..., -0.2110, -0.1910,  5.9819],
         ...,
         [-0.0779, -0.0336,  4.1129,  ..., -0.1655, -0.0911,  1.9180],
 

Testing:   1%|          | 163/20000 [00:07<37:53,  8.72it/s]

tensor([[[-1.9975e-01,  6.2165e+00,  5.1408e+00,  ..., -4.2043e-01,
          -3.7375e-01,  1.3693e+01],
         [-6.2196e-02,  1.3172e+00,  2.5311e+00,  ..., -1.9916e-01,
          -2.0371e-01,  6.9733e+00],
         [-8.2539e-02, -2.4702e-03,  3.5430e+00,  ..., -1.5497e-01,
          -1.7682e-01,  7.1960e+00],
         ...,
         [-8.5116e-02, -6.9668e-03, -3.0844e-01,  ...,  1.9811e-01,
          -1.6900e-01,  9.4831e-01],
         [-8.7741e-02,  1.3045e-01, -3.0962e-01,  ..., -1.0203e-02,
          -1.5215e-01,  2.2148e+00],
         [-8.4567e-02, -1.3253e-02, -3.1897e-01,  ...,  8.6538e-01,
          -1.9137e-01, -3.2829e-03]]], device='cuda:0') shape:  torch.Size([1, 832, 120])
Set: tensor([[[-5.8057e-01,  2.5438e+01, -1.6506e+00,  ...,  7.4323e+01,
           2.2668e+01, -4.1223e+00],
         [-2.6125e-01, -1.1738e-01, -4.6005e-01,  ...,  2.0405e+01,
           7.5695e+00, -1.7453e+00],
         [-3.4490e-01, -3.8154e-02, -5.1683e-01,  ...,  1.5905e+01,
           4.8991e+0



Set: tensor([[[ 0.2766,  0.5465, -5.1993, -0.7468, -5.1993],
         [ 0.0805,  0.5465, -3.0330, -0.6995, -3.0330],
         [ 0.0805,  0.5465, -2.8173, -0.4688, -2.8173],
         ...,
         [-2.3623, -2.8173,  1.3661, -0.9466,  1.6367],
         [-2.4752, -3.0330,  1.4222, -0.0243,  1.6849],
         [-2.5869, -5.1993,  1.4970, -0.4352,  1.7374]]], device='cuda:0') shape:  torch.Size([1, 827, 5])
Set: tensor([[[ 1.1847e+01, -7.5356e-03,  9.9259e-01,  ..., -1.3354e-01,
           1.0559e+00,  8.2777e+00],
         [ 7.3669e+00,  5.6339e-01,  8.4350e-01,  ..., -8.3622e-02,
           7.8607e-01,  4.2190e+00],
         [ 6.8895e+00,  1.6592e-01,  5.7483e-01,  ..., -7.8256e-02,
           7.5831e-01,  4.0313e+00],
         ...,
         [-9.4478e-02,  4.9775e-01, -1.7557e-02,  ...,  8.3040e+00,
          -1.3865e-01, -3.4098e-02],
         [-1.0109e-01, -1.8299e-02, -2.9856e-02,  ...,  9.0546e+00,
          -1.4949e-01, -2.5358e-02],
         [-1.5585e-01, -4.0051e-02, -3.8616e-02,  



tensor([[[ 2.0376e+02,  9.1950e+00, -2.8746e+01,  ..., -1.5718e+02,
           4.5062e+01,  2.1467e+02],
         [ 1.7896e+02, -7.1507e+01, -4.9286e+01,  ..., -5.8151e+01,
           4.6089e+01,  1.9524e+02],
         [ 1.6896e+02, -6.9662e+01, -4.5801e+01,  ..., -4.7515e+01,
           4.3300e+01,  1.8366e+02],
         ...,
         [-1.3183e+01, -1.7341e+00, -2.4463e-01,  ..., -3.3513e-01,
          -5.4267e+00,  1.4012e+00],
         [-1.3099e+01, -1.8238e+00, -1.5823e-01,  ..., -3.0670e-01,
          -5.5440e+00,  1.5037e+00],
         [-7.3888e+00, -2.7669e+00, -3.4618e+00,  ...,  2.5856e+00,
          -3.2364e+00,  1.1122e+01]]], device='cuda:0') shape:  torch.Size([1, 799, 8])
Set: tensor([[[ 0.1068, -0.5277, -5.1993, -0.6157, -5.1993],
         [ 0.1068, -0.2392, -3.0188, -1.0160, -3.0188],
         [ 0.1068, -0.2392, -2.8022, -0.4226, -2.8022],
         ...,
         [ 2.0347,  3.0188,  0.9923,  0.0573,  1.2801],
         [ 2.0347,  3.0188,  1.0133, -0.7027,  1.3247],
      

Testing:   1%|          | 168/20000 [00:08<37:36,  8.79it/s]

tensor([[[ 105.7546,    2.9279,    8.2044,  ..., -132.1786,   33.5471,
           100.2271],
         [ 122.8342,  -64.1906,  -28.2998,  ...,  -32.7682,   32.2398,
           126.8329],
         [  18.7339,   -0.1391,   -9.5709,  ...,   -1.5050,    7.7555,
            21.4436],
         ...,
         [ -10.4720,   -1.8342,   -0.9153,  ...,   -1.7870,   -4.3990,
             0.9884],
         [ -10.3264,   -2.0900,   -0.4187,  ...,   -1.1895,   -3.3685,
             2.0568],
         [  -9.3127,   -1.8840,   -0.7113,  ...,   -1.3304,   -4.0129,
             2.2089]]], device='cuda:0') shape:  torch.Size([1, 894, 8])
Set: tensor([[[-0.0527, -0.2059, -5.1993,  0.2445, -5.1993],
         [-0.0527, -0.2059, -2.8778,  0.4790, -2.8792],
         [-0.0527, -0.0464, -2.8778,  1.0144, -2.8729],
         ...,
         [ 0.2034, -0.2059, -2.2900, -0.2576, -2.2566],
         [ 0.4211, -0.0464, -2.1831, -0.0786, -2.1748],
         [-0.0527, -0.0464, -2.2900, -0.0919, -2.3100]]], device='cuda:0') sha



tensor([[[  80.5811,   10.0657,   13.2355,  ..., -125.3835,   30.2052,
            78.8511],
         [ 165.8758,  -90.2341,  -45.2182,  ...,  -51.1702,   50.2784,
           182.5580],
         [   6.7963,   -3.4502,    3.2381,  ...,    3.6888,   -9.4917,
            13.9954],
         ...,
         [ -13.9828,   -1.7866,   -0.2105,  ...,   -2.7014,   -4.8201,
             2.0474],
         [ -11.3684,   -0.3822,    5.1179,  ...,   -5.7424,    5.2774,
            19.6687],
         [ -11.2432,   -1.2066,   -1.4537,  ...,   -3.0089,   -2.6091,
             0.1949]]], device='cuda:0') shape:  torch.Size([1, 903, 8])
Set: tensor([[[ 0.5818, -0.3459, -5.1993, -1.6847, -5.1993],
         [ 0.5818,  0.0052, -2.9900, -0.4812, -2.9900],
         [ 0.5818, -0.3459, -2.7716, -0.5450, -2.7716],
         ...,
         [ 0.2689,  0.0052, -2.2398, -2.0081, -2.1988],
         [-0.0175,  0.3182, -2.1264, -2.2850, -2.1264],
         [-0.2328,  0.8337, -2.1264, -0.8842, -2.0636]]], device='cuda:0') sha

Testing:   1%|          | 172/20000 [00:08<37:05,  8.91it/s]

tensor([[[-2.3275e+00,  2.6486e+01,  1.1347e+02,  ...,  1.1816e+02,
          -1.9306e+00, -1.3053e+00],
         [-7.4071e-01, -4.0762e-01, -2.6906e-02,  ...,  2.6094e+01,
          -9.2291e-01, -5.4657e-01],
         [-1.1272e+00, -5.5806e-01, -1.1069e-01,  ...,  2.1685e+01,
          -1.3004e+00, -6.1600e-01],
         ...,
         [-7.2656e-01, -1.6313e-01, -3.5330e-01,  ..., -8.6700e-01,
          -1.0180e+00, -4.6442e-01],
         [-9.1843e-01, -6.4331e-01, -8.9932e-01,  ..., -1.2064e+00,
          -1.1837e+00, -5.3763e-01],
         [-6.3592e-01, -9.1616e-02, -2.7776e-01,  ..., -4.6605e-01,
          -9.4612e-01, -5.3569e-01]]], device='cuda:0') shape:  torch.Size([1, 930, 50])
Set: tensor([[[ 136.1880,   -3.8267,   -5.9559,  ..., -134.9178,   37.8032,
           135.1707],
         [  74.9211,  -31.5331,  -23.5623,  ...,  -22.9354,   23.3082,
            79.0208],
         [ 113.2525,  -70.0596,  -25.1138,  ...,  -30.9160,   32.7993,
           119.0877],
         ...,
      

Testing:   1%|          | 174/20000 [00:09<40:02,  8.25it/s]

tensor([[[-1.2894e-01,  2.6332e+00,  6.1249e+00,  ..., -4.0691e-01,
          -4.2785e-01,  1.9716e+01],
         [-3.7814e-02,  3.4016e+00,  2.0486e+00,  ..., -2.4685e-01,
          -2.5035e-01,  8.3595e+00],
         [-2.1592e-02,  4.8534e+00,  7.9267e-01,  ..., -2.5069e-01,
          -2.3069e-01,  5.6401e+00],
         ...,
         [-1.7398e-02,  1.0001e+00,  3.1708e+00,  ..., -1.1413e-01,
          -6.0284e-02, -2.5112e-02],
         [-9.8114e-03,  1.6528e+00,  2.4340e+00,  ..., -1.1627e-01,
          -9.6311e-02, -4.4998e-02],
         [ 4.7242e+00,  1.1089e+00, -2.7725e-02,  ..., -1.6660e-01,
          -1.8673e-01,  3.4559e+00]]], device='cuda:0') shape:  torch.Size([1, 898, 120])
Set: tensor([[[-5.7833e-01,  2.4210e+01, -1.3833e+00,  ...,  9.7910e+01,
           4.1045e+00, -4.1345e+00],
         [-2.4238e-01, -9.3800e-02, -5.1977e-01,  ...,  3.2932e+01,
           7.2630e+00, -2.3218e+00],
         [-1.7788e-01, -1.7560e-01, -5.3688e-01,  ...,  3.0715e+01,
           5.2091e+0



Set: tensor([[[ 0.2238,  0.5488, -5.1993,  1.7266, -2.5491],
         [ 0.2238,  0.5488, -1.8107,  2.7225, -2.7225],
         [ 0.2238,  0.8463, -5.1993,  1.2087, -2.8538],
         ...,
         [ 1.7768,  1.0337,  0.6161,  1.5599,  0.5488],
         [ 1.7768,  0.8463,  0.6161,  1.1490,  0.5582],
         [ 0.8328,  2.6260,  0.7246,  0.2797,  1.1335]]], device='cuda:0') shape:  torch.Size([1, 927, 5])
Set: tensor([[[ 6.8444e+00, -4.6016e-02, -1.8069e-02,  ..., -9.5310e-02,
           1.1781e+00,  5.5375e+00],
         [ 6.0189e+00, -7.1778e-02, -3.0948e-02,  ..., -5.9339e-02,
           2.2874e-01,  6.9410e+00],
         [ 8.1908e+00, -3.1767e-02, -1.0386e-02,  ..., -1.0376e-01,
           1.9685e+00,  5.4696e+00],
         ...,
         [-2.5753e-03, -2.7704e-02, -9.7239e-03,  ..., -2.3032e-02,
           6.6307e+00, -2.8216e-03],
         [-7.1107e-03, -2.1036e-02, -6.1261e-03,  ..., -2.1838e-02,
           6.2526e+00, -5.8290e-03],
         [ 3.3393e+00,  3.4905e+00,  1.1700e+00,  

Testing:   1%|          | 177/20000 [00:09<42:39,  7.75it/s]

tensor([[[-2.1130e+00,  3.5327e+01,  1.3062e+02,  ...,  1.2477e+02,
          -2.0150e+00, -1.3508e+00],
         [-8.7463e-01, -3.7920e-01,  1.3191e+01,  ...,  4.0694e+01,
          -1.2159e+00, -7.0710e-01],
         [ 9.9072e+01,  6.2718e+01,  1.0785e+02,  ..., -2.0606e+00,
          -1.0034e+00, -1.8433e+00],
         ...,
         [-3.8697e-01, -2.0548e-01, -6.0168e-02,  ..., -3.3775e-01,
          -1.9134e-01, -2.8716e-01],
         [-3.8454e-01, -2.9360e-01, -1.1633e-01,  ..., -3.6044e-01,
          -2.0706e-01, -2.8793e-01],
         [-5.3482e-01, -7.5046e-01, -6.1987e-01,  ..., -7.6198e-01,
          -5.4751e-01, -4.9060e-01]]], device='cuda:0') shape:  torch.Size([1, 860, 50])
Set: tensor([[[ 2.0075e+02, -1.7012e+00, -2.8420e+01,  ..., -1.4106e+02,
           4.6896e+01,  2.0680e+02],
         [ 1.4545e+02, -5.3575e+01, -4.4170e+01,  ..., -3.5201e+01,
           3.6023e+01,  1.5944e+02],
         [ 6.9339e+01, -1.9927e+00,  8.4227e+00,  ...,  4.2817e+01,
           4.9641e+01



tensor([[[-2.0884e-01,  5.6837e+00,  5.9003e+00,  ..., -4.1966e-01,
          -3.8477e-01,  1.5716e+01],
         [-9.7181e-02,  1.0091e+00,  5.2975e+00,  ..., -1.9859e-01,
          -2.2613e-01,  9.9404e+00],
         [-8.8347e-02,  3.4838e+00,  4.1647e+00,  ..., -2.0598e-01,
          -2.0786e-01,  7.7498e+00],
         ...,
         [ 3.2140e-01,  8.9285e-01, -6.5680e-02,  ..., -1.6123e-02,
           2.7304e-01, -5.1136e-02],
         [ 3.4591e-02, -6.0085e-02, -5.3492e-02,  ...,  2.9517e+00,
          -1.6962e-02, -5.4225e-02],
         [-3.5876e-03,  1.5499e+00, -7.1505e-02,  ..., -2.2444e-02,
           1.3055e+00, -3.6187e-02]]], device='cuda:0') shape:  torch.Size([1, 901, 120])
Set: tensor([[[-6.6178e-01,  3.5200e+01, -1.6661e+00,  ...,  7.7376e+01,
           2.4139e+01, -4.2201e+00],
         [-4.2139e-01,  5.2740e+00, -7.1648e-01,  ...,  2.5885e+01,
           1.2291e+01, -1.9716e+00],
         [-3.5352e-01,  1.2671e+00, -6.3870e-01,  ...,  1.7837e+01,
           1.8769e+0

Testing:   1%|          | 181/20000 [00:10<37:18,  8.85it/s]

tensor([[[-2.2170e-01,  5.1598e+00,  5.8748e+00,  ..., -4.1409e-01,
          -3.8428e-01,  1.5989e+01],
         [-8.9454e-02,  3.5590e+00,  4.1174e+00,  ..., -2.1445e-01,
          -2.1365e-01,  7.5915e+00],
         [-7.8256e-02,  6.1584e+00,  2.7100e+00,  ..., -2.2854e-01,
          -1.9435e-01,  4.5178e+00],
         ...,
         [-4.4456e-02,  6.5335e+00,  1.4430e+00,  ..., -1.7348e-01,
          -1.2740e-01, -9.6093e-05],
         [-7.7503e-02,  1.3827e+00,  3.5092e+00,  ..., -1.2527e-01,
          -1.3565e-01,  2.5638e+00],
         [-8.3312e-02,  6.8777e+00,  1.3194e+00,  ..., -2.0208e-01,
          -1.4183e-01, -1.5505e-03]]], device='cuda:0') shape:  torch.Size([1, 778, 120])
Set: tensor([[[-6.7983e-01,  3.7981e+01, -1.6378e+00,  ...,  7.5481e+01,
           2.3359e+01, -4.1565e+00],
         [-3.2244e-01, -5.1185e-03, -6.5189e-01,  ...,  1.8333e+01,
           1.5550e+01, -1.9782e+00],
         [-2.4543e-01, -6.6540e-02, -6.4431e-01,  ...,  1.2806e+01,
           2.0345e+0

Testing:   1%|          | 183/20000 [00:10<36:40,  9.01it/s]

tensor([[[ 9.1800e+00, -2.7313e-02,  6.7746e-02,  ..., -1.1873e-01,
          -2.1304e-02,  9.0331e+00],
         [ 4.8247e+00, -5.3877e-03,  2.5568e-01,  ..., -6.7873e-02,
          -2.7649e-02,  4.6226e+00],
         [ 4.3456e+00, -9.7595e-03, -3.3007e-04,  ..., -6.2459e-02,
          -2.7969e-02,  4.4511e+00],
         ...,
         [ 4.0181e+00,  3.8003e+00,  1.2425e+00,  ..., -2.5368e-02,
           5.1726e+00, -2.4413e-02],
         [ 3.4228e+00,  7.6091e+00,  3.7310e+00,  ..., -4.1983e-02,
           6.8987e+00, -3.7861e-02],
         [ 3.3569e+00,  3.4279e+00,  2.2955e+00,  ..., -5.9811e-02,
           1.2062e+01, -2.8410e-02]]], device='cuda:0') shape:  torch.Size([1, 825, 80])
Set: tensor([[[-2.8058e-01,  4.9360e+00,  3.0821e+00,  ..., -4.0648e-01,
          -3.4911e-01,  1.0685e+01],
         [-1.4389e-01,  2.2756e+00,  1.8685e+00,  ..., -1.9680e-01,
          -1.7659e-01,  1.3472e+00],
         [-1.3542e-01,  4.1245e+00,  5.4439e-01,  ..., -1.9791e-01,
          -1.5911e-01



tensor([[[ 0.0929,  0.2114, -5.1993, -0.8199, -5.1993],
         [ 0.0929,  0.2114, -3.0514, -0.6980, -3.0514],
         [ 0.0929,  0.2114, -2.8369, -1.2485, -2.8369],
         ...,
         [-1.8007,  1.7864,  0.8119,  1.3201,  0.8730],
         [-1.8699,  1.7864,  0.8119, -0.4708,  0.8898],
         [-1.8699,  1.7016,  0.8119,  2.1181,  0.9243]]], device='cuda:0') shape:  torch.Size([1, 879, 5])
Set: tensor([[[ 1.1122e+01, -9.7107e-03,  8.1559e-01,  ..., -1.2773e-01,
          -2.9300e-03,  8.3557e+00],
         [ 6.5574e+00,  4.0724e-02,  6.4446e-01,  ..., -8.0256e-02,
          -4.2051e-04,  4.4406e+00],
         [ 6.1507e+00,  1.4274e+00,  1.2752e+00,  ..., -7.7058e-02,
           1.1944e-01,  3.5205e+00],
         ...,
         [ 3.1796e+00,  1.3227e+00, -1.3931e-02,  ...,  1.3868e+00,
          -1.1924e-02, -2.2749e-02],
         [ 3.3469e+00,  5.4690e+00,  6.4659e-01,  ...,  9.9446e-01,
          -9.4368e-03, -4.0049e-02],
         [ 2.8444e+00, -5.4890e-03, -2.3842e-02,  ..., 

Testing:   1%|          | 186/20000 [00:10<37:46,  8.74it/s]

tensor([[[  51.4920,   19.1725,   38.1467,  ..., -116.8731,   20.3826,
            50.7663],
         [  -5.2569,   10.7268,    2.8190,  ...,    3.1806,    7.8799,
             7.4279],
         [  70.7956,  -66.5270,  -11.5648,  ...,  -33.5522,   24.3643,
            79.8653],
         ...,
         [ -11.0256,    0.1686,    2.1308,  ...,   -1.1508,   -1.8607,
             1.8149],
         [ -10.6872,   -2.0858,   -5.7315,  ...,   -3.1116,   -2.9674,
             1.9783],
         [  -9.8323,    1.7575,   -0.1371,  ...,   -0.3212,   -3.0149,
            -1.5849]]], device='cuda:0') shape:  torch.Size([1, 829, 8])
Set: tensor([[[ 0.0327, -0.1509, -5.1993, -1.1818, -5.1993],
         [ 0.0327, -0.1509, -3.0106, -0.0933, -3.0106],
         [ 0.0327, -0.1509, -2.7934,  0.2124, -2.7934],
         ...,
         [-0.4010,  0.5588, -2.2061, -0.3277, -2.2657],
         [-0.6900,  0.6941, -2.0363, -2.1213, -2.0912],
         [ 0.0327, -0.1509, -2.1060, -1.4352, -2.1213]]], device='cuda:0') sha

Testing:   1%|          | 189/20000 [00:10<33:45,  9.78it/s]

tensor([[[ 0.1791,  0.4529, -5.1993, -0.3589, -5.1993],
         [ 0.1791,  0.4529, -3.1432, -0.0437, -3.1636],
         [ 0.1791,  0.8161, -2.8938, -1.9678, -2.8941],
         ...,
         [-0.1119,  0.8161, -1.3250,  0.5760, -1.3341],
         [-0.1119,  0.8161, -1.3100,  1.2487, -1.3226],
         [-0.1119,  0.8161, -1.2896,  0.4280, -1.3115]]], device='cuda:0') shape:  torch.Size([1, 1099, 5])
Set: tensor([[[ 1.1636e+01, -1.7188e-02,  4.5715e-01,  ..., -1.3024e-01,
           4.5082e-01,  8.6728e+00],
         [ 7.2861e+00, -1.2174e-02,  6.4202e-02,  ..., -8.4492e-02,
           6.3631e-01,  5.1509e+00],
         [ 7.8046e+00,  3.8806e+00,  2.4961e+00,  ..., -8.8348e-02,
           2.0740e+00,  2.6342e+00],
         ...,
         [ 4.4459e+00, -8.1077e-03, -5.4939e-03,  ..., -4.3714e-02,
           8.5612e-01,  2.0418e+00],
         [ 4.3616e+00, -2.3410e-02, -1.3242e-02,  ..., -4.1582e-02,
           6.9386e-01,  2.6529e+00],
         [ 4.4070e+00, -4.5976e-03, -3.7984e-03,  ...,



tensor([[[-0.6930,  0.7523, -5.1993, -0.8427, -5.1993],
         [-0.6930,  0.7523, -2.9765, -0.3739, -2.9765],
         [-0.6930,  0.7523, -2.7571, -1.1847, -2.7571],
         ...,
         [-1.8313, -0.7718, -1.1380,  0.9635, -0.9635],
         [-1.8313, -0.9320, -1.1038,  1.6213, -0.9405],
         [-0.6930,  0.7523, -1.9274, -0.3700, -1.9158]]], device='cuda:0') shape:  torch.Size([1, 687, 5])
Set: tensor([[[ 1.2946e+01,  3.8107e-01,  9.1111e-01,  ..., -1.2507e-01,
          -1.0728e-02,  7.8815e+00],
         [ 8.1927e+00,  6.4540e-01,  3.3960e-01,  ..., -7.5007e-02,
          -8.9612e-03,  4.1521e+00],
         [ 7.7985e+00,  2.6294e+00,  1.2698e+00,  ..., -7.2416e-02,
          -6.7091e-03,  2.9782e+00],
         ...,
         [ 7.0047e-01, -2.7348e-02, -2.5133e-02,  ...,  2.0435e-01,
          -7.9762e-02,  2.2429e+00],
         [ 1.8686e-01, -4.4621e-02, -3.3683e-02,  ...,  6.3593e-01,
          -8.5460e-02,  2.9073e+00],
         [ 5.9471e+00,  1.2743e+00,  3.2034e-01,  ..., 



Set: tensor([[[ 0.1689, -0.0105, -5.1993, -0.9876, -5.1993],
         [ 0.1689, -0.0105, -3.0363, -0.8162, -3.0363],
         [ 0.1689, -0.0105, -2.6881,  1.0481, -2.8208],
         ...,
         [ 0.1689, -1.0771,  0.4318, -0.5821,  0.3476],
         [ 1.4280,  1.3456, -0.0375, -0.8037,  0.2658],
         [ 1.5073,  1.4621, -0.0375, -0.4252,  0.2876]]], device='cuda:0') shape:  torch.Size([1, 836, 5])
Set: tensor([[[ 1.0528e+01, -9.6510e-03,  8.9784e-01,  ..., -1.2641e-01,
          -6.1365e-03,  8.3370e+00],
         [ 5.9272e+00, -5.7292e-04,  6.6952e-01,  ..., -7.8470e-02,
          -3.7273e-03,  4.4401e+00],
         [ 5.2662e+00, -4.1944e-02, -1.4851e-02,  ..., -6.7251e-02,
          -8.2212e-03,  5.7940e+00],
         ...,
         [-3.9818e-02, -1.4690e-03, -2.9748e-03,  ...,  9.8926e-01,
          -2.7404e-02, -1.0947e-02],
         [ 1.6859e+00,  3.1323e+00,  1.8488e+00,  ..., -3.9552e-02,
           7.0683e+00, -2.2548e-02],
         [ 1.8597e+00,  2.4048e+00,  1.5099e+00,  



Set: tensor([[[ 0.2213, -0.2695, -5.1993, -0.5696, -5.1993],
         [ 0.2213, -0.2695, -3.0252, -0.2954, -3.0252],
         [ 0.2213, -0.2695, -2.8090, -0.7895, -2.8090],
         ...,
         [-1.0899, -0.7874, -1.6789, -1.8311, -1.7117],
         [-0.4199, -0.4473, -1.3756, -0.9783, -1.7536],
         [-0.0686, -0.2695, -1.9845, -1.0404, -2.0312]]], device='cuda:0') shape:  torch.Size([1, 806, 5])
Set: tensor([[[ 9.8092e+00, -2.3363e-02,  2.7785e-01,  ..., -1.2283e-01,
          -1.2329e-02,  8.8844e+00],
         [ 5.1759e+00, -1.6568e-02, -6.9048e-04,  ..., -7.4365e-02,
          -1.0155e-02,  5.0634e+00],
         [ 4.7607e+00, -3.9805e-03,  4.9670e-01,  ..., -7.0975e-02,
          -8.6746e-03,  4.1932e+00],
         ...,
         [ 1.9922e+00,  2.6978e+00,  9.4034e-01,  ..., -2.8277e-02,
          -5.4257e-02,  1.1821e+00],
         [ 2.3566e+00,  7.8696e-01,  3.6809e-01,  ..., -3.4955e-02,
          -2.9828e-02,  2.0215e+00],
         [ 3.2920e+00,  8.2279e-01,  6.7857e-01,  

Testing:   1%|          | 196/20000 [00:11<38:49,  8.50it/s]

tensor([[[ 169.5415,   -0.7358,  -15.0812,  ..., -146.2430,   41.9948,
           170.8607],
         [ 112.5092,  -81.0517,  -38.0595,  ...,  -51.4552,   39.8195,
           131.1310],
         [ 145.9123,  -77.4464,  -40.0661,  ...,  -46.8367,   43.8253,
           160.2460],
         ...,
         [  48.3247,   62.5101,  -23.2437,  ...,  -73.5839,   59.8580,
            36.0099],
         [ 166.9809,   35.2467,    7.3920,  ...,   78.2287,    4.8883,
           122.0618],
         [  49.5140,  -12.5666,  -15.8671,  ...,   -5.0444,   20.9557,
            49.0581]]], device='cuda:0') shape:  torch.Size([1, 867, 8])
Set: tensor([[[-0.4256, -0.2168, -5.1993, -0.5144, -5.1993],
         [-0.4256, -0.2168, -3.0249, -0.7467, -3.0249],
         [-0.4256, -0.2168, -2.8086, -1.3830, -2.8086],
         ...,
         [-0.8253, -0.2168, -1.4205, -0.0624, -1.4248],
         [-0.8253, -0.2168, -1.4036, -0.3500, -1.4078],
         [-0.7243, -0.2168, -1.4466,  0.2714, -1.4510]]], device='cuda:0') sha

Testing:   1%|          | 198/20000 [00:11<38:47,  8.51it/s]

tensor([[[-2.3806e-01,  3.1068e+00,  4.8914e+00,  ..., -3.8357e-01,
          -3.5348e-01,  1.3494e+01],
         [-1.2402e-01,  5.4606e+00,  9.3519e-01,  ..., -2.2790e-01,
          -1.7536e-01,  3.6498e-01],
         [-1.1389e-01, -2.6335e-02,  4.6476e+00,  ..., -1.3455e-01,
          -1.6941e-01,  3.0676e+00],
         ...,
         [ 3.8585e+00, -2.1607e-03, -2.0862e-01,  ..., -1.0381e-03,
          -9.2244e-02, -4.6952e-02],
         [ 3.8932e+00,  1.3111e-01, -2.1047e-01,  ..., -4.9757e-03,
          -8.9594e-02, -3.6636e-02],
         [-8.5977e-02, -1.6746e-02,  3.9428e+00,  ..., -9.4599e-02,
          -1.2664e-01,  1.7724e+00]]], device='cuda:0') shape:  torch.Size([1, 849, 120])
Set: tensor([[[-6.3747e-01,  3.0859e+01, -1.4264e+00,  ...,  6.5996e+01,
           2.4564e+01, -3.8077e+00],
         [-2.2177e-01, -1.1779e-01, -6.0483e-01,  ...,  8.4308e+00,
           2.2258e+01, -1.9645e+00],
         [-2.8768e-01, -5.8581e-02, -4.4184e-01,  ..., -2.9375e-03,
          -9.6770e-0



tensor([[[ 1.0334e+01, -2.1511e-02, -4.2684e-04,  ..., -1.1599e-01,
          -2.8533e-02,  8.8510e+00],
         [ 5.8938e+00,  3.7690e-01,  5.6005e-01,  ..., -7.1040e-02,
          -2.4435e-02,  4.3624e+00],
         [ 5.4332e+00,  2.4229e-01,  4.2483e-01,  ..., -6.6061e-02,
          -2.4436e-02,  4.0715e+00],
         ...,
         [-7.3564e-02, -3.7139e-02, -1.2828e-02,  ...,  6.1270e-01,
          -3.4893e-02,  3.9800e-01],
         [ 2.6040e+00,  8.9344e-01,  8.6230e-01,  ..., -4.7936e-02,
           4.6403e+00,  1.8278e-01],
         [-6.8231e-03,  3.0075e+00,  2.4727e+00,  ..., -4.3850e-02,
           8.2584e+00, -2.4818e-02]]], device='cuda:0') shape:  torch.Size([1, 975, 80])
Set: tensor([[[-0.1750,  5.2160,  4.5573,  ..., -0.4055, -0.3621, 12.7066],
         [-0.0808,  0.7343,  3.4411,  ..., -0.1917, -0.1937,  5.7915],
         [-0.0711,  1.6649,  2.6230,  ..., -0.1836, -0.1759,  4.3480],
         ...,
         [-0.2617, -0.1233, -0.1401,  ...,  0.3119, -0.0795, -0.1641],
 

Testing:   1%|          | 201/20000 [00:12<40:51,  8.08it/s]

tensor([[[ 102.5320,   10.6149,   11.9941,  ..., -127.3851,   27.1690,
           101.1940],
         [  77.6672,  -66.5684,  -21.8546,  ...,  -43.7235,   24.3625,
           104.4066],
         [   9.6167,   -1.8120,   -7.4576,  ...,   -5.0007,    5.5166,
            21.4040],
         ...,
         [  62.8180,   33.0822,   -3.6552,  ...,   37.5599,   -7.4046,
           -32.4453],
         [  51.8832,   54.8486,  -23.9917,  ...,   10.2349,   -4.3840,
           -46.5482],
         [  24.1926,   47.3731,  -15.1022,  ...,    7.8726,  -15.1312,
           -59.2327]]], device='cuda:0') shape:  torch.Size([1, 975, 8])
Set: tensor([[[ 0.2835, -0.3938, -5.1993, -1.6109, -5.1993],
         [ 0.2835, -0.3938, -3.1187, -0.2497, -3.1379],
         [ 0.2835, -0.3938, -2.9084, -1.0844, -2.9027],
         ...,
         [-1.5296,  0.6075,  0.1384,  2.3260,  0.1174],
         [-1.7987, -1.1744,  0.4654,  0.7066,  0.5676],
         [-0.6690, -1.1071, -0.1031, -0.5139, -0.1518]]], device='cuda:0') sha

Testing:   1%|          | 203/20000 [00:12<37:54,  8.70it/s]

tensor([[[-1.8395e+00,  4.4089e+01,  1.4065e+02,  ...,  1.3611e+02,
          -1.8902e+00, -1.4490e+00],
         [-1.2102e+00, -3.0742e-01,  1.2498e+00,  ...,  3.8386e+01,
          -1.3307e+00, -6.2110e-01],
         [-5.8407e-01, -3.9833e-01, -6.6395e-04,  ...,  1.9421e+01,
          -7.3439e-01, -4.3506e-01],
         ...,
         [-8.2378e-01, -4.7791e-01, -6.7729e-01,  ..., -5.1453e-01,
          -7.6320e-01, -6.2987e-02],
         [-8.2044e-01, -5.3411e-01, -6.6552e-01,  ..., -7.5248e-01,
          -7.9696e-01, -8.9381e-02],
         [-9.2770e-01, -4.3760e-01, -6.6003e-01,  ..., -4.1862e-01,
          -7.1175e-01, -1.1466e-01]]], device='cuda:0') shape:  torch.Size([1, 604, 50])
Set: tensor([[[ 203.6295,   19.3388,  -48.4792,  ..., -156.1564,   48.3735,
           239.7900],
         [  92.6665,  -67.9955,  -38.5867,  ...,  -59.3229,   33.6736,
           125.1407],
         [  41.3512,  -12.7148,  -25.3401,  ...,  -24.1034,   10.1908,
            61.0323],
         ...,
      



tensor([[[-2.5602e+00,  6.1403e+01,  9.2894e+01,  ...,  1.3364e+02,
          -1.3238e+00, -1.1910e+00],
         [-1.0477e+00, -3.2365e-01,  1.5165e+01,  ...,  3.6358e+01,
          -1.3111e+00, -6.8459e-01],
         [-1.4106e+00, -5.0030e-01, -4.7352e-02,  ...,  3.2072e+01,
          -1.6696e+00, -8.0309e-01],
         ...,
         [-4.0095e-01, -2.8286e-01, -3.7557e-01,  ..., -1.4941e-01,
          -6.2931e-01, -5.1057e-01],
         [-6.9211e-01, -6.5391e-01, -7.7822e-01,  ..., -5.1491e-01,
          -8.5518e-01, -1.3489e-01],
         [-3.2697e-01, -4.4644e-01, -5.3499e-01,  ...,  1.1411e+01,
          -3.9873e-01, -1.2974e-01]]], device='cuda:0') shape:  torch.Size([1, 858, 50])
Set: tensor([[[ 1.5906e+01,  4.1620e+01,  2.0752e+01,  ..., -1.1165e+02,
           3.6508e+01,  3.9876e+01],
         [ 1.3053e+02, -5.2702e+01, -3.5725e+01,  ..., -3.8373e+01,
           2.8765e+01,  1.4655e+02],
         [ 1.6607e+02, -1.0368e+02, -3.2988e+01,  ..., -5.0478e+01,
           4.5173e+01

Testing:   1%|          | 208/20000 [00:12<33:18,  9.90it/s]

Set: tensor([[[ 0.0724, -0.3573, -5.1993, -0.4372, -5.1993],
         [ 0.0724, -0.3573, -3.0413, -1.4708, -3.0413],
         [ 0.0724, -0.3573, -2.8261, -1.0382, -2.8261],
         ...,
         [-0.2674, -0.3573, -1.7640, -0.3542, -1.7781],
         [ 0.0724,  0.9745, -1.6386, -1.0182, -1.6735],
         [ 0.0724,  0.0384, -2.1469,  0.1467, -2.1619]]], device='cuda:0') shape:  torch.Size([1, 850, 5])
Set: tensor([[[ 9.6633e+00, -2.6766e-02,  2.2046e-02,  ..., -1.1967e-01,
          -1.9048e-02,  9.0258e+00],
         [ 5.1783e+00,  9.8108e-01,  1.1810e+00,  ..., -7.5150e-02,
          -1.3724e-02,  4.0053e+00],
         [ 4.6843e+00,  1.2197e-01,  6.7986e-01,  ..., -6.9242e-02,
          -1.4491e-02,  4.0085e+00],
         ...,
         [ 2.6010e+00, -5.9765e-03, -2.3334e-03,  ..., -4.0329e-02,
          -2.4187e-02,  2.6184e+00],
         [ 5.5847e+00,  2.7430e+00,  1.4449e+00,  ..., -5.9246e-02,
           2.1070e+00,  1.1294e+00],
         [ 4.1652e+00, -1.5997e-02, -4.5483e-03,  



Set: tensor([[[ 0.3263, -0.3322, -5.1993, -0.4035, -5.1993],
         [ 0.1984, -0.3322, -2.9362, -0.4988, -3.0598],
         [ 0.1984, -0.3322, -2.7142, -0.3765, -2.8458],
         ...,
         [-2.2017,  1.4775,  1.4775,  0.3646,  1.7270],
         [-2.2017,  1.4775,  1.5706,  0.9652,  1.7522],
         [-2.3072,  1.4775,  1.7148,  0.6598,  1.8988]]], device='cuda:0') shape:  torch.Size([1, 904, 5])
Set: tensor([[[ 9.5735e+00, -2.8770e-02,  8.1873e-02,  ..., -1.2282e-01,
          -1.1463e-02,  9.0947e+00],
         [ 5.0837e+00, -1.3003e-02,  1.1442e-01,  ..., -7.3381e-02,
          -1.2207e-02,  4.9706e+00],
         [ 4.6169e+00, -1.4515e-02, -2.9719e-04,  ..., -6.8265e-02,
          -1.2239e-02,  4.6844e+00],
         ...,
         [ 9.6887e-01,  3.8230e+00, -6.1574e-03,  ...,  3.6293e+00,
          -2.7142e-02, -4.6918e-02],
         [ 8.4187e-01,  2.4623e+00, -1.3102e-02,  ...,  3.9198e+00,
          -2.8696e-02, -4.1717e-02],
         [ 6.2145e-01,  3.3149e+00, -9.9481e-03,  

Testing:   1%|          | 212/20000 [00:13<33:32,  9.83it/s]

tensor([[[ 0.3032, -0.1702, -5.1993, -1.1896, -5.1993],
         [ 0.0603, -0.1702, -3.1072, -0.3148, -3.1072],
         [ 0.0603, -0.1702, -2.8962, -0.7422, -2.8940],
         ...,
         [ 0.3032,  1.2669,  1.7987, -1.0762,  1.5196],
         [ 0.5178,  0.1195, -1.4320, -0.3415, -1.4363],
         [ 0.3032,  0.3229, -2.2258, -0.7107, -2.2234]]], device='cuda:0') shape:  torch.Size([1, 1060, 5])
Set: tensor([[[ 1.0064e+01, -8.2451e-03,  1.0763e+00,  ..., -1.2658e-01,
          -6.0471e-03,  8.2613e+00],
         [ 5.6965e+00, -1.4130e-02, -3.9469e-04,  ..., -7.5491e-02,
          -1.2140e-02,  5.1104e+00],
         [ 5.2825e+00, -3.0844e-03,  4.4952e-01,  ..., -7.2006e-02,
          -1.0821e-02,  4.3085e+00],
         ...,
         [-6.5615e-03,  5.0695e+00,  1.7078e+00,  ...,  7.1203e-01,
           3.8787e+00, -5.0214e-02],
         [ 2.6119e+00, -2.0752e-03,  2.9063e-01,  ..., -4.7509e-02,
           1.0158e+00,  1.8815e+00],
         [ 4.9559e+00,  6.0092e-01,  7.8083e-01,  ...,



tensor([[[ 9.2498e-02,  4.7776e-03, -5.1993e+00, -1.6369e+00, -5.1993e+00],
         [ 9.2498e-02,  4.7776e-03, -2.8933e+00, -6.3795e-01, -3.0184e+00],
         [ 9.2498e-02,  4.7776e-03, -2.6161e+00, -1.1906e+00, -2.8018e+00],
         ...,
         [-5.5238e-01,  1.6876e+00,  1.0616e+00, -9.5311e-01,  9.7850e-01],
         [-4.6358e-01,  1.6876e+00,  1.0616e+00,  1.8742e-01,  1.0259e+00],
         [-4.6358e-01,  1.7216e+00,  1.1496e+00,  2.2351e+00,  1.0700e+00]]],
       device='cuda:0') shape:  torch.Size([1, 788, 5])
Set: tensor([[[ 1.0668e+01,  5.9229e-01,  1.6299e+00,  ..., -1.2748e-01,
          -6.2635e-03,  7.7046e+00],
         [ 5.9183e+00, -3.8680e-03,  4.4349e-01,  ..., -7.5579e-02,
          -6.0781e-03,  4.5572e+00],
         [ 5.4885e+00,  1.0026e+00,  1.0742e+00,  ..., -7.1710e-02,
          -4.5718e-03,  3.6336e+00],
         ...,
         [ 2.1112e+00,  5.6405e+00,  1.5602e+00,  ..., -1.6385e-03,
           2.5365e+00, -4.2986e-02],
         [ 1.8728e+00,  3.0099e+0



tensor([[[ 9.9519e+00, -1.0454e-02,  6.2250e-01,  ..., -1.1759e-01,
          -2.6695e-02,  8.3798e+00],
         [ 6.1253e+00, -3.9984e-03,  2.3740e-01,  ..., -7.3181e-02,
          -1.7913e-02,  4.6834e+00],
         [ 4.9832e+00, -4.8274e-03,  1.6780e-01,  ..., -6.5485e-02,
          -2.4541e-02,  4.3555e+00],
         ...,
         [ 2.3981e+00, -1.4402e-02, -9.3184e-03,  ..., -3.1203e-02,
          -3.4104e-02,  2.6930e+00],
         [ 3.1704e+00,  2.1710e+00,  8.4013e-01,  ..., -3.7984e-02,
          -2.3522e-02,  1.1784e+00],
         [ 2.6970e+00, -5.6665e-03, -4.5298e-03,  ..., -3.5186e-02,
          -3.3225e-02,  2.5307e+00]]], device='cuda:0') shape:  torch.Size([1, 978, 80])
Set: tensor([[[-2.1291e-01,  2.4541e+00,  5.8636e+00,  ..., -3.7171e-01,
          -3.5715e-01,  1.4699e+01],
         [-9.0975e-02,  4.0093e+00,  2.8462e+00,  ..., -2.2336e-01,
          -1.9848e-01,  5.5565e+00],
         [-1.1393e-01,  3.7472e+00,  1.9503e+00,  ..., -1.9852e-01,
          -1.6707e-01



tensor([[[-5.5389e-01,  2.4970e+01, -1.4751e+00,  ...,  8.9386e+01,
           1.0441e+01, -4.0592e+00],
         [-2.5018e-01, -8.0993e-02, -5.1735e-01,  ...,  1.2631e+01,
           1.1585e+01, -1.7282e+00],
         [-2.3834e-01, -9.7872e-02, -4.7578e-01,  ..., -2.2930e-01,
          -4.2696e-01, -4.2238e-01],
         ...,
         [ 1.1791e+01, -4.8277e-01, -1.0007e+00,  ..., -5.7318e-02,
          -2.7468e-01, -1.1807e+00],
         [ 1.9494e+01, -6.7727e-01, -8.1029e-01,  ..., -4.4515e-01,
          -1.0596e+00, -6.3562e-01],
         [-2.5480e-01, -8.5694e-01, -1.4104e+00,  ..., -1.6390e-01,
          -8.4178e-01, -1.9526e+00]]], device='cuda:0') shape:  torch.Size([1, 528, 70])
Set: tensor([[[-2.2553e+00,  4.1328e+01,  1.1850e+02,  ...,  1.3506e+02,
          -1.9209e+00, -1.3786e+00],
         [-1.0734e+00, -5.1473e-01, -2.7667e-02,  ...,  3.4049e+01,
          -1.3767e+00, -7.0935e-01],
         [-2.6572e-01, -3.8475e-01, -1.5886e-01,  ..., -3.6524e-01,
          -3.4666e-01

Testing:   1%|          | 219/20000 [00:14<32:22, 10.18it/s]

tensor([[[-6.3223e-01,  4.5021e+01, -1.6014e+00,  ...,  5.3142e+01,
           4.1642e+01, -4.3565e+00],
         [-2.9282e-01,  6.4697e-01, -5.0334e-01,  ..., -1.5961e-03,
           2.4907e+01, -1.8475e+00],
         [-3.1517e-01, -4.4141e-02, -4.1352e-01,  ..., -4.9763e-01,
          -5.2702e-01, -1.9113e-02],
         ...,
         [-2.1448e-01, -2.9582e-01, -6.8343e-01,  ..., -1.4899e-01,
          -2.2298e-01, -9.9112e-01],
         [-2.2223e-01, -3.4155e-01, -6.7878e-01,  ..., -1.5247e-01,
          -2.6461e-01, -9.9802e-01],
         [-2.2217e-01, -3.6119e-01, -7.0190e-01,  ..., -1.5324e-01,
          -2.6757e-01, -9.6225e-01]]], device='cuda:0') shape:  torch.Size([1, 874, 70])
Set: tensor([[[-2.5483e+00,  3.2499e+01,  8.9694e+01,  ...,  1.0967e+02,
          -1.5581e+00, -1.1571e+00],
         [-1.1263e+00, -4.7581e-01, -1.2641e-01,  ...,  2.2643e+01,
          -1.1913e+00, -5.3971e-01],
         [-6.3222e-02, -1.2715e-01,  3.0889e+00,  ..., -1.0221e+00,
          -4.7487e-01



tensor([[[-0.2534,  2.6260,  6.0113,  ..., -0.3893, -0.3709, 16.5926],
         [-0.0817, -0.0386,  5.7464,  ..., -0.1558, -0.2197,  9.8671],
         [-0.1098,  1.8890,  4.8036,  ..., -0.1845, -0.1961,  8.3832],
         ...,
         [30.9117, -0.4603, -0.6077,  ...,  4.8320, -0.3731, 14.8879],
         [-0.1063, -0.3203,  2.6082,  ..., 16.1733, -0.0524, -0.0707],
         [-0.0535,  6.0024,  1.0320,  ..., -0.1975, -0.1527,  2.3102]]],
       device='cuda:0') shape:  torch.Size([1, 707, 120])
Set: tensor([[[-7.6324e-01,  4.2566e+01, -1.5479e+00,  ...,  7.5891e+01,
           2.5005e+01, -4.0780e+00],
         [-3.5131e-01, -3.7117e-02, -5.9924e-01,  ...,  2.3553e+01,
          -8.2305e-02, -1.4988e+00],
         [-4.2703e-01,  8.0870e+00, -6.6684e-01,  ...,  2.0351e+01,
           1.8411e+01, -1.9626e+00],
         ...,
         [-3.3282e-01, -1.3223e+00,  5.2409e+01,  ..., -3.1957e-01,
          -1.2917e+00, -1.5459e+00],
         [-4.4378e-01, -3.8883e-01, -5.2337e-03,  ..., -1.056

Testing:   1%|          | 225/20000 [00:14<30:25, 10.83it/s]

tensor([[[-7.3897e-01,  4.7284e+01, -1.4143e+00,  ...,  6.2138e+01,
           2.4925e+01, -3.6288e+00],
         [-2.6512e-01,  2.8538e+00, -5.6168e-01,  ...,  1.0018e+00,
           2.8216e+01, -1.9215e+00],
         [-2.4097e-01,  9.8615e-01, -4.6677e-01,  ..., -4.4473e-02,
           2.5063e+01, -1.6686e+00],
         ...,
         [-2.5545e-01, -5.1787e-01, -3.0995e-01,  ..., -2.5622e-01,
           1.0106e+01, -9.7574e-01],
         [-1.5802e-01, -1.4357e-01,  2.2579e+00,  ..., -6.6479e-01,
          -2.7173e-01, -5.1361e-01],
         [-1.9752e-01,  1.3856e+01, -2.1523e-01,  ..., -3.4589e-01,
           1.7221e+01, -9.9898e-01]]], device='cuda:0') shape:  torch.Size([1, 886, 70])
Set: tensor([[[-2.1371e+00,  2.6328e+01,  1.1274e+02,  ...,  1.0921e+02,
          -1.7419e+00, -1.2343e+00],
         [-1.2294e+00, -4.5599e-01, -2.0376e-01,  ...,  1.7471e+01,
          -1.1182e+00, -4.3726e-01],
         [-1.1004e+00, -5.3407e-01, -2.6376e-01,  ...,  1.3983e+01,
          -1.0362e+00

Testing:   1%|          | 227/20000 [00:14<31:33, 10.44it/s]

tensor([[[-0.1005, -0.0411, -5.1993, -0.7929, -5.1993],
         [-0.3772, -0.0411, -3.0086, -0.0099, -3.0086],
         [-0.1005, -0.0411, -2.7913, -0.0461, -2.7913],
         ...,
         [-1.3703,  1.2393,  1.7737,  0.8624,  1.1648],
         [-0.3772, -0.0411, -1.9500,  0.4671, -1.9841],
         [-0.3772, -0.0411, -2.0468, -0.0560, -2.0336]]], device='cuda:0') shape:  torch.Size([1, 763, 5])
Set: tensor([[[ 1.0592e+01, -1.2865e-02,  5.6687e-01,  ..., -1.2241e-01,
          -1.4850e-02,  8.4836e+00],
         [ 6.0438e+00, -1.5849e-02, -4.5571e-03,  ..., -6.9014e-02,
          -2.1524e-02,  5.0548e+00],
         [ 5.4212e+00, -1.5463e-02, -3.2748e-03,  ..., -6.7515e-02,
          -1.3516e-02,  4.6682e+00],
         ...,
         [ 7.5026e-01,  1.4221e+00, -1.0913e-02,  ...,  2.7562e+00,
          -1.3916e-02, -2.8365e-02],
         [ 3.8195e+00, -2.0612e-02, -1.0209e-02,  ..., -4.4869e-02,
          -2.1409e-02,  3.5826e+00],
         [ 3.9848e+00, -8.9325e-03, -4.1612e-03,  ..., 



tensor([[[ 1.0240e+01,  5.2631e-01,  1.8391e+00,  ..., -1.3085e-01,
           1.5809e-01,  7.7018e+00],
         [ 6.1272e+00,  1.6751e-01,  9.4024e-01,  ..., -8.3537e-02,
           7.9524e-01,  4.3341e+00],
         [ 5.7557e+00,  2.5382e+00,  2.0665e+00,  ..., -8.1494e-02,
           1.0612e+00,  3.0071e+00],
         ...,
         [ 1.9984e+00, -3.0811e-03, -1.8498e-04,  ..., -3.0248e-02,
           7.7486e+00, -1.1897e-02],
         [ 1.9898e+00,  6.9583e-02,  1.6851e-01,  ..., -3.0440e-02,
           7.7889e+00, -1.3619e-02],
         [ 2.0417e+00, -1.8303e-02, -8.0314e-03,  ..., -2.8582e-02,
           7.7446e+00, -6.0492e-03]]], device='cuda:0') shape:  torch.Size([1, 828, 80])
Set: tensor([[[-2.6340e-01, -6.3946e-02,  1.0633e+01,  ..., -2.9092e-01,
          -3.8690e-01,  2.0024e+01],
         [-1.4229e-01,  2.3247e-01,  5.7624e+00,  ..., -1.8761e-01,
          -2.1632e-01,  9.0007e+00],
         [-1.3288e-01, -1.7139e-01,  1.1283e+01,  ..., -2.3343e-02,
          -2.3550e-01



tensor([[[-5.7670e-01,  2.4842e+01, -1.5142e+00,  ...,  7.2197e+01,
           1.7571e+01, -3.8606e+00],
         [-3.6888e-01, -2.4773e-02, -6.1250e-01,  ...,  1.2080e+01,
          -4.1686e-02, -1.2880e+00],
         [-2.6467e-01, -7.5716e-02, -5.1937e-01,  ...,  5.1730e+00,
           1.4611e+01, -1.6582e+00],
         ...,
         [-7.6718e-02, -2.7640e-01, -2.5478e-02,  ..., -6.8357e-01,
          -9.0972e-01,  7.8497e+00],
         [-1.5083e-01, -1.8951e-01, -2.6487e-01,  ..., -3.0703e-03,
           7.1660e+00, -1.2771e+00],
         [-3.6623e-01,  1.7086e+01, -5.2959e-01,  ...,  7.3632e+00,
          -1.2650e-01, -1.3490e+00]]], device='cuda:0') shape:  torch.Size([1, 994, 70])
Set: tensor([[[-2.2650e+00,  3.8229e+01,  1.1358e+02,  ...,  1.3129e+02,
          -1.8401e+00, -1.2846e+00],
         [-6.7317e-01, -4.3063e-01,  6.9177e+00,  ...,  3.4536e+01,
          -9.6282e-01, -6.4569e-01],
         [-1.0160e+00, -5.4911e-01, -1.5356e-02,  ...,  3.8764e+01,
          -1.3486e+00



tensor([[[-2.6252e+00,  5.0412e+01,  1.0264e+02,  ...,  1.2763e+02,
          -1.4333e+00, -1.1793e+00],
         [-5.8466e-01, -3.1799e-01, -2.6195e-03,  ...,  4.7397e+00,
          -5.2674e-01, -2.8348e-01],
         [-1.3424e+00, -4.6973e-01, -2.9825e-01,  ...,  1.3242e+01,
          -1.1565e+00, -4.1044e-01],
         ...,
         [-1.6414e-01, -2.7579e-01, -1.6113e-01,  ..., -1.9501e-01,
          -1.3231e-01, -1.8394e-01],
         [-9.1192e-01, -6.7610e-01, -7.3646e-01,  ..., -5.8604e-01,
          -6.8206e-01, -2.7920e-02],
         [-1.0695e+00, -6.5714e-01, -6.2275e-01,  ..., -2.3680e-01,
          -5.1636e-01, -5.8598e-02]]], device='cuda:0') shape:  torch.Size([1, 502, 50])
Set: tensor([[[ 3.1192e+01,  3.0713e+01,  3.8775e+01,  ..., -1.1575e+02,
           2.5501e+01,  3.9330e+01],
         [-4.3390e+00,  6.3357e+00,  1.5528e+00,  ...,  5.5264e+00,
           6.9213e+00,  3.6300e+00],
         [ 6.7142e+01, -6.0600e+01, -1.1851e+01,  ..., -3.5793e+01,
           1.8277e+01

Testing:   1%|          | 235/20000 [00:15<32:34, 10.11it/s]

tensor([[[ 163.3713,    4.5813,   -6.2554,  ..., -157.0938,   41.6544,
           163.1249],
         [ 158.9049, -100.9984,  -32.4511,  ...,  -59.3554,   49.1511,
           174.6962],
         [ 131.9342,  -68.5202,  -31.8087,  ...,  -36.2024,   35.1619,
           145.3678],
         ...,
         [ -18.7800,    0.5739,    1.8621,  ...,    2.2807,   -9.5221,
            -3.9806],
         [ -13.9867,   -1.5945,    3.3730,  ...,    3.3383,   -7.4031,
             0.5456],
         [ -12.4733,   -1.6427,    0.5330,  ...,    0.6034,   -4.5897,
             2.6638]]], device='cuda:0') shape:  torch.Size([1, 989, 8])
Set: tensor([[[ 0.2212, -1.0885, -5.1993, -0.8917, -5.1993],
         [ 0.2212, -0.9286, -3.0833, -1.0566, -3.0833],
         [ 0.2212, -0.9286, -2.8708, -1.1073, -2.8708],
         ...,
         [ 0.0423,  1.1097, -0.7013, -1.0085, -0.6753],
         [ 0.0423, -0.6014, -1.0236, -0.7079, -1.1362],
         [ 1.1097, -1.4943, -1.1917, -1.5182, -1.2456]]], device='cuda:0') sha



tensor([[[-0.2709,  4.0130,  4.5064,  ..., -0.4015, -0.3605, 13.2240],
         [-0.1172,  3.5319,  2.6848,  ..., -0.1947, -0.1756,  4.0923],
         [-0.1089,  4.0876,  1.8260,  ..., -0.1776, -0.1516,  2.6674],
         ...,
         [-0.0527, -0.1387, -0.1751,  ...,  4.1888, -0.0460,  3.6253],
         [-0.0555, -0.0153, -0.2105,  ..., -0.0288, -0.0640,  3.7498],
         [-0.0467, -0.0298, -0.2263,  ..., -0.0156, -0.1098,  0.6777]]],
       device='cuda:0') shape:  torch.Size([1, 673, 120])
Set: tensor([[[-6.9441e-01,  3.8789e+01, -1.5358e+00,  ...,  6.4021e+01,
           2.9027e+01, -4.0754e+00],
         [-3.2521e-01, -3.5325e-02, -5.7899e-01,  ...,  8.5190e+00,
           1.8175e+01, -1.7990e+00],
         [-2.9286e-01, -5.5502e-02, -4.9358e-01,  ...,  1.8634e+00,
           2.0099e+01, -1.6585e+00],
         ...,
         [-7.3882e-02, -4.5564e-01,  2.3093e+00,  ..., -5.7341e-01,
          -3.5598e-01, -2.9503e-01],
         [-1.0872e-01, -4.6496e-01, -5.1395e-02,  ..., -3.242



tensor([[[-1.9048e-01,  1.8543e+00,  7.3728e+00,  ..., -3.6566e-01,
          -3.7790e-01,  1.7170e+01],
         [-9.9856e-02,  4.6506e+00,  2.8903e+00,  ..., -2.2475e-01,
          -1.9672e-01,  4.6630e+00],
         [-5.0852e-02, -4.5205e-03,  3.1707e+00,  ..., -1.7350e-01,
          -1.9479e-01,  6.4252e+00],
         ...,
         [-8.0760e-02,  1.1288e+00, -2.0449e-02,  ..., -9.1223e-02,
           1.5469e+00, -1.0279e-01],
         [-7.3737e-02,  1.1085e+00, -2.1519e-02,  ..., -8.3361e-02,
           2.1451e+00, -9.9120e-02],
         [-6.8002e-02,  3.8978e+00, -6.3096e-03,  ..., -9.9406e-02,
          -5.7191e-03, -7.3808e-02]]], device='cuda:0') shape:  torch.Size([1, 874, 120])
Set: tensor([[[-6.7381e-01,  2.7966e+01, -1.4688e+00,  ...,  7.7171e+01,
           1.0794e+01, -3.6023e+00],
         [-2.8463e-01, -5.9429e-02, -6.2987e-01,  ...,  1.1363e+01,
           1.8920e+01, -1.9805e+00],
         [-2.3848e-01, -1.5288e-01, -3.9697e-01,  ...,  1.3050e+01,
           4.6720e-0



tensor([[[ 0.1725, -0.3607, -5.1993, -0.4276, -5.1993],
         [ 0.1725, -0.3607, -3.0558, -0.4775, -3.0558],
         [ 0.1725, -0.3607, -2.8416, -0.8586, -2.8416],
         ...,
         [-0.2705, -0.0478, -1.6965,  0.5125, -1.7085],
         [-0.2705, -0.3607, -1.7460,  0.2603, -1.7460],
         [-0.2705, -0.3607, -1.6676, -1.3216, -1.6965]]], device='cuda:0') shape:  torch.Size([1, 892, 5])
Set: tensor([[[ 9.5967e+00, -2.7614e-02,  4.3230e-02,  ..., -1.2078e-01,
          -1.6346e-02,  9.0595e+00],
         [ 5.0565e+00, -1.3598e-02,  6.9881e-02,  ..., -7.3875e-02,
          -1.3425e-02,  4.9912e+00],
         [ 4.6356e+00, -3.6056e-03,  5.0550e-01,  ..., -7.0218e-02,
          -1.2222e-02,  4.2310e+00],
         ...,
         [ 3.1628e+00, -2.0692e-02, -1.0439e-02,  ..., -4.0088e-02,
          -1.8288e-02,  3.1337e+00],
         [ 2.4795e+00, -1.9793e-02, -9.4190e-03,  ..., -3.8035e-02,
          -2.5730e-02,  3.1396e+00],
         [ 2.5038e+00,  1.6629e+00,  8.7734e-01,  ..., 



tensor([[[ 0.1455, -0.1935, -5.1993, -0.5150, -5.1993],
         [ 0.1455, -0.1935, -3.0312, -0.4907, -3.0312],
         [ 0.1455, -0.1935, -2.6305, -0.3701, -2.8154],
         ...,
         [ 0.8846, -0.5325, -1.5033,  0.6716, -1.4025],
         [ 0.8846, -0.5325, -1.4755,  0.6376, -1.3551],
         [ 0.9838, -0.5325, -1.4755, -0.1579, -1.3325]]], device='cuda:0') shape:  torch.Size([1, 822, 5])
Set: tensor([[[ 1.0040e+01, -2.2995e-02,  2.3512e-01,  ..., -1.2267e-01,
          -1.2616e-02,  8.8779e+00],
         [ 5.4415e+00, -1.0526e-02,  1.7608e-01,  ..., -7.5019e-02,
          -9.8430e-03,  4.8331e+00],
         [ 4.9160e+00, -1.2088e-02,  2.6635e-02,  ..., -6.8029e-02,
          -1.0209e-02,  4.5490e+00],
         ...,
         [ 6.2286e-01, -3.6934e-02, -1.1469e-02,  ..., -4.1605e-02,
           1.3622e-01,  3.2010e+00],
         [ 5.3146e-01, -3.5861e-02, -1.1076e-02,  ..., -4.0863e-02,
           1.5444e-01,  3.0793e+00],
         [ 5.0090e-01, -1.8165e-02, -1.5911e-03,  ..., 



tensor([[[ 0.4833, -0.1328, -5.1993, -1.0442, -5.1993],
         [ 0.4833,  0.1540, -3.0141, -0.1655, -3.0141],
         [ 0.4833,  0.1540, -2.6636, -1.2641, -2.7972],
         ...,
         [-0.9563,  1.4646,  2.1417, -1.8673,  1.2429],
         [-0.7558,  1.2823,  1.9578, -0.6543,  1.1758],
         [ 0.4833, -0.1328, -2.0037, -3.0141, -2.0159]]], device='cuda:0') shape:  torch.Size([1, 777, 5])
Set: tensor([[[ 1.0035e+01, -1.2139e-02,  9.8977e-01,  ..., -1.2868e-01,
          -4.6566e-04,  8.4146e+00],
         [ 6.0484e+00, -1.4831e-02,  1.2045e-01,  ..., -8.1704e-02,
           7.6179e-01,  4.9825e+00],
         [ 5.6431e+00,  1.1498e+00,  1.3767e+00,  ..., -7.8566e-02,
           1.0308e+00,  3.5479e+00],
         ...,
         [ 1.0549e+00,  7.7655e+00,  2.3088e+00,  ...,  1.7425e+00,
           9.4647e-01, -5.5798e-02],
         [ 5.4840e-01,  4.5635e+00,  8.7274e-01,  ...,  1.7917e+00,
           7.5170e-01, -4.1773e-02],
         [ 3.4542e+00,  5.1983e+00,  3.2114e+00,  ..., 



tensor([[[-1.9922e+00,  4.0108e+01,  1.3142e+02,  ...,  1.3449e+02,
          -1.9415e+00, -1.3541e+00],
         [-8.2289e-01, -3.2197e-01,  1.3337e+01,  ...,  4.0614e+01,
          -1.1106e+00, -5.5399e-01],
         [-9.0140e-01, -3.2194e-01,  3.5861e+00,  ...,  3.6864e+01,
          -1.0678e+00, -4.4552e-01],
         ...,
         [ 1.2811e+01, -9.8525e-01, -1.0854e+00,  ...,  2.7649e+01,
           3.8494e+01, -1.4798e-01],
         [ 9.6998e+00, -1.0098e+00, -1.0167e+00,  ...,  4.0894e+00,
           2.4054e+01, -1.3306e-01],
         [ 4.1578e+00, -1.1324e+00, -7.9798e-01,  ...,  3.2342e+01,
           7.8747e+01, -3.1109e-01]]], device='cuda:0') shape:  torch.Size([1, 822, 50])
Set: tensor([[[ 199.2489,   -1.3098,  -33.9781,  ..., -149.6421,   52.4280,
           208.9193],
         [ 102.4255,  -41.4361,  -41.9424,  ...,  -54.6170,   27.6710,
           118.2337],
         [  76.2249,  -45.4576,  -30.7809,  ...,  -46.1923,   25.3881,
            88.5629],
         ...,
      



tensor([[[-2.7564e+00,  4.4196e+01,  1.1796e+02,  ...,  1.2938e+02,
          -1.8593e+00, -1.3689e+00],
         [-1.6203e+00, -4.2585e-01,  6.3315e-01,  ...,  2.9180e+01,
          -1.7495e+00, -7.3690e-01],
         [-1.3030e+00, -4.7998e-01, -6.8758e-03,  ...,  2.5520e+01,
          -1.5222e+00, -6.7699e-01],
         ...,
         [-1.2116e+00, -2.7943e-01, -6.2180e-01,  ...,  1.8114e+01,
          -5.7736e-01, -3.3792e-01],
         [-1.2111e+00, -5.0318e-01, -1.8805e-01,  ...,  1.8447e+01,
          -1.2773e+00, -5.2331e-01],
         [-3.9745e-01, -4.2499e-01, -4.0831e-01,  ..., -4.7576e-01,
          -3.6912e-01, -1.8842e-01]]], device='cuda:0') shape:  torch.Size([1, 867, 50])
Set: tensor([[[  80.3698,   21.5391,   25.1579,  ..., -136.6589,   18.5429,
            80.6448],
         [ 142.7363, -104.8754,  -27.5290,  ...,  -57.4478,   39.0583,
           157.7456],
         [ 132.9174,  -80.8944,  -24.5529,  ...,  -37.6665,   31.3481,
           140.3229],
         ...,
      



tensor([[[-1.4367e+00,  2.4349e+01, -7.4926e-02,  ...,  4.6394e+01,
          -3.1803e-01, -3.2316e-01],
         [-2.4801e+00,  1.1939e+02,  9.3795e+01,  ...,  1.9770e+02,
          -2.2350e-01, -1.1319e+00],
         [-1.4271e+00,  4.1765e+01, -9.2016e-02,  ...,  7.0734e+01,
          -8.4835e-02, -2.4082e-01],
         ...,
         [-3.1179e-01, -5.8546e-01, -7.4236e-01,  ..., -3.6684e-01,
          -1.9927e-01, -1.3604e-01],
         [-3.2210e-01, -4.7162e-01, -5.4306e-01,  ..., -5.0620e-01,
          -2.9749e-01, -4.6125e-02],
         [-3.1189e-01, -4.8057e-01, -5.4452e-01,  ..., -5.5192e-01,
          -3.4539e-01, -5.4532e-02]]], device='cuda:0') shape:  torch.Size([1, 847, 50])
Set: tensor([[[ -6.2657,  11.1705,  11.5345,  ..., -25.6862,   7.9079, -11.9670],
         [ -2.9976,  99.4399,  -6.1395,  ..., -96.1308,  98.7652,  29.3614],
         [ -6.6415,  32.5446,   5.6236,  ..., -23.8666,  28.8285, -24.6576],
         ...,
         [ -7.7105,  -2.0160,  -0.2163,  ...,  -1.1256



tensor([[[-6.8827e-01,  4.0813e+01, -1.5663e+00,  ...,  6.6752e+01,
           3.0579e+01, -4.1388e+00],
         [-2.9627e-01, -1.9902e-02, -6.2217e-01,  ...,  8.3106e+00,
           2.3661e+01, -2.0122e+00],
         [-2.7993e-01, -3.0538e-02, -5.1990e-01,  ...,  2.4529e+00,
           2.0242e+01, -1.7265e+00],
         ...,
         [-7.1200e-01, -1.0698e+00, -6.4345e-01,  ..., -2.4401e-01,
          -3.3220e-02, -1.8413e+00],
         [-2.9956e-01, -5.2688e-02, -4.7495e-01,  ..., -3.1150e-01,
          -4.5704e-01, -2.9980e-01],
         [-3.9435e-01,  1.2787e+01, -5.3956e-01,  ..., -6.3046e-01,
          -8.8408e-01, -8.1112e-02]]], device='cuda:0') shape:  torch.Size([1, 961, 70])
Set: tensor([[[-2.5194e+00,  3.3422e+01,  1.1673e+02,  ...,  1.2009e+02,
          -1.9278e+00, -1.2882e+00],
         [-1.2337e+00, -4.4595e-01, -4.9785e-02,  ...,  3.2781e+01,
          -1.4420e+00, -6.8391e-01],
         [-1.0828e+00, -5.4617e-01, -8.1569e-02,  ...,  2.8954e+01,
          -1.3808e+00

Testing:   1%|▏         | 257/20000 [00:17<31:18, 10.51it/s]

tensor([[[-2.2435e+00,  4.0266e+01,  1.3995e+02,  ...,  1.3330e+02,
          -2.0766e+00, -1.4205e+00],
         [-8.2265e-01, -3.5948e-01,  1.0958e+01,  ...,  3.2996e+01,
          -1.1289e+00, -6.4424e-01],
         [-8.8246e-01, -4.1082e-01,  1.2334e+01,  ...,  4.3949e+01,
          -1.3354e+00, -7.9806e-01],
         ...,
         [-8.1948e-01, -3.9673e-01, -6.3780e-01,  ..., -1.1793e+00,
          -9.3594e-01, -1.2070e-01],
         [-8.0309e-01, -3.8069e-01, -6.2753e-01,  ..., -1.1912e+00,
          -9.1628e-01, -1.7305e-01],
         [-7.8350e-01, -5.4096e-01, -6.6253e-01,  ..., -5.7637e-01,
          -7.8524e-01, -6.5990e-02]]], device='cuda:0') shape:  torch.Size([1, 965, 50])
Set: tensor([[[ 198.0292,    0.4408,  -18.2006,  ..., -151.1326,   43.1455,
           196.2929],
         [ 114.3979,  -41.5692,  -40.4347,  ...,  -36.7072,   25.6795,
           135.8424],
         [ 163.6280,  -65.9547,  -41.0946,  ...,  -47.8750,   39.9175,
           174.1117],
         ...,
      



tensor([[[ 0.0944, -0.0251, -5.1993, -1.3708, -5.1993],
         [ 0.0944, -0.0251, -2.7335, -1.7642, -3.0226],
         [ 0.0944, -0.0251, -2.2809, -0.9775, -2.8062],
         ...,
         [-1.1398, -1.0620,  1.1428,  0.6884,  1.1015],
         [-0.9750, -1.0620,  1.1428, -1.9379,  1.1611],
         [-0.9750, -1.2053,  1.2151,  0.8947,  1.3018]]], device='cuda:0') shape:  torch.Size([1, 799, 5])
Set: tensor([[[ 1.0572e+01, -6.3701e-04,  1.3072e+00,  ..., -1.2646e-01,
          -7.6208e-03,  7.9696e+00],
         [ 5.9020e+00,  2.1266e+00,  1.7172e+00,  ..., -7.6765e-02,
          -4.3804e-03,  3.5274e+00],
         [ 5.3012e+00,  4.4612e-01,  7.9883e-01,  ..., -6.7403e-02,
          -6.4600e-03,  3.8707e+00],
         ...,
         [-4.8558e-02, -1.7457e-02, -2.1830e-02,  ...,  4.4258e+00,
          -6.5132e-02, -1.5853e-02],
         [-4.8324e-02,  4.1895e+00,  8.9538e-01,  ...,  3.5858e+00,
          -5.4024e-02, -4.1360e-02],
         [-5.7183e-02, -2.4065e-02, -2.4516e-02,  ..., 

Testing:   1%|▏         | 261/20000 [00:18<32:39, 10.07it/s]

tensor([[[ 9.3537e+00, -9.0795e-03,  5.0508e-01,  ..., -1.1003e-01,
          -4.4618e-02,  8.3266e+00],
         [ 5.0874e+00, -1.3100e-02, -4.2877e-03,  ..., -6.1351e-02,
          -3.8699e-02,  4.8946e+00],
         [ 4.6120e+00, -1.4745e-02, -5.8107e-03,  ..., -5.6191e-02,
          -3.8727e-02,  4.6057e+00],
         ...,
         [ 1.7649e+00,  2.8098e+00,  1.3932e+00,  ..., -3.4359e-02,
          -1.6593e-03, -6.1637e-04],
         [ 2.5999e+00,  1.9241e+00,  9.1742e-01,  ..., -4.0762e-02,
          -3.9498e-02,  1.9488e+00],
         [ 2.2066e+00, -7.6726e-03, -4.5356e-03,  ..., -3.5550e-02,
          -4.3633e-02,  2.9331e+00]]], device='cuda:0') shape:  torch.Size([1, 703, 80])
Set: tensor([[[-2.0748e-01,  1.3808e+00,  5.0729e+00,  ..., -3.5541e-01,
          -3.4689e-01,  1.1489e+01],
         [-8.4418e-02,  4.5013e+00, -1.2938e-03,  ..., -2.1059e-01,
          -1.7681e-01, -1.8675e-03],
         [-7.4242e-02,  4.6220e+00, -6.6457e-03,  ..., -1.9459e-01,
          -1.5859e-01



tensor([[[-0.3054, -0.4347, -5.1993, -0.2772, -5.1993],
         [-0.3054, -0.4347, -2.5738, -0.6387, -2.9809],
         [-0.3054, -0.4347, -2.2736, -0.8273, -2.7619],
         ...,
         [-0.5802, -1.0451,  0.5973,  0.7392,  0.5112],
         [-0.5802, -0.9560,  0.5973,  1.0389,  0.5360],
         [ 0.6745, -1.3031,  0.4867,  1.5764,  0.5611]]], device='cuda:0') shape:  torch.Size([1, 697, 5])
Set: tensor([[[ 9.6768e+00, -2.9144e-02, -3.3157e-03,  ..., -1.1397e-01,
          -3.1831e-02,  9.1475e+00],
         [ 4.8787e+00, -7.7840e-03,  3.6335e-02,  ..., -6.2047e-02,
          -2.8840e-02,  4.6620e+00],
         [ 4.4055e+00, -2.2042e-03,  2.4702e-01,  ..., -5.6912e-02,
          -2.8253e-02,  4.0769e+00],
         ...,
         [-3.9204e-02, -2.5436e-02, -2.0457e-02,  ...,  2.5379e+00,
          -5.0250e-02, -3.2368e-03],
         [-3.7668e-02, -3.0797e-02, -2.3361e-02,  ...,  2.5435e+00,
          -4.8610e-02, -1.3870e-03],
         [-5.4486e-02, -5.5919e-02, -2.7531e-02,  ..., 

Testing:   1%|▏         | 265/20000 [00:18<34:11,  9.62it/s]

tensor([[[ 2.2010e+02,  1.1097e+01, -3.2923e+01,  ..., -1.2660e+02,
           3.6476e+01,  2.3100e+02],
         [ 1.9113e+02, -7.8211e+01, -4.7136e+01,  ..., -5.3709e+01,
           4.7097e+01,  2.0573e+02],
         [ 1.8533e+02, -8.4801e+01, -4.2881e+01,  ..., -4.9089e+01,
           4.7770e+01,  1.9751e+02],
         ...,
         [ 1.3657e+01, -4.7670e+00, -2.4482e+00,  ...,  8.6385e+00,
           2.9282e+00,  2.4483e+01],
         [-1.2746e+01,  7.1136e-01,  1.2792e+00,  ...,  1.4374e+00,
          -6.0490e+00, -4.8311e+00],
         [ 3.2741e+00,  2.8039e+00, -1.7638e+00,  ...,  7.7106e+00,
           2.2944e-01,  1.1678e+01]]], device='cuda:0') shape:  torch.Size([1, 873, 8])
Set: tensor([[[ 0.4348,  0.1561, -5.1993, -0.6616, -5.1993],
         [ 0.4348,  0.3538, -2.9688, -1.0915, -2.9688],
         [ 0.4348,  0.3538, -2.7489, -0.9408, -2.7489],
         ...,
         [ 1.5822, -0.2171, -0.4328,  0.9886,  0.1259],
         [ 1.5822, -0.0244, -0.4328,  0.1032,  0.1561],
      



Set: tensor([[[ 0.2842, -0.4909, -5.1993, -0.2445, -5.1993],
         [ 0.2842, -0.4909, -3.0301, -1.0999, -3.0301],
         [ 0.2842, -0.4909, -2.8142, -0.9259, -2.8142],
         ...,
         [-5.1993, -0.7414,  0.8247,  0.3666,  0.7886],
         [-1.5354,  0.4737,  0.1725, -1.7756,  0.1693],
         [ 0.0767, -0.2540, -2.8142, -0.7515, -2.6812]]], device='cuda:0') shape:  torch.Size([1, 819, 5])
Set: tensor([[[ 9.1836e+00, -3.4469e-02, -2.0999e-03,  ..., -1.2004e-01,
          -1.7063e-02,  9.3265e+00],
         [ 4.6593e+00, -1.8963e-03,  7.4359e-01,  ..., -7.4785e-02,
          -1.2156e-02,  4.4525e+00],
         [ 4.1864e+00, -4.5751e-03,  5.4017e-01,  ..., -6.9570e-02,
          -1.2296e-02,  4.2112e+00],
         ...,
         [-9.3513e-03,  1.8401e+00, -2.9297e-02,  ...,  7.9184e+00,
          -1.6857e-01, -2.2840e-02],
         [ 1.4513e+00,  5.8555e+00,  1.4696e+00,  ...,  3.0664e-01,
          -3.1814e-02, -3.0809e-02],
         [ 4.6516e+00, -2.8597e-03,  4.1914e-01,  



tensor([[[ 0.0548,  0.2360, -5.1993, -0.9825, -5.1993],
         [ 0.0548,  0.2360, -3.0406, -0.5869, -3.0406],
         [ 0.0548,  0.2360, -2.8254, -1.6845, -2.8254],
         ...,
         [-1.1817,  0.7755, -1.1049, -0.7856, -0.8733],
         [ 0.4645, -1.1214, -0.9242, -1.4524, -1.1159],
         [ 0.0548,  0.2360, -2.1459, -2.6929, -2.1610]]], device='cuda:0') shape:  torch.Size([1, 848, 5])
Set: tensor([[[ 1.1218e+01, -5.4119e-03,  1.0039e+00,  ..., -1.2800e-01,
          -2.9890e-03,  8.1810e+00],
         [ 6.6062e+00, -1.4822e-03,  5.1753e-01,  ..., -7.9541e-02,
          -1.1290e-03,  4.5030e+00],
         [ 6.2459e+00,  2.4886e+00,  1.7780e+00,  ..., -7.7827e-02,
           1.8132e-01,  3.0672e+00],
         ...,
         [ 4.1984e+00,  3.2012e+00,  6.5031e-01,  ..., -2.7135e-02,
          -1.7655e-02, -3.0373e-03],
         [-1.1238e-02,  7.0170e-01,  7.8975e-01,  ..., -2.6359e-02,
          -2.0404e-02,  9.1864e-01],
         [ 4.9208e+00,  5.1907e+00,  2.9287e+00,  ..., 



tensor([[[ 1.0519e+01, -1.4119e-02,  6.0008e-01,  ..., -1.2418e-01,
          -1.0579e-02,  8.5275e+00],
         [ 6.0228e+00, -8.5609e-03,  6.1922e-02,  ..., -7.3299e-02,
          -1.4375e-02,  4.7880e+00],
         [ 5.5399e+00, -1.4073e-02, -2.8412e-03,  ..., -6.7753e-02,
          -1.4816e-02,  4.6638e+00],
         ...,
         [ 3.5319e+00, -2.1693e-02, -8.0503e-03,  ..., -4.8419e-02,
          -9.5621e-03,  3.5955e+00],
         [ 4.8172e+00,  5.6413e-01,  7.4326e-01,  ..., -6.4987e-02,
          -6.8075e-03,  3.3395e+00],
         [ 3.8018e+00, -6.6804e-03,  1.0448e-01,  ..., -5.7576e-02,
          -1.3943e-02,  3.5538e+00]]], device='cuda:0') shape:  torch.Size([1, 836, 80])
Set: tensor([[[-2.1095e-01,  3.7610e+00,  6.6497e+00,  ..., -3.9192e-01,
          -3.8229e-01,  1.6785e+01],
         [-9.4783e-02,  5.2844e+00,  2.6249e+00,  ..., -2.3113e-01,
          -1.9755e-01,  4.4449e+00],
         [-8.5337e-02,  6.1085e+00,  1.9980e+00,  ..., -2.2411e-01,
          -1.7675e-01

Testing:   1%|▏         | 273/20000 [00:19<34:08,  9.63it/s]

tensor([[[ 1.0255e+01,  3.9613e-01,  1.6142e+00,  ..., -1.2700e-01,
          -7.3339e-03,  7.7698e+00],
         [ 5.7760e+00, -3.5238e-03,  4.3190e-01,  ..., -7.5340e-02,
          -1.2410e-02,  4.6195e+00],
         [ 5.3182e+00,  2.6574e+00,  2.0104e+00,  ..., -7.6579e-02,
          -3.3672e-03,  3.0209e+00],
         ...,
         [-1.3061e-01, -6.2400e-02, -2.5702e-02,  ...,  5.7616e+00,
          -1.8821e-02, -1.7385e-02],
         [-1.4056e-01, -7.7828e-02, -3.3055e-02,  ...,  6.3417e+00,
          -2.9012e-02, -1.1721e-02],
         [ 2.3327e+00,  1.8052e+00,  1.1459e+00,  ..., -4.3967e-02,
          -3.4503e-03,  1.1401e+00]]], device='cuda:0') shape:  torch.Size([1, 931, 80])
Set: tensor([[[-2.3943e-01, -4.1705e-02,  9.9891e+00,  ..., -3.0765e-01,
          -3.8191e-01,  2.0106e+01],
         [-1.2424e-01,  3.1602e+00,  3.5191e+00,  ..., -2.0649e-01,
          -1.9517e-01,  5.8541e+00],
         [-1.3509e-01, -1.8062e-01,  1.1109e+01,  ..., -1.1605e-02,
          -2.2451e-01

Testing:   1%|▏         | 275/20000 [00:19<34:54,  9.42it/s]

tensor([[[ 1.0385e+01, -3.1823e-02, -3.6864e-03,  ..., -1.1987e-01,
          -1.7467e-02,  9.2591e+00],
         [ 6.5386e+00, -3.5529e-03,  3.5804e-01,  ..., -7.7857e-02,
          -7.7045e-03,  4.6860e+00],
         [ 6.1049e+00,  2.5197e-01,  5.9637e-01,  ..., -7.3766e-02,
          -6.9139e-03,  4.0883e+00],
         ...,
         [-3.0216e-02, -1.3571e-02, -2.4360e-02,  ...,  4.1250e+00,
          -7.8749e-02, -1.2630e-02],
         [-3.0056e-02, -1.9667e-02, -2.8711e-02,  ...,  4.4648e+00,
          -8.4464e-02, -1.0595e-02],
         [-3.0895e-02, -2.9079e-02, -3.4059e-02,  ...,  4.7245e+00,
          -8.7357e-02, -7.4750e-03]]], device='cuda:0') shape:  torch.Size([1, 1025, 80])
Set: tensor([[[-2.0487e-01,  6.6704e+00,  5.0919e+00,  ..., -4.2866e-01,
          -3.7066e-01,  1.2910e+01],
         [-9.5146e-02,  3.9128e+00,  3.5386e+00,  ..., -2.2965e-01,
          -2.1798e-01,  7.4386e+00],
         [-8.6534e-02,  1.9286e+00,  3.8885e+00,  ..., -1.9158e-01,
          -1.9989e-0



tensor([[[ 0.1618, -0.2825, -5.1993, -0.6161, -5.1993],
         [ 0.1618,  0.0379, -3.0673, -0.8463, -3.0673],
         [ 0.1618,  0.0379, -2.8538, -0.1851, -2.8538],
         ...,
         [-0.1508, -0.2825, -1.9535, -0.9059, -1.9627],
         [-0.1508, -0.4762, -1.9535, -0.0054, -1.9269],
         [-0.3740, -0.4762, -1.9015, -0.0487, -1.8933]]], device='cuda:0') shape:  torch.Size([1, 927, 5])
Set: tensor([[[ 9.8145e+00, -2.2100e-02,  3.0351e-01,  ..., -1.2209e-01,
          -1.4196e-02,  8.8367e+00],
         [ 6.1221e+00,  7.1926e-02,  7.3177e-01,  ..., -7.9716e-02,
          -2.6322e-03,  4.4433e+00],
         [ 5.6117e+00, -1.3103e-02, -3.2410e-04,  ..., -7.3218e-02,
          -3.9564e-03,  4.6644e+00],
         ...,
         [ 3.1615e+00,  5.9280e-01,  4.8891e-01,  ..., -4.8181e-02,
          -1.7912e-02,  2.4304e+00],
         [ 2.5297e+00, -1.7283e-02, -6.6400e-03,  ..., -4.3013e-02,
          -2.4949e-02,  3.3155e+00],
         [ 2.5898e+00, -1.4694e-02, -6.8768e-03,  ..., 



tensor([[[ 9.1754e+00, -3.2350e-02,  2.6894e-02,  ..., -1.2330e-01,
          -9.7727e-03,  9.2257e+00],
         [ 4.6657e+00, -4.7996e-03,  7.3243e-01,  ..., -7.7799e-02,
          -5.4095e-03,  4.5842e+00],
         [ 4.1721e+00, -1.3441e-02,  2.2890e-01,  ..., -7.1892e-02,
          -6.1829e-03,  4.5898e+00],
         ...,
         [-1.4756e-02,  4.3668e-01, -6.2479e-03,  ...,  1.6343e+00,
           4.1822e+00, -3.7082e-02],
         [-1.3291e-02,  1.6817e+00, -2.5127e-04,  ...,  1.4707e+00,
           4.5206e+00, -4.3041e-02],
         [-1.5709e-02,  2.1898e+00,  3.8270e-01,  ...,  1.1600e+00,
           5.3382e+00, -4.6276e-02]]], device='cuda:0') shape:  torch.Size([1, 863, 80])
Set: tensor([[[-0.3065,  5.5988,  3.1553,  ..., -0.4217, -0.3588, 11.0813],
         [-0.2162, -0.0164,  3.3120,  ..., -0.1632, -0.1884,  3.2875],
         [-0.2017,  2.7592,  0.8081,  ..., -0.1951, -0.1652,  1.0359],
         ...,
         [-0.1648, -0.0653, -0.0506,  ..., -0.1233, -0.1078,  4.7000],
 

Testing:   1%|▏         | 281/20000 [00:20<34:23,  9.56it/s]

tensor([[[ 1.0041e+01, -8.5997e-03,  1.0047e+00,  ..., -1.2514e-01,
          -9.3847e-03,  8.2801e+00],
         [ 5.4490e+00,  6.7983e-01,  1.0927e+00,  ..., -7.7803e-02,
          -6.3001e-03,  4.1117e+00],
         [ 5.7688e+00,  6.5419e-01,  9.6598e-01,  ..., -7.6065e-02,
           1.4756e-01,  3.7888e+00],
         ...,
         [ 1.3628e+00,  1.8300e+00, -5.4266e-03,  ...,  8.9361e-01,
          -7.0601e-03, -2.3839e-02],
         [ 1.6017e+00,  2.8046e+00, -1.2353e-03,  ...,  8.1600e-01,
          -5.4181e-03, -2.8508e-02],
         [ 1.5096e+00,  1.3889e+00, -9.0715e-03,  ...,  1.1507e+00,
          -9.2522e-03, -2.3415e-02]]], device='cuda:0') shape:  torch.Size([1, 817, 80])
Set: tensor([[[-2.4642e-01,  6.4146e-01,  6.9446e+00,  ..., -3.5900e-01,
          -3.6863e-01,  1.7156e+01],
         [-1.5974e-01, -4.2676e-02,  6.1911e+00,  ..., -1.3780e-01,
          -2.0104e-01,  6.4374e+00],
         [-1.1098e-01, -1.1550e-02,  5.9230e+00,  ..., -1.5358e-01,
          -1.9846e-01

Testing:   1%|▏         | 283/20000 [00:20<34:41,  9.47it/s]

tensor([[[ 1.0752e+01, -2.7516e-03,  1.3272e+00,  ..., -1.3026e-01,
           1.7145e-01,  8.0502e+00],
         [ 6.2205e+00, -1.5573e-03,  4.9780e-01,  ..., -7.6934e-02,
          -5.1940e-03,  4.4410e+00],
         [ 5.8220e+00,  1.6188e+00,  1.3230e+00,  ..., -7.4130e-02,
          -3.1653e-03,  3.3562e+00],
         ...,
         [-4.8860e-03,  5.2676e+00,  2.1905e+00,  ..., -2.9137e-03,
           3.4183e+00, -4.1113e-02],
         [-5.8124e-03,  4.5585e+00,  1.9184e+00,  ..., -3.7166e-03,
           3.7103e+00, -3.8191e-02],
         [-1.3974e-03,  1.5588e+00,  1.0702e+00,  ..., -2.8244e-02,
          -1.4605e-02,  5.7740e-01]]], device='cuda:0') shape:  torch.Size([1, 748, 80])
Set: tensor([[[-2.2632e-01,  8.6107e-02,  8.8088e+00,  ..., -3.5211e-01,
          -3.9175e-01,  2.0387e+01],
         [-1.0183e-01,  3.5953e+00,  3.7714e+00,  ..., -2.1026e-01,
          -2.0610e-01,  6.5725e+00],
         [-9.5238e-02, -6.8890e-02,  7.2025e+00,  ..., -1.0176e-01,
          -2.0542e-01



tensor([[[ 7.4730e+00, -2.6188e-02, -4.7409e-03,  ..., -9.5266e-02,
          -7.6452e-02,  9.0019e+00],
         [ 6.5295e+00,  3.6541e+00,  1.0404e+00,  ..., -4.8225e-02,
          -5.7886e-02,  2.3711e+00],
         [ 5.8222e+00, -1.0840e-02, -1.4249e-02,  ..., -3.6976e-02,
          -6.2750e-02,  3.9213e+00],
         ...,
         [ 1.2694e+01,  5.1169e+00,  3.0121e+00,  ..., -1.0609e-01,
           1.8377e+01, -1.4330e-02],
         [ 6.3518e+00,  8.2079e+00,  3.4717e+00,  ..., -4.8180e-02,
           4.0301e+00, -2.6754e-02],
         [ 5.9736e+00,  5.8400e+00,  2.6653e+00,  ..., -5.2761e-02,
           6.0381e+00, -1.9066e-02]]], device='cuda:0') shape:  torch.Size([1, 231, 80])
Set: tensor([[[-2.6891e-01,  2.4110e+00,  2.2365e-01,  ..., -3.6450e-01,
          -3.2508e-01,  2.6606e+00],
         [ 1.5846e+01, -1.5849e-01, -7.3142e-02,  ..., -1.0045e-01,
          -2.3512e-01,  5.9767e+00],
         [ 1.7048e+01,  2.9067e-01, -1.2892e-01,  ..., -1.8777e-01,
          -1.9532e-01



tensor([[[-2.1664e+00,  5.2706e+01,  1.3784e+02,  ...,  1.4648e+02,
          -1.9339e+00, -1.4308e+00],
         [-8.5046e-01, -3.3422e-01,  1.8388e+01,  ...,  3.5298e+01,
          -1.2060e+00, -6.5184e-01],
         [-4.7193e-01, -3.9720e-01, -1.1615e-01,  ..., -1.6628e-02,
          -4.5599e-01, -3.3145e-01],
         ...,
         [-7.4841e-01, -6.1191e-01, -5.7665e-01,  ..., -1.4299e-01,
          -5.4199e-01, -7.1109e-02],
         [-3.0039e-01, -1.4280e+00, -8.6337e-01,  ..., -7.3346e-01,
           2.8031e+00, -6.5988e-01],
         [-1.4025e-01, -6.7836e-02,  1.8420e+00,  ..., -6.0678e-01,
          -3.5116e-01, -3.8051e-01]]], device='cuda:0') shape:  torch.Size([1, 833, 50])
Set: tensor([[[ 1.4828e+02,  1.5141e+01, -3.1389e+01,  ..., -1.6483e+02,
           5.2608e+01,  1.7650e+02],
         [ 1.2812e+02, -4.6616e+01, -4.1896e+01,  ..., -4.3454e+01,
           2.4897e+01,  1.4933e+02],
         [-6.4444e+00, -1.5869e+00, -1.8825e+00,  ...,  1.5126e-01,
           2.2267e+00

Testing:   1%|▏         | 289/20000 [00:21<31:38, 10.38it/s]

tensor([[[ 133.6420,   24.7060,    4.3575,  ..., -160.2421,   31.7338,
           134.5034],
         [ 160.8874,  -85.4227,  -42.8641,  ...,  -63.8772,   47.8252,
           182.0806],
         [  97.5117,  -39.5129,  -44.8386,  ...,  -42.6881,   22.6244,
           137.3666],
         ...,
         [ -11.5967,   -1.2565,   -0.4657,  ...,    0.4160,   -3.4801,
             1.5375],
         [ 196.9754,   41.8980,  -45.8139,  ...,   40.5541,   78.9631,
            -0.9965],
         [ 154.5717,  -23.9461,   -8.7596,  ...,  -35.9937,    0.6767,
           108.7198]]], device='cuda:0') shape:  torch.Size([1, 966, 8])
Set: tensor([[[-0.0951, -0.2641, -5.1993, -0.9666, -5.1993],
         [-0.0951, -0.2641, -3.0399, -0.2810, -3.0399],
         [-0.0951, -0.2641, -2.8246, -1.0052, -2.8246],
         ...,
         [-0.0951,  0.4102, -2.2437, -0.8629, -2.2257],
         [-0.0951,  0.6698, -2.0892, -0.5398, -2.1026],
         [-0.0951,  0.6698, -2.0393,  0.4492, -2.0513]]], device='cuda:0') sha



Set: tensor([[[ 0.4530, -0.0632, -5.1993, -0.5484, -5.1993],
         [ 0.4530, -0.0632, -3.0759, -0.6524, -3.0759],
         [ 0.4530, -0.0632, -2.7915, -0.3812, -2.7320],
         ...,
         [ 0.4530, -0.0632, -1.5788, -0.2730, -1.5834],
         [ 0.4530, -0.0632, -1.5564, -0.2976, -1.5653],
         [ 0.1107, -0.0632, -2.7915, -1.0228, -2.8629]]], device='cuda:0') shape:  torch.Size([1, 954, 5])
Set: tensor([[[ 1.0190e+01, -2.2233e-02,  4.5237e-01,  ..., -1.2782e-01,
          -6.9888e-04,  8.8350e+00],
         [ 5.6970e+00, -7.1021e-03,  5.4159e-01,  ..., -8.1507e-02,
           2.3262e-01,  4.7535e+00],
         [ 4.9626e+00, -1.1198e-02,  2.2690e-01,  ..., -7.3835e-02,
           2.2324e-01,  4.3608e+00],
         ...,
         [ 2.4982e+00, -6.8078e-03,  8.3195e-02,  ..., -4.7666e-02,
           3.3464e-01,  2.3104e+00],
         [ 2.4605e+00, -6.1377e-03,  1.1115e-01,  ..., -4.7290e-02,
           3.4217e-01,  2.2533e+00],
         [ 5.4600e+00,  4.7604e-01,  8.5117e-01,  

Testing:   1%|▏         | 293/20000 [00:21<33:03,  9.93it/s]

tensor([[[-1.4164e+02,  1.4284e+02, -1.6915e+01,  ..., -1.4318e+02,
           1.9069e+02, -2.7559e+01],
         [-6.2516e+00,  3.6778e+01,  6.0205e+00,  ..., -3.6022e+01,
           2.9255e+01, -2.1484e+01],
         [ 1.2734e+01, -6.0771e+00,  1.0720e+01,  ..., -3.2407e+01,
          -3.6416e+00, -2.9152e+00],
         ...,
         [ 5.7583e+00, -1.4614e+01, -2.8121e+00,  ..., -1.4050e+01,
          -2.5901e+01,  2.3855e+01],
         [-1.1390e+01, -2.2148e+00, -1.4751e-01,  ..., -1.3866e+00,
          -3.8453e+00,  2.7637e+00],
         [-1.0500e+01, -2.5688e-01, -6.3146e-01,  ..., -4.5566e-01,
          -8.7657e-01, -4.0911e+00]]], device='cuda:0') shape:  torch.Size([1, 866, 8])
Set: tensor([[[ 0.1287,  0.1560, -2.2321, -0.3302, -5.1993],
         [ 0.1287,  0.1560, -1.9970, -1.2382, -3.0005],
         [ 0.1287, -0.0710, -1.8382, -1.2310, -2.7827],
         ...,
         [ 1.5959,  0.4006,  0.1372, -0.0879,  0.0405],
         [ 1.6527,  0.4006,  0.1372, -0.8103,  0.0778],
      



tensor([[[  66.3024,   12.6481,   27.6279,  ..., -115.1938,   22.3633,
            62.9688],
         [  72.6863,  -64.9731,  -12.7924,  ...,  -31.9489,   24.5203,
            78.8883],
         [  97.6192,  -76.6862,  -14.5567,  ...,  -30.5138,   27.0784,
           103.3840],
         ...,
         [ -10.8013,   -1.3218,    0.1449,  ...,   -1.9053,   -4.3403,
             2.2310],
         [ -10.5629,   -1.3493,    0.1251,  ...,   -1.8596,   -4.2825,
             2.2534],
         [ -11.1917,   -1.3618,    0.2538,  ...,   -2.1753,   -4.1780,
             1.5090]]], device='cuda:0') shape:  torch.Size([1, 718, 8])
Set: tensor([[[ 0.2645, -0.1843, -5.1993, -0.4101, -5.1993],
         [-0.0094, -0.1843, -3.0686, -0.7780, -3.0686],
         [-0.0094, -0.1843, -2.8552, -0.8455, -2.8552],
         ...,
         [ 0.9914,  1.7532, -1.2001, -0.2981, -0.5306],
         [ 1.0527,  1.3289, -1.2001, -0.5399, -0.9717],
         [ 0.9273,  1.3289, -0.7835, -1.2340, -0.6527]]], device='cuda:0') sha



tensor([[[-8.0659e-01,  5.5783e+01, -1.5363e+00,  ...,  6.7716e+01,
           3.0630e+01, -3.9950e+00],
         [-4.1910e-01,  2.0020e+01, -7.1157e-01,  ...,  6.1673e+00,
           3.7229e+01, -2.3303e+00],
         [-4.3098e-01,  1.7041e+01, -5.7926e-01,  ...,  6.7996e+00,
           1.8198e+01, -1.7170e+00],
         ...,
         [ 8.9957e+00, -1.2157e-01, -3.5829e-01,  ..., -3.7612e-01,
          -1.1557e-01, -8.2504e-01],
         [-2.7308e-02, -2.3024e-01,  1.3514e+01,  ..., -2.1707e-01,
           2.7846e+00, -7.6427e-01],
         [-1.8893e-01, -1.1234e-01, -1.8512e-01,  ..., -3.1514e-01,
           1.5324e+01, -8.6782e-01]]], device='cuda:0') shape:  torch.Size([1, 847, 70])
Set: tensor([[[-2.4507e+00,  3.3590e+01,  1.2781e+02,  ...,  1.1747e+02,
          -1.9077e+00, -1.3088e+00],
         [-1.5988e+00, -4.1346e-01, -1.3664e-01,  ...,  1.6156e+01,
          -1.3944e+00, -5.5763e-01],
         [-1.2283e+00, -4.2371e-01, -8.9620e-02,  ...,  9.3068e+00,
          -1.1589e+00



tensor([[[-0.7569, 40.3320, -1.5121,  ..., 97.0640,  5.0780, -3.9148],
         [-0.3927,  6.2969, -0.7198,  ..., 21.3122, -0.2436, -1.5296],
         [-0.3065,  3.2602, -0.6279,  ..., 25.7972, 18.0422, -2.3177],
         ...,
         [21.9162, -0.8267, -0.5064,  ..., -0.8766, -0.8269, -1.1231],
         [36.1280, -0.8036, -0.5364,  ..., -0.5083, -0.1328, -1.7502],
         [39.0158, -1.1818, -1.1708,  ..., -0.6034, 26.3513, -3.7599]]],
       device='cuda:0') shape:  torch.Size([1, 993, 70])
Set: tensor([[[-2.1166e+00,  5.0336e+01,  1.4373e+02,  ...,  1.3703e+02,
          -1.8703e+00, -1.4950e+00],
         [-6.5713e-01, -3.0362e-01, -1.3973e-03,  ...,  1.3889e+01,
          -6.4250e-01, -5.5022e-01],
         [-1.4029e+00, -3.8050e-01, -3.6289e-02,  ...,  3.0816e+01,
          -1.4934e+00, -6.9478e-01],
         ...,
         [-6.0418e-02, -2.1967e-01,  7.4402e+00,  ..., -1.4881e+00,
          -1.3743e+00, -9.8035e-01],
         [-4.3837e-01, -1.8215e-01, -1.1526e-01,  ..., -7.8470

Testing:   2%|▏         | 303/20000 [00:22<30:46, 10.67it/s]

tensor([[[ 158.0324,   54.3350,  -36.9209,  ..., -152.7617,   27.4002,
           199.7203],
         [  -8.7493,    0.8441,    3.2469,  ...,   -3.9433,    1.0197,
             7.4684],
         [  -8.3966,    1.3573,    3.2631,  ...,   -4.1931,   -0.9344,
             5.3771],
         ...,
         [ 113.1512,  -17.3107,   24.7399,  ...,   28.3470,    1.6530,
            58.7568],
         [  -6.6687,   65.3985,   21.8352,  ...,   39.9173,    7.3917,
           -62.7302],
         [ -10.5308,   -0.6317,    0.2465,  ...,   -0.6002,   -3.4419,
            -1.2386]]], device='cuda:0') shape:  torch.Size([1, 508, 8])
Set: tensor([[[ 0.4634,  0.2732, -5.1993, -1.1126, -5.1993],
         [ 0.1766,  0.2732, -2.4459, -0.5904, -2.9792],
         [ 0.1766,  0.2732, -2.2714, -0.7113, -2.7600],
         ...,
         [ 3.1855, -1.1295,  1.8160,  1.2999,  1.7975],
         [ 3.1855, -1.1295,  1.9311,  0.7980,  1.8160],
         [ 0.6974,  0.2732, -0.0145, -1.2198, -0.3589]]], device='cuda:0') sha



Set: tensor([[[ 0.3371,  0.3210, -5.1993, -0.2339, -5.1993],
         [ 0.1269,  0.5021, -2.9070, -0.1964, -3.0315],
         [ 0.1269,  0.5021, -2.6829, -0.1840, -2.8158],
         ...,
         [-0.1902,  0.6555, -1.0148,  0.8006, -1.4196],
         [-0.1902,  0.7532, -1.0148, -0.1284, -1.3870],
         [-1.2795,  0.7532,  0.0549,  0.7512,  0.0091]]], device='cuda:0') shape:  torch.Size([1, 823, 5])
Set: tensor([[[ 1.1198e+01, -2.3005e-02,  2.8523e-01,  ..., -1.3017e-01,
           5.1887e-01,  8.8956e+00],
         [ 7.1418e+00, -6.8942e-03,  2.4743e-01,  ..., -8.0932e-02,
           6.5324e-01,  4.7268e+00],
         [ 6.6806e+00, -5.8864e-03,  2.2970e-01,  ..., -7.6077e-02,
           6.7681e-01,  4.3340e+00],
         ...,
         [ 4.1258e+00, -1.5869e-02, -9.4480e-03,  ..., -3.8087e-02,
           8.6312e-02,  2.5035e+00],
         [ 4.3942e+00,  7.0204e-01,  1.8472e-01,  ..., -4.1411e-02,
           5.7282e-01,  1.5133e+00],
         [ 2.1088e+00,  2.3645e-01, -1.1878e-02,  

Testing:   2%|▏         | 306/20000 [00:22<33:51,  9.69it/s]

tensor([[[-2.5114e+00,  5.4933e+01,  9.0853e+01,  ...,  1.2970e+02,
          -1.3876e+00, -1.1941e+00],
         [-1.3939e+00, -4.2208e-01, -8.7770e-02,  ...,  2.9187e+01,
          -1.4664e+00, -6.3318e-01],
         [ 4.2761e+01,  2.4281e+01,  6.0255e+01,  ..., -1.5971e+00,
          -6.6700e-01, -1.2304e+00],
         ...,
         [-8.6429e-01, -8.1651e-01, -1.0026e+00,  ..., -7.0954e-01,
          -1.0384e+00,  5.5300e+00],
         [-6.7913e-01, -9.9623e-01, -1.0990e+00,  ..., -9.6493e-01,
          -1.2246e+00, -2.0779e-01],
         [-7.2625e-01, -8.9480e-01, -1.0055e+00,  ..., -9.6215e-01,
          -1.2246e+00, -1.0037e-01]]], device='cuda:0') shape:  torch.Size([1, 1065, 50])
Set: tensor([[[  24.1343,   37.2790,   23.3616,  ..., -112.5772,   33.0718,
            41.8472],
         [ 108.7108,  -83.1184,  -31.0171,  ...,  -51.6959,   40.7092,
           121.6154],
         [  27.0105,   -5.2662,   19.1743,  ...,   28.0162,   21.9941,
           -10.9026],
         ...,
     



tensor([[[-6.4923e-01,  3.5340e+01, -1.6090e+00,  ...,  8.5666e+01,
           1.7833e+01, -4.0939e+00],
         [-3.2386e-01,  1.6472e+00, -6.9341e-01,  ...,  2.1486e+01,
           2.1379e+01, -2.1690e+00],
         [-3.2026e-01, -8.0215e-03, -6.1307e-01,  ...,  1.6497e+01,
           1.4939e+01, -1.7966e+00],
         ...,
         [-5.4400e-02, -2.0666e-01, -1.1380e+00,  ..., -1.0857e-01,
          -1.2369e-01, -1.2373e+00],
         [-3.8234e-02, -1.8953e-01, -1.2649e+00,  ..., -7.8052e-02,
          -9.3901e-02, -1.2296e+00],
         [-2.4841e-02, -1.9573e-01, -1.4264e+00,  ..., -4.3075e-02,
          -9.2796e-02, -1.2001e+00]]], device='cuda:0') shape:  torch.Size([1, 884, 70])
Set: tensor([[[-2.4648e+00,  4.0737e+01,  1.2698e+02,  ...,  1.3477e+02,
          -1.9565e+00, -1.4110e+00],
         [-1.3972e+00, -4.4574e-01,  3.0617e+00,  ...,  3.6229e+01,
          -1.6326e+00, -7.7669e-01],
         [-1.1640e+00, -4.9689e-01, -1.2543e-02,  ...,  3.2379e+01,
          -1.4591e+00



tensor([[[-2.1037e+00,  4.4875e+01,  1.3830e+02,  ...,  1.3790e+02,
          -2.0315e+00, -1.4280e+00],
         [-9.9125e-01, -3.5456e-01,  1.6665e+01,  ...,  4.7050e+01,
          -1.4210e+00, -8.4669e-01],
         [-5.0339e-01, -3.4288e-01, -1.2123e-01,  ..., -9.7297e-03,
          -4.6692e-01, -3.8186e-01],
         ...,
         [-5.2066e-01, -1.6722e-01, -6.4586e-01,  ..., -4.5128e-01,
          -6.3199e-01, -9.6127e-02],
         [-5.2959e-01, -5.8007e-01, -3.5641e-01,  ..., -8.0566e-01,
          -5.1627e-01, -2.3648e-01],
         [-3.4746e-01, -3.4646e-01, -2.6293e-01,  ..., -2.0638e-01,
          -2.6552e-01, -2.2699e-01]]], device='cuda:0') shape:  torch.Size([1, 965, 50])
Set: tensor([[[ 1.9415e+02, -4.4329e-01, -3.2986e+01,  ..., -1.5851e+02,
           5.4735e+01,  2.0957e+02],
         [ 1.7217e+02, -7.2314e+01, -4.4511e+01,  ..., -5.6962e+01,
           4.5177e+01,  1.8675e+02],
         [ 1.3444e+00,  2.2561e+00, -6.8168e+00,  ..., -4.7339e+00,
           4.5087e+00



tensor([[[ 157.7711,   18.2888,  -10.5836,  ..., -151.3076,   36.4483,
           166.5574],
         [  29.0966,    4.2071,   -6.7899,  ...,  -12.0098,    6.5483,
            40.1841],
         [  93.7157,  -48.3715,  -25.9951,  ...,  -34.6916,   17.0630,
           118.7355],
         ...,
         [ -10.8028,    7.8203,    1.6003,  ...,    6.2193,    3.3557,
             1.4520],
         [  -3.9841,   -1.3342,   -0.7984,  ...,   -1.4130,   -4.5769,
             8.5152],
         [   7.4311,  -17.0363,   -4.7124,  ...,  -14.7886,   -4.7574,
            24.6202]]], device='cuda:0') shape:  torch.Size([1, 910, 8])
Set: tensor([[[-0.2112, -0.1249, -5.1993, -0.5472, -5.1993],
         [-0.2112, -0.1249, -3.0326, -0.5578, -3.0326],
         [-0.2112, -0.1249, -2.8169, -1.1934, -2.8169],
         ...,
         [-0.9821,  0.9945, -0.3601, -1.1221, -0.4075],
         [-0.8351, -0.1249, -0.5261, -1.0508, -0.7180],
         [-0.2112, -0.6248, -0.2003, -0.9995, -0.4009]]], device='cuda:0') sha

Testing:   2%|▏         | 315/20000 [00:23<32:04, 10.23it/s]

Set: tensor([[[-0.0489, -0.3495, -5.1993, -0.5150, -5.1993],
         [-0.0489, -0.3495, -3.1100, -1.0145, -3.1105],
         [-0.0489, -0.3495, -2.8991,  0.1829, -2.9091],
         ...,
         [ 1.0315, -2.4320, -0.1727,  1.0707,  0.5235],
         [ 1.0685, -2.4320, -0.1727,  1.7274,  0.5697],
         [-0.0489, -0.3495, -1.8804, -0.4894, -1.8713]]], device='cuda:0') shape:  torch.Size([1, 1070, 5])
Set: tensor([[[ 9.7605e+00, -2.4114e-02,  7.7382e-02,  ..., -1.1860e-01,
          -2.2008e-02,  8.9258e+00],
         [ 5.3754e+00, -1.5525e-04,  6.2244e-01,  ..., -7.4128e-02,
          -1.8071e-02,  4.5374e+00],
         [ 4.8408e+00, -2.6307e-02, -7.5932e-03,  ..., -6.6327e-02,
          -2.0729e-02,  5.2858e+00],
         ...,
         [-8.1886e-02, -6.3300e-02, -2.7123e-02,  ...,  1.6288e+00,
          -4.0652e-02,  1.0197e+00],
         [-8.3516e-02, -7.8236e-02, -3.4552e-02,  ...,  1.8197e+00,
          -4.1074e-02,  1.5563e+00],
         [ 2.7052e+00, -4.7038e-03,  4.8554e-04, 

Testing:   2%|▏         | 317/20000 [00:23<30:43, 10.67it/s]

tensor([[[ 0.0552, -0.3508, -5.1993, -0.5585, -5.1993],
         [-0.1524, -0.0151, -3.1203,  0.0082, -3.1203],
         [-0.1524, -0.0151, -2.9101, -0.6417, -2.9101],
         ...,
         [-1.0887, -0.3508,  0.3442, -0.7407,  0.1192],
         [-1.0887, -0.0151,  0.3442, -0.2626,  0.1359],
         [ 0.0552, -0.0151, -2.0533, -1.6741, -2.0509]]], device='cuda:0') shape:  torch.Size([1, 1108, 5])
Set: tensor([[[ 9.7015e+00, -2.3783e-02,  1.6136e-01,  ..., -1.1992e-01,
          -1.9042e-02,  8.9062e+00],
         [ 6.2143e+00, -1.7950e-02, -3.8497e-03,  ..., -7.4323e-02,
          -1.4817e-02,  5.3116e+00],
         [ 5.8254e+00, -1.8394e-03,  3.6018e-01,  ..., -7.1492e-02,
          -1.2967e-02,  4.3061e+00],
         ...,
         [-9.4566e-03,  1.9419e+00, -8.5132e-04,  ...,  1.1742e+00,
          -4.3661e-02, -1.4616e-02],
         [-1.7140e-03,  1.3588e+00, -4.3186e-03,  ...,  9.2829e-01,
          -3.6152e-02, -1.2299e-02],
         [ 3.9709e+00,  2.5598e+00,  1.6037e+00,  ...,

Testing:   2%|▏         | 319/20000 [00:23<32:05, 10.22it/s]

tensor([[[-5.9468e-01,  3.0436e+01, -1.5530e+00,  ...,  6.6971e+01,
           2.6844e+01, -4.0081e+00],
         [-3.9414e-01,  7.8205e-01, -6.7597e-01,  ...,  1.8586e+01,
          -3.6974e-02, -1.4307e+00],
         [-2.6752e-01, -9.0400e-02, -6.9886e-01,  ...,  1.3323e+01,
           2.1912e+01, -2.0736e+00],
         ...,
         [-1.7602e-01, -1.4176e-01, -4.3795e-01,  ..., -4.1034e-02,
           2.1726e+01, -1.5510e+00],
         [-1.3560e-01, -2.1572e-01, -4.3541e-01,  ...,  4.9650e+00,
           1.3989e+01, -1.6396e+00],
         [-1.0956e-01, -2.5977e-01, -1.5895e-01,  ...,  6.1858e-01,
          -1.2580e-01, -9.1917e-01]]], device='cuda:0') shape:  torch.Size([1, 877, 70])
Set: tensor([[[-2.4219e+00,  3.3228e+01,  1.1004e+02,  ...,  1.2098e+02,
          -1.8535e+00, -1.2681e+00],
         [-7.7875e-01, -3.9131e-01,  1.1807e+01,  ...,  3.4157e+01,
          -1.0644e+00, -6.7725e-01],
         [-1.3770e+00, -4.6430e-01, -9.3770e-02,  ...,  2.4983e+01,
          -1.4136e+00



tensor([[[-6.7280e-01,  3.7965e+01, -1.6208e+00,  ...,  7.1585e+01,
           2.6994e+01, -4.1639e+00],
         [-4.4446e-01,  8.9311e+00, -6.9550e-01,  ...,  1.4726e+01,
          -3.8889e-03, -1.5095e+00],
         [-4.4576e-01,  1.1787e+01, -6.7100e-01,  ...,  1.0188e+01,
           2.1106e+01, -1.9097e+00],
         ...,
         [-3.1922e-01, -5.4943e-02, -9.7176e-01,  ..., -1.8802e-01,
           2.0865e+00, -1.7265e+00],
         [-7.6315e-01, -4.4676e-01, -4.9073e-01,  ..., -1.6495e+00,
          -1.2118e+00, -1.1731e+00],
         [-1.9886e-01, -1.2910e-01, -3.6994e-01,  ..., -3.8702e-01,
          -4.1076e-01, -1.3650e-01]]], device='cuda:0') shape:  torch.Size([1, 978, 70])
Set: tensor([[[-2.5290e+00,  3.4723e+01,  1.1198e+02,  ...,  1.2388e+02,
          -1.8741e+00, -1.3161e+00],
         [-8.5594e-01, -3.4337e-01,  8.0115e+00,  ...,  2.3436e+01,
          -1.0026e+00, -5.2755e-01],
         [-1.3336e+00, -5.1002e-01, -4.2558e-02,  ...,  1.6962e+01,
          -1.4565e+00



tensor([[[ 139.6194,  -10.7405,   -4.9459,  ..., -139.3570,   37.4465,
           137.2037],
         [  66.7462,  -57.7003,  -24.2209,  ...,  -42.2443,   18.8854,
            93.7289],
         [  -9.6742,   -2.0254,    0.3144,  ...,   -2.3148,   -3.0876,
             5.9110],
         ...,
         [ -11.9017,   -2.5514,   -0.6171,  ...,   -1.6106,   -4.7147,
             2.7633],
         [  22.3909,   -4.1918,   -3.9411,  ...,   15.0394,    4.5888,
            18.1711],
         [  37.0381,   44.7808,   -8.6082,  ...,   15.2602,  -34.8833,
            -4.8058]]], device='cuda:0') shape:  torch.Size([1, 707, 8])
Set: tensor([[[ 0.3325,  0.1988, -5.1993, -1.1651, -5.1993],
         [ 0.1294,  0.1988, -3.0184, -1.9978, -3.0184],
         [ 0.1294,  0.1988, -2.8018, -0.7728, -2.8018],
         ...,
         [ 0.3325, -1.3962,  0.6536, -0.0207,  2.3700],
         [ 0.3325, -1.3962,  0.5975, -0.0845,  2.5702],
         [ 0.7857, -1.4580,  0.7120,  1.2169,  0.6302]]], device='cuda:0') sha



Set: tensor([[[ 0.3371,  0.2320, -5.1993, -0.5139, -5.1993],
         [ 0.3371,  0.4684, -2.7843, -0.8625, -3.0693],
         [ 0.3371,  0.4684, -2.5879, -0.2417, -2.8559],
         ...,
         [-1.0611, -2.0458,  1.2951,  0.6779,  1.1688],
         [ 0.1716, -0.9560,  1.1715, -0.4865,  0.9308],
         [ 0.3371, -0.9560,  1.1715,  0.9308,  0.9603]]], device='cuda:0') shape:  torch.Size([1, 933, 5])
Set: tensor([[[ 1.0995e+01, -1.7900e-02,  5.5298e-01,  ..., -1.2986e-01,
           3.5847e-01,  8.6836e+00],
         [ 7.0169e+00,  6.2090e-01,  1.0552e+00,  ..., -8.3894e-02,
           1.2729e+00,  4.2402e+00],
         [ 6.5156e+00, -6.6829e-03,  3.3819e-01,  ..., -7.7686e-02,
           1.1535e+00,  4.4229e+00],
         ...,
         [-7.5559e-02, -3.1926e-02, -2.7474e-02,  ...,  5.7368e+00,
          -8.8268e-02, -1.1558e-02],
         [-4.9742e-02,  1.5773e-01, -3.4898e-03,  ...,  2.3161e+00,
          -2.3989e-02, -2.1573e-02],
         [-5.2501e-02, -3.1655e-02, -1.9263e-02,  



tensor([[[ 1.0586e+01, -6.6307e-03,  8.8089e-01,  ..., -1.2297e-01,
          -1.4631e-02,  8.2259e+00],
         [ 5.9187e+00,  6.7338e-01,  8.5630e-01,  ..., -7.4674e-02,
          -1.1741e-02,  4.0895e+00],
         [ 5.4203e+00, -1.4340e-03,  3.7562e-01,  ..., -6.8753e-02,
          -1.2462e-02,  4.0706e+00],
         ...,
         [ 4.1294e-01,  6.8157e-01,  3.3337e-01,  ..., -1.2814e-02,
           1.0785e+01, -3.6386e-02],
         [ 4.1087e-01,  3.2156e+00,  1.6025e+00,  ..., -1.3333e-02,
           1.1017e+01, -4.7043e-02],
         [ 2.0569e-01,  6.2275e-01,  3.3487e-01,  ..., -1.0740e-02,
           1.0967e+01, -3.6976e-02]]], device='cuda:0') shape:  torch.Size([1, 736, 80])
Set: tensor([[[-1.9972e-01,  1.7194e+00,  7.5155e+00,  ..., -3.6380e-01,
          -3.7923e-01,  1.7519e+01],
         [-1.0702e-01, -9.4989e-03,  5.4048e+00,  ..., -1.6418e-01,
          -1.9814e-01,  7.0334e+00],
         [-9.7683e-02,  3.3704e+00,  3.2340e+00,  ..., -1.8552e-01,
          -1.7442e-01



tensor([[[ 0.0871,  0.1207, -5.1993, -1.2025, -5.1993],
         [ 0.3478,  0.3803, -2.8317, -1.7188, -2.8317],
         [ 0.3478,  0.3803, -2.6995, -0.6068, -2.6995],
         ...,
         [ 1.3009,  0.3803, -1.3426,  0.9861, -1.1392],
         [ 0.0871,  0.1207, -2.2170, -1.1281, -2.2343],
         [ 0.3478, -0.1265, -5.1993, -0.5791, -3.0466]]], device='cuda:0') shape:  torch.Size([1, 865, 5])
Set: tensor([[[ 1.0922e+01, -2.2823e-03,  1.1966e+00,  ..., -1.2757e-01,
          -4.5305e-03,  8.0468e+00],
         [ 6.4463e+00,  2.5944e+00,  1.9986e+00,  ..., -8.3076e-02,
           1.3668e+00,  3.0278e+00],
         [ 6.0687e+00,  1.3324e-01,  7.1650e-01,  ..., -7.7126e-02,
           1.1145e+00,  3.8253e+00],
         ...,
         [ 2.1158e+00, -3.1589e-02, -8.2256e-03,  ..., -5.1552e-02,
           3.6019e+00,  2.5773e+00],
         [ 4.6238e+00,  1.3820e+00,  1.0677e+00,  ..., -6.2118e-02,
          -8.7203e-04,  2.5548e+00],
         [ 6.1931e+00, -7.7159e-03,  4.6983e-01,  ..., 



tensor([[[ 1.1644e+01, -1.1136e-02,  8.5305e-01,  ..., -1.3308e-01,
           1.0124e+00,  8.4180e+00],
         [ 4.3697e+00, -4.4299e-02, -1.9114e-02,  ..., -6.1214e-02,
           1.0970e+00,  4.0764e+00],
         [ 5.6194e+00, -5.2327e-02, -2.0848e-02,  ..., -6.4873e-02,
           7.2452e-01,  5.6973e+00],
         ...,
         [-1.8692e-02, -4.8479e-03, -1.2853e-02,  ...,  1.7201e+00,
          -6.1643e-02, -2.8193e-03],
         [-7.0775e-03,  2.4422e+00,  1.4068e+00,  ..., -2.2657e-02,
           2.3287e+00, -1.6022e-02],
         [-4.6673e-02,  9.9178e-01,  1.0110e+00,  ..., -8.4647e-03,
          -2.1770e-02, -7.0492e-03]]], device='cuda:0') shape:  torch.Size([1, 776, 80])
Set: tensor([[[-1.9310e-01,  3.6691e+00,  8.2507e+00,  ..., -4.1156e-01,
          -4.2786e-01,  2.1482e+01],
         [-2.3071e-02,  4.9830e+00,  2.8255e+00,  ..., -2.5796e-01,
          -2.5415e-01, -3.3859e-02],
         [-6.8644e-02,  6.7980e+00,  3.4572e+00,  ..., -2.7271e-01,
          -1.6885e-01



tensor([[[-2.0905e-01,  6.7042e+00,  4.8465e+00,  ..., -4.2945e-01,
          -3.6771e-01,  1.2733e+01],
         [-9.2984e-02,  6.6747e+00,  1.0101e+00,  ..., -2.3744e-01,
          -1.7671e-01,  1.6337e+00],
         [-1.1080e-01,  6.2722e+00,  1.7040e+00,  ..., -2.2779e-01,
          -1.7152e-01,  2.0933e+00],
         ...,
         [ 1.8690e+01, -3.6888e-01, -3.7137e-01,  ...,  1.2816e+01,
          -2.0116e-01,  6.9054e+00],
         [ 2.9995e+00, -1.7765e-02, -1.6111e-01,  ..., -5.2170e-02,
          -9.5288e-02, -2.4368e-02],
         [-2.4231e-04, -2.9944e-02, -1.5640e-01,  ..., -2.8271e-02,
          -6.9112e-02, -2.7615e-03]]], device='cuda:0') shape:  torch.Size([1, 893, 120])
Set: tensor([[[-5.9008e-01,  2.9540e+01, -1.8396e+00,  ...,  7.6262e+01,
           2.1749e+01, -4.3349e+00],
         [-2.4312e-01, -9.2639e-02, -8.3246e-01,  ...,  2.0672e+01,
           1.9047e+01, -2.2607e+00],
         [-2.7278e-01, -4.7671e-02, -7.2737e-01,  ...,  8.5193e+00,
           2.4509e+0



tensor([[[-3.0008e-01,  6.2301e+00,  2.5956e+00,  ..., -4.7736e-01,
          -3.2513e-01,  4.7520e-01],
         [-1.3956e-01,  2.8876e+00, -2.1163e-02,  ..., -4.2926e-01,
          -4.6536e-01, -2.4160e-01],
         [-1.3335e-01,  1.0058e+00,  2.6964e+00,  ..., -3.6693e-01,
          -4.2774e-01, -8.0264e-02],
         ...,
         [-1.4159e-01, -1.5276e-02, -3.5683e-01,  ..., -2.2762e-03,
          -1.9687e-01,  5.4095e+00],
         [-1.4938e-01, -1.0857e-02, -3.5904e-01,  ..., -5.9261e-03,
          -1.9042e-01,  5.9284e+00],
         [-1.4844e-01, -1.6794e-02, -3.6195e-01,  ...,  6.7329e-01,
          -2.1722e-01,  3.9594e+00]]], device='cuda:0') shape:  torch.Size([1, 655, 120])
Set: tensor([[[-4.3722e-01,  3.8443e+01, -2.8015e+00,  ...,  1.0733e+02,
          -1.6562e-01, -5.0173e+00],
         [ 1.2490e+01, -4.2658e-01, -2.8560e+00,  ...,  1.1091e+02,
          -1.3228e+00, -3.8272e+00],
         [ 1.3289e+01, -2.1410e-01, -1.5604e+00,  ...,  2.5038e+01,
          -4.5489e-0



Set: tensor([[[-0.0586,  0.4394, -5.1993,  0.3997, -5.1993],
         [-0.0586,  0.4394, -3.0203, -0.3053, -3.0203],
         [-0.0586,  0.4394, -2.8038,  0.0380, -2.8038],
         ...,
         [ 2.7311, -1.4895,  0.0079,  0.5121,  0.0222],
         [ 3.0203, -2.2375,  0.0681,  0.9725,  0.1269],
         [ 0.1557,  0.1413, -1.8408, -5.1993, -1.8408]]], device='cuda:0') shape:  torch.Size([1, 793, 5])
Set: tensor([[[ 1.1682e+01, -3.3165e-02, -4.9900e-03,  ..., -1.2532e-01,
          -4.1771e-03,  9.3482e+00],
         [ 7.1241e+00, -3.9686e-03,  2.8137e-01,  ..., -7.9430e-02,
           3.7782e-02,  4.5972e+00],
         [ 6.6351e+00, -1.0512e-02, -1.1700e-03,  ..., -7.3738e-02,
          -1.7189e-04,  4.5139e+00],
         ...,
         [-5.9195e-02, -5.0586e-02, -9.6608e-03,  ..., -1.9766e-02,
           2.9900e+00,  1.2798e+00],
         [-8.2260e-02, -7.3376e-02, -1.8548e-02,  ..., -1.1154e-02,
           1.7797e+00,  1.9858e+00],
         [ 4.1718e+00,  1.0915e+01,  5.7875e+00,  

Testing:   2%|▏         | 341/20000 [00:26<31:13, 10.49it/s]

tensor([[[ 9.8447e+00, -1.4114e-02,  6.2417e-01,  ..., -1.2109e-01,
          -1.7916e-02,  8.5126e+00],
         [ 5.2946e+00, -8.6018e-03,  2.2571e-01,  ..., -7.3403e-02,
          -1.5905e-02,  4.8114e+00],
         [ 4.8448e+00, -6.5542e-03,  2.6118e-01,  ..., -6.8809e-02,
          -1.5545e-02,  4.3793e+00],
         ...,
         [ 3.7331e+00, -1.2617e-02, -4.6552e-03,  ..., -4.9396e-02,
          -1.7365e-02,  3.3688e+00],
         [ 3.5809e+00, -3.2614e-02, -1.6094e-02,  ..., -4.3208e-02,
          -2.4234e-02,  4.0653e+00],
         [ 3.7907e+00, -2.3333e-02, -1.0643e-02,  ..., -4.6905e-02,
          -1.4931e-02,  3.6208e+00]]], device='cuda:0') shape:  torch.Size([1, 924, 80])
Set: tensor([[[-2.4461e-01,  2.5501e+00,  5.7427e+00,  ..., -3.7715e-01,
          -3.6098e-01,  1.5306e+01],
         [-1.5570e-01,  3.7128e+00,  2.2160e+00,  ..., -2.1547e-01,
          -1.8558e-01,  3.8527e+00],
         [-1.4639e-01,  3.3776e+00,  2.1123e+00,  ..., -1.9415e-01,
          -1.6841e-01



tensor([[[-0.0535,  6.4473,  0.4875,  ..., -0.2579, -0.1906,  4.0952],
         [-0.0495, -0.0385,  2.9654,  ..., -0.1354, -0.1896,  9.3548],
         [-0.0456,  4.2781,  0.3218,  ..., -0.1962, -0.1741,  5.5023],
         ...,
         [-0.1021, -0.0221, -0.0679,  ..., -0.0667,  0.8095, -0.0802],
         [-0.0653,  6.2223,  0.2987,  ..., -0.1835, -0.1366,  2.3748],
         [-0.1916,  3.2043,  3.3922,  ..., -0.3111, -0.2289, 19.3589]]],
       device='cuda:0') shape:  torch.Size([1, 995, 120])
Set: tensor([[[-2.5898e-01, -1.1342e-01, -9.1627e-01,  ...,  3.8089e+01,
           7.3645e+00, -2.5870e+00],
         [-3.0695e-01, -3.8013e-02, -4.5739e-01,  ...,  1.4343e+01,
          -7.5762e-02, -1.4166e+00],
         [-2.2430e-01, -7.6413e-02, -4.1376e-01,  ...,  1.4057e+01,
           1.7123e+01, -1.9048e+00],
         ...,
         [-3.9495e-01, -3.0053e-01, -9.9023e-01,  ..., -1.2254e-01,
           8.3639e+00, -1.0940e+00],
         [-2.3102e-01, -7.3446e-02, -4.7106e-01,  ...,  5.982



tensor([[[-6.8925e-01,  3.0770e+01, -1.5707e+00,  ...,  7.8800e+01,
           1.6547e+01, -3.9730e+00],
         [-3.8627e-01, -1.5385e-02, -6.5835e-01,  ...,  1.5162e+01,
           1.3626e+01, -1.8625e+00],
         [-3.6271e-01, -4.2085e-02, -5.8052e-01,  ...,  1.1765e+01,
           5.0648e+00, -1.4832e+00],
         ...,
         [-3.1210e-01, -4.7323e-01, -2.6427e-01,  ..., -4.2252e-01,
          -2.2402e-01, -7.6653e-01],
         [-3.8166e-01, -6.1300e-01, -5.5883e-02,  ..., -7.2605e-01,
          -7.2423e-01, -5.4753e-01],
         [-3.7057e-01, -8.1102e-01,  4.4995e+00,  ..., -5.9813e-01,
          -8.6317e-01, -7.4176e-01]]], device='cuda:0') shape:  torch.Size([1, 758, 70])
Set: tensor([[[-2.4312e+00,  3.9365e+01,  1.3079e+02,  ...,  1.3549e+02,
          -2.1254e+00, -1.4045e+00],
         [-1.1527e+00, -4.9727e-01,  7.0748e+00,  ...,  4.1880e+01,
          -1.6273e+00, -9.0707e-01],
         [-8.1034e-01, -4.5545e-01,  6.7936e+00,  ...,  4.3869e+01,
          -1.2567e+00

Testing:   2%|▏         | 347/20000 [00:26<32:07, 10.20it/s]

tensor([[[-2.0876e+00,  2.4025e+01,  7.8382e+01,  ...,  1.0376e+02,
          -1.4772e+00, -1.0443e+00],
         [-1.0686e+00, -4.3353e-01, -3.2156e-01,  ...,  1.6613e+01,
          -8.8859e-01, -3.5168e-01],
         [-1.1155e+00, -4.7774e-01, -3.7173e-01,  ...,  1.1575e+01,
          -9.0065e-01, -3.4313e-01],
         ...,
         [-7.1310e-01, -5.5924e-01, -6.5059e-01,  ..., -1.1853e-01,
          -4.0172e-01,  8.2575e+00],
         [-4.6659e-01, -4.4386e-01, -7.1444e-01,  ..., -7.7840e-01,
          -6.3243e-01, -7.3152e-02],
         [-1.6119e+00, -5.1106e-02, -3.6632e-01,  ..., -9.8918e-01,
          -8.8010e-01, -6.2448e-01]]], device='cuda:0') shape:  torch.Size([1, 953, 50])
Set: tensor([[[  75.1775,    5.2275,    5.3070,  ..., -104.1944,   35.2965,
            74.9618],
         [  47.7635,  -35.8599,  -15.6381,  ...,  -27.0198,   15.9156,
            50.8072],
         [  52.2674,  -39.5027,  -16.5199,  ...,  -25.4842,   14.1236,
            55.5935],
         ...,
      

Testing:   2%|▏         | 349/20000 [00:26<33:22,  9.81it/s]

tensor([[[-2.1140e+00,  3.5948e+01,  1.3211e+02,  ...,  1.2395e+02,
          -2.0375e+00, -1.4017e+00],
         [-1.2799e+00, -4.2895e-01,  6.2290e+00,  ...,  3.9896e+01,
          -1.6021e+00, -8.3014e-01],
         [-7.8302e-01, -3.8892e-01,  5.0919e+00,  ...,  3.4087e+01,
          -1.0858e+00, -6.3887e-01],
         ...,
         [-3.1503e-01, -7.4656e-01, -9.5166e-01,  ..., -2.7476e-01,
          -3.7833e-01,  1.9708e+00],
         [-2.9519e-01, -7.9940e-01, -1.0523e+00,  ..., -2.1421e-01,
          -3.2624e-01, -7.0168e-03],
         [-3.4700e-01, -8.2871e-01, -1.0993e+00,  ..., -3.1356e-01,
          -5.3655e-01,  2.5363e+00]]], device='cuda:0') shape:  torch.Size([1, 831, 50])
Set: tensor([[[ 2.0898e+02, -4.4598e+00, -2.6762e+01,  ..., -1.4044e+02,
           4.3033e+01,  2.1513e+02],
         [ 1.6024e+02, -9.3172e+01, -3.8826e+01,  ..., -6.1351e+01,
           4.6547e+01,  1.7890e+02],
         [ 1.0574e+02, -4.4741e+01, -3.5026e+01,  ..., -3.4167e+01,
           2.7022e+01



tensor([[[-2.5148e+00,  3.8459e+01,  9.7251e+01,  ...,  1.1755e+02,
          -1.6516e+00, -1.1856e+00],
         [-8.7823e-01, -3.8562e-01,  9.7955e+00,  ...,  4.2835e+01,
          -1.2351e+00, -7.4075e-01],
         [-7.2175e-01, -3.9495e-01,  5.7993e+00,  ...,  4.2142e+01,
          -1.0489e+00, -6.7543e-01],
         ...,
         [-8.4053e-01, -3.0060e-01, -7.3501e-01,  ..., -1.9808e-01,
          -4.0431e-01, -7.5069e-02],
         [-9.2815e-01, -3.7402e-01, -6.3518e-01,  ...,  1.6226e+00,
          -3.9399e-01, -1.1071e-01],
         [-9.4314e-01, -4.0333e-01, -5.0303e-01,  ...,  4.1859e+00,
          -5.7775e-01, -1.4022e-01]]], device='cuda:0') shape:  torch.Size([1, 951, 50])
Set: tensor([[[  59.0415,   16.4265,   22.4454,  ..., -119.7416,   25.4050,
            62.5935],
         [ 145.1402,  -55.2016,  -43.8181,  ...,  -42.5357,   37.6500,
           161.1238],
         [ 121.1429,  -44.1801,  -38.6766,  ...,  -37.1751,   32.8701,
           132.5976],
         ...,
      



tensor([[[ 126.4361,   18.3572,    0.9323,  ..., -136.7404,   29.6191,
           132.8083],
         [  -6.0695,   -5.0226,    5.1655,  ...,  -10.5340,    7.2056,
            -5.8207],
         [  15.0801,  -13.7356,  -14.7649,  ...,  -11.8416,   11.0700,
            16.4325],
         ...,
         [  87.8368,  -45.0797,  -10.6894,  ...,  -29.6516,    5.6377,
            76.1838],
         [ 172.9260,  -21.6411,   -5.6171,  ...,  -37.7788,    6.2932,
           112.3105],
         [ -30.9660,   12.8763,   12.3456,  ...,   17.9671,    4.9072,
           -81.4096]]], device='cuda:0') shape:  torch.Size([1, 792, 8])
Set: tensor([[[-0.3299,  0.2200, -5.1993, -0.6147, -5.1993],
         [-0.3299,  0.6585, -2.9976, -1.1696, -2.9976],
         [-0.3299,  0.6585, -2.7796, -0.4723, -2.7796],
         ...,
         [-0.8687,  1.9344, -0.2586,  1.2182,  0.5580],
         [-2.9976, -0.4457, -0.6564,  0.7571,  0.9811],
         [-0.6064,  0.2200, -0.5303,  1.0858, -0.6863]]], device='cuda:0') sha

Testing:   2%|▏         | 355/20000 [00:27<33:49,  9.68it/s]

tensor([[[ 1.3657e+02,  1.8789e+00,  2.1219e-01,  ..., -1.4556e+02,
           3.8710e+01,  1.3297e+02],
         [ 1.8026e+02, -7.5235e+01, -4.7915e+01,  ..., -6.0683e+01,
           4.9736e+01,  1.9459e+02],
         [ 1.2304e+02, -4.3233e+01, -3.7987e+01,  ..., -3.8669e+01,
           3.0375e+01,  1.3192e+02],
         ...,
         [ 2.0673e+01,  1.0121e+01, -8.5969e+00,  ...,  3.1341e+01,
           1.9933e+01, -5.6280e+01],
         [-9.2720e+00, -1.2619e+00,  5.3357e-02,  ..., -2.1408e-01,
          -3.7073e+00,  1.6859e+00],
         [-8.0359e+00, -1.3626e+00, -1.2931e-01,  ..., -3.6474e-01,
          -3.9305e+00,  2.5114e+00]]], device='cuda:0') shape:  torch.Size([1, 986, 8])
Set: tensor([[[ 2.3061e-01, -1.5550e-03, -5.1993e+00, -1.2048e+00, -5.1993e+00],
         [ 2.3061e-01,  1.5770e-01, -3.0256e+00, -1.3148e+00, -3.0256e+00],
         [ 2.3061e-01,  1.5770e-01, -2.8094e+00, -1.1549e+00, -2.8094e+00],
         ...,
         [-7.0806e-01,  1.2787e+00,  9.2763e-01,  9.2048e-

Testing:   2%|▏         | 358/20000 [00:27<29:11, 11.21it/s]

tensor([[[-2.3534e+00,  4.3830e+01,  1.3947e+02,  ...,  1.4126e+02,
          -2.1135e+00, -1.4777e+00],
         [-1.3611e+00, -4.2559e-01,  5.4887e+00,  ...,  4.0867e+01,
          -1.7050e+00, -8.2924e-01],
         [-1.2868e+00, -5.0659e-01, -2.7403e-02,  ...,  3.4498e+01,
          -1.6191e+00, -7.9652e-01],
         ...,
         [-9.9854e-01, -4.0213e-01, -4.1773e-01,  ..., -6.7637e-02,
          -8.6120e-01, -1.2501e+00],
         [-5.1896e-01, -3.7878e-01, -1.4987e-01,  ..., -4.8525e-01,
          -4.6778e-01, -3.5863e-01],
         [-5.3961e-01, -6.5454e-01, -6.8335e-01,  ...,  2.1081e+01,
          -7.3973e-02, -2.2341e-01]]], device='cuda:0') shape:  torch.Size([1, 849, 50])
Set: tensor([[[ 181.2007,    4.3768,  -17.2645,  ..., -161.7785,   44.8263,
           187.1967],
         [ 173.8221, -100.1989,  -37.1428,  ...,  -62.5193,   48.4003,
           187.8458],
         [ 169.7624,  -98.5184,  -33.9226,  ...,  -50.6530,   43.9784,
           180.4374],
         ...,
      



tensor([[[-2.7415e-01,  3.2324e+00,  3.7625e+00,  ..., -3.8512e-01,
          -3.5171e-01,  1.1847e+01],
         [-1.8326e-01,  9.1275e-01,  1.5410e+00,  ..., -1.7014e-01,
          -1.6776e-01,  1.7050e+00],
         [-1.7486e-01, -5.1171e-03,  2.1195e+00,  ..., -1.4294e-01,
          -1.5802e-01,  1.0932e+00],
         ...,
         [ 1.3654e+01, -1.6647e-01, -1.3293e-01,  ...,  4.6197e+00,
          -9.4337e-02,  9.5461e-02],
         [ 2.4975e+01, -2.6768e-01, -3.2502e-01,  ...,  4.2699e+00,
          -2.0922e-01,  5.2100e+00],
         [ 4.1202e+01, -3.5709e-01, -5.7402e-01,  ...,  1.4721e+00,
          -3.9720e-01,  1.0776e+01]]], device='cuda:0') shape:  torch.Size([1, 738, 120])
Set: tensor([[[-6.2300e-01,  3.7539e+01, -1.4354e+00,  ...,  5.6696e+01,
           2.9909e+01, -3.8843e+00],
         [-3.5578e-01,  8.2569e+00, -5.1317e-01,  ...,  2.9423e+00,
           1.8208e+01, -1.6086e+00],
         [-3.1647e-01,  5.4058e+00, -4.3799e-01,  ..., -1.0673e-02,
           1.0671e+0



tensor([[[-2.5759e+00,  3.8741e+01,  1.2606e+02,  ...,  1.2651e+02,
          -1.9235e+00, -1.3899e+00],
         [-1.4973e+00, -3.8283e-01,  7.4326e-01,  ...,  3.2028e+01,
          -1.5896e+00, -7.1160e-01],
         [-1.2912e+00, -4.8050e-01, -2.5830e-02,  ...,  2.7832e+01,
          -1.4989e+00, -7.0115e-01],
         ...,
         [-8.0716e-01, -8.7867e-01, -1.0763e+00,  ..., -6.3650e-01,
          -7.6181e-01, -8.6827e-02],
         [-3.6967e-01, -6.1271e-01, -7.8837e-01,  ...,  1.6830e+01,
          -1.7647e-01, -4.9655e-02],
         [-1.1199e+00, -3.9804e-01, -9.1499e-01,  ..., -9.4132e-02,
          -2.8189e-01, -1.2827e-01]]], device='cuda:0') shape:  torch.Size([1, 919, 50])
Set: tensor([[[ 1.0975e+02,  2.0101e+01,  1.3540e+01,  ..., -1.3810e+02,
           2.4333e+01,  1.0721e+02],
         [ 1.2421e+02, -9.1817e+01, -3.0328e+01,  ..., -5.8095e+01,
           3.6638e+01,  1.4348e+02],
         [ 1.3167e+02, -8.3637e+01, -2.6793e+01,  ..., -4.1579e+01,
           3.2403e+01

Testing:   2%|▏         | 364/20000 [00:28<31:47, 10.29it/s]

tensor([[[ 1.9428e+02,  3.2061e+01, -3.3356e+01,  ..., -1.2729e+02,
           2.6839e+01,  2.1828e+02],
         [ 1.0051e+02, -4.6266e+01, -2.9795e+01,  ..., -2.9987e+01,
           2.0028e+01,  1.2004e+02],
         [ 4.8063e+01, -1.7174e+01, -1.9818e+01,  ..., -1.5850e+01,
           1.1626e+01,  5.9292e+01],
         ...,
         [-1.0236e+01, -1.1354e+00, -1.9528e-01,  ..., -1.3799e+00,
          -4.1735e+00,  1.6627e+00],
         [-9.4047e+00, -1.0538e+00, -2.6387e-01,  ..., -1.5403e+00,
          -4.1582e+00,  2.1936e+00],
         [-1.0095e+01, -1.1383e+00, -2.0479e-01,  ..., -1.5611e+00,
          -4.1848e+00,  1.8913e+00]]], device='cuda:0') shape:  torch.Size([1, 967, 8])
Set: tensor([[[-5.2313e-02,  2.6619e-01, -5.1993e+00, -7.6876e-02, -5.1993e+00],
         [-2.9338e-01,  2.6619e-01, -2.6073e+00, -1.1063e+00, -3.0106e+00],
         [-2.9338e-01,  4.3312e-01, -2.2875e+00, -7.3632e-01, -2.7934e+00],
         ...,
         [ 4.9131e-01,  1.8093e+00, -1.7243e-01, -1.7166e+

Testing:   2%|▏         | 366/20000 [00:28<33:25,  9.79it/s]

tensor([[[-2.4445e+00,  3.0989e+01,  8.7507e+01,  ...,  1.1400e+02,
          -1.5262e+00, -1.1657e+00],
         [-1.1838e+00, -5.4436e-01, -1.0118e-01,  ...,  2.2433e+01,
          -1.2849e+00, -5.9670e-01],
         [-1.1720e+00, -6.0672e-01, -2.2868e-01,  ...,  1.6533e+01,
          -1.2151e+00, -5.3370e-01],
         ...,
         [-5.7022e-01, -4.9341e-01, -1.4186e-01,  ..., -3.4612e-01,
          -6.2776e-01, -1.8272e-01],
         [-1.0521e+00, -1.0063e+00, -1.5497e+00,  ..., -1.1562e-01,
          -6.7754e-01,  4.3824e+00],
         [-1.0413e+00, -1.0020e+00, -1.4488e+00,  ..., -4.5868e-02,
          -5.4682e-01,  7.8869e+00]]], device='cuda:0') shape:  torch.Size([1, 944, 50])
Set: tensor([[[ 5.7495e+01,  1.4282e+01,  2.3553e+01,  ..., -1.0991e+02,
           2.4016e+01,  5.4820e+01],
         [ 1.0560e+02, -6.8256e+01, -2.4073e+01,  ..., -3.6298e+01,
           2.9087e+01,  1.0881e+02],
         [ 1.0637e+02, -6.4167e+01, -2.4495e+01,  ..., -3.2036e+01,
           2.7803e+01

Testing:   2%|▏         | 368/20000 [00:28<36:58,  8.85it/s]

tensor([[[-6.2352e-01,  3.4437e+01, -1.6319e+00,  ...,  7.5818e+01,
           2.0121e+01, -4.0279e+00],
         [-4.4672e-01,  1.0801e+01, -7.2545e-01,  ..., -2.0820e-01,
          -4.9318e-01, -5.3842e-01],
         [-3.9968e-01,  1.1072e+01, -5.8572e-01,  ...,  2.3423e+00,
          -2.6959e-02, -1.1943e+00],
         ...,
         [-2.2408e-01, -7.2353e-01, -1.4112e-01,  ..., -1.2472e-01,
          -4.8996e-01, -8.8641e-01],
         [-2.0617e-01, -6.0470e-01,  2.0478e+00,  ..., -2.8991e-01,
          -3.3021e-01, -7.7162e-01],
         [-2.2695e-01, -6.3214e-01,  3.5538e+00,  ..., -2.9989e-01,
          -3.7261e-01, -8.0850e-01]]], device='cuda:0') shape:  torch.Size([1, 891, 70])
Set: tensor([[[-2.3975e+00,  4.5236e+01,  1.1859e+02,  ...,  1.3097e+02,
          -1.7318e+00, -1.3037e+00],
         [-4.1006e-01, -2.4347e-01,  2.9943e+00,  ..., -6.6439e-01,
          -4.9436e-01, -5.5682e-01],
         [-6.3632e-01, -3.1040e-01,  1.0227e+00,  ...,  7.8784e+00,
          -6.1765e-01

Testing:   2%|▏         | 369/20000 [00:28<25:41, 12.74it/s]


tensor([[[ 0.0476, -0.3656, -5.1993, -1.6869, -5.1993],
         [-0.1698, -0.3656, -3.0094, -0.7637, -3.0094],
         [-0.1698, -0.1267, -2.7922, -0.5214, -2.7922],
         ...,
         [-2.2861, -1.6043,  1.2338,  0.5365,  1.3229],
         [-2.2236, -1.6043,  1.2338,  1.0286,  1.3887],
         [-2.1687, -1.6043,  1.1315,  1.7148,  1.4707]]], device='cuda:0') shape:  torch.Size([1, 765, 5])
Set: tensor([[[ 9.7673e+00,  1.8416e-01,  1.4481e+00,  ..., -1.2272e-01,
          -1.6898e-02,  7.8513e+00],
         [ 5.1734e+00, -4.7232e-03,  2.8342e-01,  ..., -6.9639e-02,
          -2.2276e-02,  4.5676e+00],
         [ 5.2940e+00, -5.4092e-03,  1.4593e-01,  ..., -6.7027e-02,
          -1.6449e-02,  4.2557e+00],
         ...,
         [-5.9566e-02, -1.3385e-02, -2.7076e-02,  ...,  6.6932e+00,
          -1.0979e-01, -2.0769e-02],
         [-6.1560e-02, -2.4591e-02, -3.2522e-02,  ...,  6.8324e+00,
          -1.0905e-01, -1.7278e-02],
         [-6.3654e-02, -4.0019e-02, -4.0181e-02,  ..., 

KeyboardInterrupt: 