In [None]:
from src.dataset_wrapper import *
from src.networks import *
import torch
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pickle 
      
device = 'cuda:0'
net_params = {
    "relation_dim": 23,
    "object_dim": 8,
    "hidden_dim": 256,

}
dataset_params = {
    "device": device,
    "path": 'Data/articulated_joint_mass_9/',
    "num_of_traj": 32000,
    "device": device,
    "data_split": (30000, 1000, 1000),
    "batch_size_train": 16,
    "batch_size_val": 32,
    'num_workers': 4,
    "num_of_batch_train": (0, 0), #(0, 10000),
    "num_of_batch_val":(0, 0), #(0, 500),
    'no_angle': False,
    'shuffle': True,
}

In [None]:
Net = PropagationNetwork(**net_params).to(device)

dataset_params['Network'] = Net

dataset = DatasetWrapper(**dataset_params)

with open('scaler.pickle', 'wb') as handle:
    pickle.dump(dataset.scaler, handle)

In [None]:
Net.load_state_dict(torch.load('articulated.pt'),strict=False)
Net.eval()

In [None]:
dataset6_params = {
    "device": device,
    "path": 'Data/articulated_joint_mass_6/',
    "num_of_traj": 1000,
    "device": device,
    "data_split": (0, 0, 1000),
    "batch_size_train": 16,
    "batch_size_val": 32,
    'num_workers': 4,
    "num_of_batch_train": (0, 0), #(0, 10000),
    "num_of_batch_val":(0, 0), #(0, 500),
    'shuffle': True,
    'for_br_onestep':False,    
    'no_angle': False,
    "scaler": dataset.scaler
}
dataset6_params['Network'] = Net

dataset6 = DatasetWrapper(**dataset6_params)

dataset12_params = {
    "device": device,
    "path": 'Data/articulated_joint_mass_12/',
    "num_of_traj": 1000,
    "device": device,
    "data_split": (0, 0, 1000),
    "batch_size_train": 16,
    "batch_size_val": 32,
    'num_workers': 4,
    "num_of_batch_train": (0, 0), #(0, 10000),
    "num_of_batch_val":(0, 0), #(0, 500),
    'shuffle': True,
    'for_br_onestep':False,    
    'no_angle': False,
    "scaler": dataset.scaler
}

dataset12_params['Network'] = Net

dataset12 = DatasetWrapper(**dataset12_params)



## Physics Prediction Experiment

In [None]:
import numpy as np
metric1 = torch.nn.MSELoss(reduction='none')
metric2 = torch.nn.MSELoss()

def error_position(dataset, predicted):
    error_over_time=list()
    for t in range(1,50):
        error = (dataset.val_tester.test_states[:, :t, 1:, 0:2]-predicted[:, :t, 1:, 0:2]).norm(dim=-1)
        error_over_time.append(100*error.mean(dim=[1,2]).cpu().numpy())
    return error_over_time

def error_angle(dataset,predicted):
    error_over_time=list()
    for t in range(1,50):
        error_over_time.append(np.sqrt(metric2(dataset.val_tester.test_states[:, :t, :, 2:3], predicted[:, :t, :, 2:3]).item())/np.pi*180)
    return error_over_time

with torch.set_grad_enabled(False):
    with torch.no_grad():
        predicted9 = dataset.val_tester.test_pp()
        error_pos9 = error_position(dataset,predicted9) 
        error_angle9 = error_angle(dataset,predicted9) 
        del predicted9
        predicted6 = dataset6.val_tester.test_pp()
        error_pos6 = error_position(dataset6,predicted6) 
        error_angle6 = error_angle(dataset6,predicted6) 
        del predicted6
        predicted12 = dataset12.val_tester.test_pp()
        error_pos12 = error_position(dataset12,predicted12) 
        error_angle12 = error_angle(dataset12,predicted12) 
        del predicted12        

In [None]:
fig,ax = plt.subplots(3,2,figsize=(12,15))
num_of_objects=['9 ','6 ','12' ]
for ind,error_pos in enumerate([error_pos9,error_pos6,error_pos12]):
    ax[ind,0].set_ylabel(num_of_objects[ind]+'Objects',fontsize=24)
    bins_ = [-1.5,-0.5,0.5,1.5,2.5,3.5,4.5,5.5]
    xticks =[-1,0,1,2,3,4,5]
    xtick_labels = ['< 0.1 cm','0.1-0.2 cm','0.2-0.4 cm','0.4-0.8 cm','0.8-1.6 cm', '1.6-3.2 cm','> 3.2 cm']

    data = np.log2(10*error_pos[28])
    data[data<-1.]=-1
    ax[ind,0].set_axisbelow(True)
    ax[ind,0].grid()
    ax[ind,0].hist(data,bins=bins_,color='gray',rwidth=0.8)


    data = np.log2(10*error_pos[48])
    data[data<-1.]=-1
    ax[ind,1].set_axisbelow(True)
    ax[ind,1].grid()    
    ax[ind,1].hist(data,bins=bins_,color='gray',rwidth=0.8)

ax[0,0].set_title('Timestep 30',fontsize=20)
ax[0,1].set_title('Timestep 50',fontsize=20)
for i in range(3):
    plt.sca(ax[i,0])
    plt.xticks(xticks,[])
    plt.sca(ax[i,1])
    plt.xticks(xticks,[])
    ax[i,1].yaxis.set_ticklabels([])
    ax[i,0].set_ylim([0,1050])
    ax[i,1].set_ylim([0,1050])
    
for i in range(3):
    plt.sca(ax[i,0])
    if i==2:
        plt.xticks(xticks,xtick_labels,rotation=45,fontsize=16)
    plt.yticks(fontsize=16)
    plt.sca(ax[i,1])
    if i==2:
        plt.xticks(xticks,xtick_labels,rotation=45,fontsize=16)

fig.tight_layout()
plt.savefig("articulated_pp.png",bbox_inches='tight')

## Belief Regulation 

In [None]:
# Finding all joint baseline relations
# This operation is slow, so its not in base code.
import collision
def all_joint_criteria(obj_info, obj_info2):
    # Object 1    
    if obj_info['type']==0:
        center_object = collision.Vector(obj_info['position'][0],obj_info['position'][1])
        obj1 =collision.Circle(center_object, (obj_info['shape'][0])/2+0.05)
    elif obj_info['type']==1:
        center_object = collision.Vector(obj_info['position'][0],obj_info['position'][1])
        obj1 =collision.Poly.from_box(center_object, obj_info['shape'][0]+0.05,obj_info['shape'][1]+0.05)
        obj1.angle = obj_info['angle']
    if obj_info2['type']==0:
        center_object = collision.Vector(obj_info2['position'][0],obj_info2['position'][1])
        obj2 =collision.Circle(center_object, (obj_info2['shape'][0])/2+0.05)
    elif obj_info2['type']==1:
        center_object = collision.Vector(obj_info2['position'][0],obj_info2['position'][1])
        obj2 =collision.Poly.from_box(center_object, obj_info2['shape'][0]+0.05,obj_info2['shape'][1]+0.05)
        obj2.angle = obj_info2['angle']

    # Maybe too much bias.
    if obj_info['type'] == 1 and obj_info2['type'] ==1:
        if abs((obj_info['angle']-obj_info2['angle'])%np.pi/2)>1e-5:
            return False 

    return collision.collide(obj1,obj2)
def find_all_fixed_baseline(dataset):
    test_rels_all_fixed = torch.stack([dataset.val_tester.test_rels] * 200 , dim=1)
    test_rels_all_fixed[:,:,:,:]=0
    n_of_obj=dataset.val_tester.num_of_objects
    test_rels_all_fixed[:,:,:,0]=1
    for traj_ind in tqdm(range(1000)):
        traj_len = min(200, dataset.val_tester.test_traj_lens[traj_ind])
        for ts in range(traj_len):
            cnt=0
            for i in range(n_of_obj):
                for j in range(n_of_obj):
                    if i!=j:
                        if i!=0 or j!=0:
                            obj_info=dict()
                            obj_info2=dict()                    
                            obj_info['type']  = int(dataset.val_tester.test_shapes[traj_ind,i,1])
                            obj_info2['type'] = int(dataset.val_tester.test_shapes[traj_ind,j,1]) 
                            obj_info['position']  = dataset.val_tester.test_states[traj_ind,ts,i,:2].tolist()
                            obj_info2['position'] = dataset.val_tester.test_states[traj_ind,ts,j,:2].tolist()
                            obj_info['angle']  = dataset.val_tester.test_states[traj_ind,ts,i,2].tolist()
                            obj_info2['angle'] = dataset.val_tester.test_states[traj_ind,ts,j,2].tolist()                    
                            obj_info['shape']  = dataset.val_tester.test_shapes[traj_ind,i,-2:].tolist()
                            obj_info2['shape'] = dataset.val_tester.test_shapes[traj_ind,j,-2:].tolist()
                            if all_joint_criteria(obj_info, obj_info2):
                                test_rels_all_fixed[traj_ind,ts,cnt,0]=0
                                test_rels_all_fixed[traj_ind,ts,cnt,1]=1
                        cnt=cnt+1
        test_rels_all_fixed[traj_ind,traj_len:,:,:] = test_rels_all_fixed[traj_ind,traj_len-1,:,:]
    return  test_rels_all_fixed
test_rels_all_fixed = find_all_fixed_baseline(dataset)
test_rels_all_fixed6 = find_all_fixed_baseline(dataset6)
test_rels_all_fixed12 = find_all_fixed_baseline(dataset12)


In [None]:
torch.save(test_rels_all_fixed,'test_rels_all_fixed9_2.pt')
torch.save(test_rels_all_fixed6,'test_rels_all_fixed6_2.pt')
torch.save(test_rels_all_fixed12,'test_rels_all_fixed12_2.pt')

In [None]:
test_rels_all_fixed = dict()
for n_of_obj in [9,6,12]:
    test_rels_all_fixed[n_of_obj] = torch.load('test_rels_all_fixed'+str(n_of_obj)+'_2.pt',map_location=torch.device('cpu'))


In [None]:
test_rels_br = dict()
test_rels_oracle = dict()

datasets = dict()
datasets[9] = dataset
datasets[6] = dataset6
datasets[12] = dataset12      

In [None]:
with torch.set_grad_enabled(False):
    with torch.no_grad():
        for n_of_obj in [12,9,6]:
            predicted_br, ground_truth = datasets[n_of_obj].val_tester.test_br(batch_size=12)
            test_rels_br[n_of_obj] = predicted_br['rel_to_predict'].cpu() 
            test_rels_oracle[n_of_obj] = ground_truth['rel_to_predict'].cpu() 
            del predicted_br, ground_truth
for obj_n in [9,6,12]:
    torch.save(test_rels_br[obj_n],'test_rels_'+str(obj_n)+'.pt')      

In [None]:
test_rels_all_fixed = dict()
test_rels_all_fixed_prev = dict()
test_rels_oracle = dict()
test_rels_br = dict()

for n_of_obj in [9,6,12]:
    test_rels_all_fixed_prev[n_of_obj] = torch.load('test_rels_all_fixed'+str(n_of_obj)+'.pt',map_location=torch.device('cpu'))    
    test_rels_all_fixed[n_of_obj] = torch.load('test_rels_all_fixed'+str(n_of_obj)+'_2.pt',map_location=torch.device('cpu'))
    test_rels_oracle[n_of_obj] = torch.stack([datasets[n_of_obj].val_tester.test_rels] * 200, dim=1) 
    test_rels_br[n_of_obj] = torch.load('test_rels_'+str(n_of_obj)+'.pt',map_location=torch.device('cpu'))

In [None]:
def getAccuracyOverTime(ypred, ytarget):
    return ((ypred.max(dim=-1)[1] == ytarget.max(dim=-1)[1])*1.0).mean(dim=[0, 2]).cpu().numpy()
accuracy_over_time_BR=dict()
accuracy_over_time_NJ=dict()
accuracy_over_time_AJ=dict()
accuracy_over_time_AJ_old=dict()

for n_of_obj in [9,6,12]:
    accuracy_over_time_BR[n_of_obj] = getAccuracyOverTime(test_rels_br[n_of_obj], test_rels_oracle[n_of_obj])
    accuracy_over_time_NJ[n_of_obj] = getAccuracyOverTime(datasets[n_of_obj].val_tester.test_rels_no_joint, test_rels_oracle[n_of_obj])
    accuracy_over_time_AJ[n_of_obj] = getAccuracyOverTime(test_rels_all_fixed[n_of_obj], test_rels_oracle[n_of_obj])
    accuracy_over_time_AJ_old[n_of_obj] = getAccuracyOverTime(test_rels_all_fixed_prev[n_of_obj], test_rels_oracle[n_of_obj])



In [None]:
labels_gt = test_rels_oracle[9].max(dim=-1)[1]
labels_predicted = test_rels_br[9].max(dim=-1)[1]

In [None]:
(/(labels_gt[:,-1]==2).sum().item()

In [None]:
(labels_gt==0).sum().item()

In [None]:
16076000/(1000*200*90)

In [None]:
precisions= dict()
recalls = dict()
for n_of_obj in [9,6,12]:
    labels_gt = test_rels_oracle[n_of_obj].max(dim=-1)[1]
    labels_predicted = test_rels_br[n_of_obj].max(dim=-1)[1]
    for relation in range(4):
        precisions[(n_of_obj,relation)]=np.zeros((20,))
        recalls[(n_of_obj,relation)]=np.zeros((20,))
        for ind,ts in enumerate(range(0,200,10)):
            Y = labels_gt[:,ts] == relation
            h = labels_predicted[:,ts] == relation    
            Ynh = (torch.logical_and(Y,h)).sum().item()
            precision = Ynh/h.sum().item()
            recall = Ynh/Y.sum().item()
            precisions[(n_of_obj,relation)][ind]=precision
            recalls[(n_of_obj,relation)][ind]=recall         
#             print(str(n_of_obj)+" Object at timestep -1 precision: "+str(precision)+" recall: "+str(recall))

In [None]:
relation_type = ['No Joint','Fixed Joint','Prismatic Joint','Revolute Joint']
fig,ax = plt.subplots(1,2,figsize=(11,5),dpi=100)
for relation in range(4):
    precision = (precisions[9,relation]+precisions[6,relation]+precisions[12,relation])/3
    ax[0].plot(range(10,200,10),precision[1:],label=relation_type[relation],lw=4)
for relation in range(4):
    recall = (recalls[9,relation]+recalls[6,relation]+recalls[12,relation])/3
    ax[1].plot(range(10,200,10),recall[1:],label=relation_type[relation],lw=4)
plt.legend(fontsize=18)
for i in range(2):
    ax[i].set_xticks(range(0,201,40)) 
    ax[i].tick_params(labelsize=16,labelrotation=45)
    ax[i].set_ylim([0,1.05])
    ax[0].set_xlabel("Number of Observed Timesteps",fontsize=20)
ax[0].set_title("Precision",fontsize=20)
ax[1].set_title("Recall",fontsize=20)
ax[1].yaxis.set_ticklabels([])
ax[1].grid()
ax[0].grid()

fig.savefig("Articulated_Prec_recall.png",bbox_inches='tight')

plt.show()

In [None]:
fig,ax = plt.subplots(1,3,figsize=(12,5))
for ind, n_of_obj in enumerate([9,6,12]):

    ax[ind].plot(accuracy_over_time_BR[n_of_obj],label='Belief Regulation',lw=4,ls=':',color='gray')
    ax[ind].plot(accuracy_over_time_NJ[n_of_obj],label='No Joint',lw=4,ls='--',color='gray')
    ax[ind].plot(accuracy_over_time_AJ[n_of_obj],label='All Fixed Joints',lw=4,color='gray')
    ax[ind].set_title(str(n_of_obj) + " Objects",fontsize=16)

ax[1].set_xlabel('Number of Observed Timestep',fontsize=16)
ax[0].set_ylabel('Joint Prediction Accuracy',fontsize=16)
ax[2].legend(fontsize=12)
fig.tight_layout()
plt.savefig("Articulated_br_baselines.png",bbox_inches='tight')
plt.show()
import gc
gc.collect()

In [None]:
accuracy_over_time_BR[n_of_obj].shape

In [None]:
fig,ax = plt.subplots(1,1,figsize=(5,5),dpi=100)
for ind, n_of_obj in enumerate([9,6,12]):
    ax.plot(range(10,200),accuracy_over_time_BR[n_of_obj][10:],label=str(n_of_obj)+' Objects',lw=4)
ax.set_title("Joint Prediction Accuracy",fontsize=20)

ax.set_xlabel('Number of Observed Timesteps',fontsize=20)
ax.legend(fontsize=18)
fig.tight_layout()
ax.set_xticks(range(0,201,40)) 
ax.tick_params(labelsize=16,labelrotation=45)
plt.grid()
plt.savefig("articulated_br_n_of_obj.png",bbox_inches='tight')
plt.show()
import gc
gc.collect()

## Coupled Results

In [None]:
import gc
def test_pp_custom(dataset, joint_rels, start_timestep, end_timestep):
    number_of_timestep = end_timestep - start_timestep
    to_be_pred = dataset.val_tester.test_states.clone()[:,start_timestep:end_timestep].to(device)
    to_be_pred[:, 1:, 1:, :] = 0
    batch_size = 100
    for batch in range(10):
        x = dict()
        x['objects_shape'] = dataset.val_tester.test_shapes[batch*100:(batch+1)*100].to(device)
        x['relation_info'] = joint_rels[batch*100:(batch+1)*100].to(device)
        for timestep in range(1, number_of_timestep):
            x['objects_state'] = to_be_pred[batch*100:(batch+1)*100, timestep-1, :, :]
            to_be_pred[batch*100:(batch+1)*100, timestep, 1:, :6] = dataset.gp_pp(x)
        del x
    to_be_pred_cpu = to_be_pred.cpu()
    del to_be_pred
    gc.collect()
    return to_be_pred_cpu



In [None]:
def convertOneHot(probs):
    max_idx = torch.argmax(probs, -1, keepdim=True)
    one_hot = torch.FloatTensor(probs.shape)
    one_hot.zero_()
    one_hot.scatter_(-1, max_idx, 1)
    return one_hot

In [None]:
metric2 = torch.nn.MSELoss()

def error_position(gt, predicted):
    return 100*np.sqrt(metric2(gt,predicted).cpu().numpy())
errors_mean =dict()
errors_standart =dict()

dataset_rels=dict()
for n_of_obj in [9,6,12]:
    dataset_rels['Oracle',n_of_obj] = test_rels_oracle[n_of_obj]
    dataset_rels['BR',n_of_obj] = convertOneHot(test_rels_br[n_of_obj])
    dataset_rels['No-Joint',n_of_obj] = datasets[n_of_obj].val_tester.test_rels_no_joint
    dataset_rels['All-Fixed',n_of_obj] = test_rels_all_fixed[n_of_obj]

for baseline in ['Oracle','BR','No-Joint','All-Fixed']:
    for n_of_obj in [9,6,12]:
        errors_mean[baseline,n_of_obj] = list()
        errors_standart[baseline,n_of_obj] = list()        

In [None]:
with torch.set_grad_enabled(False):
    with torch.no_grad():
        for ts in range(0,101,10):
            for n_of_obj in [9,6,12]:
                gt_traj = datasets[n_of_obj].val_tester.test_states[:,ts:ts+50,:,:2]
                for baseline in ['Oracle','BR','No-Joint','All-Fixed']: #
                    print(ts,n_of_obj,baseline)
                    predicted = test_pp_custom(datasets[n_of_obj], dataset_rels[baseline,n_of_obj][:,ts], ts, ts+50)
                    errors= (100*(gt_traj[:,:,1:]-predicted[:,:,1:,:2]).norm(dim=-1)).mean(dim=[1,2])
                    print(errors.shape)
                    err_mean= errors.mean().item()
                    err_standart = errors.std().item()/np.sqrt(1000)                    
                    errors_mean[baseline,n_of_obj].append(err_mean)
                    errors_standart[baseline,n_of_obj].append(err_standart)

                    del predicted, errors

In [None]:
fig,ax = plt.subplots(1,3,figsize=(15,5),dpi=200)
for ind, n_of_obj in enumerate([9,6,12]):
    for baseline in ['Oracle','BR','No-Joint','All-Fixed']:
        err_me = np.array(errors_mean[baseline,n_of_obj])[:]
        err_standart = np.array(errors_standart[baseline,n_of_obj])[:]
        ax[ind].plot(range(0,101,10),err_me,label=baseline,lw=4)
        ax[ind].fill_between(range(0,101,10),err_me-err_standart,err_me+err_standart,alpha=0.25)        
    ax[ind].set_title(str(n_of_obj) + " Objects",fontsize=20)
    ax[ind].set_ylim([-0.1,7.2])
ax[1].yaxis.set_ticklabels([])
ax[2].yaxis.set_ticklabels([])


ax[1].set_xlabel('Number of Observed Timesteps',fontsize=20)
ax[0].set_ylabel('50 Timestep Rollout Error (cm)',fontsize=20)
ax[1].legend(fontsize=18)
for i in range(3):
    ax[i].tick_params(labelsize=16)
    ax[i].set_axisbelow(True)
    ax[i].grid()
fig.tight_layout()
plt.savefig("articulated_full.png",bbox_inches='tight')
plt.show()
