In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from model import Transformer
from dataset import TransformerDataset
import matplotlib.pyplot as plt
import os 
import numpy as np
import shutil
from sim import simulate_mechanism
from utils import preprocess_curves
import random 

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_ddp_model(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Extract the state dict if it's stored under "model_state_dict"
    if "model_state_dict" in checkpoint:
        state_dict = checkpoint["model_state_dict"]
    else:
        state_dict = checkpoint

    # Remove DDP wrapper key, if present
    new_state_dict = {}
    for k, v in state_dict.items():
        new_key = k.replace("module.", "") if k.startswith("module.") else k
        new_state_dict[new_key] = v

    model.load_state_dict(new_state_dict)
    return model

# Example usage:
model = Transformer(output_size=2,
                    tgt_seq_len=10,
                    d_model=1024,
                    h=32,
                    N=6)

model = load_ddp_model(model, "weights/clip/d1024_h32_bs512_lr0.0001_best.pth")
model.eval()


In [None]:
dataset = TransformerDataset(
        data_dir='/home/anurizada/Documents/nobari_10_transformer',
    )
data_loader = DataLoader(dataset, shuffle=True, batch_size=1)

In [None]:
def greedy_decode_conditional(model, source, adj_type, max_len, eos_token=torch.tensor([1.0, 1.0])):
    model, source = model.to(device), source.to(device)
    adj_type = adj_type.to(device)

    # Encode the source once with conditioning on mech_type
    encoder_output, curve_emb, adj_emb = model.encode(source, adj_type)

    # Initial inputs for the decoders
    decoder_input_first = torch.ones(1, 1, 2).to(device) * -2.0 # Start token for decoder 1
    decoder_input_second = torch.ones(1, 1, 2).to(device) * -2.0 # Start token for decoder 2

    # Decoded outputs for both decoders
    decoded_first = []
    decoded_second = []

    # First decoding loop
    while decoder_input_first.size(1) < max_len // 2:
        # Build causal mask for decoder 1
        decoder_mask_first = causal_mask(decoder_input_first.size(1)).type_as(decoder_input_first).to(device)

        # Decode using the first decoder with its specific positional encoding
        decoder_output_first = model.decode(
            model.decoder_first,
            encoder_output,
            None,
            decoder_input_first,
            decoder_mask_first,
            model.decoder_positional_encoding_first  # Pass the first positional encoding
        )
        proj_output_first = model.projection_first(model.projection_norm_first(decoder_output_first))

        # Get the next predicted token
        next_token_first = proj_output_first[:, -1].unsqueeze(0)

        # Append to the decoded sequence
        decoded_first.append(next_token_first.squeeze(0).detach().cpu())

        # Check for the EOS token
        if torch.allclose(next_token_first.squeeze(), eos_token.to(device), atol=1e-1):
            break

        # Append the next token to the input sequence for further decoding
        decoder_input_first = torch.cat([decoder_input_first, next_token_first], dim=1)

    # Second decoding loop
    while decoder_input_second.size(1) < max_len // 2:
        # Build causal mask for decoder 2
        decoder_mask_second = causal_mask(decoder_input_second.size(1)).type_as(decoder_input_second).to(device)

        # Decode using the second decoder with its specific positional encoding
        decoder_output_second = model.decode(
            model.decoder_second,
            encoder_output,
            None,
            decoder_input_second,
            decoder_mask_second,
            model.decoder_positional_encoding_second  # Pass the second positional encoding
        )
        proj_output_second = model.projection_second(model.projection_norm_second(decoder_output_second))

        # Get the next predicted token
        next_token_second = proj_output_second[:, -1].unsqueeze(0)

        # Append to the decoded sequence
        decoded_second.append(next_token_second.squeeze(0).detach().cpu())

        # Check for the EOS token
        if torch.allclose(next_token_second.squeeze(), eos_token.to(device), atol=1e-1):
            break

        # Append the next token to the input sequence for further decoding
        decoder_input_second = torch.cat([decoder_input_second, next_token_second], dim=1)

    return decoder_input_first.squeeze(0), decoder_input_second.squeeze(0), curve_emb, adj_emb


def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

In [None]:
import os
import shutil
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
from contextlib import contextmanager

# Configuration
class Config:
    results_dir = '1024_dim_h_32_results_clip_bs_512'
    max_mech_size = 10
    num_conditions = 200  # Number of different adjacency conditions to test per curve
    prefix_rows = np.array([[0.5, 0.5], [0.6, 0.5]], dtype=np.float32)
    plt_style = {
        'truth_joints': {'color': 'red', 'marker': 'o', 's': 60, 'label': 'Truth Joints'},
        'pred_joints': {'color': 'blue', 'marker': 'x', 's': 60, 'label': 'Pred Joints'},
        'truth_curve': {'color': 'magenta', 'linestyle': '-', 'linewidth': 3, 'label': 'Truth Curve'},
        'pred_curve': {'color': 'cyan', 'linestyle': '-', 'linewidth': 2, 'label': 'Predicted Curve'},
        'cond_curve': {'color': 'green', 'linestyle': '--', 'linewidth': 2, 'label': 'Condition Curve'}
    }

# Invalid joint values to remove
INVALID_JOINTS = np.array([
    [-1.0, -1.0],
    [1.0,  1.0],
], dtype=np.float32)

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device).float()  # assume `model` is already defined elsewhere

@contextmanager
def managed_figure(figsize=(10, 8)):
    """Context manager to ensure figures are properly cleaned up"""
    fig, ax = plt.subplots(figsize=figsize)
    try:
        yield ax
    finally:
        plt.close(fig)

def prepare_results_dir():
    """Prepare clean results directory"""
    if os.path.exists(Config.results_dir):
        shutil.rmtree(Config.results_dir)
    os.makedirs(Config.results_dir, exist_ok=True)

def process_adjacency(adj_tensor):
    """Convert raw adjacency to simulation format"""
    adj = adj_tensor.detach().cpu().squeeze().float().numpy()
    valid_mask = ~np.all(adj == 0, axis=1)
    adj = adj[valid_mask][:, valid_mask]
    node_types = np.diag(adj).astype(bool)
    np.fill_diagonal(adj, 0)
    return adj.astype(np.float32), node_types

def simulate(adj, joints, node_types):
    """Run simulation with error handling"""
    try:
        if adj.dtype != np.float32:
            adj = adj.astype(np.float32)
        if joints.dtype != np.float32:
            joints = joints.astype(np.float32)
        
        result = simulate_mechanism(adj, joints, node_types)
        if result is None:
            return None
        
        trajectory = torch.tensor(result[-1], dtype=torch.float32).unsqueeze(0)
        return preprocess_curves(trajectory).detach().cpu().numpy().squeeze()
    except Exception as e:
        print(f"Simulation failed: {str(e)}")
        return None

def get_joints(pred_first, pred_second):
    """Combine and prepare joint positions"""
    joints = torch.cat([
        pred_first[1:].detach().cpu().float(),
        pred_second[1:].detach().cpu().float()
    ], dim=0)
    return np.concatenate([Config.prefix_rows, joints.numpy()], axis=0)

def save_plot(ax, curve_dir, filename, title=None):
    """Helper function to save plots with guaranteed cleanup"""
    try:
        if title:
            ax.set_title(title)
        ax.legend()
        ax.grid(True)
        ax.axis('equal')
        plt.tight_layout()
        plt.savefig(os.path.join(curve_dir, filename))
    finally:
        plt.close()

def filter_invalid_joints(joints: np.ndarray, invalid_values: np.ndarray) -> np.ndarray:
    """Remove any row in `joints` that exactly matches any row in `invalid_values`."""
    mask = np.ones(len(joints), dtype=bool)
    for invalid in invalid_values:
        mask &= ~np.all(joints == invalid, axis=1)
    return joints[mask]

def create_curve_plot(ax, truth_joints, pred_joints, gt_curve=None, pred_curve=None, 
                     cond_curve=None, title=None):
    """Create a standardized plot with proper error handling"""
    try:
        if gt_curve is not None:
            ax.plot(gt_curve[:, 0], gt_curve[:, 1], **Config.plt_style['truth_curve'])
        if pred_curve is not None:
            ax.plot(pred_curve[:, 0], pred_curve[:, 1], **Config.plt_style['pred_curve'])
        if cond_curve is not None:
            ax.plot(cond_curve[:, 0], cond_curve[:, 1], **Config.plt_style['cond_curve'])
        
        ax.scatter(truth_joints[:, 0], truth_joints[:, 1], **Config.plt_style['truth_joints'])
        ax.scatter(pred_joints[:, 0], pred_joints[:, 1], **Config.plt_style['pred_joints'])
        
        if title:
            ax.set_title(title)
        ax.set_xlabel('X Position')
        ax.set_ylabel('Y Position')
    except Exception as e:
        print(f"Plotting error: {str(e)}")

# Main processing
prepare_results_dir()

for batch_idx, batch in enumerate(tqdm(data_loader)):
    if batch_idx == 100:
        break

    # Create subfolder for this curve
    curve_dir = os.path.join(Config.results_dir, f'curve_{batch_idx}')
    os.makedirs(curve_dir, exist_ok=True)
    
    # Prepare data
    curve_data = batch["curve_numerical"].to(device).float()
    gt_adj = batch["adjacency"].to(device).float()
    
    # Get ground truth joints
    truth_joints = torch.cat([
        batch["label_first"].view(-1, 2)[:-1].detach().cpu().float(),
        batch["label_second"].view(-1, 2)[:-1].detach().cpu().float()
    ], dim=0)
    truth_joints = np.concatenate([Config.prefix_rows, truth_joints.numpy()], axis=0)
    
    # Process ground truth adjacency
    gt_adj_processed, gt_node_types = process_adjacency(gt_adj)
    
    # Get prediction with ground truth adjacency
    with torch.no_grad():
        pred_first, pred_second, curve_emb, adj_emb = greedy_decode_conditional(
            model, curve_data, gt_adj, Config.max_mech_size
        )
    pred_joints = get_joints(pred_first, pred_second)

    # Filter out invalid joints
    pred_joints = filter_invalid_joints(pred_joints, INVALID_JOINTS)
    truth_joints = filter_invalid_joints(truth_joints, INVALID_JOINTS)

    # Compute cosine similarity
    cos_sim = F.cosine_similarity(curve_emb, adj_emb, dim=-1)
    sim_val = cos_sim.mean().item()

    # Run simulations
    gt_curve = simulate(gt_adj_processed, truth_joints, gt_node_types)
    pred_curve = simulate(gt_adj_processed, pred_joints, gt_node_types)
    
    # Plot 1: Ground Truth vs Prediction with GT Adjacency
    with managed_figure() as ax:
        create_curve_plot(
            ax, truth_joints, pred_joints,
            gt_curve=gt_curve, pred_curve=pred_curve,
            title=f'Curve {batch_idx}: GT vs Pred (sim={sim_val:.4f})'
        )
        save_plot(ax, curve_dir, f'gt_vs_pred_sim_{sim_val:.4f}.png')
    
    # Plot each condition separately
    for cond_idx in range(Config.num_conditions):
        cond_adj = dataset[(batch_idx + cond_idx) % len(dataset)]["adjacency"].to(device).float()
        
        with torch.no_grad():
            cond_first, cond_second, cond_curve_emb, cond_adj_emb = greedy_decode_conditional(
                model, curve_data, cond_adj.unsqueeze(0), Config.max_mech_size
            )
        cond_joints = get_joints(cond_first, cond_second)
        cond_joints = filter_invalid_joints(cond_joints, INVALID_JOINTS)
        cond_adj_processed, cond_node_types = process_adjacency(cond_adj)
        cond_curve = simulate(cond_adj_processed, cond_joints, cond_node_types)
        
        # Skip this condition if cond_curve is not generated
        if cond_curve is None:
            plt.close()  # Ensure any existing figure is closed
            continue  # Skip to next condition
            
        # Compute cosine similarity for this condition
        cond_cos_sim = F.cosine_similarity(cond_curve_emb, cond_adj_emb, dim=-1)
        cond_sim_val = cond_cos_sim.mean().item()

        with managed_figure() as ax:
            create_curve_plot(
                ax, truth_joints, cond_joints,
                gt_curve=gt_curve, cond_curve=cond_curve,
                title=f'Curve {batch_idx} Cond {cond_idx} (sim={cond_sim_val:.4f})'
            )
            save_plot(ax, curve_dir, f'condition_{cond_idx}_sim_{cond_sim_val:.4f}.png')