In [31]:
import os
import glob
import random
import subprocess

import numpy as np
import pandas as pd
import h5py
import uproot
import awkward as ak

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import knn_graph

import tqdm
from tqdm import tqdm

import os
import os.path as osp  # This defines 'osp'
import glob



def find_highest_branch(path, base_name):
    with uproot.open(path) as f:
        # Find keys that exactly match the base_name (not containing other variations)
        branches = [k for k in f.keys() if k.startswith(base_name + ';')]
        
        # Sort and select the highest-numbered branch
        sorted_branches = sorted(branches, key=lambda x: int(x.split(';')[-1]))
        return sorted_branches[-1] if sorted_branches else None

class CCV1(Dataset):
    r'''
    Loads trackster-level features and associations for positive/negative edge creation.
    '''

    url = '/dummy/'

    def __init__(self, root, transform=None, max_events=1e8, inp='train'):
        super(CCV1, self).__init__(root, transform)
        self.inp = inp
        self.max_events = max_events
        self.fill_data(max_events)

    def fill_data(self, max_events):
        counter = 0
        print("### Loading tracksters data")


        for path in tqdm(self.raw_paths):
            print(path)
            
            tracksters_path = find_highest_branch(path, 'tracksters')
            associations_path = find_highest_branch(path, 'associations')
            simtrack = find_highest_branch(path, 'simtrackstersCP')
            # Load tracksters features in chunks
            for array in uproot.iterate(
                f"{path}:{tracksters_path}",
                [
                    "time", "raw_energy",
                    "barycenter_x", "barycenter_y", "barycenter_z", 
                    "barycenter_eta", "barycenter_phi",
                    "EV1", "EV2", "EV3",
                    "eVector0_x", "eVector0_y", "eVector0_z",
                    "sigmaPCA1", "sigmaPCA2", "sigmaPCA3", "raw_pt", "vertices_time"
                ],
            ):

                tmp_time = array["time"]
                tmp_raw_energy = array["raw_energy"]
                tmp_bx = array["barycenter_x"]
                tmp_by = array["barycenter_y"]
                tmp_bz = array["barycenter_z"]
                tmp_beta = array["barycenter_eta"]
                tmp_bphi = array["barycenter_phi"]
                tmp_EV1 = array["EV1"]
                tmp_EV2 = array["EV2"]
                tmp_EV3 = array["EV3"]
                tmp_eV0x = array["eVector0_x"]
                tmp_eV0y = array["eVector0_y"]
                tmp_eV0z = array["eVector0_z"]
                tmp_sigma1 = array["sigmaPCA1"]
                tmp_sigma2 = array["sigmaPCA2"]
                tmp_sigma3 = array["sigmaPCA3"]
                tmp_pt = array["raw_pt"]
                tmp_vt = array["vertices_time"]
                
                
                vert_array = []
                for vert_chunk in uproot.iterate(
                    f"{path}:{simtrack}",
                    ["barycenter_x"],
                ):
                    vert_array = vert_chunk["barycenter_x"]
                    break  # Since we have a matching chunk, no need to continue
                

                # Now load the associations for the same events/chunk
                # 'tsCLUE3D_recoToSim_CP' gives association arrays like [[1,0],[0,1],...]
                # Make sure we read from the same events
                tmp_array = []
                score_array = []
                for assoc_chunk in uproot.iterate(
                    f"{path}:{associations_path}",
                    ["tsCLUE3D_recoToSim_CP", "tsCLUE3D_recoToSim_CP_score"],
                ):
                    tmp_array = assoc_chunk["tsCLUE3D_recoToSim_CP"]
                    score_array = assoc_chunk["tsCLUE3D_recoToSim_CP_score"]
                    break  # Since we have a matching chunk, no need to continue
                
                
                skim_mask = []
                for e in vert_array:
                    if len(e) >= 2:
                        skim_mask.append(True)
                    elif len(e) == 0:
                        skim_mask.append(False)

                    else:
                        skim_mask.append(False)



                tmp_time = tmp_time[skim_mask]
                tmp_raw_energy = tmp_raw_energy[skim_mask]
                tmp_bx = tmp_bx[skim_mask]
                tmp_by = tmp_by[skim_mask]
                tmp_bz = tmp_bz[skim_mask]
                tmp_beta = tmp_beta[skim_mask]
                tmp_bphi = tmp_bphi[skim_mask]
                tmp_EV1 = tmp_EV1[skim_mask]
                tmp_EV2 = tmp_EV2[skim_mask]
                tmp_EV3 = tmp_EV3[skim_mask]
                tmp_eV0x = tmp_eV0x[skim_mask]
                tmp_eV0y = tmp_eV0y[skim_mask]
                tmp_eV0z = tmp_eV0z[skim_mask]
                tmp_sigma1 = tmp_sigma1[skim_mask]
                tmp_sigma2 = tmp_sigma2[skim_mask]
                tmp_sigma3 = tmp_sigma3[skim_mask]
                tmp_array = tmp_array[skim_mask]
                tmp_pt = tmp_pt[skim_mask]
                tmp_vt = tmp_vt[skim_mask]
                score_array = score_array[skim_mask]
                
                skim_mask = []
                for e in tmp_array:
                    if 2 <= len(e):
                        skim_mask.append(True)

                    elif len(e) == 0:
                        skim_mask.append(False)

                    else:
                        skim_mask.append(False)

                        
                tmp_time = tmp_time[skim_mask]
                tmp_raw_energy = tmp_raw_energy[skim_mask]
                tmp_bx = tmp_bx[skim_mask]
                tmp_by = tmp_by[skim_mask]
                tmp_bz = tmp_bz[skim_mask]
                tmp_beta = tmp_beta[skim_mask]
                tmp_bphi = tmp_bphi[skim_mask]
                tmp_EV1 = tmp_EV1[skim_mask]
                tmp_EV2 = tmp_EV2[skim_mask]
                tmp_EV3 = tmp_EV3[skim_mask]
                tmp_eV0x = tmp_eV0x[skim_mask]
                tmp_eV0y = tmp_eV0y[skim_mask]
                tmp_eV0z = tmp_eV0z[skim_mask]
                tmp_sigma1 = tmp_sigma1[skim_mask]
                tmp_sigma2 = tmp_sigma2[skim_mask]
                tmp_sigma3 = tmp_sigma3[skim_mask]
                tmp_array = tmp_array[skim_mask]
                tmp_pt = tmp_pt[skim_mask]
                tmp_vt = tmp_vt[skim_mask]
                score_array = score_array[skim_mask]

                
                # Concatenate or initialize storage
                if counter == 0:
                    self.time = tmp_time
                    self.raw_energy = tmp_raw_energy
                    self.bx = tmp_bx
                    self.by = tmp_by
                    self.bz = tmp_bz
                    self.beta = tmp_beta
                    self.bphi = tmp_bphi
                    self.EV1 = tmp_EV1
                    self.EV2 = tmp_EV2
                    self.EV3 = tmp_EV3
                    self.eV0x = tmp_eV0x
                    self.eV0y = tmp_eV0y
                    self.eV0z = tmp_eV0z
                    self.sigma1 = tmp_sigma1
                    self.sigma2 = tmp_sigma2
                    self.sigma3 = tmp_sigma3
                    self.assoc = tmp_array
                    self.pt = tmp_pt
                    self.vt = tmp_vt
                    self.score = score_array
                else:
                    self.time = ak.concatenate((self.time, tmp_time))
                    self.raw_energy = ak.concatenate((self.raw_energy, tmp_raw_energy))
                    self.bx = ak.concatenate((self.bx, tmp_bx))
                    self.by = ak.concatenate((self.by, tmp_by))
                    self.bz = ak.concatenate((self.bz, tmp_bz))
                    self.beta = ak.concatenate((self.beta, tmp_beta))
                    self.bphi = ak.concatenate((self.bphi, tmp_bphi))
                    self.EV1 = ak.concatenate((self.EV1, tmp_EV1))
                    self.EV2 = ak.concatenate((self.EV2, tmp_EV2))
                    self.EV3 = ak.concatenate((self.EV3, tmp_EV3))
                    self.eV0x = ak.concatenate((self.eV0x, tmp_eV0x))
                    self.eV0y = ak.concatenate((self.eV0y, tmp_eV0y))
                    self.eV0z = ak.concatenate((self.eV0z, tmp_eV0z))
                    self.sigma1 = ak.concatenate((self.sigma1, tmp_sigma1))
                    self.sigma2 = ak.concatenate((self.sigma2, tmp_sigma2))
                    self.sigma3 = ak.concatenate((self.sigma3, tmp_sigma3))
                    self.assoc = ak.concatenate((self.assoc, tmp_array))
                    self.pt = ak.concatenate((self.pt, tmp_pt))
                    self.vt = ak.concatenate((self.vt, tmp_vt))
                    self.score = ak.concatenate((self.score, score_array))

                counter += len(tmp_bx)
                if counter >= max_events:
                    print(f"Reached {max_events} events!")
                    break
            if counter >= max_events:
                break

    def download(self):
        raise RuntimeError(
            f'Dataset not found. Please download it from {self.url} and move all '
            f'*.root files to {self.raw_dir}')

    def len(self):
        return len(self.time)

    @property
    def raw_file_names(self):
        raw_files = sorted(glob.glob(osp.join(self.raw_dir, '*.root')))
        return raw_files

    @property
    def processed_file_names(self):
        return []



    def get(self, idx):

        def reconstruct_array(grouped_indices):
            # Finds the maximum index and returns a 1D array listing the group for each index.
            max_index = max(max(indices) for indices in grouped_indices.values())
            reconstructed = [-1] * (max_index + 1)
            for value, indices in grouped_indices.items():
                for idx2 in indices:
                    reconstructed[idx2] = value
            return reconstructed

        # Extract per-event arrays
        event_time = self.time[idx]
        event_raw_energy = self.raw_energy[idx]
        event_bx = self.bx[idx]
        event_by = self.by[idx]
        event_bz = self.bz[idx]
        event_beta = self.beta[idx]
        event_bphi = self.bphi[idx]
        event_EV1 = self.EV1[idx]
        event_EV2 = self.EV2[idx]
        event_EV3 = self.EV3[idx]
        event_eV0x = self.eV0x[idx]
        event_eV0y = self.eV0y[idx]
        event_eV0z = self.eV0z[idx]
        event_sigma1 = self.sigma1[idx]
        event_sigma2 = self.sigma2[idx]
        event_sigma3 = self.sigma3[idx]
        event_assoc = self.assoc[idx]      # associations; e.g. [0, 4, 3, 2]
        event_pt = self.pt[idx]
        event_vt = self.vt[idx]
        event_score = self.score[idx]      # scores; e.g. [0.000, 0.281, 1.0, 1.0]

        # Convert each to NumPy
        event_time = np.array(event_time)
        event_raw_energy = np.array(event_raw_energy)
        event_bx = np.array(event_bx)
        event_by = np.array(event_by)
        event_bz = np.array(event_bz)
        event_beta = np.array(event_beta)
        event_bphi = np.array(event_bphi)
        event_EV1 = np.array(event_EV1)
        event_EV2 = np.array(event_EV2)
        event_EV3 = np.array(event_EV3)
        event_eV0x = np.array(event_eV0x)
        event_eV0y = np.array(event_eV0y)
        event_eV0z = np.array(event_eV0z)
        event_sigma1 = np.array(event_sigma1)
        event_sigma2 = np.array(event_sigma2)
        event_sigma3 = np.array(event_sigma3)
        event_assoc = np.array(event_assoc)   # shape (N, ?) with nested arrays
        event_pt = np.array(event_pt)
        event_score = np.array(event_score)     # shape (N, ?) with nested arrays


        # Stack trackster features into x.
        flat_feats = np.column_stack((
            event_bx, event_by, event_bz, event_raw_energy,
            event_beta, event_bphi,
            event_EV1, event_EV2, event_EV3,
            event_eV0x, event_eV0y, event_eV0z,
            event_sigma1, event_sigma2, event_sigma3,
            event_pt
        ))
        x = torch.from_numpy(flat_feats).float()

        # Convert associations & scores to tensors.
        links_tensor = torch.from_numpy(event_assoc.astype(np.int64))
        scores_tensor = torch.from_numpy(event_score).float()

        # --- Truncate or pad each tensor to 4 columns ---
        def ensure_four_columns(tensor):
            if tensor.ndim == 1:
                tensor = tensor.unsqueeze(1)
            nrow, ncol = tensor.shape
            if ncol > 4:
                tensor = tensor[:, :4]
            elif ncol < 4:
                last_col = tensor[:, -1].unsqueeze(1)
                repeat_count = 4 - ncol
                repeated = last_col.repeat(1, repeat_count)
                tensor = torch.cat([tensor, repeated], dim=1)
            return tensor

        scores_tensor = ensure_four_columns(scores_tensor)
        links_tensor = ensure_four_columns(links_tensor)

        

        # Return the Data object with all fields.
        return Data(
            x=x,
            scores=scores_tensor,
            links=links_tensor
        )





In [33]:
ipath = "/vols/cms/mm1221/Data/mix/train/"
vpath = "/vols/cms/mm1221/Data/mix/test/"
data_train = CCV1(ipath, max_events=10000, inp='train')
data_val = CCV1(vpath, max_events=10000, inp='val')

### Loading tracksters data


  0%|                                                     | 0/3 [00:00<?, ?it/s]

/vols/cms/mm1221/Data/mix/train/raw/18k.root


  0%|                                                     | 0/3 [01:15<?, ?it/s]


Reached 10000 events!
### Loading tracksters data


  0%|                                                     | 0/1 [00:00<?, ?it/s]

/vols/cms/mm1221/Data/mix/test/raw/test.root


  0%|                                                     | 0/1 [01:04<?, ?it/s]

Reached 10000 events!





In [62]:
import torch
import torch.nn.functional as F
import random

import torch
import torch.nn.functional as F

import torch
import torch.nn.functional as F

import torch
import torch.nn.functional as F

def contrastive_loss_fractional(embeddings, groups, scores, temperature=0.1):
    """
    Computes a contrastive loss using "shared-energy" logic.

    For each node pair (i, j):
      1) Compute the shared energy e_ij by summing, over all group IDs that i and j share,
         the minimum of (1 - score[i]) and (1 - score[j]) for that group.
      2) If e_ij >= 0.5, the edge is considered positive with weight w_ij^+ = 2 * (e_ij - 0.5).
         If e_ij < 0.5, the edge is considered negative with weight w_ij^- = 2 * (0.5 - e_ij).
      3) The numerator of the InfoNCE term for anchor i sums only the positive weights:
            numerator_i = ∑_j [ w_ij^+ * exp(sim(i,j)/T) ]
         The denominator sums both positive and negative weights:
            denominator_i = ∑_j [ (w_ij^+ + w_ij^-) * exp(sim(i,j)/T) ]
      4) The loss for anchor i is -log(numerator_i / denominator_i), averaged over i.

    Args:
        embeddings: FloatTensor (N, D). Embeddings for N nodes.
        groups: LongTensor (N, num_slots). Group IDs for each node/slot.
        scores: FloatTensor (N, num_slots). Fractional “negativity” [0,1]; energy = 1 - score.
        temperature: Temperature for the softmax.

    Returns:
        A scalar tensor with the mean loss.
    """
    device = embeddings.device
    N, D = embeddings.shape
    num_slots = groups.size(1)
    
    # 1) Normalize embeddings and compute similarity (N x N).
    norm_emb = F.normalize(embeddings, p=2, dim=1)
    sim_matrix = norm_emb @ norm_emb.t()  # cosine similarity

    # 2) Compute the "energy" = (1 - scores).
    #    We'll broadcast for pairwise slot comparisons (N x 1 x num_slots) vs (1 x N x num_slots).
    energy_i = (1.0 - scores).unsqueeze(1)  # shape (N, 1, num_slots)
    energy_j = (1.0 - scores).unsqueeze(0)  # shape (1, N, num_slots)
    
    # 3) Identify slot matches and take the minimum of the energies for matching groups.
    #    match shape: (N, N, num_slots, num_slots), indicating group_i == group_j
    groups_i = groups.unsqueeze(1)  # (N, 1, num_slots)
    groups_j = groups.unsqueeze(0)  # (1, N, num_slots)
    match = (groups_i.unsqueeze(-1) == groups_j.unsqueeze(-2)).float()

    # min_energy shape: (N, N, num_slots, num_slots)
    # each entry is min(energy_i[k], energy_j[l]) if the groups match in that slot pair,
    # otherwise 0 if there's no match.
    min_energy = torch.min(
        energy_i.unsqueeze(-1),  # (N, 1, num_slots, 1)
        energy_j.unsqueeze(-2)   # (1, N, 1, num_slots)
    )
    
    shared_energy = (match * min_energy).sum(dim=(-1, -2))  # sum over slot-pairs => shape (N, N)
    # 4) Determine which edges are positive vs negative.
    pos_mask = (shared_energy >= 0.5)
    neg_mask = ~pos_mask  # everything else < 0.5
    
    # 5) Compute weights:
    #    If e >= 0.5 => w^+ = 2*(e - 0.5)
    #    If e < 0.5 => w^- = 2*(0.5 - e)
    pos_weight = torch.zeros_like(shared_energy, device=device)
    neg_weight = torch.zeros_like(shared_energy, device=device)
    
    # Positive edges
    pos_weight[pos_mask] = 2.0 * (shared_energy[pos_mask] - 0.5)
    # Negative edges
    neg_weight[neg_mask] = 2.0 * (0.5 - shared_energy[neg_mask])
    
    # Set diagonal = 0 to exclude self-similarity from the contrastive terms
    pos_weight.fill_diagonal_(0)
    neg_weight.fill_diagonal_(0)

    # 6) Compute numerator & denominator using exponentiated sim.
    exp_sim = torch.exp(sim_matrix / temperature)  # shape (N, N)
    
    #   numerator_i = sum_j [ pos_weight(i,j) * exp_sim(i,j) ]
    #   denominator_i = sum_j [ (pos_weight(i,j) + neg_weight(i,j)) * exp_sim(i,j) ]
    numerator = (pos_weight * exp_sim).sum(dim=1)         # shape (N,)
    denominator = ((pos_weight + neg_weight) * exp_sim).sum(dim=1)  # shape (N,)
    
    # 7) Compute the per-anchor loss and average.
    #    -log( numerator_i / denominator_i ).
    #    Since user explicitly said "no need for eps", we omit it. This may produce NaNs if
    #    numerator=0 or denominator=0, so ensure your data has at least some positives per node.
    loss = -torch.log(numerator / denominator)
    return loss.mean()





#################################
# Updated Training and Testing Functions
#################################

def train_new(train_loader, model, optimizer, device, k_value, alpha_unused=None):
    model.train()
    total_loss = 0.0
    n_samples = 0
    for data in tqdm(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        


        # Build k-NN graph using first 3 features.
        edge_index = knn_graph(data.x[:, :3], k=k_value, batch=data.x_batch)
        embeddings, _ = model(data.x, edge_index, data.x_batch)
        
        # Partition batch by event.
        batch_np = data.x_batch.detach().cpu().numpy()
        _, counts = np.unique(batch_np, return_counts=True)
        
        loss_event_total = 0.0
        start_idx = 0
        for count in counts:
            end_idx = start_idx + count
            event_embeddings = embeddings[start_idx:end_idx]
            event_scores = data.scores[start_idx:end_idx]
            event_links = data.links[start_idx:end_idx]
            
            loss_event = contrastive_loss_fractional(
                embeddings=event_embeddings,
                groups = event_links,
                scores =event_scores,
                temperature=0.1
            )
            loss_event_total += loss_event
            start_idx = end_idx
        
        loss = loss_event_total / len(counts)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * embeddings.size(0)
        n_samples += embeddings.size(0)
    return total_loss / n_samples

@torch.no_grad()
def test_new(test_loader, model, device, k_value, alpha_unused=None):
    model.eval()
    total_loss = 0.0
    n_samples = 0
    for data in tqdm(test_loader):
        data = data.to(device)
        
        
        edge_index = knn_graph(data.x[:, :3], k=k_value, batch=data.x_batch)
        embeddings, _ = model(data.x, edge_index, data.x_batch)
        
        batch_np = data.x_batch.detach().cpu().numpy()
        _, counts = np.unique(batch_np, return_counts=True)
        
        loss_event_total = 0.0
        start_idx = 0
        for count in counts:
            end_idx = start_idx + count
            event_embeddings = embeddings[start_idx:end_idx]
            event_scores = data.scores[start_idx:end_idx]
            event_links = data.links[start_idx:end_idx]
            loss_event = contrastive_loss_fractional(
                embeddings=event_embeddings,
                groups = event_links,
                scores =event_scores,
                temperature=0.1
            )
            loss_event_total += loss_event
            start_idx = end_idx
        total_loss += loss_event_total / len(counts)
        n_samples += embeddings.size(0)
    return total_loss / n_samples


In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.nn import knn_graph



class CustomStaticEdgeConv(nn.Module):
    def __init__(self, nn_module):
        super(CustomStaticEdgeConv, self).__init__()
        self.nn_module = nn_module

    def forward(self, x, edge_index):
        """
        Args:
            x (torch.Tensor): Node features of shape (N, F).
            edge_index (torch.Tensor): Predefined edges [2, E], where E is the number of edges.

        Returns:
            torch.Tensor: Node features after static edge aggregation.
        """
        row, col = edge_index  # Extract row (source) and col (target) nodes
        x_center = x[row]
        x_neighbor = x[col]

        # Compute edge features (relative)
        edge_features = torch.cat([x_center, x_neighbor - x_center], dim=-1)
        edge_features = self.nn_module(edge_features)

        # Aggregate features back to nodes
        num_nodes = x.size(0)
        node_features = torch.zeros(num_nodes, edge_features.size(-1), device=x.device)
        node_features.index_add_(0, row, edge_features)

        # Normalization (Divide by node degrees)
        counts = torch.bincount(row, minlength=num_nodes).clamp(min=1).view(-1, 1)
        node_features = node_features / counts

        return node_features

class Net(nn.Module):
    def __init__(self, hidden_dim=64, num_layers=4, dropout=0.3, contrastive_dim=8, heads=4):
        """
        Initializes the neural network with alternating StaticEdgeConv and GAT layers.

        Args:
            hidden_dim (int): Dimension of hidden layers.
            num_layers (int): Total number of convolutional layers (both StaticEdgeConv and GAT).
            dropout (float): Dropout rate.
            contrastive_dim (int): Dimension of the contrastive output.
            heads (int): Number of attention heads in GAT layers.
        """
        super(Net, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.contrastive_dim = contrastive_dim
        self.heads = heads

        # Input encoder
        self.lc_encode = nn.Sequential(
            nn.Linear(16, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ELU()
        )

        # Define the network's convolutional layers, alternating between StaticEdgeConv and GAT
        self.convs = nn.ModuleList()
        for layer_idx in range(num_layers):
            conv = CustomStaticEdgeConv(
                nn.Sequential(
                    nn.Linear(2 * hidden_dim, hidden_dim),
                    nn.ELU(),
                    nn.BatchNorm1d(hidden_dim),
                    nn.Dropout(p=dropout)
                )
            )
            self.convs.append(conv)

        # Output layer
        self.output = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(64, 32),
            nn.ELU(),
            nn.Dropout(p=dropout),
            nn.Linear(32, contrastive_dim)
        )

    def forward(self, x, edge_index, batch):
        """
        Forward pass of the network.

        Args:
            x (torch.Tensor): Input node features of shape (N, 15).
            edge_index (torch.Tensor): Edge indices of shape (2, E).
            batch (torch.Tensor): Batch vector.

        Returns:
            torch.Tensor: Output features after processing.
            torch.Tensor: Batch vector.
        """
        # Input encoding
        x_lc_enc = self.lc_encode(x)  # Shape: (N, hidden_dim)

        # Apply convolutional layers with residual connections
        feats = x_lc_enc
        for idx, conv in enumerate(self.convs):
            feats = conv(feats, edge_index) + feats  # Residual connection

        # Final output
        out = self.output(feats)
        return out, batch



In [64]:
print("Instantiating model...")
# Instantiate model.

# Set device.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Loading data...")
model = Net(
    hidden_dim=128,
    num_layers=3,
    dropout=0.3,
    contrastive_dim=128
).to(device)

k_value = 24
BS = 1

# Setup optimizer and scheduler.
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

# Create DataLoaders.
train_loader = DataLoader(data_train, batch_size=BS, shuffle=False, follow_batch=['x'])
val_loader = DataLoader(data_val, batch_size=BS, shuffle=False, follow_batch=['x'])

# Setup output directory.
output_dir = '/vols/cms/mm1221/hgcal/Mixed/Track/Fraction/test/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
best_val_loss = float('inf')
train_losses = []
val_losses = []
patience = 300
no_improvement_epochs = 0

print("Starting full training with curriculum for hard negative mining...")

epochs = 300
for epoch in range(epochs):
    # For epochs 1 to 150, gradually increase alpha from 0 to 1.
    # From epoch 151 onward, set alpha = 1 (fully hard negatives).


    print(f"Epoch {epoch+1}/{epochs}")
    train_loss = train_new(train_loader, model, optimizer, device, k_value)
    val_loss = test_new(val_loader, model, device, k_value)

    train_losses.append(train_loss.item())
    val_losses.append(val_loss.item())
    scheduler.step()

    # Save best model if validation loss improves.
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
        torch.save(model.state_dict(), os.path.join(output_dir, 'best_model.pt'))
    else:
        no_improvement_epochs += 1

    # Save intermediate checkpoint.
    state_dicts = {'model': model.state_dict(),
                   'opt': optimizer.state_dict(),
                   'lr': scheduler.state_dict()}
    torch.save(state_dicts, os.path.join(output_dir, f'epoch-{epoch+1}.pt'))

    print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss.item():.8f}, Validation Loss: {val_loss.item():.8f}")
    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered. No improvement for {patience} epochs.")
        break

# Save training history.
import pandas as pd
results_df = pd.DataFrame({
    'epoch': list(range(1, len(train_losses) + 1)),
    'train_loss': train_losses,
    'val_loss': val_losses
})
results_df.to_csv(os.path.join(output_dir, 'continued_training_loss.csv'), index=False)
print(f"Saved loss curves to {os.path.join(output_dir, 'continued_training_loss.csv')}")

# Save final model.
torch.save(model.state_dict(), os.path.join(output_dir, 'final_model.pt'))
print("Training complete. Final model saved.")


Instantiating model...
Loading data...
Starting full training with curriculum for hard negative mining...
Epoch 1/300


  0%|                                         | 4/16214 [00:00<07:07, 37.90it/s]

tensor([[1.1582e+00, 0.0000e+00, 1.1138e-01, 0.0000e+00, 9.5310e-03, 0.0000e+00,
         2.7165e-01, 2.1270e-02, 1.0767e-01, 4.1791e-01, 5.7489e-01, 1.3353e-02,
         6.6713e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 1.0009e+00, 1.5807e-04, 1.8090e-04, 7.4488e-04, 9.8153e-01,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 7.4488e-04,
         7.4488e-04, 3.8205e-01, 9.0295e-04, 0.0000e+00, 1.5807e-04, 0.0000e+00,
         7.4488e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 7.4488e-04, 7.4488e-04],
        [1.1138e-01, 1.5807e-04, 1.0214e+00, 8.7701e-01, 3.6135e-02, 0.0000e+00,
         1.1138e-01, 9.0460e-03, 3.1976e-02, 1.1138e-01, 1.1138e-01, 1.4972e-02,
         3.3275e-02, 0.0000e+00, 1.6630e-03, 0.0000e+00, 2.6265e-02, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 1.80

tensor([[1.0003e+00, 2.6917e-04, 1.0394e-02, 1.0000e+00, 1.0000e+00, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 2.5749e-05,
         0.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.6917e-04, 1.0004e+00, 4.2753e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.6305e-04,
         3.6305e-04, 3.6305e-04, 3.6305e-04],
        [1.0394e-02, 4.2753e-02, 1.0498e+00, 1.0125e-02, 1.0125e-02, 1.0125e-02,
         1.0125e-02, 1.0125e-02, 1.0125e-02, 1.0125e-02, 1.0125e-02, 9.8723e-01,
         9.8718e-01, 9.8718e-01, 9.8718e-01],
        [1.0000e+00, 0.0000e+00, 1.0125e-02, 1.0000e+00, 1.0000e+00, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 2.5749e-05,
         0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 0.0000e+00, 1.0125e-02, 1.0000e+00, 1.0000e+00, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 2.5749e-05,
      

  0%|                                        | 14/16214 [00:00<07:03, 38.22it/s]

tensor([[1.0154e+00, 1.6399e-02, 3.6958e-02, 1.8359e-03, 7.9340e-01, 3.5604e-03,
         1.6799e-01, 1.1903e-02, 1.5215e-02, 9.7843e-01, 0.0000e+00, 1.7222e-02,
         4.3215e-03, 1.8397e-02, 3.9043e-03, 2.3099e-03, 2.3099e-03, 0.0000e+00,
         2.3099e-03, 1.5944e-03, 4.2626e-03, 0.0000e+00, 2.3099e-03, 2.3099e-03,
         2.3099e-03, 1.4449e-03, 3.9043e-03, 2.3099e-03, 0.0000e+00],
        [1.6399e-02, 1.0164e+00, 1.6399e-02, 2.3295e-03, 3.3536e-01, 1.4609e-03,
         4.4037e-02, 6.9106e-04, 1.5366e-02, 8.9518e-02, 1.5128e-04, 1.4971e-02,
         6.2764e-04, 5.0513e-03, 2.1040e-04, 4.6066e-02, 2.1040e-04, 1.5128e-04,
         5.9128e-05, 1.5128e-04, 8.9025e-04, 1.5128e-04, 5.9128e-05, 5.9128e-05,
         5.9128e-05, 2.1040e-04, 5.9128e-05, 2.1040e-04, 1.5128e-04],
        [3.6958e-02, 1.6399e-02, 1.0249e+00, 2.1408e-03, 2.5206e-02, 1.7101e-03,
         2.5056e-02, 2.0513e-02, 1.5520e-02, 2.5360e-02, 0.0000e+00, 1.5372e-02,
         5.7203e-04, 1.8922e-02, 1.5479e-04, 1.547

tensor([[1.0180e+00, 4.1267e-03, 1.1750e-02, 1.7828e-03, 2.6493e-02, 1.9575e-02,
         1.9575e-02, 2.0432e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [4.1267e-03, 1.0556e+00, 1.0385e+00, 3.2604e-05, 1.8675e-02, 9.8765e-02,
         7.1719e-02, 6.2934e-02, 4.4030e-03, 7.5090e-01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 2.5044e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.1750e-02, 1.0385e+00, 1.0518e+00, 1.2124e-04, 2.6299e-02, 1.0993e-01,
         8.2889e-02, 7.4104e-02, 6.4613e-03, 7.5296e-01, 1.4670e-

  0%|                                        | 24/16214 [00:00<06:44, 40.05it/s]

tensor([[1.0136, 0.0531, 0.0000,  ..., 0.0252, 0.0252, 0.0252],
        [0.0531, 1.0280, 0.0000,  ..., 0.0000, 0.0056, 0.0000],
        [0.0000, 0.0000, 1.0705,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0252, 0.0000, 0.0000,  ..., 1.0000, 0.9991, 1.0000],
        [0.0252, 0.0056, 0.0000,  ..., 0.9991, 1.0047, 0.9991],
        [0.0252, 0.0000, 0.0000,  ..., 1.0000, 0.9991, 1.0000]])
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.9981, 1.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.9981, 0.0000, 0.9981],
        [0.0000, 0.0000, 0.0000,  ..., 1.0000, 0.9981, 0.0000]])
tensor([[0.0000, 0.8938, 1.0000,  ..., 0.9496, 0.9496, 0.9496],
        [0.8938, 0.0000, 1.0000,  ..., 1.0000, 0.9888, 1.0000],
        [1.0000, 1.0000, 0.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,

tensor([[1.0000e+00, 5.1141e-05, 3.5960e-02, 5.1141e-05, 5.1141e-05, 0.0000e+00,
         5.1141e-05, 5.1141e-05, 0.0000e+00, 5.1141e-05, 5.1141e-05, 0.0000e+00,
         0.0000e+00, 5.1141e-05, 0.0000e+00, 0.0000e+00],
        [5.1141e-05, 1.0001e+00, 8.9169e-05, 9.6858e-05, 8.9169e-05, 0.0000e+00,
         8.9169e-05, 8.9169e-05, 0.0000e+00, 8.9169e-05, 8.9169e-05, 0.0000e+00,
         0.0000e+00, 8.9169e-05, 0.0000e+00, 0.0000e+00],
        [3.5960e-02, 8.9169e-05, 1.0686e+00, 1.0006e+00, 9.9872e-01, 3.4014e-02,
         4.1647e-02, 3.5405e-02, 3.4014e-02, 7.8600e-01, 5.6438e-02, 3.4014e-02,
         3.4014e-02, 4.1277e-02, 3.4014e-02, 3.4014e-02],
        [5.1141e-05, 9.6858e-05, 1.0006e+00, 1.0017e+00, 9.9989e-01, 1.8525e-03,
         9.4854e-03, 3.2438e-03, 1.8525e-03, 7.5384e-01, 2.4276e-02, 1.8525e-03,
         1.8525e-03, 9.1151e-03, 1.8525e-03, 1.8525e-03],
        [5.1141e-05, 8.9169e-05, 9.9872e-01, 9.9989e-01, 9.9995e-01, 0.0000e+00,
         7.6329e-03, 1.3913e-03, 0.0000

  0%|                                        | 29/16214 [00:00<06:28, 41.61it/s]

tensor([[1.0017e+00, 1.6817e-03, 1.8106e-03, 1.7888e-03, 1.7888e-03, 1.0729e-04,
         1.7458e-03, 0.0000e+00, 1.6815e-03, 1.6815e-03, 1.7888e-03, 1.6385e-03,
         9.9809e-01, 8.4990e-01, 1.0729e-04, 3.6988e-01, 6.6178e-02, 9.8774e-01,
         4.0994e-01, 0.0000e+00, 9.4841e-02, 7.8077e-02],
        [1.6817e-03, 1.0128e+00, 8.2373e-03, 1.6975e-02, 3.9959e-02, 2.3842e-07,
         6.6453e-03, 0.0000e+00, 9.5991e-01, 8.8117e-01, 1.3031e-02, 1.2470e-02,
         0.0000e+00, 0.0000e+00, 2.3842e-07, 0.0000e+00, 2.3842e-07, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.8106e-03, 8.2373e-03, 1.0512e+00, 7.0050e-02, 4.8832e-02, 1.2606e-04,
         1.0672e-02, 0.0000e+00, 4.1343e-02, 4.4345e-02, 9.6090e-01, 7.6764e-03,
         2.1815e-05, 2.1815e-05, 1.7579e-02, 2.1815e-05, 2.7537e-05, 2.1815e-05,
         2.1815e-05, 0.0000e+00, 2.1815e-05, 2.1815e-05],
        [1.7888e-03, 1.6975e-02, 7.0050e-02, 1.0704e+00, 9.6739e-02, 1.2606e-04,
         1.0672e

  0%|                                        | 34/16214 [00:00<06:23, 42.20it/s]

tensor([[1.0000e+00, 4.3333e-05, 3.9935e-05, 0.0000e+00, 3.9935e-05, 3.9935e-05,
         4.3333e-05, 0.0000e+00, 3.9935e-05, 0.0000e+00, 0.0000e+00, 4.3333e-05,
         4.3333e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [4.3333e-05, 1.0695e+00, 4.9772e-02, 1.7525e-03, 3.4708e-02, 4.7671e-02,
         2.4853e-01, 0.0000e+00, 2.8817e-02, 3.5038e-02, 2.7583e-02, 4.6429e-02,
         3.8674e-02, 2.7583e-02, 2.7583e-02, 5.6505e-02, 2.7583e-02, 2.7583e-02,
         2.7583e-02, 2.7583e-02],
        [3.9935e-05, 4.9772e-02, 1.0300e+00, 3.1899e-03, 4.5420e-02, 8.4365e-02,
         6.6361e-02, 6.8693e-03, 5.3286e-03, 1.1550e-02, 1.0965e-02, 5.7868e-03,
         1.1989e-02, 4.0953e-03, 1.0965e-02, 6.9272e-02, 1.0965e-02, 4.0953e-03,
         1.0965e-02, 4.0953e-03],
        [0.0000e+00, 1.7525e-03, 3.1899e-03, 1.0032e+00, 3.1899e-03, 3.1899e-03,
         1.7525e-03, 1.4374e-03, 0.0000e+00, 1.7525e-03, 1.4374e-03, 1.0362e-03,
       

  0%|                                        | 39/16214 [00:00<06:38, 40.56it/s]

tensor([[1.0553, 0.0411, 0.0000,  ..., 0.0156, 0.0000, 0.0000],
        [0.0411, 0.9997, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0156, 0.0000, 0.0000,  ..., 1.2658, 0.5769, 0.7264],
        [0.0000, 0.0000, 0.0000,  ..., 0.5769, 1.2282, 1.2018],
        [0.0000, 0.0000, 0.0000,  ..., 0.7264, 1.2018, 1.3513]])
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.1539, 0.4528],
        [0.0000, 0.0000, 0.0000,  ..., 0.1539, 0.0000, 1.4036],
        [0.0000, 0.0000, 0.0000,  ..., 0.4528, 1.4036, 0.0000]])
tensor([[0.0000, 0.9178, 1.0000,  ..., 0.9689, 1.0000, 1.0000],
        [0.9178, 0.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 0.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,

tensor([[1.0043e+00, 1.2860e-03, 4.6106e-01,  ..., 8.6391e-04, 8.6391e-04,
         8.6391e-04],
        [1.2860e-03, 1.0016e+00, 1.2860e-03,  ..., 8.5115e-04, 9.6869e-04,
         5.3656e-04],
        [4.6106e-01, 1.2860e-03, 1.3779e+00,  ..., 8.0569e-01, 6.7615e-02,
         4.0988e-03],
        ...,
        [8.6391e-04, 8.5115e-04, 8.0569e-01,  ..., 1.2819e+00, 1.0875e-01,
         4.6437e-01],
        [8.6391e-04, 9.6869e-04, 6.7615e-02,  ..., 1.0875e-01, 1.0899e+00,
         2.9254e-02],
        [8.6391e-04, 5.3656e-04, 4.0988e-03,  ..., 4.6437e-01, 2.9254e-02,
         1.0025e+00]])
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.6114, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.6114,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])
tensor

  0%|                                        | 44/16214 [00:01<07:09, 37.69it/s]

tensor([[1.0054e+00, 4.1753e-04, 3.7991e-03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [4.1753e-04, 1.0465e+00, 5.0791e-03,  ..., 1.3997e-02, 1.3997e-02,
         1.3997e-02],
        [3.7991e-03, 5.0791e-03, 1.0048e+00,  ..., 1.3732e-03, 1.3732e-03,
         1.3732e-03],
        ...,
        [0.0000e+00, 1.3997e-02, 1.3732e-03,  ..., 1.0000e+00, 1.0000e+00,
         1.7325e-01],
        [0.0000e+00, 1.3997e-02, 1.3732e-03,  ..., 1.0000e+00, 1.0000e+00,
         1.7325e-01],
        [0.0000e+00, 1.3997e-02, 1.3732e-03,  ..., 1.7325e-01, 1.7325e-01,
         1.1324e+00]])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([[0.0000, 0.9992, 0.9924,  ..., 1.0000, 1.0000, 1.0000],
        [0.9992, 0.0000, 0.9898,  ..., 0.9720, 0.9720, 0.9720],
        [0.9924, 0.989

  0%|                                        | 49/16214 [00:01<06:49, 39.47it/s]

tensor([[0.0000, 0.9714, 0.9656, 0.8918, 0.0000, 0.6286, 0.0000, 0.9059],
        [0.9714, 0.0000, 0.0000, 0.9938, 0.9931, 0.9216, 0.9938, 0.9986],
        [0.9656, 0.0000, 0.0000, 0.9933, 0.9926, 0.9211, 0.9933, 0.9933],
        [0.8918, 0.9938, 0.9933, 0.0000, 0.9266, 0.0000, 0.8552, 0.0000],
        [0.0000, 0.9931, 0.9926, 0.9266, 0.0000, 0.7040, 0.0000, 0.9408],
        [0.6286, 0.9216, 0.9211, 0.0000, 0.7040, 0.0000, 0.6340, 0.0000],
        [0.0000, 0.9938, 0.9933, 0.8552, 0.0000, 0.6340, 0.0000, 0.8693],
        [0.9059, 0.9986, 0.9933, 0.0000, 0.9408, 0.0000, 0.8693, 0.0000]])
tensor([[1.0018e+00, 1.5773e-03, 2.0922e-02, 5.3182e-03, 9.9997e-01, 8.6991e-03,
         0.0000e+00, 1.0462e-03, 1.6548e-03, 1.4061e-03, 2.9206e-06, 1.8458e-03,
         2.9206e-06, 1.4061e-03, 1.4061e-03, 0.0000e+00, 3.9345e-01, 2.9206e-06,
         1.4061e-03, 2.9206e-06],
        [1.5773e-03, 1.0011e+00, 1.0083e-02, 1.2634e-03, 6.8331e-04, 1.1376e-03,
         0.0000e+00, 6.4421e-04, 7.0298e-04, 4.54

  0%|▏                                       | 54/16214 [00:01<06:56, 38.79it/s]

tensor([[1.0001e+00, 6.6143e-04, 2.9492e-04, 1.8454e-03, 1.0138e-03, 5.3173e-04,
         2.9492e-04, 2.9492e-04, 0.0000e+00, 0.0000e+00, 4.0531e-06, 2.9492e-04,
         2.9492e-04, 2.9492e-04, 2.9492e-04, 2.9492e-04, 2.9492e-04, 2.9492e-04,
         2.9492e-04, 0.0000e+00, 2.9492e-04, 2.9492e-04, 2.9492e-04, 2.9492e-04,
         0.0000e+00],
        [6.6143e-04, 1.0153e+00, 9.9768e-01, 9.9805e-01, 4.0046e-02, 4.7610e-03,
         2.4451e-01, 8.2571e-01, 6.7174e-04, 6.7174e-04, 9.6977e-04, 4.8861e-01,
         6.9081e-01, 1.3347e-02, 4.3753e-01, 9.9480e-01, 6.9169e-02, 1.0065e+00,
         8.5561e-01, 1.7205e-02, 9.6140e-01, 1.0115e+00, 9.5944e-01, 8.1200e-01,
         1.7205e-02],
        [2.9492e-04, 9.9768e-01, 9.9987e-01, 9.9935e-01, 3.7853e-02, 3.8525e-03,
         2.2797e-01, 8.0918e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.7207e-01,
         6.7428e-01, 1.2675e-02, 4.2100e-01, 9.7826e-01, 5.2636e-02, 9.9000e-01,
         8.3908e-01, 0.0000e+00, 9.4487e-01, 9.9500e-01, 9.4291e-

  0%|▏                                       | 59/16214 [00:01<06:28, 41.57it/s]

tensor([[1.0085e+00, 7.2557e-03, 3.0428e-03, 1.2840e-02, 1.0000e+00, 9.5501e-01,
         3.9160e-05, 9.9839e-01, 3.0758e-03, 0.0000e+00, 3.0428e-03, 2.6425e-03,
         1.0000e+00, 1.0038e+00, 2.4218e-02, 0.0000e+00, 9.9789e-01, 9.6652e-03],
        [7.2557e-03, 1.0018e+00, 9.9759e-01, 4.2130e-03, 7.2904e-03, 4.2130e-03,
         3.9160e-05, 3.8238e-03, 1.4579e-02, 0.0000e+00, 9.9237e-03, 2.6425e-03,
         4.2130e-03, 4.6852e-03, 9.3211e-01, 0.0000e+00, 3.3231e-03, 3.3231e-03],
        [3.0428e-03, 9.9759e-01, 1.0068e+00, 0.0000e+00, 3.0774e-03, 0.0000e+00,
         3.9160e-05, 0.0000e+00, 2.1053e-02, 0.0000e+00, 1.6430e-02, 2.6425e-03,
         0.0000e+00, 1.9925e-03, 9.2879e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.2840e-02, 4.2130e-03, 0.0000e+00, 1.0046e+00, 3.8015e-02, 1.8808e-01,
         0.0000e+00, 5.7499e-03, 3.3081e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         1.0585e-02, 1.1128e-02, 5.2492e-03, 0.0000e+00, 5.2492e-03, 5.2492e-03],
        [1.0000e+00, 7.2

  0%|▏                                       | 65/16214 [00:01<06:08, 43.85it/s]

tensor([[1.0346e+00, 9.7454e-05, 1.8828e-02, 6.0395e-02, 3.5533e-03, 3.8499e-02,
         0.0000e+00, 6.7711e-03, 0.0000e+00, 9.6293e-01, 3.4969e-02, 7.0322e-04,
         1.8293e-04, 1.8293e-04, 0.0000e+00, 0.0000e+00, 1.8293e-04, 1.8293e-04,
         1.8293e-04, 1.8293e-04, 1.8293e-04, 1.8293e-04, 1.6443e-02, 1.8293e-04,
         1.8293e-04],
        [9.7454e-05, 1.0174e+00, 4.2641e-04, 9.7454e-05, 5.4229e-02, 0.0000e+00,
         9.8558e-01, 0.0000e+00, 6.1094e-02, 9.7454e-05, 9.7454e-05, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.7719e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.8828e-02, 4.2641e-04, 1.0157e+00, 2.2249e-02, 3.8823e-03, 2.5111e-03,
         3.2896e-04, 2.5111e-03, 3.2896e-04, 1.6388e-02, 1.0030e+00, 5.2029e-04,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.5111e-

  0%|▏                                       | 70/16214 [00:01<06:29, 41.45it/s]

tensor([[1.0053e+00, 1.0858e-03, 4.4077e-02, 7.5281e-05, 9.2208e-05, 1.8345e-03,
         9.9995e-01, 1.6145e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [1.0858e-03, 1.0011e+00, 1.1393e-03, 3.5644e-05, 3.9876e-05, 2.2948e-05,
         1.0825e-03, 2.6851e-03, 1.2696e-05, 1.2696e-05, 1.2696e-05, 1.2696e-05,
         1.2696e-05, 1.2696e-05],
        [4.4077e-02, 1.1393e-03, 1.0433e+00, 5.7108e-03, 5.7277e-03, 7.4700e-03,
         3.8721e-02, 7.3035e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [7.5281e-05, 3.5644e-05, 5.7108e-03, 1.0123e+00, 1.4475e-02, 1.8454e-02,
         0.0000e+00, 9.2414e-03, 4.9567e-04, 4.9567e-04, 4.9567e-04, 4.9567e-04,
         4.9567e-04, 4.9567e-04],
        [9.2208e-05, 3.9876e-05, 5.7277e-03, 1.4475e-02, 1.0141e+00, 3.5928e-02,
         1.6928e-05, 1.0606e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [1.8345e-03,




KeyboardInterrupt: 