In [None]:
from google.colab import drive

drive.mount('/content/gdrive', force_remount=True)

%cd '/content/gdrive/My Drive/ma_proj'
!git pull

In [None]:
import torch
print('PyTorch version: {}'.format(torch.__version__))

!pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-cluster==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-spline-conv==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
!pip install torch-geometric

%cd ..
%cd '/content/gdrive/My Drive/'
!ls
# overwrite coalesce.py and storage.py in torch_sparse, such that torch.short edge indices are supported to safe memory when loading data
!cp torch_sparse/coalesce.py /usr/local/lib/python3.6/dist-packages/torch_sparse/coalesce.py
!cp torch_sparse/storage.py /usr/local/lib/python3.6/dist-packages/torch_sparse/storage.py
!cp torch_sparse/tensor.py /usr/local/lib/python3.6/dist-packages/torch_sparse/tensor.py
!cp torch_geometric/data/data.py /usr/local/lib/python3.6/dist-packages/torch_geometric/data/data.py

!pip install openmesh

cd mesh

!BOOST_INCLUDE_DIRS=/usr/include/boost make all

%cd ../ma_proj/spiralnet_plus/

In [None]:
import pickle
import argparse
import os
import os.path as osp
from easydict import EasyDict
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch_geometric.transforms as T
from psbody.mesh import Mesh

from reconstruction import AE, VAE, run, eval_error, test
from datasets import MeshData
from utils import utils, writer, DataLoader, mesh_sampling

args = EasyDict()
args.exp_name = 'interpolation_exp'
args.dataset = 'CoMA'
args.split = 'interpolation'
args.test_exp = 'bareteeth'
args.n_threads = 4
args.device_idx = 0

# network hyperparameters
args.out_channels = [32, 32, 32, 64]
args.latent_channels = 16
args.in_channels = 3
args.seq_length = [9,9,9,9]
args.dilation = [1, 1, 1, 1]

# optimizer hyperparmeters
args.lr = 1e-3
args.optimizer ='Adam'
args.lr_decay = 0.99
args.decay_step = 1
args.weight_decay =0

# training hyperparameters
args.batch_size = 32
args.epochs = 300

# others
args.seed = 1

args.work_dir = '/content/gdrive/My Drive/ma_proj/spiralnet_plus/reconstruction'
args.data_fp = osp.join(args.work_dir, '..', 'data', args.dataset)
args.out_dir = osp.join(args.work_dir, 'out', args.exp_name)
args.checkpoints_dir = osp.join(args.out_dir, 'checkpoints')
print(args)

utils.makedirs(args.out_dir)
utils.makedirs(args.checkpoints_dir)

writer = writer.Writer(args)
if torch.cuda.is_available():
    device = torch.device('cuda', args.device_idx)
else:
    device = torch.device('cpu')
torch.set_num_threads(args.n_threads)

# deterministic
torch.manual_seed(args.seed)
cudnn.benchmark = False
cudnn.deterministic = True

template_fp = osp.join(args.data_fp, 'template', 'template.obj')

# generate/load transform matrices
transform_fp = osp.join(args.data_fp, 'transform.pkl')
print(template_fp)
if not osp.exists(transform_fp):
    print('Generating transform matrices...')
    mesh = Mesh(filename=template_fp)
    ds_factors = [4, 4, 4, 4]
    _, _, D, U, F, V = mesh_sampling.generate_transform_matrices(
        mesh, ds_factors)
    tmp = {
        'vertices': V,
        'face': F, #'adj': A,
        'down_transform': D,
        'up_transform': U
    }

    with open(transform_fp, 'wb') as fp:
        pickle.dump(tmp, fp)
    print('Done!')
    print('Transform matrices are saved in \'{}\''.format(transform_fp))
else:
    with open(transform_fp, 'rb') as f:
        tmp = pickle.load(f, encoding='latin1')

spiral_indices_list = [
    utils.preprocess_spiral(tmp['face'][idx], args.seq_length[idx],
                            tmp['vertices'][idx],
                            args.dilation[idx]).to(device)
    for idx in range(len(tmp['face']) - 1)
]
del tmp['face']
del tmp['vertices']
down_transform_list = [
    utils.to_sparse(down_transform).to(device)
    for down_transform in tmp['down_transform']
]
del tmp['down_transform']
up_transform_list = [
    utils.to_sparse(up_transform).to(device)
    for up_transform in tmp['up_transform']
]
del tmp



In [None]:

model = VAE(args.in_channels, args.out_channels, args.latent_channels,
           spiral_indices_list, down_transform_list,
           up_transform_list).to(device)
del up_transform_list, down_transform_list, spiral_indices_list
print('Number of parameters: {}'.format(utils.count_parameters(model)))
print(model)

optimizer = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                            args.decay_step,
                                            gamma=args.lr_decay)

# load dataset
template_fp = osp.join(args.data_fp, 'template', 'template.obj')
print('Creating MeshData obj')
meshdata = MeshData(args.data_fp,
                    template_fp,
                    split=args.split,
                    test_exp=args.test_exp)
#print('creating training DataLoader')
#train_loader = DataLoader(meshdata.train_dataset,
#                          batch_size=args.batch_size,
#                          shuffle=True)
print('creating testing DataLoader')
test_loader = DataLoader(meshdata.test_dataset, batch_size=args.batch_size)

#run(model, train_loader, test_loader, args.epochs, optimizer, scheduler, writer, device)
#print(model.en_mu.weight)
#print(model.en_mu.weight)


In [None]:
model.load_state_dict(torch.load(osp.join(args.checkpoints_dir, 'checkpoint_300.pt'), map_location=torch.device('cpu'))['model_state_dict'])


In [None]:
print(meshdata.std)
print(meshdata.mean)
test_loss = test(model, test_loader, device)
print(test_loss)