In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch
import trimesh
from functools import partial
import ssl
from Solvers import SolverEmbedding,Loss
from Models.EncoderModels import TextEncoder, ShapeEncoder
from config import cfg
from dataEmbedding.dataEmbedding import Read_Load_BuildBatch
from dataEmbedding.dataEmbeddingLoader import GenerateDataLoader,check_dataset,collate_embedding
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device=torch.device("cuda")
else:
    device=torch.device("cpu")

device=cfg.DEVICE
print(device)

#for mac os fix 
ssl._create_default_https_context = ssl._create_unverified_context

cpu


In [2]:
#primiData = Read_Load_BuildBatch('/primitives',batchSize,'primitives')
stanData=Read_Load_BuildBatch(cfg.EMBEDDING_VOXEL_FOLDER,cfg.EMBEDDING_BATCH_SIZE)


In [3]:
stanData.wordlens.sort()
mid = len(stanData.wordlens) // 2
res = (stanData.wordlens[mid] + stanData.wordlens[~mid]) / 2
print("Median of list is : " + str(res))
print("Mean: "+ str(sum(stanData.wordlens)/len(stanData.wordlens)))

Median of list is : 14.0
Mean: 16.307937873199133


In [3]:
train_dataset = GenerateDataLoader(stanData.data_agg_train,stanData.data_dir,stanData.dict_word2idx,'train')

val_dataset = GenerateDataLoader(stanData.data_agg_val,stanData.data_dir,stanData.dict_word2idx,'val')

test_dataset=GenerateDataLoader(stanData.data_agg_test,stanData.data_dir,stanData.dict_word2idx,'test')

dataloader = {
            'train': DataLoader(
                train_dataset, 
                batch_size=cfg.EMBEDDING_BATCH_SIZE * 2,              
                drop_last=check_dataset(train_dataset, cfg.EMBEDDING_BATCH_SIZE * 2),
                collate_fn=collate_embedding,
                num_workers=4
            ),
            'val': DataLoader(
                val_dataset, 
                batch_size=cfg.EMBEDDING_BATCH_SIZE*2,
                collate_fn=collate_embedding,
                num_workers=4
            ),
            'test': DataLoader(
                test_dataset, 
                batch_size=cfg.EMBEDDING_BATCH_SIZE*2,
                collate_fn=collate_embedding
                #num_workers=2
            )
    }       


In [None]:
for iter,(_,labels,texts,_,shapes) in enumerate(dataloader['train']):
    
    batch_size = shapes.size(0)
    texts = texts.to(device)
    text_labels = labels.to(device)

    shapes = shapes.to(device).index_select(0, torch.LongTensor([i * 2 for i in range(batch_size // 2)]).to(device))
    shape_labels = labels.to(device).index_select(0, torch.LongTensor([i * 2 for i in range(batch_size // 2)]).to(device))
        

    #s = ShapeModel(shapes)
    
    print(texts)
    #print(texts.shape)

    if iter==5:
      break

In [4]:
criterion={
        'walker': Loss.RoundTripLoss(device=device),
        'visit': Loss.AssociationLoss(device=device),
        #'metric': Loss.SmoothedMetricLoss(device=device)
        'metric': Loss.InstanceMetricLoss()
        }

In [5]:
ShapeModel=ShapeEncoder()
ShapeModel=ShapeModel.to(device)
TextModel=TextEncoder(len(stanData.dict_word2idx))
TextModel=TextModel.to(device)
optimizer = torch.optim.Adam(list(ShapeModel.parameters()) + list(TextModel.parameters()), lr=cfg.EMBEDDING_LR, weight_decay=cfg.EMBEDDING_WEIGHT_DC)
history=SolverEmbedding.Solver(TextModel,ShapeModel,dataloader,optimizer,criterion,cfg.EMBEDDING_BATCH_SIZE,device)

In [6]:
history.train(cfg.EMBEDDING_EPOCH_NR,stanData.dict_idx2word)

Epoch [1/4] starting...

Training...


0it [00:00, ?it/s]

Validating...



0it [00:00, ?it/s]

epoch [1/4] done...
------------------------summary------------------------
[train] total_loss: 77.177933
[val]   total_loss: 84.423693
[train] walker_loss_tst: 5.545174, walker_loss_sts: 4.852016
[val]   walker_loss_tst: 5.544910, walker_loss_sts: 4.852371
[train] visit_loss_ts: 1.213008, visit_loss_st: 1.386294
[val]   visit_loss_ts: 1.213294, visit_loss_st: 1.386295
[train] metric_loss_st: 42.787445, metric_loss_tt: 21.393993
[val]   metric_loss_st: 42.789116, metric_loss_tt: 21.393992
[train] shape_norm_penalty: 0.000000, text_norm_penalty: 0.000000
[val]   shape_norm_penalty: 7.243717, text_norm_penalty: 0.000000

Evaluating...


0it [00:00, ?it/s]


Number of embedding: 30674
Number of shape embeddings: 1490, number of text embeddings: 29184
Dimensionality of embedding: 128
Evaluation mode: t2t

min embedding val: -0.09974249452352524
max embedding val: 0.1125732958316803
mean embedding (abs) val: 0.03281861741604581
Number of models : 1490

Example model IDs:
933096cbd0f7ef0aa73562d75299fcd8
b11993c9e7d9ab970077d80217bffad
9271bb0cab9365d44b3c42e318f3affc
c5178a8a0da618a25d78ff7fb413274d
9e55b1135ddf93211c8d18742f91c015
63d92bf1f175a75a25ffbad401072b4d
b192cda468f9390aa3f22b4b00de6dfb
c755eeaa4a588fcba9126dd5adc92c1e
738395f54b301d80b1f5d603f931c1aa
62a4f3c24bc69f593eff95e5c4b79279

Using unnormalized cosine distance
Nearest neighbors on block 1
Nearest neighbors on block 2
Nearest neighbors on block 3
Nearest neighbors on block 4
Nearest neighbors on block 5
Nearest neighbors on block 6
Nearest neighbors on block 7
Nearest neighbors on block 8
Nearest neighbors on block 9
Nearest neighbors on block 10
Computing precision recall

0it [00:00, ?it/s]

Validating...



0it [00:00, ?it/s]