In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import torch
import trimesh
from functools import partial
import ssl
from Solvers import SolverEmbedding,Loss
from Models.EncoderModels import TextEncoder, ShapeEncoder
from config import cfg
from dataEmbedding.dataEmbedding import Read_Load_BuildBatch
from dataEmbedding.dataEmbeddingLoader import GenerateDataLoader,check_dataset,collate_embedding
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device=torch.device("cuda")
else:
    device=torch.device("cpu")

device=cfg.DEVICE
print(device)

#for mac os fix 
ssl._create_default_https_context = ssl._create_unverified_context

cpu


In [2]:
#primiData = Read_Load_BuildBatch('/primitives',batchSize,'primitives')
stanData=Read_Load_BuildBatch(cfg.EMBEDDING_VOXEL_FOLDER,cfg.EMBEDDING_BATCH_SIZE)


In [3]:
stanData.wordlens.sort()
mid = len(stanData.wordlens) // 2
res = (stanData.wordlens[mid] + stanData.wordlens[~mid]) / 2
print("Median of list is : " + str(res))
print("Mean: "+ str(sum(stanData.wordlens)/len(stanData.wordlens)))

Median of list is : 14.0
Mean: 16.307937873199133


In [3]:
train_dataset = GenerateDataLoader(stanData.data_agg_train,stanData.data_dir,stanData.dict_word2idx,'train')

val_dataset = GenerateDataLoader(stanData.data_agg_val,stanData.data_dir,stanData.dict_word2idx,'val')

test_dataset=GenerateDataLoader(stanData.data_agg_test,stanData.data_dir,stanData.dict_word2idx,'test')

dataloader = {
            'train': DataLoader(
                train_dataset, 
                batch_size=cfg.EMBEDDING_BATCH_SIZE * 2,              
                drop_last=check_dataset(train_dataset, cfg.EMBEDDING_BATCH_SIZE * 2),
                collate_fn=collate_embedding,
                num_workers=4
            ),
            'val': DataLoader(
                val_dataset, 
                batch_size=cfg.EMBEDDING_BATCH_SIZE*2,
                collate_fn=collate_embedding,
                num_workers=4
            ),
            'test': DataLoader(
                test_dataset, 
                batch_size=cfg.EMBEDDING_BATCH_SIZE*2,
                collate_fn=collate_embedding
                #num_workers=2
            )
    }       


In [None]:
for iter,(_,labels,texts,_,shapes) in enumerate(dataloader['train']):
    
    batch_size = shapes.size(0)
    texts = texts.to(device)
    text_labels = labels.to(device)

    shapes = shapes.to(device).index_select(0, torch.LongTensor([i * 2 for i in range(batch_size // 2)]).to(device))
    shape_labels = labels.to(device).index_select(0, torch.LongTensor([i * 2 for i in range(batch_size // 2)]).to(device))
        

    #s = ShapeModel(shapes)
    
    print(texts)
    #print(texts.shape)

    if iter==5:
      break

In [4]:
criterion={
        'walker': Loss.RoundTripLoss(device=device),
        'visit': Loss.AssociationLoss(device=device),
        #'metric': Loss.SmoothedMetricLoss(device=device)
        'metric': Loss.InstanceMetricLoss()
        }

In [5]:
ShapeModel=ShapeEncoder()
ShapeModel=ShapeModel.to(device)
TextModel=TextEncoder(len(stanData.dict_word2idx))
TextModel=TextModel.to(device)
optimizer = torch.optim.Adam(list(ShapeModel.parameters()) + list(TextModel.parameters()), lr=cfg.EMBEDDING_LR, weight_decay=cfg.EMBEDDING_WEIGHT_DC)
history=SolverEmbedding.Solver(TextModel,ShapeModel,dataloader,optimizer,criterion,cfg.EMBEDDING_BATCH_SIZE,device)

In [6]:
history.train(cfg.EMBEDDING_EPOCH_NR,stanData.dict_idx2word)

epoch [1/4] starting...



  0%|          | 0/944 [00:00<?, ?it/s]

77.17841339111328


  0%|          | 0/944 [00:21<?, ?it/s]
100%|██████████| 114/114 [00:59<00:00,  1.92it/s]


evaluating...

Number of embedding: 30674
Number of shape embeddings: 1490, number of text embeddings: 29184
Dimensionality of embedding: 128
Evaluation mode: t2t

min embedding val: -0.11785503476858139
max embedding val: 0.10447129607200623
mean embedding (abs) val: 0.038610791471488645
Number of models : 1490

Example model IDs:
c356393b27c3fbca34ee3fb22432c207
445528514535ca621d5ccc40b510e4bd
66c791cf5f1e61a09753496ba23f2183
780809a0d1b68f4a8ef4ac3a24abb05b
cc3f1a06508f2ebd1aed2875db0a8711
3aebb428c4f378174078a3e6d5ee40f4
63f6ff0ad9cf9d17adb532bf77da46c2
a98482ce1ac411406b2cda27b9d80e15
4fe364b3390d3e158afe76b3d612e00b
325d922d3092f7bfc3bd24f986301745

Using unnormalized cosine distance
Nearest neighbors on block 1
Nearest neighbors on block 2
Nearest neighbors on block 3
Nearest neighbors on block 4
Nearest neighbors on block 5
Nearest neighbors on block 6
Nearest neighbors on block 7
Nearest neighbors on block 8
Nearest neighbors on block 9
Nearest neighbors on block 10
Computing

  0%|          | 0/944 [00:00<?, ?it/s]

78.42781829833984


  0%|          | 0/944 [00:21<?, ?it/s]
  2%|▏         | 2/114 [00:01<00:59,  1.87it/s]


KeyboardInterrupt: 