In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from model import SingleImageTransformer
from dataset import SingleImageTransformerDataset
import os 
import numpy as np
import shutil
from sim import simulate_mechanism
from utils import preprocess_curves
import matplotlib.pyplot as plt
from tqdm import tqdm
from contextlib import contextmanager

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

2025-08-24 00:13:09.191545: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756008789.202824 2295047 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756008789.206479 2295047 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1756008789.216468 2295047 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1756008789.216476 2295047 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1756008789.216478 2295047 computation_placer.cc:177] computation placer alr

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_ddp_model(model, checkpoint_path, strict=True, verbose=True):
    """
    Load model weights from checkpoint, handling DDP wrapping if present.
    
    Args:
        model: The model to load weights into
        checkpoint_path: Path to the checkpoint file
        clip_loss_fn: Optional CLIP loss function to load weights for
        strict: Whether to strictly enforce matching keys between model and checkpoint
        verbose: Whether to print debugging information
        
    Returns:
        Tuple of (model, checkpoint) with loaded weights
    """
    try:
        # Load checkpoint with error handling
        checkpoint = torch.load(checkpoint_path, map_location=device)
        if verbose:
            print(f"✅ Checkpoint loaded from {checkpoint_path}")
            print(f"Checkpoint keys: {list(checkpoint.keys())}")
            # print(checkpoint['model_state_dict'])
        
        # Extract state dict
        state_dict = checkpoint.get("model_state_dict", checkpoint)
        
        # Remove DDP wrapper prefixes
        new_state_dict = {}
        for k, v in state_dict.items():
            new_key = k.replace("module.", "") if k.startswith("module.") else k
            new_state_dict[new_key] = v
            if verbose and k != new_key:
                print(f"⚠️ Renamed DDP key: {k} -> {new_key}")
        
        # Verify model architecture matches checkpoint
        model_keys = set(model.state_dict().keys())
        ckpt_keys = set(new_state_dict.keys())
        
        if verbose:
            print("\n🔍 Model vs Checkpoint Key Analysis:")
            print(f"Model has {len(model_keys)} keys")
            print(f"Checkpoint has {len(ckpt_keys)} keys")
            
            missing_in_model = ckpt_keys - model_keys
            missing_in_ckpt = model_keys - ckpt_keys
            
            if missing_in_model:
                print(f"\n❌ Keys in checkpoint but not in model: {missing_in_model}")
            if missing_in_ckpt:
                print(f"\n❌ Keys in model but not in checkpoint: {missing_in_ckpt}")
        
        # Load model weights with error handling
        try:
            model.load_state_dict(new_state_dict, strict=strict)
            if verbose:
                print("\n✅ Model weights loaded successfully!")
        except RuntimeError as e:
            print(f"\n❌ Error loading model weights: {str(e)}")
            if not strict:
                print("⚠️ Attempting partial load (strict=False)")
                model.load_state_dict(new_state_dict, strict=False)
        
        return model, checkpoint
    
    except Exception as e:
        print(f"❌ Failed to load checkpoint: {str(e)}")
        raise

# Model Configuration
model_config = {
    'output_size': 2,
    'tgt_seq_len': 10,
    'd_model': 1024,
    'h': 32,
    'N': 6
}
# Hyperparameters
batch_size = 1024
lr = 1e-4

# Initialize Model
model = SingleImageTransformer(
    output_size=model_config['output_size'],
    tgt_seq_len=model_config['tgt_seq_len'],
    d_model=model_config['d_model'],
    h=model_config['h'],
    N=model_config['N']
).to(device)


model, checkpoint = load_ddp_model(model, f"weights/image_weights/d{model_config['d_model']}_h{model_config['h']}_n{model_config['N']}_bs{batch_size}_lr{lr}_clip1.0_best.pth")
model.eval()


✅ Checkpoint loaded from weights/image_weights/d1024_h32_n6_bs1024_lr0.0001_clip1.0_best.pth
Checkpoint keys: ['model_state_dict', 'optimizer_state_dict', 'clip_loss_state_dict', 'epoch', 'best_loss', 'batch_size', 'learning_rate', 'model_config', 'global_step']

🔍 Model vs Checkpoint Key Analysis:
Model has 807 keys
Checkpoint has 807 keys

✅ Model weights loaded successfully!


SingleImageTransformer(
  (encoder_positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (decoder_positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (tgt_embed): InputEmbeddings(
    (embedding): Linear(in_features=2, out_features=1024, bias=False)
  )
  (encoder): ModuleList(
    (0-5): 6 x Encoder(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
      )
      (feed_forward): FeedForwardBlock(
        (w1): Linear(in_features=1024, out_features=4096, bias=False)
        (w2): Linear(in_features=4096, out_features=1024, bias=False)
        (w3): Linear(in_features=1024, out_features=4096, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
      (final_norm): RMSNorm()
    )
  )
  (decoder): ModuleList(
    (0-5): 6 x Decoder(
      (self_

In [3]:
dataset = SingleImageTransformerDataset(
        data_dir='/home/anurizada/Documents/nobari_10_transformer',
    )
data_loader = DataLoader(dataset, shuffle=True, batch_size=1)

In [4]:
def greedy_decode_conditional(model, source, adj_type, max_len, eos_token=torch.tensor([1.0, 1.0])):
    model, source = model.to(device), source.to(device)
    adj_type = adj_type.to(device)

    # Encode the source once
    encoder_output, curve_emb, adj_emb = model.encode(source, adj_type)

    # Initialize decoder input with start token
    decoder_input = torch.ones(1, 1, 2).to(device) * -2.0  # Start token
    
    while decoder_input.size(1) < max_len:
        # Build causal mask (0 for allowed, -inf for blocked)
        decoder_mask = causal_mask(decoder_input.size(1), device)

        # Decode
        decoder_output = model.decode(
            encoder_output,
            None,
            decoder_input,
            decoder_mask
        )
        
        proj_output = model.projection(model.projection_norm(decoder_output))
        
        # Get next token
        next_token = proj_output[:, -1].unsqueeze(1)
        
        # EOS check
        if torch.allclose(next_token.squeeze(), eos_token.to(device), atol=1e-1):
            break

        # Append token
        decoder_input = torch.cat([decoder_input, next_token], dim=1)

    return decoder_input.squeeze(0), curve_emb, adj_emb


def causal_mask(size, device):
    # (n, n) upper-triangular boolean mask
    mask = torch.triu(torch.ones(size, size, device=device), diagonal=1).bool()
    # Add batch dimension -> (1, n, n)
    return mask.unsqueeze(0)

In [5]:
# Configuration
class Config:
    results_dir = f"image_d{model_config['d_model']}_h{model_config['h']}_n{model_config['N']}_bs{batch_size}_lr{lr}"
    max_mech_size = 10
    num_conditions = 200  # Number of different adjacency conditions to test per curve
    prefix_rows = np.array([[0.5, 0.5], [0.6, 0.5]], dtype=np.float32)
    plt_style = {
        'truth_joints': {'color': 'red', 'marker': 'o', 's': 60, 'label': 'Truth Joints'},
        'pred_joints': {'color': 'blue', 'marker': 'x', 's': 60, 'label': 'Pred Joints'},
        'truth_curve': {'color': 'magenta', 'linestyle': '-', 'linewidth': 3, 'label': 'Truth Curve'},
        'pred_curve': {'color': 'cyan', 'linestyle': '-', 'linewidth': 2, 'label': 'Predicted Curve'},
        'cond_curve': {'color': 'green', 'linestyle': '--', 'linewidth': 2, 'label': 'Condition Curve'}
    }

# Invalid joint values to remove
INVALID_JOINTS = np.array([
    [-1.0, -1.0],
    [1.0,  1.0],
], dtype=np.float32)

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device).float()  # assume `model` is already defined elsewhere

@contextmanager
def managed_figure(figsize=(10, 8)):
    """Context manager to ensure figures are properly cleaned up"""
    fig, ax = plt.subplots(figsize=figsize)
    try:
        yield ax
    finally:
        plt.close(fig)

def prepare_results_dir():
    """Prepare clean results directory"""
    if os.path.exists(Config.results_dir):
        shutil.rmtree(Config.results_dir)
    os.makedirs(Config.results_dir, exist_ok=True)

def process_adjacency(adj_tensor):
    """Convert raw adjacency to simulation format (handles both tensors and numpy arrays)"""
    if isinstance(adj_tensor, torch.Tensor):
        adj = adj_tensor.detach().cpu().squeeze().float().numpy()
    elif isinstance(adj_tensor, np.ndarray):
        adj = adj_tensor.squeeze().astype(np.float32)
    else:
        raise TypeError(f"Input must be torch.Tensor or numpy.ndarray, got {type(adj_tensor)}")

    valid_mask = ~np.all(adj == 0, axis=1)
    adj = adj[valid_mask][:, valid_mask]
    node_types = np.diag(adj).astype(bool)
    np.fill_diagonal(adj, 0)
    return adj.astype(np.float32), node_types

def simulate(adj, joints, node_types):
    """Run simulation with error handling"""
    try:
        if adj.dtype != np.float32:
            adj = adj.astype(np.float32)
        if joints.dtype != np.float32:
            joints = joints.astype(np.float32)
        
        result = simulate_mechanism(adj, joints, node_types)
        if result is None:
            return None
        
        trajectory = torch.tensor(result[-1], dtype=torch.float32).unsqueeze(0)
        return preprocess_curves(trajectory).detach().cpu().numpy().squeeze()
    except Exception as e:
        return None

def get_joints(pred_sequence):
    """Prepare joint positions from single decoder output"""
    # Remove start token and EOS token if present
    joints = pred_sequence[1:].detach().cpu().float()
    return np.concatenate([Config.prefix_rows, joints.numpy()], axis=0)

def save_plot(ax, curve_dir, filename, title=None):
    """Helper function to save plots with guaranteed cleanup"""
    try:
        if title:
            ax.set_title(title)
        ax.legend()
        ax.grid(True)
        ax.axis('equal')
        plt.tight_layout()
        plt.savefig(os.path.join(curve_dir, filename))
    finally:
        plt.close()

def filter_invalid_joints(joints: np.ndarray, invalid_values: np.ndarray) -> np.ndarray:
    """Remove any row in `joints` that exactly matches any row in `invalid_values`."""
    mask = np.ones(len(joints), dtype=bool)
    for invalid in invalid_values:
        mask &= ~np.all(joints == invalid, axis=1)
    return joints[mask]

def create_curve_plot(ax, truth_joints, pred_joints, gt_curve=None, pred_curve=None, 
                     cond_curve=None, title=None):
    """Create a standardized plot with joint enumeration"""
    try:
        if gt_curve is not None:
            ax.plot(gt_curve[:, 0], gt_curve[:, 1], **Config.plt_style['truth_curve'])
        if pred_curve is not None:
            ax.plot(pred_curve[:, 0], pred_curve[:, 1], **Config.plt_style['pred_curve'])
        if cond_curve is not None:
            ax.plot(cond_curve[:, 0], cond_curve[:, 1], **Config.plt_style['cond_curve'])
        
        # Scatter points
        ax.scatter(truth_joints[:, 0], truth_joints[:, 1], **Config.plt_style['truth_joints'])
        ax.scatter(pred_joints[:, 0], pred_joints[:, 1], **Config.plt_style['pred_joints'])

        # Annotate truth joints
        for idx, (x, y) in enumerate(truth_joints):
            ax.annotate(str(idx), (x, y), textcoords="offset points", xytext=(5,5), fontsize=8, color="red")

        # Annotate predicted joints
        for idx, (x, y) in enumerate(pred_joints):
            ax.annotate(str(idx), (x, y), textcoords="offset points", xytext=(5,-10), fontsize=8, color="blue")

        if title:
            ax.set_title(title)
        ax.set_xlabel('X Position')
        ax.set_ylabel('Y Position')
    except Exception as e:
        pass

# ---------------------------
# Utility functions
# ---------------------------
def compute_all_adj_embeddings(model, unique_adjs, curve_data, device):
    """Compute and cache embeddings for all candidate adjacencies."""
    all_adj_embs = []
    with torch.no_grad():
        for cond_adj in tqdm(unique_adjs, desc="Caching adjacency embeddings"):
            cond_adj_tensor = torch.from_numpy(cond_adj).to(device).float().unsqueeze(0)
            _, _, cond_adj_emb = model.encode(curve_data, cond_adj_tensor)

            # (1,1,1024) -> (1024,)
            cond_adj_emb = cond_adj_emb.squeeze(0).squeeze(0)
            all_adj_embs.append(cond_adj_emb)

    # [M, 1024]
    all_adj_embs = torch.stack(all_adj_embs, dim=0)
    all_adj_embs = F.normalize(all_adj_embs, p=2, dim=-1)
    return all_adj_embs


# ---------------------------
# Main processing
# ---------------------------
prepare_results_dir()
unique_adjs = np.load("unique_adjacency_matrices.npy")
print(f"Loaded {len(unique_adjs)} unique adjacency matrices")

for batch_idx, batch in enumerate(tqdm(data_loader)):
    if batch_idx == 50:
        break

    # Create subfolder for this curve
    curve_dir = os.path.join(Config.results_dir, f'curve_{batch_idx}')
    os.makedirs(curve_dir, exist_ok=True)

    # Prepare data
    curve_data = batch["curve_numerical"].to(device).float()
    gt_adj = batch["adjacency"].to(device).float()

    # Get ground truth joints
    truth_joints = batch["label"].view(-1, 2)[:-1].detach().cpu().float()  # Remove EOS token
    truth_joints = np.concatenate([Config.prefix_rows, truth_joints.numpy()], axis=0)

    # Process ground truth adjacency
    gt_adj_processed, gt_node_types = process_adjacency(gt_adj)

    # Get prediction with ground truth adjacency
    with torch.no_grad():
        print(curve_data.shape, gt_adj.shape)
        pred_sequence, curve_emb, gt_adj_emb = greedy_decode_conditional(
            model, curve_data, gt_adj, Config.max_mech_size
        )

    # (1,1,1024) -> (1,1024)
    curve_emb = curve_emb.squeeze(0).squeeze(0).unsqueeze(0)  # [1,1024]
    curve_emb = F.normalize(curve_emb, p=2, dim=-1)

    # Prediction joints
    pred_joints = get_joints(pred_sequence)
    pred_joints = filter_invalid_joints(pred_joints, INVALID_JOINTS)
    truth_joints = filter_invalid_joints(truth_joints, INVALID_JOINTS)

    gt_curve = simulate(gt_adj_processed, truth_joints, gt_node_types)
    pred_curve = simulate(gt_adj_processed, pred_joints, gt_node_types)

    # Plot GT vs Pred
    with managed_figure() as ax:
        create_curve_plot(
            ax, truth_joints, pred_joints,
            gt_curve=gt_curve, pred_curve=pred_curve,
            title=f'Curve {batch_idx}: GT vs Pred'
        )
        save_plot(ax, curve_dir, f'gt_vs_pred_sim.png')

    # # ---------------------------
    # # Step 2: Compare against all adjacencies
    # # ---------------------------
    # all_adj_embs = compute_all_adj_embeddings(model, unique_adjs[:Config.num_conditions], curve_data, device)
    # # Similarity [1,1024] @ [1024,M] -> [M]
    # sims = (curve_emb @ all_adj_embs.T).squeeze(0)

    # # Top-k candidates
    # topk_vals, topk_idx = sims.topk(100)

    # for rank, (sim_val, cond_idx) in enumerate(zip(topk_vals.tolist(), topk_idx.tolist())):
    #     cond_adj = unique_adjs[cond_idx]
    #     cond_adj_tensor = torch.from_numpy(cond_adj).to(device).float().unsqueeze(0)

    #     with torch.no_grad():
    #         cond_sequence, _, _ = greedy_decode_conditional(
    #             model, curve_data, cond_adj_tensor, Config.max_mech_size
    #         )

    #     cond_joints = get_joints(cond_sequence)
    #     cond_joints = filter_invalid_joints(cond_joints, INVALID_JOINTS)
    #     cond_adj_processed, cond_node_types = process_adjacency(cond_adj)
    #     cond_curve = simulate(cond_adj_processed, cond_joints, cond_node_types)

    #     if cond_curve is None:
    #         continue

    #     with managed_figure() as ax:
    #         create_curve_plot(
    #             ax, truth_joints, cond_joints,
    #             gt_curve=gt_curve, cond_curve=cond_curve,
    #             title=f'Curve {batch_idx} CondRank {rank} (sim={sim_val:.4f})'
    #         )
    #         save_plot(ax, curve_dir, f'cond_rank_{rank}_sim_{sim_val:.4f}.png')



Loaded 54507 unique adjacency matrices


  0%|          | 0/2257251 [00:00<?, ?it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 1/2257251 [00:00<442:15:05,  1.42it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 2/2257251 [00:00<264:33:41,  2.37it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 3/2257251 [00:01<220:40:31,  2.84it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 4/2257251 [00:01<190:30:08,  3.29it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 5/2257251 [00:01<191:32:43,  3.27it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 6/2257251 [00:02<195:36:20,  3.21it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 7/2257251 [00:02<192:41:18,  3.25it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 8/2257251 [00:02<176:50:50,  3.55it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 9/2257251 [00:02<168:40:24,  3.72it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 10/2257251 [00:03<173:01:13,  3.62it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 11/2257251 [00:03<174:43:36,  3.59it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 12/2257251 [00:03<177:27:12,  3.53it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 13/2257251 [00:03<169:34:09,  3.70it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 14/2257251 [00:04<164:08:36,  3.82it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 15/2257251 [00:04<172:37:09,  3.63it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 16/2257251 [00:04<166:09:28,  3.77it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 17/2257251 [00:04<164:20:29,  3.82it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 18/2257251 [00:05<170:26:17,  3.68it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 19/2257251 [00:05<166:48:40,  3.76it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 20/2257251 [00:05<159:57:10,  3.92it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 21/2257251 [00:06<169:55:31,  3.69it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 22/2257251 [00:06<168:41:29,  3.72it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 23/2257251 [00:06<163:41:57,  3.83it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 24/2257251 [00:07<227:02:00,  2.76it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 25/2257251 [00:07<212:33:13,  2.95it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 26/2257251 [00:07<191:09:57,  3.28it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 27/2257251 [00:07<177:44:31,  3.53it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 28/2257251 [00:08<166:27:38,  3.77it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 29/2257251 [00:08<166:17:33,  3.77it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 30/2257251 [00:08<157:21:28,  3.98it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 31/2257251 [00:08<156:57:01,  3.99it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 32/2257251 [00:09<151:19:36,  4.14it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 33/2257251 [00:09<153:32:28,  4.08it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 34/2257251 [00:09<160:34:45,  3.90it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 35/2257251 [00:09<162:19:41,  3.86it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 36/2257251 [00:10<159:28:37,  3.93it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 37/2257251 [00:10<163:25:07,  3.84it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 38/2257251 [00:10<166:05:04,  3.78it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 39/2257251 [00:10<165:20:57,  3.79it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 40/2257251 [00:11<159:32:46,  3.93it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 41/2257251 [00:11<165:34:26,  3.79it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 42/2257251 [00:11<168:53:42,  3.71it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 43/2257251 [00:12<177:15:05,  3.54it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 44/2257251 [00:12<178:02:57,  3.52it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 45/2257251 [00:12<179:47:25,  3.49it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 46/2257251 [00:12<176:03:51,  3.56it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 47/2257251 [00:13<182:51:39,  3.43it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 48/2257251 [00:13<175:28:07,  3.57it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 49/2257251 [00:13<174:07:27,  3.60it/s]

torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 20, 20])


  0%|          | 50/2257251 [00:14<175:46:48,  3.57it/s]


In [None]:
# import torch
# from tqdm import tqdm
# import numpy as np
# from collections import defaultdict

# # Track unique matrices and their counts
# unique_matrices = {}  # {hash_key: (matrix, count)}
# counts = defaultdict(int)

# for batch in tqdm(data_loader):
#     gt_adj = batch["adjacency"].float().cpu()  # Original matrices
    
#     for matrix in gt_adj:
#         # Convert to numpy and create hashable key
#         np_matrix = matrix.numpy()
#         matrix_key = tuple(np_matrix.round(4).flatten())  # Round to avoid float precision issues
        
#         # Store first occurrence and count
#         if matrix_key not in unique_matrices:
#             unique_matrices[matrix_key] = np_matrix
#         counts[matrix_key] += 1

# # Save all unique matrices in one file
# all_matrices = np.stack(list(unique_matrices.values()))
# np.save("unique_adjacency_matrices.npy", all_matrices)

# print(f"Saved {len(unique_matrices)} unique matrices")
# print(f"Output shape: {all_matrices.shape}")  # e.g. (25, 10, 10)

: 