In [7]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch
import trimesh
from functools import partial
import ssl
from Solvers import SolverEmbedding,Loss
from Models.EncoderModels import TextEncoder, ShapeEncoder
from config import cfg
from dataEmbedding.dataEmbedding import Read_Load_BuildBatch
from dataEmbedding.dataEmbeddingLoader import GenerateDataLoader,check_dataset,collate_embedding
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device=torch.device("cuda")
else:
    device=torch.device("cpu")

device=cfg.DEVICE
print(device)
PYTORCH_ENABLE_MPS_FALLBACK=1
#for mac os fix 
ssl._create_default_https_context = ssl._create_unverified_context

cpu


In [8]:
#primiData = Read_Load_BuildBatch('/primitives',batchSize,'primitives')
stanData=Read_Load_BuildBatch(cfg.EMBEDDING_VOXEL_FOLDER,cfg.EMBEDDING_BATCH_SIZE)


In [3]:
stanData.wordlens.sort()
mid = len(stanData.wordlens) // 2
res = (stanData.wordlens[mid] + stanData.wordlens[~mid]) / 2
print("Median of list is : " + str(res))
print("Mean: "+ str(sum(stanData.wordlens)/len(stanData.wordlens)))

Median of list is : 14.0
Mean: 16.307937873199133


In [9]:
train_dataset = GenerateDataLoader(stanData.data_agg_train,stanData.data_dir,stanData.dict_word2idx)

val_dataset = GenerateDataLoader(stanData.data_agg_val,stanData.data_dir,stanData.dict_word2idx)

test_dataset=GenerateDataLoader(stanData.data_agg_test,stanData.data_dir,stanData.dict_word2idx)

dataloader = {
            'train': DataLoader(
                train_dataset, 
                batch_size=cfg.EMBEDDING_BATCH_SIZE * 2,              
                drop_last=check_dataset(train_dataset, cfg.EMBEDDING_BATCH_SIZE * 2),
                collate_fn=collate_embedding,
                num_workers=4
            ),
            'val': DataLoader(
                val_dataset, 
                batch_size=cfg.EMBEDDING_BATCH_SIZE*2,
                collate_fn=collate_embedding,
                num_workers=4
            ),
            'test': DataLoader(
                test_dataset, 
                batch_size=cfg.EMBEDDING_BATCH_SIZE*2,
                collate_fn=collate_embedding
                #num_workers=2
            )
    }       


In [4]:
for iter,(_,labels,texts,_,shapes) in enumerate(dataloader['train']):
    if cfg.EMBEDDING_SHAPE_ENCODER:

        batch_size = shapes.size(0)
    texts = texts.to(device)
    text_labels = labels.to(device)
    
    if cfg.EMBEDDING_SHAPE_ENCODER:

        shapes = shapes.to(device).index_select(0, torch.LongTensor([i * 2 for i in range(batch_size // 2)]).to(device))
        shape_labels = labels.to(device).index_select(0, torch.LongTensor([i * 2 for i in range(batch_size // 2)]).to(device))
        

    #s = ShapeModel(shapes)
    
    print(texts)
    #print(texts.shape)

    if iter==5:
      break

tensor([[141,  38, 134,  ...,   9, 363,   4],
        [ 35, 141,  48,  ...,   0,   0,   0],
        [387,   3,  17,  ...,  10,   0,   0],
        ...,
        [  2,   3,   4,  ..., 285,  83,  10],
        [ 48,  21,  38,  ...,   0,   0,   0],
        [  2,   3,   4,  ...,  37,  35, 217]], device='mps:0')
tensor([[  35,  190,   17,  ..., 1822,   88,   12],
        [  48,   21,  482,  ...,    0,    0,    0],
        [ 106,  125,   21,  ...,    0,    0,    0],
        ...,
        [  18,    4,  572,  ..., 1245,   17,  285],
        [  35,   12,   17,  ...,    0,    0,    0],
        [ 181,   30,  493,  ...,    0,    0,    0]], device='mps:0')
tensor([[ 12,  13,  14,  ...,   0,   0,   0],
        [ 35, 239,  34,  ...,  14,   2,  12],
        [ 18, 318,  35,  ...,   9,   0,   0],
        ...,
        [ 35,   5, 108,  ...,   0,   0,   0],
        [ 82,  83, 101,  ...,  10,   0,   0],
        [ 48,  21,  82,  ...,  37,  73,   9]], device='mps:0')
tensor([[ 120,   48,   38,  ...,    0,    0,  

In [10]:
criterion={
        'walker': Loss.RoundTripLoss(device=device),
        'visit': Loss.AssociationLoss(device=device),
        #'metric': Loss.SmoothedMetricLoss(device=device)
        #'metric': Loss.InstanceMetricLoss()
        'metric': Loss.TripletLoss()
        'metric':Loss.NPairLoss()
        }

In [11]:
ShapeModel=ShapeEncoder()
ShapeModel=ShapeModel.to(device)
TextModel=TextEncoder(len(stanData.dict_word2idx))
TextModel=TextModel.to(device)
optimizer = torch.optim.Adam(list(ShapeModel.parameters()) + list(TextModel.parameters()), lr=cfg.EMBEDDING_LR, weight_decay=cfg.EMBEDDING_WEIGHT_DC)
history=SolverEmbedding.Solver(TextModel,ShapeModel,dataloader,optimizer,criterion,cfg.EMBEDDING_BATCH_SIZE,device)

In [12]:
history.train(cfg.EMBEDDING_EPOCH_NR,stanData.dict_idx2word)