In [1]:
import torch
import json
import torch
import os
import numpy as np
from PIL import Image
import random
import math
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader


In [2]:
dataset_root = '/home/rliu/Desktop/cvfiler04/datasets/objaverse/hf-objaverse-v1/views_whole_sphere'

# dataset

In [3]:
class objaverse_sfm(Dataset):
    def __init__(self, root_dir, total_view, train=True, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        with open(os.path.join(root_dir, 'valid_paths.json')) as f:
            self.paths = json.load(f)
        random.shuffle(self.paths)
        self.total_view = total_view
        self.train = train
        total_objects = len(self.paths)
        if train:
            self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training|
        else:
            self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
        
    def __len__(self):
        return len(self.img_labels)

    def __len__(self):
        return len(self.paths)
        
    def cartesian_to_spherical(self, xyz):
        ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
        xy = xyz[:,0]**2 + xyz[:,1]**2
        z = np.sqrt(xy + xyz[:,2]**2)
        theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
        #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
        azimuth = np.arctan2(xyz[:,1], xyz[:,0])
        return np.array([theta, azimuth, z])

    def get_T(self, target_RT, cond_RT):
        R, T = target_RT[:3, :3], target_RT[:, -1]
        T_target = -R.T @ T

        R, T = cond_RT[:3, :3], cond_RT[:, -1]
        T_cond = -R.T @ T

        theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
        theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
        
        d_theta = theta_target - theta_cond
        d_azimuth = azimuth_target - azimuth_cond
        d_z = z_target - z_cond
        
        d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
        return d_T

    def load_im(self, path):
        '''
        replace background pixel with white in rendering
        '''
        img = plt.imread(path)
        img[img[:, :, -1] == 0.] = [1., 1., 1., 1.]
        img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
        return img

    def __getitem__(self, index):
        data = {}
        index_target, index_cond = random.sample(range(self.total_view), 2) # without replacement
        filename = os.path.join(self.root_dir, self.paths[index])

        # print(self.paths[index])

#         if self.return_paths:
#             data["path"] = str(filename)
            
        target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target)))
        target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
        cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond)))
        cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond))

        data["image_target"] = target_im
        data["image_cond"] = cond_im
        data["T"] = self.get_T(target_RT, cond_RT)

        return data
    
    def process_im(self, im):
        im = im.convert("RGB")
        im = self.transform(im)
        im = torchvision.transforms.functional.resize(im, 224)
        im = im * 2. - 1.
        return im

In [4]:
dataset = objaverse_sfm(dataset_root, 4, train=True, transform = ToTensor())

In [5]:
dataset[0]['image_target'].shape

torch.Size([3, 224, 224])

In [6]:
train_dataloader = DataLoader(objaverse_sfm(dataset_root, 4, train=True, transform = ToTensor()),\
                              batch_size=16, shuffle=True)
test_dataloader = DataLoader(objaverse_sfm(dataset_root, 4, train=False, transform = ToTensor()),\
                             batch_size=16, shuffle=False)

In [7]:
batch = next(iter(train_dataloader))

In [8]:
batch['image_target'].shape

torch.Size([16, 3, 224, 224])

# model

In [9]:
class sfm(torch.nn.Module):
    def __init__(self):
        super(sfm, self).__init__()
    
        self.cond_cnn = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.DEFAULT)
        self.cond_cnn.fc = torch.nn.Sequential()
        self.target_cnn = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.DEFAULT)
        self.target_cnn.fc = torch.nn.Sequential()
        self.linear1 = torch.nn.Linear(1024, 1024)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(1024, 4)

    def forward(self, cond, target):
        cond_feature = self.cond_cnn(cond)
        target_feature = self.target_cnn(target)
        x = torch.cat([cond_feature, target_feature], dim=-1)
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        return x

In [10]:
model = sfm()



In [12]:
image_target, image_cond, T = batch['image_target'], batch['image_cond'], batch['T']

In [13]:
image_target.shape, image_cond.shape, T.shape

(torch.Size([16, 3, 224, 224]),
 torch.Size([16, 3, 224, 224]),
 torch.Size([16, 4]))

In [14]:
pred = model(image_cond, image_target)

In [15]:
loss = ((pred - T) ** 2).mean()

In [16]:
loss.backward()

# training

In [129]:
train_dataloader = DataLoader(objaverse_sfm(dataset_root, 4, train=True, transform = ToTensor()),\
                              batch_size=16, shuffle=True, num_workers=4)
test_dataloader = DataLoader(objaverse_sfm(dataset_root, 4, train=False, transform = ToTensor()),\
                             batch_size=16, shuffle=False, num_workers=4)

model = sfm().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for epoch in range(100):  # loop over the dataset multiple times

    running_loss = 0.0
    
    for i, batch in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        image_target, image_cond, T_gt = batch['image_target'].cuda(), batch['image_cond'].cuda(), batch['T'].cuda()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        T_pred = model(image_cond, image_target)
        loss = (T_pred - T_gt).abs().mean()
        loss.backward()
#         print(loss.item())
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # print every 2000 mini-batches
            print(f'[{i + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

[100,   100] loss: 0.566
[200,   200] loss: 0.537
[300,   300] loss: 0.516
[400,   400] loss: 0.517
[500,   500] loss: 0.509
[600,   600] loss: 0.500
[700,   700] loss: 0.491
[800,   800] loss: 0.493
[900,   900] loss: 0.495
[1000,  1000] loss: 0.485
[1100,  1100] loss: 0.492
[1200,  1200] loss: 0.484
[1300,  1300] loss: 0.487
[1400,  1400] loss: 0.482
[1500,  1500] loss: 0.486
[1600,  1600] loss: 0.485
[1700,  1700] loss: 0.484
[1800,  1800] loss: 0.481
[1900,  1900] loss: 0.485
[2000,  2000] loss: 0.484
[2100,  2100] loss: 0.482
[2200,  2200] loss: 0.483
[2300,  2300] loss: 0.482
[2400,  2400] loss: 0.485
[2500,  2500] loss: 0.478
[2600,  2600] loss: 0.483
[2700,  2700] loss: 0.480
[2800,  2800] loss: 0.481
[2900,  2900] loss: 0.479
[3000,  3000] loss: 0.476
[3100,  3100] loss: 0.477
[3200,  3200] loss: 0.482
[3300,  3300] loss: 0.474
[3400,  3400] loss: 0.482
[3500,  3500] loss: 0.479
[3600,  3600] loss: 0.477
[3700,  3700] loss: 0.474
[3800,  3800] loss: 0.479
[3900,  3900] loss: 0

libpng error: IDAT: CRC error
libpng error: failed to read file


ConnectionAbortedError: Caught ConnectionAbortedError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/rliu/anaconda3/lib/python3.9/site-packages/matplotlib/image.py", line 1434, in imread
    return handler(fd)
  File "/home/rliu/anaconda3/lib/python3.9/site-packages/matplotlib/image.py", line 1390, in read_png
    return _png.read_png(*args, **kwargs)
RuntimeError: error setting jump

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/rliu/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/rliu/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/rliu/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_1159826/381479019.py", line 69, in __getitem__
    cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond)))
  File "/tmp/ipykernel_1159826/381479019.py", line 52, in load_im
    img = plt.imread(path)
  File "/home/rliu/anaconda3/lib/python3.9/site-packages/matplotlib/pyplot.py", line 2135, in imread
    return matplotlib.image.imread(fname, format)
  File "/home/rliu/anaconda3/lib/python3.9/site-packages/matplotlib/image.py", line 1434, in imread
    return handler(fd)
ConnectionAbortedError: [Errno 103] Software caused connection abort
