In [None]:
import torch
from torch.utils.data import DataLoader
from dataset import SingleDataset
import torchvision.transforms as transforms
from attention_cvae import cVAE
import numpy as np
from constants import *
import matplotlib.pyplot as plt
torch.set_float32_matmul_precision('medium')

from curve_plot import get_pca_inclination, rotate_curve
import scipy.spatial.distance as sciDist
from tqdm import tqdm
import requests
import time
import matplotlib.pyplot as plt
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Headless simulator version
index = 0 # local server index 
API_ENDPOINT = 'http://localhost:400' + str(index) + '/simulation' # NOT THE LS VERSION
HEADERS = {"Content-Type": "application/json"}
speedscale = 1
steps = 360
minsteps = int(steps*20/360)

In [None]:
dataset = SingleDataset(transform=transforms.Compose([transforms.ToTensor(), 
                                                      transforms.Resize((32, 32), antialias=True), 
                                                      transforms.Lambda(lambda x: torch.flatten(x)), ]))

data_loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, drop_last=True)

In [None]:
model = cVAE.load_from_checkpoint('weights/beta-25.ckpt', learning_rate=learning_rate, \
                                  encoder_dim=attention_dim, num_joints=num_joints, beta=beta_cvae, \
                                  n_blocks=n_attention_blocks, n_heads=n_heads, batch_size=batch_size, num_classes=num_classes).eval()

In [None]:
import torch
import cv2
from torchvision import transforms

# Define the transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32), antialias=True),
    transforms.Lambda(lambda x: torch.flatten(x)),
])

# Define the function to be timed
def process_image():
    img = cv2.imread('example.jpg', 0)
    img = cv2.bitwise_not(img)

    img_tensor = transform(img).unsqueeze(0)

    with torch.no_grad():
        image = img_tensor.repeat(batch_size, 1).to(device)

        indices = torch.arange(batch_size).unsqueeze(1)
        indices = indices[torch.randperm(batch_size)]
        random_labels = indices % 3
        random_labels = random_labels.to(dtype=torch.float32)
        random_labels = random_labels.to(device)

        labels_encoded = model.label_encode(random_labels)
        # images_encoded = model.image_encode(image)

        conditions = model.condition_cross_attention(labels_encoded, image)

        z = torch.randn([batch_size, attention_dim]).to(device)
        output = model.decoder_cross_attention(z, conditions)

        for decoder_self_attention in model.decoder_self_attentions:
            output = decoder_self_attention(output)

        pred_joints = model.joint_predictor(output)

# Time the function using %timeit
%timeit process_image()

In [None]:
print('Started')
start_time = time.time()

for num, (images, joints, labels) in enumerate(data_loader):
    
    if num == 1:
        break
    
    for num_1, (image, joint, label) in enumerate(zip(images, joints, labels)):
        
        if num_1 == 20:
            break
            
        with torch.no_grad():
            image = image.repeat(batch_size, 1).to(device)

            indices = torch.arange(batch_size).unsqueeze(1)
            indices = indices[torch.randperm(batch_size)]
            random_labels = indices % 3
            random_labels = random_labels.to(dtype=torch.float32)
            # random_labels = torch.ones(batch_size, 1) * label
            random_labels = random_labels.to(device)
            
            labels_encoded = model.label_encode(random_labels)
            # images_encoded = model.image_encode(image)
            
            conditions = model.condition_cross_attention(labels_encoded, image)
                                
            z = torch.randn([batch_size, attention_dim]).to(device)
            output = model.decoder_cross_attention(z, conditions)

            for decoder_self_attention in model.decoder_self_attentions:
                output = decoder_self_attention(output)

            pred_joints = model.joint_predictor(output)
    
        pred_joints = pred_joints.cpu().detach().numpy()
        random_labels = random_labels.cpu().detach().numpy()
        joint = joint.cpu().detach().numpy()
        label = label.numpy()
        
        pred_joints = pred_joints[:100]
        random_labels = random_labels[:100]
        
        for num_2, (pred_joint, random_label) in enumerate(zip(pred_joints, random_labels)):  
            
            j_0, j_1, j_2, j_3, j_4 = [0.0, 0.0], [float(joint[0]), float(joint[1])], [float(joint[2]), float(joint[3])], \
                          [float(joint[4]), float(joint[5])], [float(joint[6]), float(joint[7])]
            
            couplerCurveIndex = 4
            posInit = [j_0, j_1, j_3, j_2, j_4] 
            
            if label[0] == 0:
                mech_type = "RRRR"
            
            elif label[0] == 1:
                mech_type = "RRRP"
                
            elif label[0] == 2:
                mech_type = "RRPR2"              
                posInit = [j_0, j_1, j_2, j_3, j_4]     
                                        
            exampleData = {
                'params': posInit, 
                'type': mech_type,
                'speedScale': speedscale, 
                'steps': steps,
                'relativeTolerance':0.1
            }

            try:
                temp = requests.post(url = API_ENDPOINT, headers=HEADERS, data = json.dumps(exampleData)).json()
                time.sleep(0.05)

            except ValueError as v:
                for i in range(3):
                    time.sleep(2)
                    try:
                        temp = requests.post(url = API_ENDPOINT, headers=HEADERS, data = json.dumps(exampleData)).json()
                        break
                    except ValueError as v2:
                        plt.clf()
                        print('wtf')
            
            if temp is None:
                plt.clf()
                continue
                
            P = np.array(temp['poses'])
            
            original_x, original_y = P[:, couplerCurveIndex, 0], P[:, couplerCurveIndex, 1]
            
            original_mean_x, original_mean_y = np.mean(original_x), np.mean(original_y)
            original_denom = np.sqrt(np.var(original_x, axis=0, keepdims=True) + np.var(original_y, axis=0, keepdims=True))
            original_phi = -get_pca_inclination(original_x, original_y)
            
            if P.shape[0] >= minsteps:
                plt.plot(original_x, original_y, 'r', label='original')
               
                j_0, j_1, j_2, j_3, j_4 = [0.0, 0.0], [float(pred_joint[0]), float(pred_joint[1])], [float(pred_joint[2]), float(pred_joint[3])], \
                                          [float(pred_joint[4]), float(pred_joint[5])], [float(pred_joint[6]), float(pred_joint[7])]
                
                posInit = [j_0, j_1, j_3, j_2, j_4] 
                couplerCurveIndex = 4
                                             
                if random_label[0] == 0.0:
                    mech_type = "RRRR"

                elif random_label[0] == 1.0:
                    mech_type = "RRRP"
                
                elif random_label[0] == 2.0:
                    mech_type = "RRPR2"
                    posInit = [j_0, j_1, j_2, j_3, j_4]     
                                
            exampleData = {
                'params': posInit, 
                'type': mech_type,
                'speedScale': speedscale, 
                'steps': steps,
                'relativeTolerance':0.1
            }

            try:
                temp = requests.post(url = API_ENDPOINT, headers=HEADERS, data = json.dumps(exampleData)).json()
                time.sleep(0.05)

            except ValueError as v:
                for i in range(3):
                    time.sleep(2)
                    try:
                        temp = requests.post(url = API_ENDPOINT, headers=HEADERS, data = json.dumps(exampleData)).json()
                        break
                    except ValueError as v2:
                        plt.clf()
                        print('wtf')
            
            if temp is None:
                plt.clf()
                continue
                
            P = np.array(temp['poses'])
            
            generated_x, generated_y = P[:, couplerCurveIndex, 0], P[:, couplerCurveIndex, 1]
                    
            if np.isnan(generated_x).any() or np.isinf(generated_x).any() or len(generated_x) < 30:
                plt.clf()
                continue
            
            # Rotating
            generated_phi = -get_pca_inclination(generated_x, generated_y)
            rotation = generated_phi - original_phi
            generated_x, generated_y = rotate_curve(generated_x, generated_y, rotation)
            
            j_0 = rotate_curve(j_0[0], j_0[1], rotation)
            j_1 = rotate_curve(j_1[0], j_1[1], rotation)
            j_2 = rotate_curve(j_2[0], j_2[1], rotation)
            j_3 = rotate_curve(j_3[0], j_3[1], rotation)
            j_4 = rotate_curve(j_4[0], j_4[1], rotation)
                        
            # Scaling
            generated_denom = np.sqrt(np.var(generated_x, axis=0, keepdims=True) + np.var(generated_y, axis=0, keepdims=True))
            scale_factor = original_denom / generated_denom
            generated_x, generated_y = np.multiply(generated_x, scale_factor), np.multiply(generated_y, scale_factor)
            
            j_0 = [np.multiply(j_0[0], scale_factor), np.multiply(j_0[1], scale_factor)]
            j_1 = [np.multiply(j_1[0], scale_factor), np.multiply(j_1[1], scale_factor)]
            j_2 = [np.multiply(j_2[0], scale_factor), np.multiply(j_2[1], scale_factor)]
            j_3 = [np.multiply(j_3[0], scale_factor), np.multiply(j_3[1], scale_factor)]
            j_4 = [np.multiply(j_4[0], scale_factor), np.multiply(j_4[1], scale_factor)]
            
            # Translating
            generated_mean_x, generated_mean_y = np.mean(generated_x), np.mean(generated_y)
            translation_x, translation_y = generated_mean_x - original_mean_x, generated_mean_y - original_mean_y
            generated_x, generated_y = np.subtract(generated_x, translation_x), np.subtract(generated_y, translation_y)
            
            j_0 = [np.subtract(j_0[0], translation_x), np.subtract(j_0[1], translation_y)]
            j_1 = [np.subtract(j_1[0], translation_x), np.subtract(j_1[1], translation_y)]
            j_2 = [np.subtract(j_2[0], translation_x), np.subtract(j_2[1], translation_y)]
            j_3 = [np.subtract(j_3[0], translation_x), np.subtract(j_3[1], translation_y)]
            j_4 = [np.subtract(j_4[0], translation_x), np.subtract(j_4[1], translation_y)]
            
            if P.shape[0] >= minsteps:
                plt.plot(generated_x, generated_y, 'g', label='predicted')
                
                if random_label[0] == 0.0:
                    plt.plot(j_0[0], j_0[1], marker="x", markersize=10, markeredgecolor='red',
                         markerfacecolor='red')
                    plt.plot(j_3[0], j_3[1], marker="x", markersize=10, markeredgecolor='green',
                             markerfacecolor='green')

                    x = [j_1[0], j_4[0], j_2[0], j_1[0]]
                    y = [j_1[1], j_4[1], j_2[1], j_1[1]]
                    plt.fill(x, y, color='pink')
                    
                    plt.plot([j_0[0], j_1[0]], [j_0[1], j_1[1]], color='green')
                    plt.plot([j_2[0], j_3[0]], [j_2[1], j_3[1]], color='green')

                elif random_label[0] == 1.0:
                    plt.plot(j_0[0], j_0[1], marker="x", markersize=10, markeredgecolor='red',
                             markerfacecolor='red')
                    plt.plot(j_2[0], j_2[1], marker="x", markersize=10, markeredgecolor='green',
                             markerfacecolor='green')
                    plt.plot(j_3[0], j_3[1], marker="x", markersize=10, markeredgecolor='green',
                             markerfacecolor='green')
    
                    j_5 = [(j_2[0] + j_3[0]) / 2 , (j_2[1] + j_3[1]) / 2]
                    
                    x = [j_1[0], j_4[0], j_5[0], j_1[0]]
                    y = [j_1[1], j_4[1], j_5[1], j_1[1]]
                    plt.fill(x, y, color='pink')
                    
                    plt.plot([j_0[0], j_1[0]], [j_0[1], j_1[1]], color='green')
                    plt.plot([j_2[0], j_3[0]], [j_2[1], j_3[1]], color='green')

                elif random_label[0] == 2.0:
                    plt.plot(j_0[0], j_0[1], marker="x", markersize=10, markeredgecolor='red',
                             markerfacecolor='red')
                    plt.plot(j_2[0], j_2[1], marker="x", markersize=10, markeredgecolor='green',
                             markerfacecolor='green')

                    x = [j_1[0], j_4[0], j_3[0], j_1[0]]
                    y = [j_1[1], j_4[1], j_3[1], j_1[1]]
                    plt.fill(x, y, color='pink')

                    plt.plot([j_0[0], j_1[0]], [j_0[1], j_1[1]], color='green')
                    plt.plot([j_2[0], j_3[0]], [j_2[1], j_3[1]], color='green')
                                                        
            if not os.path.exists('results'):
                os.makedirs('results')
            if not os.path.exists('results/{} {}'.format(num_1, label[0])):
                os.makedirs('results/{} {}'.format(num_1, label[0]))
            
            plt.axis('equal')
            plt.legend()
            plt.savefig('results/{} {}/{} {}.jpg'.format(num_1, label[0], num_2, random_label[0]))
            plt.clf()


In [None]:
!rm results.zip
!zip -r results.zip results