In [6]:
%load_ext autoreload
%autoreload 2

import os
import torch
import torch.optim as optim
import torch.nn.functional as F
import json
import numpy as np


from datasets.fusiongallery import FusionGalleryDataset
from datasets.mfcad import MFCADDataset
from datasets.mfcad_extended import MFCADPDataset
from datasets.mftest import MFTestDataset

from uvnet.models import UVNetSegmenter

from evaluation.jaccard import get_mf_jaccard
from tqdm import tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
class AttrDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
args = AttrDict({})
args.batch_size = 512
args.random_rotate = True
args.num_workers = 100
args.dataset = "mfcadps"
args.crv_in_channels = 6
args.max_epochs = 10

# args.dataset_path = '/home/egor/data/machining_features_sprint_1/'
# args.dataset_path = '/home/egor/data/mftest20/'


args.dataset_path = "/home/egor/data/MFCAD++_dataset/converted_10"
if args.dataset == "mfcad":
    Dataset = MFCADDataset
elif args.dataset == "fusiongallery":
    Dataset = FusionGalleryDataset
elif args.dataset == "mfcadp":
    Dataset = MFCADPDataset

In [8]:
fnm_list = os.listdir('/home/egor/data/SolidLetters/graph_with_eattr')
N_ITEMS_PER_CLASS = 100

# creating loaders for SolidLetters dataset quering
test_loaders = []
for case in ('lower', 'upper'):
    ncl, fnm_labels = sample_from_letter(fnm_list, N_ITEMS_PER_CLASS, case)
    dset = RankingDataset('/home/egor/data/SolidLetters/graph_with_eattr/', 
                           fnm_labels, 
                           ncl)
    test_loaders.append(dset.get_dataloader(batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers))

  4%|▍         | 99/2600 [00:00<00:02, 985.59it/s]

/home/egor/data/SolidLetters/graph_with_eattr/a_Overpass Mono SemiBold_lower.bin True


100%|██████████| 2600/2600 [00:02<00:00, 1019.92it/s]
  0%|          | 0/2600 [00:00<?, ?it/s]

Done loading 2600 files
/home/egor/data/SolidLetters/graph_with_eattr/a_Inconsolata_upper.bin True


100%|██████████| 2600/2600 [00:02<00:00, 1050.89it/s]


Done loading 2600 files


In [9]:
class Segmentation:
    """
    module to train/test the segmenter (per-face classifier).
    """

    def __init__(self, num_classes, crv_in_channels=6):
        """
        Args:
            num_classes (int): Number of per-face classes in the dataset
        """
        self.device = torch.device('cuda:0') 
        
        self.model = UVNetSegmenter(num_classes, crv_in_channels=crv_in_channels)   
        self.model = self.model.to(device = self.device)
        
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay= 0.0007)
        

    def forward(self, batched_graph):
        logits = self.model(batched_graph)
        return logits

    def training_step(self, batch, batch_idx):
        inputs = batch["graph"].to(self.device)
        inputs.ndata["x"] = inputs.ndata["x"].permute(0, 3, 1, 2)
        inputs.edata["x"] = inputs.edata["x"].permute(0, 2, 1)
        labels = inputs.ndata["y"]
        logits = self.model(inputs)
        loss = F.cross_entropy(logits, labels, reduction="mean")
        # self.log("train_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
        preds = F.softmax(logits, dim=-1)
        # self.train_iou(preds, labels)
        # self.train_accuracy(preds, labels)
        return loss
    
    @torch.no_grad()
    def validation_step(self, batch, valid_preds):
        inputs = batch["graph"].to(self.device)
        inputs.ndata["x"] = inputs.ndata["x"].permute(0, 3, 1, 2)
        inputs.edata["x"] = inputs.edata["x"].permute(0, 2, 1)
        logits = self.model(inputs)
        
        softmax = F.softmax(logits, dim=-1)  
        preds = softmax.argmax(dim=1).cpu().numpy().tolist()
        valid_preds[batch["filename"]] = preds


    def train(self, data_loader):
        self.model.train()
    
        for batch_idx, batch  in tqdm(enumerate(data_loader), desc="Training", total=len(data_loader)):  
            self.optimizer.zero_grad()
            
            loss = self.training_step(batch, batch_idx)

            loss.backward()
            self.optimizer.step()    
        print(loss.item())

    @torch.no_grad()
    def valid(self, dataset):  
        self.model.eval()  
        valid_preds = dict()
        for batch_idx, batch  in tqdm(enumerate(dataset), desc="Validation", total=len(dataset)):  
            self.validation_step(batch, valid_preds)
        
        jaccards = []
        for sample in dataset:
            flnm = sample['filename']
            if flnm in valid_preds:
                jaccards.append(get_mf_jaccard(sample=sample, 
                                           labels=valid_preds[flnm]))
        print('jaccared', np.mean(jaccards))
        return np.mean(jaccards)

In [10]:
train_data = Dataset(
    root_dir=args.dataset_path, split="train", random_rotate=args.random_rotate
)

  1%|          | 284/41730 [00:00<00:14, 2836.76it/s]

Loading train data...


100%|██████████| 41730/41730 [00:14<00:00, 2810.35it/s]


Done loading 41730 files


In [11]:
with open("/home/egor/data/janush_dataset/non_duplicated_filenames.json", 'r') as f:
    allowed = set(json.load(f))

val_data = MFTestDataset(
    root_dir="/home/egor/data/janush_dataset/converted_10/", split="test",  random_rotate=False, allow_list=allowed
)

100%|██████████| 469/469 [01:40<00:00,  4.69it/s]


Done loading 469 files


In [12]:
model = Segmentation(
    num_classes=Dataset.num_classes(), 
    crv_in_channels=args.crv_in_channels
)

In [13]:
train_loader = train_data.get_dataloader(
    batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers
)
val_loader = val_data.get_dataloader(
    batch_size=1, shuffle=False, num_workers=1
)

In [None]:
jaccards = {}
for epoch in range(1, args.max_epochs + 1):
    print(f"epoch {epoch}")
    model.train(train_loader)
    
    #map_metr = cals_map_all(test_loaders, model.model, model.device)
    #print('map@k ', map_metr)
    jaccards[epoch] = model.valid(val_data)

epoch 1


Training: 100%|██████████| 81/81 [01:31<00:00,  1.13s/it]
Validation:   4%|▎         | 17/469 [00:00<00:02, 164.61it/s]

0.07699595391750336


Validation: 100%|██████████| 469/469 [00:03<00:00, 149.63it/s]


jaccared 0.508222300710451
epoch 2


Training: 100%|██████████| 81/81 [01:31<00:00,  1.13s/it]
Validation:   3%|▎         | 15/469 [00:00<00:03, 149.19it/s]

0.07329627871513367


Validation: 100%|██████████| 469/469 [00:03<00:00, 140.40it/s]


jaccared 0.5273431388133738
epoch 3


Training: 100%|██████████| 81/81 [01:31<00:00,  1.13s/it]
Validation:   4%|▎         | 17/469 [00:00<00:02, 162.69it/s]

0.0722552090883255


Validation: 100%|██████████| 469/469 [00:03<00:00, 146.51it/s]


jaccared 0.5181583892940564
epoch 4


Training: 100%|██████████| 81/81 [01:31<00:00,  1.13s/it]
Validation:   3%|▎         | 16/469 [00:00<00:02, 155.40it/s]

0.06665526330471039


Validation: 100%|██████████| 469/469 [00:03<00:00, 139.89it/s]


jaccared 0.5197461007093894
epoch 5


Training: 100%|██████████| 81/81 [01:31<00:00,  1.13s/it]
Validation:   3%|▎         | 15/469 [00:00<00:03, 147.67it/s]

0.07357190549373627


Validation: 100%|██████████| 469/469 [00:03<00:00, 137.67it/s]


jaccared 0.5146356545490425
epoch 6


Training: 100%|██████████| 81/81 [01:31<00:00,  1.13s/it]
Validation:   3%|▎         | 16/469 [00:00<00:02, 157.23it/s]

0.05732521414756775


Validation: 100%|██████████| 469/469 [00:03<00:00, 145.01it/s]


jaccared 0.5189247198054591
epoch 7


Training: 100%|██████████| 81/81 [01:31<00:00,  1.13s/it]
Validation:   3%|▎         | 15/469 [00:00<00:03, 149.81it/s]

0.06131221354007721


Validation: 100%|██████████| 469/469 [00:03<00:00, 142.30it/s]


jaccared 0.565045913071092
epoch 8


Training: 100%|██████████| 81/81 [01:31<00:00,  1.13s/it]
Validation:   4%|▎         | 17/469 [00:00<00:02, 163.63it/s]

0.05702909454703331


Validation: 100%|██████████| 469/469 [00:03<00:00, 147.48it/s]


jaccared 0.5129363038937852
epoch 9


Training: 100%|██████████| 81/81 [01:31<00:00,  1.13s/it]
Validation:   4%|▎         | 17/469 [00:00<00:02, 164.99it/s]

0.05182715877890587


Validation: 100%|██████████| 469/469 [00:03<00:00, 149.64it/s]


jaccared 0.4955406058920413
epoch 10


In [16]:
model.valid(val_data)

Validation: 100%|██████████| 469/469 [00:03<00:00, 140.80it/s]


jaccared 0.5681343692312089


0.5681343692312089