In [1]:
%load_ext autoreload
%autoreload 2

import os
import json

import argparse
import pathlib
import time
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch
import torch.nn.functional as F
    
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything

from datasets.fusiongallery import FusionGalleryDataset
from datasets.mfcad import MFCADDataset
from datasets.mfcad_extended import MFCADPDataset
from datasets.mftest import MFTestDataset

from uvnet.models import Segmentation

class AttrDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

Using backend: pytorch


python segmentation.py train --dataset mfcad --dataset_path /home/egor/data/mfcad30 --max_epochs 5 --batch_size 256 --gpus 1 --num_processes 220 --experiment_name segmentation

In [6]:
args = AttrDict({})
args.batch_size = 1
# args.dataset_path = '/home/egor/data/machining_features_sprint_1/'
args.dataset_path = '/home/egor/data/janush_dataset/converted_20/'

#args.checkpoint = '/home/egor/UV-Net/results/seg_30sam_10ep/0329/054617/best.ckpt'
args.checkpoint = '/home/egor/UV-Net/results/seg_new_data_20sam_20ep/0401/122015/best.ckpt'

# args.checkpoint = '/home/egor/UV-Net/results/segmentation/0329/052330/best.ckpt'
# args.checkpoint = '/home/egor/UV-Net/results/segmentation/0328/210148/best.ckpt'
# args.checkpoint = '/home/egor/UV-Net/results/segmentation/0319/120241/best.ckpt'
args.random_rotate = False
args.num_workers = 30

device = torch.device('cuda:2')

In [7]:
Dataset = MFTestDataset
test_data = Dataset(
        root_dir=args.dataset_path, split="test",  random_rotate=args.random_rotate
    )

test_loader = test_data.get_dataloader(
        batch_size=1, shuffle=False, num_workers=1
    )

100%|██████████| 577/577 [00:00<00:00, 596.23it/s]


Done loading 577 files


In [8]:
model = Segmentation.load_from_checkpoint(args.checkpoint).model.to(device = device)
model = model.eval()

In [9]:
def get_labels(model, dataset, device):
    rez_list = []
    with torch.no_grad():  
        for data in dataset:
            rz = dict()
            rz['part'] = data["filename"]
            inputs = data["graph"].to(device)
            inputs.ndata["x"] = inputs.ndata["x"].permute(0, 3, 1, 2)
            inputs.edata["x"] = inputs.edata["x"].permute(0, 2, 1)
            
            logits = model(inputs).to(device=torch.device('cpu'))
            softmax = F.softmax(logits, dim=-1)  
            
            rz['labels'] = softmax.argmax(dim=1).numpy().tolist()
            rez_list.append(rz)
    return rez_list

def get_labels_loader(model, loader, device):
    rez_list = []
    with torch.no_grad():  
        for data in loader:
            rz = dict()
            rz['part'] = data["filename"][0]
            inputs = data["graph"].to(device)
            inputs.ndata["x"] = inputs.ndata["x"].permute(0, 3, 1, 2)
            inputs.edata["x"] = inputs.edata["x"].permute(0, 2, 1)
            
            logits = model(inputs).to(device=torch.device('cpu'))
            
            softmax = F.softmax(logits, dim=-1)  
            rz['labels'] = softmax.argmax(dim=1).numpy().tolist()
            rez_list.append(rz)
    return rez_list

In [10]:
rzz = get_labels(model, test_data, device)
rzz_loader = get_labels_loader(model, test_loader, device)

In [None]:
rzz

In [26]:
rzz[50]['labels'][:20]

IndexError: list index out of range

In [27]:
rzz_loader[50]['labels'][:20]

IndexError: list index out of range

In [45]:
with open('labels_10ep_model_20samples.json', 'w') as fp:
    json.dump(rzz, fp)

In [75]:
rzz = get_labels(model, test_data, device)


[5, 8, 9, 9, 14, 4, 15, 11, 15, 9, 15, 15, 6, 9, 11, 1, 10, 1, 12, 3]

In [76]:
rzz_loader[0]['labels'][:20]

[5, 7, 9, 4, 4, 9, 9, 9, 15, 9, 15, 15, 6, 9, 14, 1, 2, 12, 1, 2]

In [102]:
rzz1 = get_labels(model, test_data, device)
rzz2 = get_labels(model, test_data, device)

In [103]:
rzz1[0]['labels'][:20]

[13, 3, 3, 3, 3, 3, 3, 3, 13, 3, 3, 3, 3, 3, 3, 13, 13, 13, 13, 13]

In [104]:
rzz2[0]['labels'][:20]

[13, 3, 3, 3, 3, 3, 3, 3, 13, 3, 3, 3, 3, 3, 3, 13, 13, 13, 13, 13]

In [None]:
path =pathlib.Path('/home/egor/mfcad/')
with open(str(str(path.joinpath("split.json"))), "r") as read_file:
    filelist = json.load(read_file)

In [None]:
# Creating custom split file

test_files = [x.split('.')[0] for x in os.listdir(args.dataset_path + 'graphs') if x.endswith('.bin')]
split = {'test': test_files}

with open(args.dataset_path+'split.json', 'w') as fp:
    json.dump(split, fp)

In [3]:
from datasets.mfcad_extended import MFCADPDataset

d = MFCADPDataset(root_dir='/home/egor/data/MFCAD++_dataset/converted_20', split="train")


  0%|          | 171/41730 [00:00<00:24, 1709.92it/s]

Loading train data...


100%|██████████| 41730/41730 [00:26<00:00, 1578.88it/s]


Done loading 41730 files


In [9]:
test = MFCADPDataset(root_dir='/home/egor/data/MFCAD++_dataset/converted_20', split="test")
val =MFCADPDataset(root_dir='/home/egor/data/MFCAD++_dataset/converted_20', split="val")

  2%|▏         | 211/8942 [00:00<00:04, 2103.66it/s]

Loading test data...


100%|██████████| 8942/8942 [00:04<00:00, 1983.70it/s]


Done loading 8942 files


  2%|▏         | 207/8941 [00:00<00:04, 2066.69it/s]

Loading val data...


100%|██████████| 8941/8941 [00:04<00:00, 1838.01it/s]


Done loading 8941 files


In [10]:
for f in [d, test, val]:
    for i in f.data:
        assert len(i['graph'].ndata['y']) > 0 and len(i['graph'].ndata['y']) == len(i['graph'])