In [1]:
import torch
import torch.optim as optim

import argparse
import time
import random
import dsacstar
import os

from network import Network
import datasets
from utils import tr, reverse_tr
import pickle
from torch.utils.tensorboard import SummaryWriter

In [2]:
dataset = datasets.SevenScenesDataset(f'/mundus/mrahman527/projects/homography-loss-function/datasets/7-Scenes/fire', 0.025, 0.975)


Loading seq-01


100%|███████████████████████████████████████| 1000/1000 [01:07<00:00, 14.73it/s]


Loading seq-02


100%|███████████████████████████████████████| 1000/1000 [01:10<00:00, 14.10it/s]


Loading seq-03


100%|███████████████████████████████████████| 1000/1000 [01:05<00:00, 15.16it/s]


Loading seq-04


100%|███████████████████████████████████████| 1000/1000 [01:05<00:00, 15.21it/s]


Sorting depths, this may take a while...


In [3]:
def compute_ABC(w_t_c, c_R_w, w_t_chat, chat_R_w, c_n, eye):
    """
    Computes A, B, and C matrix given estimated and ground truth poses
    and normal vector n.
    `w_t_c` and `w_t_chat` must have shape (batch_size, 3, 1).
    `c_R_w` and `chat_R_w` must have shape (batch_size, 3, 3).
    `n` must have shape (3, 1).
    `eye` is the (3, 3) identity matrix on the proper device.
    """
    chat_t_c = chat_R_w @ (w_t_c - w_t_chat)
#     print(f"in abc chatRW={chat_R_w.shape} and transpose={c_R_w.transpose(1,2).shape}")
    chat_R_c = chat_R_w @ c_R_w.transpose(1, 2)

    A = eye - chat_R_c
    C = c_n @ chat_t_c.transpose(1, 2)
    B = C @ A
    A = A @ A.transpose(1, 2)
    B = B + B.transpose(1, 2)
    C = C @ C.transpose(1, 2)

    return A, B, C


class LocalHomographyLoss(torch.nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()

        # `c_n` is the normal vector of the plane inducing the homographies in the ground-truth camera frame
        self.c_n = torch.tensor([0, 0, -1], dtype=torch.float32, device=device).view(3, 1)

        # `eye` is the (3, 3) identity matrix
        self.eye = torch.eye(3, device=device)

    def __call__(self, batch):
        A, B, C = compute_ABC(batch['w_t_c'], batch['c_R_w'], batch['w_t_chat'], batch['chat_R_w'], self.c_n, self.eye)

        xmin = batch['xmin'].view(-1, 1, 1)
        xmax = batch['xmax'].view(-1, 1, 1)
        B_weight = torch.log(xmax / xmin) / (xmax - xmin)
        C_weight = xmin * xmax

        error = A + B * B_weight + C / C_weight
        error = error.diagonal(dim1=1, dim2=2).sum(dim=1).mean()
        return error


In [4]:

train_dataset = datasets.RelocDataset(dataset.train_data)
test_dataset = datasets.RelocDataset(dataset.test_data)

trainset_loader = torch.utils.data.DataLoader(train_dataset, shuffle=False, num_workers=6, batch_size=1)
testset_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, num_workers=6, batch_size=1)

# load network
network = Network(torch.zeros((3)), False)
with_init=False
network = network.cuda()
network.train()


optimizer = torch.optim.Adam(network.parameters(),lr=0.000001)
iteration = 2

if with_init:
    writer_folder = 'with_init'
else:
    writer_folder = 'without_init'
    
    
writer = SummaryWriter(os.path.join('logs',os.path.basename(os.path.normpath('7-Scenes')),'fire',writer_folder))


if with_init:
        checkpoint_folder = f'our_checkpoints/{opt.dataset_name}/{opt.scene_name}_with_init'
        os.mkdir(checkpoint_folder)
else:
    checkpoint_folder = f"our_checkpoints/{'7-Scenes'}/{'fire'}_without_init"
    if os.path.isdir(checkpoint_folder):
        checkpoint_folder = checkpoint_folder+'_1'

    os.makedirs(checkpoint_folder, exist_ok=True)


In [5]:
len(trainset_loader)

2000

In [None]:

def train(network = network,trainset_loader=trainset_loader,testset_laoder=testset_loader,optimizer=optimizer, iteration=iteration, with_init=with_init, writer=writer,checkpoint_folder=checkpoint_folder):
    
    for epoch in range(iteration):
        network.train()
        print(f'epoch:{epoch}\n')
        
        running_loss = 0
        it = 0
        for data in trainset_loader:
            break
            it+=1
            with torch.no_grad():
                optimizer.zero_grad()

            focal_length = data['K'][0][0][0]
            file = data['image_file']
            image = data['image'].cuda()
#             start_time = time.time()
            wtc, crw = data['w_t_c'], data['c_R_w']
            
            # predict scene coordinates and neural guidance
            scene_coordinates = network(image)
            scene_coordinates_gradients = torch.zeros(scene_coordinates.size())
            gt_pose = reverse_tr(crw, wtc)[0]
            # print(f"shape pose={gt_pose.shape}")
            # print(f"xmin = {data['xmin']} shape {data['xmin'].shape}")
            # print(f"xmax = {data['xmax']} shape {data['xmax'].shape}")

            # pose from RGB
            loss = dsacstar.backward_rgb(
                scene_coordinates.cpu(),
                scene_coordinates_gradients,
                gt_pose, 
                64, 
                10,
                focal_length, 
                float(image.size(3) / 2), #principal point assumed in image center
                float(image.size(2) / 2),
                1.0,
                100.0,
                100,
                100,
                100,
                network.OUTPUT_SUBSAMPLE,
                random.randint(0,1000000), #used to initialize random number generator in cpp
                data['xmin'].item(),
                data['xmax'].item()
            )
    
            
            running_loss += loss
            torch.autograd.backward((scene_coordinates),(scene_coordinates_gradients.cuda()))
            optimizer.step()
            if it%10==0 and it!=0:
                writer.add_scalar('train loss',running_loss/it)
                
        
        writer.add_scalar('per_epoch_training_loss',running_loss/len(trainset_loader),epoch)
        

#         if epoch%opt.save_every==0:
#             checkpoint_path = os.path.join(checkpoint_folder,f'check_point_epoch_{epoch}.pt')
#             torch.save(
#                 {
#                     'epoch': epoch,
#                     'model_state_dict': network.state_dict(),
#                     'optimizer_state_dict': optimizer.state_dict(),

#                 }, checkpoint_path
#             )
       
        print(f"loss: {running_loss}")
        criterion = LocalHomographyLoss()
        #test
        network.eval()
        it = 0
        running_test_loss = 0
        with torch.no_grad():
            for data in testset_loader:
                it+=1
                focal_length = data['K'][0][0][0]
                file = data['image_file']
                image = data['image'].cuda()
                wtc, crw = data['w_t_c'], data['c_R_w']

                # predict scene coordinates and neural guidance
                scene_coordinates = network(image)
                gt_pose = reverse_tr(crw, wtc)[0]
                out_pose = torch.zeros((4,4))
                print('here')
                dsacstar.forward_rgb(
                    scene_coordinates,
                    out_pose,
                    64,
                    10,
                    focal_length,
                    float(image.size(3)/2),
                    float(image.size(2)/2),
                    100,
                    100,
                    network.OUTPUT_SUBSAMPLE
                )
                print('here 2')
                
                batch={}
                batch['w_t_c'] = data['w_t_c']
                batch['c_R_w'] = data['c_R_w']
                
                batch['w_t_chat'],batch['chat_R_w'] = tr(out_pose)
                batch['w_t_chat'] = batch['w_t_chat'].unsqueeze(0)
                batch['chat_R_w'] = batch['chat_R_w'].unsqueeze(0)
                batch['xmin'] = data['xmin']
                batch['xmax'] = data['xmax']
                
                print('here 3')
                loss = criterion(batch)
                running_test_loss+=loss.item()
                print(running_test_loss)

                
    

train(network = network, optimizer=optimizer, iteration=iteration)



epoch:0

loss: 0


In [36]:
!nvidia-smi


Mon May  8 17:18:52 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090         On | 00000000:1B:00.0 Off |                  N/A |
|  0%   40C    P8               25W / 350W|   2781MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                         