In [22]:
import torch.utils.data as data
import numpy as np
import os, sys
import data_transforms

import random
import os
import json
import torch
import torchvision.transforms as transforms
import pickle
import math
import time
import io123
IO = io123.IO


In [32]:

def rotation_z(pts, theta):
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    rotation_matrix = np.array([[cos_theta, -sin_theta, 0.0],
                                [sin_theta, cos_theta, 0.0],
                                [0.0, 0.0, 1.0]])
    return pts @ rotation_matrix.T


def rotation_y(pts, theta):
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    rotation_matrix = np.array([[cos_theta, 0.0, -sin_theta],
                                [0.0, 1.0, 0.0],
                                [sin_theta, 0.0, cos_theta]])
    return pts @ rotation_matrix.T


def rotation_x(pts, theta):
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    rotation_matrix = np.array([[1.0, 0.0, 0.0],
                                [0.0, cos_theta, -sin_theta],
                                [0.0, sin_theta, cos_theta]])
    return pts @ rotation_matrix.T


class Shapenet_ViPC(data.Dataset):
    # def __init__(self, data_root, subset, class_choice = None):
    def __init__(self, subset,View_align):
        self.partial_points_path = "/home_nfs/fucheng.niu/Data_time_test/data/Partial/%s/%s/%s.dat"
        self.complete_points_path = "/home_nfs/fucheng.niu/Data_time_test/data/GT/%s/%s/%s.dat"
        self.view_path = "/home_nfs/fucheng.niu/Data_time_test/data/view/%s/%s/rendering/%s.png"
        self.category_file = "%s_list2.txt"
        self.npoints = 3500
        self.subset = subset
        self.category = "all" 
        self.cat_map = {
            'plane':'02691156',
            'bench': '02828884', 
            'cabinet':'02933112', 
            'car':'02958343',
            'chair':'03001627',
            'monitor': '03211117',
            'lamp':'03636649',
            'speaker': '03691459', 
            'firearm': '04090263', 
            'couch':'04256520',
            'table':'04379243',
            'cellphone': '04401088', 
            'watercraft':'04530566'
        }
        self.filelist = []
        self.cat = []
        self.key = []
        self.filepath = self.category_file % self.subset
        self.view_align = View_align

        with open(self.filepath,'r') as f:
            line = f.readline()
            while (line):
                self.filelist.append(line)
                line = f.readline()

        for key in self.filelist:
            if self.category !='all':
                if key.split(';')[0]!= self.cat_map[self.category]:
                    continue
            self.cat.append(key.split(';')[0])
            self.key.append(key)
        self.img_transforms = self._img_get_transforms(self.subset)

    def _img_get_transforms(self,subset):
        if subset == 'train':
            transform1 = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),   
               
            ])
        else:
            transform1 = transforms.Compose([ 
                transforms.Resize(224),
                transforms.ToTensor(),   
            ])
        return transform1


    def __getitem__(self, idx):
        
        key = self.key[idx]
        pc_part_path = self.partial_points_path % (key.split(';')[0],key.split(';')[1], key.split(';')[-1].replace('\n', ''))
        # pc_part_path = os.path.join(self.imcomplete_path,key.split(';')[0]+'/'+ key.split(';')[1]+'/'+key.split(';')[-1].replace('\n', '')+'.dat')
        
        if self.view_align:
            ran_key = key        
        else:
            ran_key = key[:-3]+str(random.randint(0,23)).rjust(2,'0')
        
        pc_path = self.complete_points_path % (ran_key.split(';')[0], ran_key.split(';')[1], ran_key.split(';')[-1].replace('\n', ''))
        # pc_path = os.path.join(self.gt_path, ran_key.split(';')[0]+'/'+ ran_key.split(';')[1]+'/'+ran_key.split(';')[-1].replace('\n', '')+'.dat')
        view_path = self.view_path % (ran_key.split(';')[0], ran_key.split(';')[1], ran_key.split(';')[-1].replace('\n','')) 
        if(len(ran_key.split(';')[-1])>3):
            print("bug")
            print(ran_key.split(';')[-1])
            fin = ran_key.split(';')[-1][-2:]
            interm = ran_key.split(';')[-1][:-2]
            pc_path = self.complete_points_path % (ran_key.split(';')[0],  interm +'/', fin.replace('\n', ''))          
            view_path = self.view_path % (ran_key.split(';')[0]+ '/', interm , fin.replace('\n',''))


        views = self.img_transforms(IO.get(view_path))
        views = views[:3,:,:]
        # load gt points
        with open(pc_path,'rb') as f:
            pc = pickle.load(f).astype(np.float32)
        # load partial points
        with open(pc_part_path,'rb') as f:
            pc_part = pickle.load(f).astype(np.float32)
        # incase some item point number less than 3500 
        if pc_part.shape[0]<self.npoints:
            pc_part = np.repeat(pc_part,(self.npoints//pc_part.shape[0])+1,axis=0)[0:self.npoints]


        # load the view metadata
        image_view_id = view_path.split('/')[-1].split('.')[0]
        part_view_id = pc_part_path.split('/')[-1].split('.')[0]
        # print(pc_path)
        # print(view_path)
        # print(image_view_id)
        # print(part_view_id)
        view_metadata = np.loadtxt(view_path[:-6]+'rendering_metadata.txt')

        theta_part = math.radians(view_metadata[int(part_view_id),0])
        phi_part = math.radians(view_metadata[int(part_view_id),1])

        theta_img = math.radians(view_metadata[int(image_view_id),0])
        phi_img = math.radians(view_metadata[int(image_view_id),1])

        pc_part = rotation_y(rotation_x(pc_part, - phi_part),np.pi + theta_part)
        pc_part = rotation_x(rotation_y(pc_part, np.pi - theta_img), phi_img)

        # normalize partial point cloud and GT to the same scale
        gt_mean = pc.mean(axis=0) 
        pc = pc - gt_mean
        pc_L_max = np.max(np.sqrt(np.sum(abs(pc ** 2), axis=-1)))
        pc = pc/pc_L_max

        pc_part = pc_part-gt_mean
        pc_part = pc_part/pc_L_max

        return ran_key.split(';')[0], ran_key.split(';')[1], (torch.from_numpy(pc_part).float(), torch.from_numpy(pc).float(),views.float()),torch.tensor(0), torch.tensor(0)
    
    def __len__(self):
        return len(self.key)






In [33]:
train_data = Shapenet_ViPC("train",True)
test_data = Shapenet_ViPC("test",True)

In [None]:
bs = 64
shuffle = 1
num_workers = 0

In [16]:
def worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [46]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=bs,
                                                    shuffle = shuffle, 
                                                    drop_last = True,
                                                    num_workers = num_workers,
                                                    worker_init_fn=worker_init_fn,
                                                    pin_memory = True)

test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=bs,
                                                    shuffle = shuffle, 
                                                    drop_last = False,
                                                    num_workers = num_workers,
                                                    worker_init_fn=worker_init_fn,
                                                    pin_memory = True)

In [48]:
max_epoch =10

all_data_times = []

for epoch in range(0, max_epoch + 1):
    batch_start_time = time.time()
    n_batches = len(train_dataloader)
    epoch_data_times = []
    for idx, (taxonomy_ids, model_ids, data,_,_) in enumerate(train_dataloader):
            # print("data2:",data[0].shape)
            data_time= time.time() - batch_start_time 
            

            print('[Epoch %d/%d][Batch %d/%d] DataTime = %.3f (s)' %
                            (epoch, max_epoch, idx + 1, n_batches, data_time,))
            epoch_data_times.append(data_time)
            batch_start_time = time.time()
all_data_times.extend(epoch_data_times)
                            

[Epoch 0/10][Batch 1/3] DataTime = 9.044 (s)
[Epoch 0/10][Batch 2/3] DataTime = 0.502 (s)
[Epoch 0/10][Batch 3/3] DataTime = 0.530 (s)
[Epoch 1/10][Batch 1/3] DataTime = 0.653 (s)
[Epoch 1/10][Batch 2/3] DataTime = 0.514 (s)
[Epoch 1/10][Batch 3/3] DataTime = 0.585 (s)
[Epoch 2/10][Batch 1/3] DataTime = 0.652 (s)
[Epoch 2/10][Batch 2/3] DataTime = 0.607 (s)
[Epoch 2/10][Batch 3/3] DataTime = 0.708 (s)
[Epoch 3/10][Batch 1/3] DataTime = 0.630 (s)
[Epoch 3/10][Batch 2/3] DataTime = 0.431 (s)
[Epoch 3/10][Batch 3/3] DataTime = 0.446 (s)
[Epoch 4/10][Batch 1/3] DataTime = 0.342 (s)
[Epoch 4/10][Batch 2/3] DataTime = 0.248 (s)
[Epoch 4/10][Batch 3/3] DataTime = 0.247 (s)
[Epoch 5/10][Batch 1/3] DataTime = 0.362 (s)
[Epoch 5/10][Batch 2/3] DataTime = 0.246 (s)
[Epoch 5/10][Batch 3/3] DataTime = 0.249 (s)
[Epoch 6/10][Batch 1/3] DataTime = 0.352 (s)
[Epoch 6/10][Batch 2/3] DataTime = 0.213 (s)
[Epoch 6/10][Batch 3/3] DataTime = 0.230 (s)
[Epoch 7/10][Batch 1/3] DataTime = 0.355 (s)
[Epoch 7/1

In [49]:
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.plot(all_data_times, label="Data Time per Batch", color='b')
mean_data_time = sum(all_data_times) / len(all_data_times)
plt.axhline(mean_data_time, color='r', linestyle='--', label=f"Mean Data Time = {mean_data_time:.3f} s")

plt.xlabel('Batch')
plt.ylabel('Data Time (s)')
plt.title('Data Time per Batch across Epochs')
plt.legend()
plt.grid(True)
plt.show()

NameError: name 'plt' is not defined