In [1]:
import argparse
import pathlib
import time
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch
    
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything

from datasets.fusiongallery import FusionGalleryDataset
from datasets.mfcad import MFCADDataset
from uvnet.models import Segmentation

class AttrDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

Using backend: pytorch


In [2]:
args = AttrDict({})
args.batch_size = 64
args.dataset_path = '/home/egor/mfcad/'
args.checkpoint = '../results/encoding/0311/130828/best.ckpt'
args.random_rotate = False
args.num_workers = 30

device = torch.device('cuda:2')

In [4]:
model = Segmentation.load_from_checkpoint(args.checkpoint).model.to(device = device)

In [5]:
Dataset = MFCADDataset
test_data = Dataset(
        root_dir=args.dataset_path, split="test", random_rotate=args.random_rotate
    )

test_loader = test_data.get_dataloader(
        batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
    )

  9%|▉         | 293/3094 [00:00<00:00, 2925.84it/s]

Loading test data...


100%|██████████| 3094/3094 [00:01<00:00, 2825.22it/s]


Done loading 3094 files


In [7]:
def encode(model, loader, device):
    embs_list = []
    with torch.no_grad():  
        for batch in loader:
            inputs = batch["graph"].to(device)
            inputs.ndata["x"] = inputs.ndata["x"].permute(0, 3, 1, 2)
            inputs.edata["x"] = inputs.edata["x"].permute(0, 2, 1)
            embs_list.append(model.encode_part(inputs).to(device=torch.device('cpu')))
    return embs_list

In [8]:
p_embs = torch.cat(
    encode(model, test_loader, device),
    dim=0)

In [9]:
p_embs.shape

torch.Size([3072, 128])

In [11]:
torch.save(p_embs, '../embs/embs_0.pt')

In [None]:
torch.load('embs_0.pt')

In [None]:
import json

with open('part_names', 'w') as file:
     file.write(json.dumps(p_names))