In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
from model import SingleImageTransformer
from dataset import BarLinkageDataset 
import matplotlib.pyplot as plt
torch.set_float32_matmul_precision('medium')

from curve_plot import get_pca_inclination, rotate_curve
import scipy.spatial.distance as sciDist
from tqdm import tqdm
import requests
import time
import matplotlib.pyplot as plt
import os

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
# Headless simulator version
index = 0 # local server index 
HEADERS = {"Content-Type": "application/json"}
speedscale = 1
steps = 360
minsteps = int(steps*20/360)

In [None]:
checkpoint_path = "weights/d2048_h32_n6_bs1024_lr0.0001.pth"
data_dir = "/home/anurizada/Documents/processed_dataset_105"
batch_size = 1

dataset = BarLinkageDataset(data_dir=data_dir)
dataset = torch.utils.data.Subset(dataset, range(1000000))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
checkpoint = torch.load(checkpoint_path, map_location=device)
model_config = checkpoint['model_config']

# Initialize model
model = SingleImageTransformer(
    tgt_seq_len=model_config['tgt_seq_len'],
    d_model=model_config['d_model'],
    h=model_config['h'],
    N=model_config['N'],
    num_labels=model_config['num_labels'],
    vocab_size=model_config['vocab_size'] + 1,
).to(device)

# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

In [None]:
import json
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import requests
import time
import os
from curve_plot import get_pca_inclination, rotate_curve

# ===================================
# CONFIGURATION
# ===================================
label_mapping_path = "/home/anurizada/Documents/processed_dataset_105/label_mapping.json"
with open(label_mapping_path, "r") as f:
    label_mapping = json.load(f)
index_to_label = label_mapping["index_to_label"]

# --- coordinate binning setup ---
class CoordinateBinner:
    def __init__(self, kappa=1.0, num_bins=200):
        self.kappa = kappa
        self.num_bins = num_bins
        self.bin_edges = np.linspace(-kappa, kappa, num_bins + 1)
        self.bin_centers = (self.bin_edges[:-1] + self.bin_edges[1:]) / 2

    def bin_to_value_torch(self, bin_index_tensor):
        bin_index_tensor = torch.clamp(bin_index_tensor, 0, self.num_bins - 1)
        bin_centers_tensor = torch.tensor(self.bin_centers, device=bin_index_tensor.device, dtype=torch.float32)
        return bin_centers_tensor[bin_index_tensor]

# from your label_mapping.json
NUM_BINS = label_mapping["num_bins"]
BIN_OFFSET = label_mapping["special_tokens"]["NUM_SPECIAL_TOKENS"]  # usually 3
binner = CoordinateBinner(kappa=1.0, num_bins=NUM_BINS)

print(f"Loaded label mapping with {len(index_to_label)} mechanism types.")
print(f"Coordinate binning: {NUM_BINS} bins, BIN_OFFSET={BIN_OFFSET}")

print('Started')
start_time = time.time()

eos_token = 1
pad_token = 2

# ===================================
# BATCH INFERENCE
# ===================================
def predict_batch(model, dataloader, max_samples=100, device="cuda"):
    all_predictions, all_targets, all_labels = [], [], []
    samples_processed = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Running batch inference"):
            if samples_processed >= max_samples:
                break

            decoder_input = batch["decoder_input_discrete"].to(device)
            decoder_mask = batch["causal_mask"].to(device)
            images = batch["images"].to(device)
            encoded_labels = batch["encoded_labels"].to(device)
            target_tokens = batch["labels_discrete"].to(device)

            predictions, _, _ = model(decoder_input, decoder_mask, images, encoded_labels)
            pred_tokens = predictions.argmax(dim=-1)

            for i in range(pred_tokens.shape[0]):
                if samples_processed >= max_samples:
                    break

                pred_seq = pred_tokens[i].cpu().numpy()
                target_seq = target_tokens[i].cpu().numpy()

                valid_mask = target_seq != pad_token
                pred_seq = pred_seq[valid_mask]
                target_seq = target_seq[valid_mask]

                if eos_token in pred_seq:
                    pred_seq = pred_seq[: np.where(pred_seq == eos_token)[0][0]]
                if eos_token in target_seq:
                    target_seq = target_seq[: np.where(target_seq == eos_token)[0][0]]

                # get label index from one-hot or already-int encoded tensor
                label_idx = encoded_labels[i].item()

                all_predictions.append(pred_seq)
                all_targets.append(target_seq)
                all_labels.append(label_idx)
                samples_processed += 1

    print(f"\nProcessed {samples_processed} samples total")
    return all_predictions, all_targets, all_labels


# ===================================
# RUN INFERENCE
# ===================================
max_samples = 20
predictions, targets, label_indices = predict_batch(model, dataloader, max_samples=max_samples, device=device)
print(label_indices)

# ===================================
# CONVERT BINS → CONTINUOUS COORDINATES
# ===================================
def bins_to_continuous(seq, binner, bin_offset):
    seq = np.array(seq)
    # remove special tokens (anything below BIN_OFFSET)
    numeric_mask = seq >= bin_offset
    seq_numeric = seq[numeric_mask] - bin_offset
    seq_tensor = torch.tensor(seq_numeric, dtype=torch.long)
    seq_cont = binner.bin_to_value_torch(seq_tensor).cpu().numpy()
    return seq_cont


# ===================================
# SIMULATION LOOP
# ===================================
for idx, (pred_seq, target_seq, label_idx) in enumerate(zip(predictions, targets, label_indices)):
    mech_type = index_to_label[str(label_idx)]

    # Convert discrete bins → continuous coords
    pred_cont = bins_to_continuous(pred_seq, binner, BIN_OFFSET)
    target_cont = bins_to_continuous(target_seq, binner, BIN_OFFSET)

    # Drop odd lengths to form (N, 2)
    if len(pred_cont) % 2 == 1:
        pred_cont = pred_cont[:-1]
    if len(target_cont) % 2 == 1:
        target_cont = target_cont[:-1]

    pred_joints = pred_cont.reshape(-1, 2)
    gt_joints = target_cont.reshape(-1, 2)
    num_joints = gt_joints.shape[0]

    j_points_gt = [gt_joints[i].tolist() for i in range(num_joints)]
    j_points_pred = [pred_joints[i].tolist() for i in range(min(num_joints, pred_joints.shape[0]))]
    couplerCurveIndex = num_joints - 1  # last joint as coupler

    # --- ORIGINAL MECHANISM SIMULATION ---
    exampleData = {
        "params": j_points_gt,
        "type": mech_type,
        "speedScale": speedscale,
        "steps": steps,
        "relativeTolerance": 0.1,
    }

    if "Type" in mech_type:
        API_ENDPOINT = f"http://localhost:4000/simulation-8bar"
    else:
        API_ENDPOINT = f"http://localhost:4000/simulation"

    try:
        if "Type" in mech_type:
            exampleData = [exampleData]
            temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps(exampleData)).json()
        else:
            temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps([exampleData])).json()

        time.sleep(0.05)
    except ValueError:
        continue

    if temp[0]["poses"] is None:
        continue

    print("Known type: ", mech_type)

    P = np.array(temp[0]["poses"])
    if P.shape[0] < minsteps:
        continue

    original_x, original_y = P[:, couplerCurveIndex, 0], P[:, couplerCurveIndex, 1]
    original_mean_x, original_mean_y = np.mean(original_x), np.mean(original_y)
    original_denom = np.sqrt(np.var(original_x) + np.var(original_y))
    original_phi = -get_pca_inclination(original_x, original_y)

    # --- PREDICTED MECHANISM SIMULATION ---
    if "Type" in mech_type:
        API_ENDPOINT = f"http://localhost:4000/simulation-8bar"
    else:
        API_ENDPOINT = f"http://localhost:4000/simulation"
    
    exampleData["params"] = j_points_pred
    try:
        if "Type" in mech_type:
            exampleData = [exampleData]
            temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps(exampleData)).json()
        else:
            temp = requests.post(url=API_ENDPOINT, headers=HEADERS, data=json.dumps([exampleData])).json()

        time.sleep(0.05)
    except ValueError:
        continue

    if temp[0]["poses"] is None:
        continue

    P = np.array(temp[0]["poses"])
    if P.shape[0] < minsteps:
        continue

    generated_x, generated_y = P[:, couplerCurveIndex, 0], P[:, couplerCurveIndex, 1]
    if np.isnan(generated_x).any() or np.isinf(generated_x).any() or len(generated_x) < 30:
        continue

    # --- ALIGN CURVES ---
    generated_phi = -get_pca_inclination(generated_x, generated_y)
    rotation = generated_phi - original_phi
    generated_x, generated_y = rotate_curve(generated_x, generated_y, rotation)

    generated_denom = np.sqrt(np.var(generated_x) + np.var(generated_y))
    scale_factor = original_denom / generated_denom
    generated_x, generated_y = np.multiply(generated_x, scale_factor), np.multiply(generated_y, scale_factor)

    generated_mean_x, generated_mean_y = np.mean(generated_x), np.mean(generated_y)
    translation_x, translation_y = generated_mean_x - original_mean_x, generated_mean_y - original_mean_y
    generated_x, generated_y = np.subtract(generated_x, translation_x), np.subtract(generated_y, translation_y)

    # --- PLOT BOTH CURVES ---
    plt.plot(original_x, original_y, "r", label="original")
    # plt.plot(generated_x, generated_y, "g", label="predicted")
    plt.title(f"Mechanism: {mech_type}")
    plt.axis("equal")
    plt.legend()

    out_dir = f"results/{idx}"
    os.makedirs(out_dir, exist_ok=True)
    plt.savefig(f"{out_dir}/{idx}_{mech_type}_batch_pred.jpg")
    plt.clf()

print(f"Finished in {time.time() - start_time:.2f} seconds")


In [None]:
# print('Started')
# start_time = time.time()

# for num, (images, joints, labels) in enumerate(data_loader):
    
#     if num == 1:
#         break
    
#     for num_1, (image, joint, label) in enumerate(zip(images, joints, labels)):
        
#         if num_1 == 20:
#             break
            
#         with torch.no_grad():
#             image = image.repeat(batch_size, 1).to(device)

#             indices = torch.arange(batch_size).unsqueeze(1)
#             indices = indices[torch.randperm(batch_size)]
#             random_labels = indices % 3
#             random_labels = random_labels.to(dtype=torch.float32)
#             # random_labels = torch.ones(batch_size, 1) * label
#             random_labels = random_labels.to(device)
            
#             labels_encoded = model.label_encode(random_labels)
#             # images_encoded = model.image_encode(image)
            
#             conditions = model.condition_cross_attention(labels_encoded, image)
                                
#             z = torch.randn([batch_size, attention_dim]).to(device)
#             output = model.decoder_cross_attention(z, conditions)

#             for decoder_self_attention in model.decoder_self_attentions:
#                 output = decoder_self_attention(output)

#             pred_joints = model.joint_predictor(output)
    
#         pred_joints = pred_joints.cpu().detach().numpy()
#         random_labels = random_labels.cpu().detach().numpy()
#         joint = joint.cpu().detach().numpy()
#         label = label.numpy()
        
#         pred_joints = pred_joints[:100]
#         random_labels = random_labels[:100]
        
#         for num_2, (pred_joint, random_label) in enumerate(zip(pred_joints, random_labels)):  
            
#             j_0, j_1, j_2, j_3, j_4 = [0.0, 0.0], [float(joint[0]), float(joint[1])], [float(joint[2]), float(joint[3])], \
#                           [float(joint[4]), float(joint[5])], [float(joint[6]), float(joint[7])]
            
#             couplerCurveIndex = 4
#             posInit = [j_0, j_1, j_3, j_2, j_4] 
            
#             if label[0] == 0:
#                 mech_type = "RRRR"
            
#             elif label[0] == 1:
#                 mech_type = "RRRP"
                
#             elif label[0] == 2:
#                 mech_type = "RRPR2"              
#                 posInit = [j_0, j_1, j_2, j_3, j_4]     
                                        
#             exampleData = {
#                 'params': posInit, 
#                 'type': mech_type,
#                 'speedScale': speedscale, 
#                 'steps': steps,
#                 'relativeTolerance':0.1
#             }

#             try:
#                 temp = requests.post(url = API_ENDPOINT, headers=HEADERS, data = json.dumps(exampleData)).json()
#                 time.sleep(0.05)

#             except ValueError as v:
#                 for i in range(3):
#                     time.sleep(2)
#                     try:
#                         temp = requests.post(url = API_ENDPOINT, headers=HEADERS, data = json.dumps(exampleData)).json()
#                         break
#                     except ValueError as v2:
#                         plt.clf()
#                         print('wtf')
            
#             if temp is None:
#                 plt.clf()
#                 continue
                
#             P = np.array(temp['poses'])
            
#             original_x, original_y = P[:, couplerCurveIndex, 0], P[:, couplerCurveIndex, 1]
            
#             original_mean_x, original_mean_y = np.mean(original_x), np.mean(original_y)
#             original_denom = np.sqrt(np.var(original_x, axis=0, keepdims=True) + np.var(original_y, axis=0, keepdims=True))
#             original_phi = -get_pca_inclination(original_x, original_y)
            
#             if P.shape[0] >= minsteps:
#                 plt.plot(original_x, original_y, 'r', label='original')
               
#                 j_0, j_1, j_2, j_3, j_4 = [0.0, 0.0], [float(pred_joint[0]), float(pred_joint[1])], [float(pred_joint[2]), float(pred_joint[3])], \
#                                           [float(pred_joint[4]), float(pred_joint[5])], [float(pred_joint[6]), float(pred_joint[7])]
                
#                 posInit = [j_0, j_1, j_3, j_2, j_4] 
#                 couplerCurveIndex = 4
                                             
#                 if random_label[0] == 0.0:
#                     mech_type = "RRRR"

#                 elif random_label[0] == 1.0:
#                     mech_type = "RRRP"
                
#                 elif random_label[0] == 2.0:
#                     mech_type = "RRPR2"
#                     posInit = [j_0, j_1, j_2, j_3, j_4]     
                                
#             exampleData = {
#                 'params': posInit, 
#                 'type': mech_type,
#                 'speedScale': speedscale, 
#                 'steps': steps,
#                 'relativeTolerance':0.1
#             }

#             try:
#                 temp = requests.post(url = API_ENDPOINT, headers=HEADERS, data = json.dumps(exampleData)).json()
#                 time.sleep(0.05)

#             except ValueError as v:
#                 for i in range(3):
#                     time.sleep(2)
#                     try:
#                         temp = requests.post(url = API_ENDPOINT, headers=HEADERS, data = json.dumps(exampleData)).json()
#                         break
#                     except ValueError as v2:
#                         plt.clf()
#                         print('wtf')
            
#             if temp is None:
#                 plt.clf()
#                 continue
                
#             P = np.array(temp['poses'])
            
#             generated_x, generated_y = P[:, couplerCurveIndex, 0], P[:, couplerCurveIndex, 1]
                    
#             if np.isnan(generated_x).any() or np.isinf(generated_x).any() or len(generated_x) < 30:
#                 plt.clf()
#                 continue
            
#             # Rotating
#             generated_phi = -get_pca_inclination(generated_x, generated_y)
#             rotation = generated_phi - original_phi
#             generated_x, generated_y = rotate_curve(generated_x, generated_y, rotation)
            
#             j_0 = rotate_curve(j_0[0], j_0[1], rotation)
#             j_1 = rotate_curve(j_1[0], j_1[1], rotation)
#             j_2 = rotate_curve(j_2[0], j_2[1], rotation)
#             j_3 = rotate_curve(j_3[0], j_3[1], rotation)
#             j_4 = rotate_curve(j_4[0], j_4[1], rotation)
                        
#             # Scaling
#             generated_denom = np.sqrt(np.var(generated_x, axis=0, keepdims=True) + np.var(generated_y, axis=0, keepdims=True))
#             scale_factor = original_denom / generated_denom
#             generated_x, generated_y = np.multiply(generated_x, scale_factor), np.multiply(generated_y, scale_factor)
            
#             j_0 = [np.multiply(j_0[0], scale_factor), np.multiply(j_0[1], scale_factor)]
#             j_1 = [np.multiply(j_1[0], scale_factor), np.multiply(j_1[1], scale_factor)]
#             j_2 = [np.multiply(j_2[0], scale_factor), np.multiply(j_2[1], scale_factor)]
#             j_3 = [np.multiply(j_3[0], scale_factor), np.multiply(j_3[1], scale_factor)]
#             j_4 = [np.multiply(j_4[0], scale_factor), np.multiply(j_4[1], scale_factor)]
            
#             # Translating
#             generated_mean_x, generated_mean_y = np.mean(generated_x), np.mean(generated_y)
#             translation_x, translation_y = generated_mean_x - original_mean_x, generated_mean_y - original_mean_y
#             generated_x, generated_y = np.subtract(generated_x, translation_x), np.subtract(generated_y, translation_y)
            
#             j_0 = [np.subtract(j_0[0], translation_x), np.subtract(j_0[1], translation_y)]
#             j_1 = [np.subtract(j_1[0], translation_x), np.subtract(j_1[1], translation_y)]
#             j_2 = [np.subtract(j_2[0], translation_x), np.subtract(j_2[1], translation_y)]
#             j_3 = [np.subtract(j_3[0], translation_x), np.subtract(j_3[1], translation_y)]
#             j_4 = [np.subtract(j_4[0], translation_x), np.subtract(j_4[1], translation_y)]
            
#             if P.shape[0] >= minsteps:
#                 plt.plot(generated_x, generated_y, 'g', label='predicted')
                
#                 if random_label[0] == 0.0:
#                     plt.plot(j_0[0], j_0[1], marker="x", markersize=10, markeredgecolor='red',
#                          markerfacecolor='red')
#                     plt.plot(j_3[0], j_3[1], marker="x", markersize=10, markeredgecolor='green',
#                              markerfacecolor='green')

#                     x = [j_1[0], j_4[0], j_2[0], j_1[0]]
#                     y = [j_1[1], j_4[1], j_2[1], j_1[1]]
#                     plt.fill(x, y, color='pink')
                    
#                     plt.plot([j_0[0], j_1[0]], [j_0[1], j_1[1]], color='green')
#                     plt.plot([j_2[0], j_3[0]], [j_2[1], j_3[1]], color='green')

#                 elif random_label[0] == 1.0:
#                     plt.plot(j_0[0], j_0[1], marker="x", markersize=10, markeredgecolor='red',
#                              markerfacecolor='red')
#                     plt.plot(j_2[0], j_2[1], marker="x", markersize=10, markeredgecolor='green',
#                              markerfacecolor='green')
#                     plt.plot(j_3[0], j_3[1], marker="x", markersize=10, markeredgecolor='green',
#                              markerfacecolor='green')
    
#                     j_5 = [(j_2[0] + j_3[0]) / 2 , (j_2[1] + j_3[1]) / 2]
                    
#                     x = [j_1[0], j_4[0], j_5[0], j_1[0]]
#                     y = [j_1[1], j_4[1], j_5[1], j_1[1]]
#                     plt.fill(x, y, color='pink')
                    
#                     plt.plot([j_0[0], j_1[0]], [j_0[1], j_1[1]], color='green')
#                     plt.plot([j_2[0], j_3[0]], [j_2[1], j_3[1]], color='green')

#                 elif random_label[0] == 2.0:
#                     plt.plot(j_0[0], j_0[1], marker="x", markersize=10, markeredgecolor='red',
#                              markerfacecolor='red')
#                     plt.plot(j_2[0], j_2[1], marker="x", markersize=10, markeredgecolor='green',
#                              markerfacecolor='green')

#                     x = [j_1[0], j_4[0], j_3[0], j_1[0]]
#                     y = [j_1[1], j_4[1], j_3[1], j_1[1]]
#                     plt.fill(x, y, color='pink')

#                     plt.plot([j_0[0], j_1[0]], [j_0[1], j_1[1]], color='green')
#                     plt.plot([j_2[0], j_3[0]], [j_2[1], j_3[1]], color='green')
                                                        
#             if not os.path.exists('results'):
#                 os.makedirs('results')
#             if not os.path.exists('results/{} {}'.format(num_1, label[0])):
#                 os.makedirs('results/{} {}'.format(num_1, label[0]))
            
#             plt.axis('equal')
#             plt.legend()
#             plt.savefig('results/{} {}/{} {}.jpg'.format(num_1, label[0], num_2, random_label[0]))
#             plt.clf()


In [None]:
# !rm results.zip
# !zip -r results.zip results