In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from model import SingleTransformer
from dataset import SingleTransformerDataset
import matplotlib.pyplot as plt
import os 
import numpy as np
import shutil
from sim import simulate_mechanism
from utils import preprocess_curves
import random 

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

2025-08-18 16:28:10.622293: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755548890.640250 3311990 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755548890.646056 3311990 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755548890.660992 3311990 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755548890.661017 3311990 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755548890.661019 3311990 computation_placer.cc:177] computation placer alr

In [2]:
class CLIPContrastiveLoss(nn.Module):
    def __init__(self, init_scale=1/0.07):
        super().__init__()
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(init_scale)))

    def forward(self, image_embeddings, text_embeddings):
        image_embeddings = F.normalize(image_embeddings.squeeze(), p=2, dim=1)
        text_embeddings = F.normalize(text_embeddings.squeeze(), p=2, dim=1)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_embeddings @ text_embeddings.t()

        N = logits.shape[0]
        targets = torch.arange(N, device=logits.device)

        loss_i2t = F.cross_entropy(logits, targets)
        loss_t2i = F.cross_entropy(logits.t(), targets)
        return (loss_i2t + loss_t2i) / 2

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_ddp_model(model, checkpoint_path, clip_loss_fn=None, strict=True, verbose=True):
    """
    Load model weights from checkpoint, handling DDP wrapping if present.
    
    Args:
        model: The model to load weights into
        checkpoint_path: Path to the checkpoint file
        clip_loss_fn: Optional CLIP loss function to load weights for
        strict: Whether to strictly enforce matching keys between model and checkpoint
        verbose: Whether to print debugging information
        
    Returns:
        Tuple of (model, checkpoint) with loaded weights
    """
    try:
        # Load checkpoint with error handling
        checkpoint = torch.load(checkpoint_path, map_location=device)
        if verbose:
            print(f"✅ Checkpoint loaded from {checkpoint_path}")
            print(f"Checkpoint keys: {list(checkpoint.keys())}")
            # print(checkpoint['model_state_dict'])
        
        # Extract state dict
        state_dict = checkpoint.get("model_state_dict", checkpoint)
        
        # Remove DDP wrapper prefixes
        new_state_dict = {}
        for k, v in state_dict.items():
            new_key = k.replace("module.", "") if k.startswith("module.") else k
            new_state_dict[new_key] = v
            if verbose and k != new_key:
                print(f"⚠️ Renamed DDP key: {k} -> {new_key}")
        
        # Verify model architecture matches checkpoint
        model_keys = set(model.state_dict().keys())
        ckpt_keys = set(new_state_dict.keys())
        
        if verbose:
            print("\n🔍 Model vs Checkpoint Key Analysis:")
            print(f"Model has {len(model_keys)} keys")
            print(f"Checkpoint has {len(ckpt_keys)} keys")
            
            missing_in_model = ckpt_keys - model_keys
            missing_in_ckpt = model_keys - ckpt_keys
            
            if missing_in_model:
                print(f"\n❌ Keys in checkpoint but not in model: {missing_in_model}")
            if missing_in_ckpt:
                print(f"\n❌ Keys in model but not in checkpoint: {missing_in_ckpt}")
        
        # Load model weights with error handling
        try:
            model.load_state_dict(new_state_dict, strict=strict)
            if verbose:
                print("\n✅ Model weights loaded successfully!")
        except RuntimeError as e:
            print(f"\n❌ Error loading model weights: {str(e)}")
            if not strict:
                print("⚠️ Attempting partial load (strict=False)")
                model.load_state_dict(new_state_dict, strict=False)
        
        # Handle CLIP loss function if provided
        if clip_loss_fn is not None:
            if "clip_loss_state_dict" in checkpoint:
                clip_state_dict = checkpoint["clip_loss_state_dict"]
                new_clip_state_dict = {}
                for k, v in clip_state_dict.items():
                    new_key = k.replace("module.", "") if k.startswith("module.") else k
                    new_clip_state_dict[new_key] = v
                
                try:
                    clip_loss_fn.load_state_dict(new_clip_state_dict)
                    if verbose:
                        print("✅ CLIP loss function weights loaded successfully!")
                except RuntimeError as e:
                    print(f"❌ Error loading CLIP weights: {str(e)}")
            elif verbose:
                print("⚠️ No CLIP loss weights found in checkpoint")
        
        return model, checkpoint
    
    except Exception as e:
        print(f"❌ Failed to load checkpoint: {str(e)}")
        raise

# Model Configuration
model_config = {
    'output_size': 2,
    'tgt_seq_len': 10,
    'd_model': 1024,
    'h': 32,
    'N': 6
}

# Initialize Model
model = SingleTransformer(
    output_size=model_config['output_size'],
    tgt_seq_len=model_config['tgt_seq_len'],
    d_model=model_config['d_model'],
    h=model_config['h'],
    N=model_config['N']
).to(device)

clip_loss_fn = CLIPContrastiveLoss().to(device)

model, checkpoint = load_ddp_model(model, f"weights/new_model/d{model_config['d_model']}_h{model_config['h']}_n{model_config['N']}_bs1024_lr0.0002_best.pth", clip_loss_fn)
model.eval()


✅ Checkpoint loaded from weights/new_model/d1024_h32_n6_bs1024_lr0.0002_best.pth
Checkpoint keys: ['model_state_dict', 'optimizer_state_dict', 'clip_loss_state_dict', 'epoch', 'best_loss', 'batch_size', 'learning_rate', 'model_config']

🔍 Model vs Checkpoint Key Analysis:
Model has 807 keys
Checkpoint has 807 keys

✅ Model weights loaded successfully!
✅ CLIP loss function weights loaded successfully!


SingleTransformer(
  (encoder_positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (decoder_positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (tgt_embed): InputEmbeddings(
    (embedding): Linear(in_features=2, out_features=1024, bias=False)
  )
  (encoder): ModuleList(
    (0-5): 6 x Encoder(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
      )
      (feed_forward): FeedForwardBlock(
        (w1): Linear(in_features=1024, out_features=4096, bias=False)
        (w2): Linear(in_features=4096, out_features=1024, bias=False)
        (w3): Linear(in_features=1024, out_features=4096, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
      (final_norm): RMSNorm()
    )
  )
  (decoder): ModuleList(
    (0-5): 6 x Decoder(
      (self_atten

In [4]:
dataset = SingleTransformerDataset(
        data_dir='/home/anurizada/Documents/nobari_10_transformer',
    )
data_loader = DataLoader(dataset, shuffle=True, batch_size=1)

In [5]:
def greedy_decode_conditional(model, source, adj_type, max_len, eos_token=torch.tensor([1.0, 1.0])):
    model, source = model.to(device), source.to(device)
    adj_type = adj_type.to(device)

    # Encode the source once
    encoder_output, curve_emb, adj_emb = model.encode(source, adj_type)

    # Initialize decoder input with start token
    decoder_input = torch.ones(1, 1, 2).to(device) * -2.0  # Start token
    
    while decoder_input.size(1) < max_len:
        # Build causal mask (0 for allowed, -inf for blocked)
        decoder_mask = causal_mask(decoder_input.size(1), device)

        # Decode
        decoder_output = model.decode(
            encoder_output,
            None,
            decoder_input,
            decoder_mask
        )
        
        proj_output = model.projection(model.projection_norm(decoder_output))
        
        # Get next token
        next_token = proj_output[:, -1].unsqueeze(1)
        
        # EOS check
        if torch.allclose(next_token.squeeze(), eos_token.to(device), atol=1e-1):
            break

        # Append token
        decoder_input = torch.cat([decoder_input, next_token], dim=1)

    return decoder_input.squeeze(0), curve_emb, adj_emb


def causal_mask(size, device):
    # (n, n) upper-triangular boolean mask
    mask = torch.triu(torch.ones(size, size, device=device), diagonal=1).bool()
    # Add batch dimension -> (1, n, n)
    return mask.unsqueeze(0)




In [None]:
import os
import shutil
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
from contextlib import contextmanager

# Configuration
class Config:
    results_dir = '1024_dim_h_32_results_clip_bs_1024_new'
    max_mech_size = 10
    num_conditions = 200  # Number of different adjacency conditions to test per curve
    prefix_rows = np.array([[0.5, 0.5], [0.6, 0.5]], dtype=np.float32)
    plt_style = {
        'truth_joints': {'color': 'red', 'marker': 'o', 's': 60, 'label': 'Truth Joints'},
        'pred_joints': {'color': 'blue', 'marker': 'x', 's': 60, 'label': 'Pred Joints'},
        'truth_curve': {'color': 'magenta', 'linestyle': '-', 'linewidth': 3, 'label': 'Truth Curve'},
        'pred_curve': {'color': 'cyan', 'linestyle': '-', 'linewidth': 2, 'label': 'Predicted Curve'},
        'cond_curve': {'color': 'green', 'linestyle': '--', 'linewidth': 2, 'label': 'Condition Curve'}
    }

# Invalid joint values to remove
INVALID_JOINTS = np.array([
    [-1.0, -1.0],
    [1.0,  1.0],
], dtype=np.float32)

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device).float()  # assume `model` is already defined elsewhere

@contextmanager
def managed_figure(figsize=(10, 8)):
    """Context manager to ensure figures are properly cleaned up"""
    fig, ax = plt.subplots(figsize=figsize)
    try:
        yield ax
    finally:
        plt.close(fig)

def prepare_results_dir():
    """Prepare clean results directory"""
    if os.path.exists(Config.results_dir):
        shutil.rmtree(Config.results_dir)
    os.makedirs(Config.results_dir, exist_ok=True)

def process_adjacency(adj_tensor):
    """Convert raw adjacency to simulation format (handles both tensors and numpy arrays)"""
    if isinstance(adj_tensor, torch.Tensor):
        adj = adj_tensor.detach().cpu().squeeze().float().numpy()
    elif isinstance(adj_tensor, np.ndarray):
        adj = adj_tensor.squeeze().astype(np.float32)
    else:
        raise TypeError(f"Input must be torch.Tensor or numpy.ndarray, got {type(adj_tensor)}")

    valid_mask = ~np.all(adj == 0, axis=1)
    adj = adj[valid_mask][:, valid_mask]
    node_types = np.diag(adj).astype(bool)
    np.fill_diagonal(adj, 0)
    return adj.astype(np.float32), node_types

def simulate(adj, joints, node_types):
    """Run simulation with error handling"""
    try:
        if adj.dtype != np.float32:
            adj = adj.astype(np.float32)
        if joints.dtype != np.float32:
            joints = joints.astype(np.float32)
        
        result = simulate_mechanism(adj, joints, node_types)
        if result is None:
            return None
        
        trajectory = torch.tensor(result[-1], dtype=torch.float32).unsqueeze(0)
        return preprocess_curves(trajectory).detach().cpu().numpy().squeeze()
    except Exception as e:
        return None

def get_joints(pred_sequence):
    """Prepare joint positions from single decoder output"""
    # Remove start token and EOS token if present
    joints = pred_sequence[1:].detach().cpu().float()
    return np.concatenate([Config.prefix_rows, joints.numpy()], axis=0)

def save_plot(ax, curve_dir, filename, title=None):
    """Helper function to save plots with guaranteed cleanup"""
    try:
        if title:
            ax.set_title(title)
        ax.legend()
        ax.grid(True)
        ax.axis('equal')
        plt.tight_layout()
        plt.savefig(os.path.join(curve_dir, filename))
    finally:
        plt.close()

def filter_invalid_joints(joints: np.ndarray, invalid_values: np.ndarray) -> np.ndarray:
    """Remove any row in `joints` that exactly matches any row in `invalid_values`."""
    mask = np.ones(len(joints), dtype=bool)
    for invalid in invalid_values:
        mask &= ~np.all(joints == invalid, axis=1)
    return joints[mask]

def create_curve_plot(ax, truth_joints, pred_joints, gt_curve=None, pred_curve=None, 
                     cond_curve=None, title=None):
    """Create a standardized plot with joint enumeration"""
    try:
        if gt_curve is not None:
            ax.plot(gt_curve[:, 0], gt_curve[:, 1], **Config.plt_style['truth_curve'])
        if pred_curve is not None:
            ax.plot(pred_curve[:, 0], pred_curve[:, 1], **Config.plt_style['pred_curve'])
        if cond_curve is not None:
            ax.plot(cond_curve[:, 0], cond_curve[:, 1], **Config.plt_style['cond_curve'])
        
        # Scatter points
        ax.scatter(truth_joints[:, 0], truth_joints[:, 1], **Config.plt_style['truth_joints'])
        ax.scatter(pred_joints[:, 0], pred_joints[:, 1], **Config.plt_style['pred_joints'])

        # Annotate truth joints
        for idx, (x, y) in enumerate(truth_joints):
            ax.annotate(str(idx), (x, y), textcoords="offset points", xytext=(5,5), fontsize=8, color="red")

        # Annotate predicted joints
        for idx, (x, y) in enumerate(pred_joints):
            ax.annotate(str(idx), (x, y), textcoords="offset points", xytext=(5,-10), fontsize=8, color="blue")

        if title:
            ax.set_title(title)
        ax.set_xlabel('X Position')
        ax.set_ylabel('Y Position')
    except Exception as e:
        pass


# Main processing
prepare_results_dir()
unique_adjs = np.load("unique_adjacency_matrices.npy")
print(f"Loaded {len(unique_adjs)} unique adjacency matrices")

for batch_idx, batch in enumerate(tqdm(data_loader)):
    if batch_idx == 10:
        break

    # Create subfolder for this curve
    curve_dir = os.path.join(Config.results_dir, f'curve_{batch_idx}')
    os.makedirs(curve_dir, exist_ok=True)
    
    # Prepare data
    curve_data = batch["curve_numerical"].to(device).float()
    gt_adj = batch["adjacency"].to(device).float()

    # Get ground truth joints from single label sequence
    truth_joints = batch["label"].view(-1, 2)[:-1].detach().cpu().float()  # Remove EOS token
    truth_joints = np.concatenate([Config.prefix_rows, truth_joints.numpy()], axis=0)
    
    # Process ground truth adjacency
    gt_adj_processed, gt_node_types = process_adjacency(gt_adj)

    # Get prediction with ground truth adjacency
    with torch.no_grad():
        pred_sequence, curve_emb, adj_emb = greedy_decode_conditional(
            model, curve_data, gt_adj, Config.max_mech_size
        )
    pred_joints = get_joints(pred_sequence)

    # Filter out invalid joints
    pred_joints = filter_invalid_joints(pred_joints, INVALID_JOINTS)
    truth_joints = filter_invalid_joints(truth_joints, INVALID_JOINTS)

    curve_emb = F.normalize(curve_emb, p=2, dim=-1)
    adj_emb   = F.normalize(adj_emb, p=2, dim=-1)

    # similarity = (curve_emb[0] @ adj_emb[0].t()).diagonal()
    sim_val = F.cosine_similarity(curve_emb.view(1, -1), adj_emb.view(1, -1)).item()

    # Run simulations
    gt_curve = simulate(gt_adj_processed, truth_joints, gt_node_types)
    pred_curve = simulate(gt_adj_processed, pred_joints, gt_node_types)
    
    # Plot 1: Ground Truth vs Prediction with GT Adjacency
    with managed_figure() as ax:
        create_curve_plot(
            ax, truth_joints, pred_joints,
            gt_curve=gt_curve, pred_curve=pred_curve,
            title=f'Curve {batch_idx}: GT vs Pred (sim={sim_val:.4f})'
        )
        save_plot(ax, curve_dir, f'gt_vs_pred_sim_{sim_val:.4f}.png')
    
    for cond_idx, cond_adj in enumerate(unique_adjs[:Config.num_conditions]):
        cond_adj_tensor = torch.from_numpy(cond_adj).to(device).float().unsqueeze(0)
        
        with torch.no_grad():
            cond_sequence, cond_curve_emb, cond_adj_emb = greedy_decode_conditional(
                model, curve_data, cond_adj_tensor, Config.max_mech_size
            )

        cond_joints = get_joints(cond_sequence)
        cond_joints = filter_invalid_joints(cond_joints, INVALID_JOINTS)
        cond_adj_processed, cond_node_types = process_adjacency(cond_adj)
        cond_curve = simulate(cond_adj_processed, cond_joints, cond_node_types)
        
        if cond_curve is None:
            plt.close()
            continue

        cond_curve_emb = F.normalize(cond_curve_emb, p=2, dim=-1)
        cond_adj_emb = F.normalize(cond_adj_emb, p=2, dim=-1)

        cond_sim_val = F.cosine_similarity(cond_curve_emb.view(1, -1), cond_adj_emb.view(1, -1)).item()
        # cond_sim_val = similarity.mean().item()

        with managed_figure() as ax:
            create_curve_plot(
                ax, truth_joints, cond_joints,
                gt_curve=gt_curve, cond_curve=cond_curve,
                title=f'Curve {batch_idx} Cond {cond_idx} (sim={cond_sim_val:.4f})'
            )
            save_plot(ax, curve_dir, f'condition_{cond_idx}_sim_{cond_sim_val:.4f}.png')

Loaded 54507 unique adjacency matrices


  0%|          | 10/2257251 [04:01<15119:19:47, 24.11s/it]


In [7]:
# import torch
# from tqdm import tqdm
# import numpy as np
# from collections import defaultdict

# # Track unique matrices and their counts
# unique_matrices = {}  # {hash_key: (matrix, count)}
# counts = defaultdict(int)

# for batch in tqdm(data_loader):
#     gt_adj = batch["adjacency"].float().cpu()  # Original matrices
    
#     for matrix in gt_adj:
#         # Convert to numpy and create hashable key
#         np_matrix = matrix.numpy()
#         matrix_key = tuple(np_matrix.round(4).flatten())  # Round to avoid float precision issues
        
#         # Store first occurrence and count
#         if matrix_key not in unique_matrices:
#             unique_matrices[matrix_key] = np_matrix
#         counts[matrix_key] += 1

# # Save all unique matrices in one file
# all_matrices = np.stack(list(unique_matrices.values()))
# np.save("unique_adjacency_matrices.npy", all_matrices)

# print(f"Saved {len(unique_matrices)} unique matrices")
# print(f"Output shape: {all_matrices.shape}")  # e.g. (25, 10, 10)