In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch
import trimesh
from functools import partial
import ssl
import pickle
from Solvers import SolverEmbedding,Loss
from Models.EncoderModels import TextEncoder, ShapeEncoder
from config import cfg
from dataEmbedding.dataEmbedding import Read_Load_BuildBatch
from dataEmbedding.dataEmbeddingLoader import GenerateDataLoader,check_dataset,collate_embedding
from dataEmbedding.generateEmbedding import build_embeedings_CWGAN
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device=torch.device("cuda")
else:
    device=torch.device("cpu")

device=cfg.DEVICE
print(device)
PYTORCH_ENABLE_MPS_FALLBACK=1
#for mac os fix 
ssl._create_default_https_context = ssl._create_unverified_context

cpu


In [2]:
#primiData = Read_Load_BuildBatch('/primitives',batchSize,'primitives')
stanData=Read_Load_BuildBatch(cfg.EMBEDDING_VOXEL_FOLDER,cfg.EMBEDDING_BATCH_SIZE)


In [3]:
stanData.wordlens.sort()
mid = len(stanData.wordlens) // 2
res = (stanData.wordlens[mid] + stanData.wordlens[~mid]) / 2
print("Median of list is : " + str(res))
print("Mean: "+ str(sum(stanData.wordlens)/len(stanData.wordlens)))

Median of list is : 14.0
Mean: 16.307937873199133


In [4]:
train_dataset = GenerateDataLoader(stanData.data_agg_train,stanData.data_dir,stanData.dict_word2idx)

val_dataset = GenerateDataLoader(stanData.data_agg_val,stanData.data_dir,stanData.dict_word2idx)

test_dataset=GenerateDataLoader(stanData.data_agg_test,stanData.data_dir,stanData.dict_word2idx)

dataloader = {
            'train': DataLoader(
                train_dataset, 
                batch_size=cfg.EMBEDDING_BATCH_SIZE * 2,              
                drop_last=check_dataset(train_dataset, cfg.EMBEDDING_BATCH_SIZE * 2),
                collate_fn=collate_embedding,
                num_workers=4
            ),
            'val': DataLoader(
                val_dataset, 
                batch_size=cfg.EMBEDDING_BATCH_SIZE*2,
                collate_fn=collate_embedding,
                num_workers=4
            ),
            'test': DataLoader(
                test_dataset, 
                batch_size=cfg.EMBEDDING_BATCH_SIZE*2,
                collate_fn=collate_embedding
                #num_workers=2
            )
    }       


In [5]:
criterion={
        'walker': Loss.RoundTripLoss(device=device),
        'visit': Loss.AssociationLoss(device=device),
        #'metric': Loss.SmoothedMetricLoss(device=device)
        #'metric': Loss.InstanceMetricLoss()
        'metric': Loss.TripletLoss()
        #'metric':Loss.NPairLoss()
        }

In [6]:
ShapeModel=ShapeEncoder()
ShapeModel=ShapeModel.to(device)
TextModel=TextEncoder(len(stanData.dict_word2idx))
TextModel=TextModel.to(device)
optimizer = torch.optim.Adam(list(ShapeModel.parameters()) + list(TextModel.parameters()), lr=cfg.EMBEDDING_LR, weight_decay=cfg.EMBEDDING_WEIGHT_DC)
history=SolverEmbedding.Solver(TextModel,ShapeModel,dataloader,optimizer,criterion,cfg.EMBEDDING_BATCH_SIZE,device)

In [7]:
history.train(cfg.EMBEDDING_EPOCH_NR,stanData.dict_idx2word)

Epoch [1/4] starting...

Training...


0it [00:00, ?it/s]

Validating...



0it [00:00, ?it/s]

epoch [1/4] done...
------------------------summary------------------------
[train] total_loss: 0.971750
[val]   total_loss: 0.820018
[train] metric_loss_tt: 0.971750
[val]  metric_loss_tt: 0.820018
Evaluating...


0it [00:00, ?it/s]


Number of embedding: 31861
Number of shape embeddings: 1463, number of text embeddings: 30398
Dimensionality of embedding: 128
Evaluation mode: t2t

min embedding val: -0.2281501293182373
max embedding val: 0.2419724464416504
mean embedding (abs) val: 0.04498983075216001
Number of models : 1463

Example model IDs:
c07c9ca0cfbb531359c956f09c934d51
f9d9ef770e04c5772b3242897b354191
74ade89963828a37d94ed55f750426f
40e2ccbc74d0aae3b398a1cfd1079875
f37348b116d83408febad4f49b26ec52
5141810a02a145ad55f46d55537192b6
b94ea1b7a715f5052b151d8b52c53b90
795d4213e1dac276f9814818e8ac1c35
511e6440fad9bfa81fc8b86678ea0c8b
1fcc1a3a879b2a037d43e094da89ace

Using unnormalized cosine distance
Nearest neighbors on block 1
Nearest neighbors on block 2
Nearest neighbors on block 3
Nearest neighbors on block 4
Nearest neighbors on block 5
Nearest neighbors on block 6
Nearest neighbors on block 7
Nearest neighbors on block 8
Nearest neighbors on block 9
Nearest neighbors on block 10
Nearest neighbors on block 1

0it [00:00, ?it/s]

Validating...



0it [00:00, ?it/s]

epoch [2/4] done...
------------------------summary------------------------
[train] total_loss: 0.789786
[val]   total_loss: 0.786474
[train] metric_loss_tt: 0.789786
[val]  metric_loss_tt: 0.786474
Evaluating...


0it [00:00, ?it/s]


Number of embedding: 31861
Number of shape embeddings: 1463, number of text embeddings: 30398
Dimensionality of embedding: 128
Evaluation mode: t2t

min embedding val: -0.24237410724163055
max embedding val: 0.2649449110031128
mean embedding (abs) val: 0.03803766831673353
Number of models : 1463

Example model IDs:
c07c9ca0cfbb531359c956f09c934d51
f9d9ef770e04c5772b3242897b354191
74ade89963828a37d94ed55f750426f
40e2ccbc74d0aae3b398a1cfd1079875
f37348b116d83408febad4f49b26ec52
5141810a02a145ad55f46d55537192b6
b94ea1b7a715f5052b151d8b52c53b90
795d4213e1dac276f9814818e8ac1c35
511e6440fad9bfa81fc8b86678ea0c8b
1fcc1a3a879b2a037d43e094da89ace

Using unnormalized cosine distance
Nearest neighbors on block 1
Nearest neighbors on block 2
Nearest neighbors on block 3
Nearest neighbors on block 4
Nearest neighbors on block 5
Nearest neighbors on block 6
Nearest neighbors on block 7
Nearest neighbors on block 8
Nearest neighbors on block 9
Nearest neighbors on block 10
Nearest neighbors on block 

0it [00:00, ?it/s]

Validating...



0it [00:00, ?it/s]

epoch [3/4] done...
------------------------summary------------------------
[train] total_loss: 0.777568
[val]   total_loss: 0.788385
[train] metric_loss_tt: 0.777568
[val]  metric_loss_tt: 0.788385
Evaluating...


0it [00:00, ?it/s]


Number of embedding: 31861
Number of shape embeddings: 1463, number of text embeddings: 30398
Dimensionality of embedding: 128
Evaluation mode: t2t

min embedding val: -0.28095173835754395
max embedding val: 0.2969523072242737
mean embedding (abs) val: 0.062007955682598696
Number of models : 1463

Example model IDs:
c07c9ca0cfbb531359c956f09c934d51
f9d9ef770e04c5772b3242897b354191
74ade89963828a37d94ed55f750426f
40e2ccbc74d0aae3b398a1cfd1079875
f37348b116d83408febad4f49b26ec52
5141810a02a145ad55f46d55537192b6
b94ea1b7a715f5052b151d8b52c53b90
795d4213e1dac276f9814818e8ac1c35
511e6440fad9bfa81fc8b86678ea0c8b
1fcc1a3a879b2a037d43e094da89ace

Using unnormalized cosine distance
Nearest neighbors on block 1
Nearest neighbors on block 2
Nearest neighbors on block 3
Nearest neighbors on block 4
Nearest neighbors on block 5
Nearest neighbors on block 6
Nearest neighbors on block 7
Nearest neighbors on block 8
Nearest neighbors on block 9
Nearest neighbors on block 10
Nearest neighbors on block

0it [00:00, ?it/s]

Validating...



0it [00:00, ?it/s]

epoch [4/4] done...
------------------------summary------------------------
[train] total_loss: 0.773256
[val]   total_loss: 0.786554
[train] metric_loss_tt: 0.773256
[val]  metric_loss_tt: 0.786554
Evaluating...


0it [00:00, ?it/s]


Number of embedding: 31861
Number of shape embeddings: 1463, number of text embeddings: 30398
Dimensionality of embedding: 128
Evaluation mode: t2t

min embedding val: -0.32634779810905457
max embedding val: 0.4207761287689209
mean embedding (abs) val: 0.059469626737170905
Number of models : 1463

Example model IDs:
c07c9ca0cfbb531359c956f09c934d51
f9d9ef770e04c5772b3242897b354191
74ade89963828a37d94ed55f750426f
40e2ccbc74d0aae3b398a1cfd1079875
f37348b116d83408febad4f49b26ec52
5141810a02a145ad55f46d55537192b6
b94ea1b7a715f5052b151d8b52c53b90
795d4213e1dac276f9814818e8ac1c35
511e6440fad9bfa81fc8b86678ea0c8b
1fcc1a3a879b2a037d43e094da89ace

Using unnormalized cosine distance
Nearest neighbors on block 1
Nearest neighbors on block 2
Nearest neighbors on block 3
Nearest neighbors on block 4
Nearest neighbors on block 5
Nearest neighbors on block 6
Nearest neighbors on block 7
Nearest neighbors on block 8
Nearest neighbors on block 9
Nearest neighbors on block 10
Nearest neighbors on block

In [21]:
loader={'train' : GenerateDataLoader(stanData.train,stanData.data_dir,stanData.dict_word2idx),
        'test' : GenerateDataLoader(stanData.test,stanData.data_dir,stanData.dict_word2idx),
        'val':GenerateDataLoader(stanData.val,stanData.data_dir,stanData.dict_word2idx)}

dataloader = {
            'train': DataLoader(
                loader['train'], 
                batch_size=cfg.EMBEDDING_BATCH_SIZE * 2,              
                drop_last=check_dataset(loader['train'], cfg.EMBEDDING_BATCH_SIZE * 2),
                collate_fn=collate_embedding,
                num_workers=4
            ),
            'val': DataLoader(
                loader['val'], 
                batch_size=cfg.EMBEDDING_BATCH_SIZE*2,
                collate_fn=collate_embedding,
                num_workers=4
            ),
            'test': DataLoader(
                loader['test'], 
                batch_size=cfg.EMBEDDING_BATCH_SIZE*2,
                collate_fn=collate_embedding
                #num_workers=2
            )
    }   
build_embeedings_CWGAN(cfg.EMBEDDING_TEXT_MODELS_PATH,TextEncoder(len(stanData.dict_word2idx)),dataloader,cfg.EMBEDDING_SAVE_PATH,cfg.DEVICE)


100%|██████████| 233/233 [00:52<00:00,  4.44it/s]
100%|██████████| 30/30 [00:03<00:00,  9.25it/s]
100%|██████████| 30/30 [00:27<00:00,  1.08it/s]
