In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

import trimesh

scene = trimesh.Scene()

In [2]:
from my_code.datasets.surreal_dataset_3dc import TemplateSurrealDataset3DC
import torch
from tqdm import tqdm
import time
import my_code.diffusion_training_sign_corr.data_loading as data_loading


train_dataset = data_loading.get_val_dataset(
    'FAUST_a', 'train', 128, preload=True, canonicalize_fmap=None
    )[1]

test_datasets = {
    'FAUST_a': data_loading.get_val_dataset(
        'FAUST_a', 'train', 128, preload=False, canonicalize_fmap=None
        )[1],
    'FAUST_orig train': data_loading.get_val_dataset(
        'FAUST_orig', 'train', 128, preload=False, canonicalize_fmap=None
        )[1],
    'FAUST_orig test': data_loading.get_val_dataset(
        'FAUST_orig', 'test', 128, preload=False, canonicalize_fmap=None
        )[1],
    'FAUST_r train': data_loading.get_val_dataset(
        'FAUST_r', 'train', 128, preload=False, canonicalize_fmap=None
        )[1],
    'FAUST_r test': data_loading.get_val_dataset(
        'FAUST_r', 'test', 128, preload=False, canonicalize_fmap=None
        )[1]
}

# test_dataset = data_loading.get_val_dataset(
#     'FAUST_orig', 'test', 200, canonicalize_fmap=None
# )[1]

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


Loading base dataset: 100%|█████████████████████████████████████████████████████████████████████████████████| 20/20 [00:04<00:00,  4.12it/s]


In [3]:
import networks.diffusion_network as diffusion_network

condition_dim = 0
start_dim = 0

feature_dim = 32
evecs_per_support = 4


device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = diffusion_network.DiffusionNet(
    in_channels=feature_dim,
    out_channels=feature_dim // evecs_per_support,
    cache_dir=None,
    input_type='wks',
    k_eig=128,
    n_block=6
    ).to(device)

In [4]:
net.load_state_dict(torch.load('/home/s94zalek_hpc/shape_matching/my_code/experiments/sign_double_start_0_feat_32_6block_factor4_dataset_SURREAL_train_rot_180_180_180_normal_True_noise_0.0_-0.05_0.05_lapl_mesh_scale_0.9_1.1_wks/40000.pth'))

<All keys matched successfully>

In [5]:
from tqdm import tqdm
import my_code.sign_canonicalization.training as sign_training


def test_sign_correction(net, test_dataset):
    
    tqdm._instances.clear()

    n_epochs = 5
        
    # iterator = tqdm(total=len(test_dataset) * n_epochs)
    incorrect_signs_list = torch.tensor([])
    curr_iter = 0

        
    for _ in range(n_epochs):
        for curr_idx in range(len(test_dataset)):

            ##############################################
            # Select a shape
            ##############################################

            train_shape = test_dataset[curr_idx]['second']

            # train_shape = double_shape['second']
            verts = train_shape['verts'].unsqueeze(0).to(device)
            faces = train_shape['faces'].unsqueeze(0).to(device)    

            evecs_orig = train_shape['evecs'].unsqueeze(0)[:, :, start_dim:start_dim+feature_dim].to(device)
            
            mass_mat = torch.diag_embed(
                torch.ones_like(train_shape['mass'].unsqueeze(0))
                ).to(device)

            ##############################################
            # Set the signs on shape 0
            ##############################################

            # create a random combilation of +1 and -1, length = feature_dim
            sign_gt_0 = torch.randint(0, 2, (feature_dim,)).float().to(device)
            
            sign_gt_0[sign_gt_0 == 0] = -1
            sign_gt_0 = sign_gt_0.float().unsqueeze(0)

            # multiply evecs [6890 x 16] by sign_flip [16]
            evecs_flip_0 = evecs_orig * sign_gt_0
            
            # predict the sign change
            with torch.no_grad():
                sign_pred_0, supp_vec_0, _ = sign_training.predict_sign_change(
                    net, verts, faces, evecs_flip_0, 
                    mass_mat=mass_mat, input_type=net.input_type,
                    
                    mass=train_shape['mass'].unsqueeze(0), L=train_shape['L'].unsqueeze(0),
                    evals=train_shape['evals'].unsqueeze(0), evecs=train_shape['evecs'].unsqueeze(0),
                    gradX=train_shape['gradX'].unsqueeze(0), gradY=train_shape['gradY'].unsqueeze(0)
                    )
            
            ##############################################
            # Set the signs on shape 1
            ##############################################
            
            # create a random combilation of +1 and -1, length = feature_dim
            sign_gt_1 = torch.randint(0, 2, (feature_dim,)).float().to(device)
            
            sign_gt_1[sign_gt_1 == 0] = -1
            sign_gt_1 = sign_gt_1.float().unsqueeze(0)
            
            # multiply evecs [6890 x 16] by sign_flip [16]
            evecs_flip_1 = evecs_orig * sign_gt_1
            
            # predict the sign change
            with torch.no_grad():
                sign_pred_1, supp_vec_1, _ = sign_training.predict_sign_change(
                    net, verts, faces, evecs_flip_1, 
                    mass_mat=mass_mat, input_type=net.input_type,
                    
                    mass=train_shape['mass'].unsqueeze(0), L=train_shape['L'].unsqueeze(0),
                    evals=train_shape['evals'].unsqueeze(0), evecs=train_shape['evecs'].unsqueeze(0),
                    gradX=train_shape['gradX'].unsqueeze(0), gradY=train_shape['gradY'].unsqueeze(0)
                    )
            
            ##############################################
            # Calculate the loss
            ##############################################
            
            # calculate the ground truth sign difference
            sign_diff_gt = sign_gt_1 * sign_gt_0
            
            # calculate the sign difference between predicted evecs
            sign_diff_pred = sign_pred_1 * sign_pred_0
            
            sign_correct = sign_diff_pred.sign() * sign_diff_gt.sign() 
            
            
            # count the number of incorrect signs
            count_incorrect_signs = (sign_correct < 0).int().sum()
                
            # incorrect_signs_list.append(count_incorrect_signs)
            incorrect_signs_list = torch.cat([incorrect_signs_list, torch.tensor([count_incorrect_signs])])
            
            
            # iterator.set_description(f'Mean incorrect signs {incorrect_signs_list.float().mean():.2f} / {feature_dim}, max {incorrect_signs_list.max()}')
            # iterator.update(1)
            # if count_incorrect_signs > 7:
            #     raise ValueError('Too many incorrect signs')
        
    return incorrect_signs_list.float().mean(), incorrect_signs_list.max()

# Train

In [6]:
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.LinearLR(
    opt, start_factor=1, end_factor=0.1, 
    total_iters=1000)

In [7]:
from tqdm import tqdm
import my_code.sign_canonicalization.training as sign_training
import pandas as pd


tqdm._instances.clear()

loss_fn = torch.nn.MSELoss()
losses = torch.tensor([])
# train_iterator = tqdm(range(1000))
train_iterator = range(1000)     
        
curr_iter = 0
for epoch in range(len(train_iterator) // len(train_dataset)):
    
    # train_shapes_shuffled = train_shapes.copy()
    # np.random.shuffle(train_shapes)
    
    
    for curr_idx in range(len(train_dataset)):

        ##############################################
        # Select a shape
        ##############################################
        # curr_idx = np.random.randint(0, len(train_shapes))
    
        train_shape = train_dataset[curr_idx]['second']

        # train_shape = double_shape['second']
        verts = train_shape['verts'].unsqueeze(0).to(device)
        faces = train_shape['faces'].unsqueeze(0).to(device)    

        evecs_orig = train_shape['evecs'].unsqueeze(0)[:, :, start_dim:start_dim+feature_dim].to(device)
        
        mass_mat = torch.diag_embed(
            torch.ones_like(train_shape['mass'].unsqueeze(0))
            ).to(device)

        ##############################################
        # Set the signs on shape 0
        ##############################################

        # create a random combilation of +1 and -1, length = feature_dim
        sign_gt_0 = torch.randint(0, 2, (feature_dim,)).float().to(device)
        
        sign_gt_0[sign_gt_0 == 0] = -1
        sign_gt_0 = sign_gt_0.float().unsqueeze(0)

        # multiply evecs [6890 x 16] by sign_flip [16]
        evecs_flip_0 = evecs_orig * sign_gt_0
        
        # predict the sign change
        sign_pred_0 = sign_training.predict_sign_change(
            net, verts, faces, evecs_flip_0, 
            mass_mat=mass_mat, input_type=net.input_type,
            
            mass=train_shape['mass'].unsqueeze(0), L=train_shape['L'].unsqueeze(0),
            evals=train_shape['evals'].unsqueeze(0), evecs=train_shape['evecs'].unsqueeze(0),
            gradX=train_shape['gradX'].unsqueeze(0), gradY=train_shape['gradY'].unsqueeze(0)
            )[0]
        
        ##############################################
        # Set the signs on shape 1
        ##############################################
        
        # create a random combilation of +1 and -1, length = feature_dim
        sign_gt_1 = torch.randint(0, 2, (feature_dim,)).float().to(device)
        
        sign_gt_1[sign_gt_1 == 0] = -1
        sign_gt_1 = sign_gt_1.float().unsqueeze(0)
        
        # multiply evecs [6890 x 16] by sign_flip [16]
        evecs_flip_1 = evecs_orig * sign_gt_1
        
        # predict the sign change
        sign_pred_1 = sign_training.predict_sign_change(
            net, verts, faces, evecs_flip_1, 
            mass_mat=mass_mat, input_type=net.input_type,
            
            mass=train_shape['mass'].unsqueeze(0), L=train_shape['L'].unsqueeze(0),
            evals=train_shape['evals'].unsqueeze(0), evecs=train_shape['evecs'].unsqueeze(0),
            gradX=train_shape['gradX'].unsqueeze(0), gradY=train_shape['gradY'].unsqueeze(0)
            )[0]
        
        ##############################################
        # Calculate the loss
        ##############################################
        
        # calculate the ground truth sign difference
        sign_diff_gt = sign_gt_1 * sign_gt_0
        
        # calculate the sign difference between predicted evecs
        sign_diff_pred = sign_pred_1 * sign_pred_0
        
        # calculate the loss
        loss = loss_fn(
            sign_diff_pred.reshape(sign_diff_pred.shape[0], -1),
            sign_diff_gt.reshape(sign_diff_gt.shape[0], -1)
            )

        opt.zero_grad()
        loss.backward()
        opt.step()
        scheduler.step()
        
        losses = torch.cat([losses, torch.tensor([loss.item()])])
        
        # print mean of last 10 losses
        # train_iterator.set_description(f'loss={torch.mean(losses[-10:]):.3f}')
        
        # plot the losses every 1000 iterations
        if curr_iter == 0 or curr_iter % (len(train_iterator) // 10) == 0:
            
            print(f'{curr_iter}')
            
            for test_dataset_name in test_datasets.keys():
                mean_incorrect, max_incorrect = test_sign_correction(net, test_datasets[test_dataset_name])
                
                print(f'{test_dataset_name}: mean = {mean_incorrect:.2f} / {feature_dim}, max = {int(max_incorrect)}')
            
            # mean_incorrect, max_incorrect = test_sign_correction(net, train_dataset)
            
            # print(f'{curr_iter}: mean = {mean_incorrect:.2f} / {feature_dim}, max = {int(max_incorrect)}')
            
        curr_iter += 1
        # train_iterator.update(1)
        
        

0
FAUST_a: mean = 1.17 / 32, max = 5
FAUST_orig train: mean = 0.25 / 32, max = 5
FAUST_orig test: mean = 0.42 / 32, max = 3
FAUST_r train: mean = 0.59 / 32, max = 6
FAUST_r test: mean = 0.89 / 32, max = 9
100
FAUST_a: mean = 0.22 / 32, max = 3
FAUST_orig train: mean = 0.63 / 32, max = 5
FAUST_orig test: mean = 0.70 / 32, max = 5
FAUST_r train: mean = 0.83 / 32, max = 7
FAUST_r test: mean = 0.99 / 32, max = 6
200
FAUST_a: mean = 0.09 / 32, max = 2
FAUST_orig train: mean = 0.64 / 32, max = 7
FAUST_orig test: mean = 0.65 / 32, max = 3
FAUST_r train: mean = 0.97 / 32, max = 8
FAUST_r test: mean = 1.01 / 32, max = 9
300
FAUST_a: mean = 0.21 / 32, max = 2
FAUST_orig train: mean = 0.61 / 32, max = 5
FAUST_orig test: mean = 0.62 / 32, max = 7
FAUST_r train: mean = 0.72 / 32, max = 7
FAUST_r test: mean = 0.91 / 32, max = 6
400
FAUST_a: mean = 0.07 / 32, max = 1
FAUST_orig train: mean = 0.62 / 32, max = 5
FAUST_orig test: mean = 0.73 / 32, max = 4
FAUST_r train: mean = 0.70 / 32, max = 6
FAUST_r