In [1]:
# Local utilities
from util import *
environment_check()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CUDA is available
Tensor on GPU: tensor([1., 2., 3.], device='cuda:0')

PyTorch3D is using CUDA


In [2]:
import torch
import torch.nn as nn
import torchvision.models as models

class PoseRefinementNetwork(nn.Module):
    def __init__(self):
        super(PoseRefinementNetwork, self).__init__()
        # Load a pretrained ResNet model
        self.feature_extractor = models.resnet18(pretrained=True)
        
        # Modify the first convolutional layer to accept 1-channel input
        self.feature_extractor.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
        # Remove the final fully connected layer
        self.feature_extractor.fc = nn.Identity()
        
        # Transformer encoder for processing the features
        # Adjust d_model to match the concatenated feature size (512 * 2 = 1024)
        encoder_layer = nn.TransformerEncoderLayer(d_model=1024, nhead=8)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        
        # Output layers for pose updates
        # Adjust input features to 1024 to match the output of the transformer encoder
        self.fc_translation = nn.Linear(1024, 3)  # Adjusted for translation update
        self.fc_rotation = nn.Linear(1024, 4)  # Adjusted for rotation update (quaternion representation)

    def forward(self, rendered_img, real_img_cropped):
        # Extract features from both images
        rendered_features = self.feature_extractor(rendered_img)
        real_features = self.feature_extractor(real_img_cropped)
        
        # Concatenate features and prepare for transformer
        combined_features = torch.cat((rendered_features, real_features), dim=1)
        combined_features = combined_features.unsqueeze(0)  # Add batch dimension for transformer
        
        # Pass through transformer encoder
        transformed_features = self.transformer_encoder(combined_features)
        
        # Predict translation and rotation updates
        translation_update = self.fc_translation(transformed_features.squeeze(0))
        #rotation_update = self.fc_rotation(transformed_features.squeeze(0))
        
        return translation_update, rotation_update
    

import torch
import torch.nn as nn
import torchvision.models as models

class PoseRefinementNetworkWithTransformer(nn.Module):
    def __init__(self):
        super(PoseRefinementNetworkWithTransformer, self).__init__()
        # Load a pretrained ResNet model
        self.feature_extractor = models.resnet18(pretrained=True)
        
        # Modify the first convolutional layer of the ResNet model to accept 1-channel input instead of the default 3
        self.feature_extractor.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
        # Remove the final fully connected layer to use ResNet as a feature extractor
        self.feature_extractor.fc = nn.Identity()
        
        # Reduce the combined feature size from 1024 (512*2) to 512 to match the transformer's expected input size
        self.feature_size_reducer = nn.Linear(1024, 512)
        
        # Define Transformer Encoder Layer with d_model=512
        encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        
        # Output layers for pose updates
        self.fc_translation = nn.Linear(512, 3)
        #self.fc_rotation = nn.Linear(512, 4)

    def forward(self, rendered_img, real_img_cropped):
        # Extract features from both images using the modified ResNet as the feature extractor
        rendered_features = self.feature_extractor(rendered_img).flatten(start_dim=1)
        real_features = self.feature_extractor(real_img_cropped).flatten(start_dim=1)
        
        # Concatenate features from both images
        combined_features = torch.cat((rendered_features, real_features), dim=1)
        
        # Reduce the combined feature size to match the transformer's expected input size
        reduced_features = self.feature_size_reducer(combined_features)
        
        # Add an extra dimension for the transformer
        reduced_features = reduced_features.unsqueeze(1)
        
        # Transformer encoder with residual connection
        # Adding the original reduced features to its transformed version
        transformed_features = self.transformer_encoder(reduced_features)
        # Ensure the original reduced_features is broadcastable to the transformed_features shape
        residual_connection = reduced_features + transformed_features
        
        # Remove the sequence dimension
        final_features = residual_connection.squeeze(1)
        
        # Predict translation and rotation updates
        translation_update = self.fc_translation(final_features)
        # rotation_update = self.fc_rotation(final_features)
        
        return translation_update  #rotation_update

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vit_b_16

class PoseRefinementNetworkWithViT(nn.Module):
    def __init__(self):
        super(PoseRefinementNetworkWithViT, self).__init__()
        
        # Initialize a Vision Transformer model with pretrained weights
        self.feature_extractor = vit_b_16(pretrained=True)
        
        # Remove the classifier head to use the ViT as a feature extractor
        self.feature_extractor.head = nn.Identity()
        
        # Adjust the input dimensions of the linear layers to match the actual feature size
        feature_size = 1000  # Adjusted based on diagnostic output
        
        # Output layers for pose updates
        self.fc_translation = nn.Linear(feature_size, 3)
        #self.fc_rotation = nn.Linear(feature_size, 4)

    def forward(self, rendered_img, real_img_cropped):
        # Process images and extract features as before
        rendered_img = self.prepare_image(rendered_img)
        real_img_cropped = self.prepare_image(real_img_cropped)
        
        rendered_features = self.feature_extractor(rendered_img)
        real_features = self.feature_extractor(real_img_cropped)
        
        # Now, correctly aggregate and predict updates based on adjusted feature size
        translation_update = (self.fc_translation(rendered_features) + self.fc_translation(real_features)) / 2
        #rotation_update = (self.fc_rotation(rendered_features) + self.fc_rotation(real_features)) / 2
        
        return translation_update#, rotation_update
    
    def prepare_image(self, img):
        # Ensure the image has 3 channels by repeating the single channel
        img_3ch = img.repeat(1, 3, 1, 1)
        # Resize the image to match the expected input size of the ViT model (224x224)
        img_resized = F.interpolate(img_3ch, size=(224, 224), mode='bilinear', align_corners=False)
        return img_resized

import torch
import torch.nn as nn
import torchvision.models as models

class PoseRefinementNetworkSimple(nn.Module):
    def __init__(self):
        super(PoseRefinementNetworkSimple, self).__init__()
        # Load a pretrained ResNet model
        self.feature_extractor = models.resnet18(pretrained=True)
        
        # Modify the first convolutional layer to accept 1-channel input
        self.feature_extractor.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
        # Remove the final fully connected layer
        self.feature_extractor.fc = nn.Identity()
        
        # Instead of using a transformer, directly concatenate the features and use a linear layer
        # Assuming that the feature size is 512 for each image from ResNet18 and we concatenate them
        self.fc_combined = nn.Linear(512 * 2, 512)  # Combined feature layer
        
        # Output layers for pose updates, directly from the combined features
        self.fc_translation = nn.Linear(512, 3)  # Adjusted for translation update
        self.fc_rotation = nn.Linear(512, 4)  # Adjusted for rotation update (quaternion representation)

    def forward(self, rendered_img, real_img_cropped):
        # Extract features from both images
        rendered_features = self.feature_extractor(rendered_img).flatten(start_dim=1)
        real_features = self.feature_extractor(real_img_cropped).flatten(start_dim=1)
        
        # Concatenate features
        combined_features = torch.cat((rendered_features, real_features), dim=1)
        
        # Pass concatenated features through a combined feature layer
        combined_features = self.fc_combined(combined_features)
        
        # Predict translation and rotation updates
        translation_update = self.fc_translation(combined_features)
        rotation_update = self.fc_rotation(combined_features)
        
        return translation_update, rotation_update
    
import os
import json
import PIL.Image as Image
import torch 

from torchvision import transforms

class PoseRefinementDataset(torch.utils.data.Dataset):
    def __init__(self, data_json_filepath):
        with open(data_json_filepath) as f:
            self.data_json = json.load(f)
        
        # Define a transform to convert PIL images to tensors
        self.transform = transforms.Compose([
            transforms.ToTensor(),  # Converts PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
        ])

    def __len__(self):
        return len(self.data_json)

    def __getitem__(self, idx):
        entry = self.data_json[idx]

        # Load rendered image as grayscale
        rendered_img_path = entry['silhouette_path']  # Directly use the full path
        rendered_img = Image.open(rendered_img_path).convert('L')

        # Convert PIL images to tensors
        rendered_img = self.transform(rendered_img)

        # Ground truth pose 
        ground_truth_rt_delta = torch.tensor(entry['RT']) 

        return rendered_img, ground_truth_rt_delta

    
dataset = PoseRefinementDataset('./pose_refine_dataset/dataset_info.json')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)   
 
model = PoseRefinementNetworkWithViT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 

from torchvision import transforms
# Define the transformation pipeline
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts PIL Image or numpy.ndarray to tensor.
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1] for each channel.
])


 
reference_img_path = "./pose_refine_dataset/silhouettes/first_silhouette_0.png"

# Load reference image (grayscle)
reference_img = Image.open(reference_img_path).convert('L') 
reference_img = reference_img
# Convert the reference image to a tensor and add a batch dimension
reference_img_tensor = transform(reference_img).unsqueeze(0)  # Now shape is [1, 1, H, W]

print(f'Reference image tensor shape: {reference_img.size}')

import kornia.geometry.conversions as conversions
def rotation_matrix_to_quaternion(R):
    """
    Convert a rotation matrix to a quaternion.
    Assumes R is a batch of 3x3 rotation matrices.
    
    Parameters:
    R (torch.Tensor): The input tensor containing batches of 3x3 rotation matrices.

    Returns:
    torch.Tensor: The output tensor containing the corresponding quaternions.
    """
    # Use kornia's function to convert rotation matrix to quaternion
    quaternion = conversions.rotation_matrix_to_quaternion(R)
    
    return quaternion

 




Reference image tensor shape: (256, 256)


In [3]:
# Adjust the path as necessary
#model_path = "./pose_refine_example/pose_refine_translate.pth"

# Load the model weights
#state_dict = torch.load(model_path, map_location=device)

# Apply the weights to your model instance
#model.load_state_dict(state_dict)


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

def safe_acos(x):
    # Clamp input to the range [-1 + epsilon, 1 - epsilon]
    eps = 1e-7
    x = torch.clamp(x, -1 + eps, 1 - eps)
    return torch.acos(x)

def loss_fn(pred_translation,  gt_translation):
    # Translation loss
    translation_loss = nn.MSELoss()(pred_translation, gt_translation)
    
    # Ensure the predicted rotation quaternion is normalized
    # pred_rotation = F.normalize(pred_rotation, p=2, dim=-1)
    
    # Rotation loss using quaternion angular distance approximation
    # dot_product = (pred_rotation * gt_rotation).sum(dim=-1)
    # dot_product = torch.clamp(dot_product, -1.0 + 1e-7, 1.0 - 1e-7)
    # angular_distance = 2 * torch.acos(torch.abs(dot_product))
    # rotation_loss = angular_distance.mean()

    # Combine losses
    total_loss = translation_loss  # rotation_loss
    return total_loss



from torchvision import transforms
# Define the transformation pipeline
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts PIL Image or numpy.ndarray to tensor.
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1] for each channel.
])

num_epochs = 10000

for epoch in range(num_epochs):
    for i, (rendered_img, ground_truth_rt_delta) in enumerate(dataloader):
        rendered_img = rendered_img.to(device)
        ground_truth_rt_delta = ground_truth_rt_delta.to(device)
        
        batch_size = rendered_img.shape[0]
        reference_img_batched = reference_img_tensor.repeat(batch_size, 1, 1, 1)
        
        translation_update = model(rendered_img.to(device), reference_img_batched.to(device))

        # Assuming your model outputs and ground_truth_rt_delta are structured correctly
        translation_vector = ground_truth_rt_delta[:, :3, 3]
        # Example conversion, replace with your actual conversion function
         #quaternion = rotation_matrix_to_quaternion(ground_truth_rt_delta[:, :3, :3])

        loss = loss_fn(translation_update, translation_vector)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f'Epoch {epoch+1}, Batch {i}, Loss: {loss.item():.4f}')

# After training completes
model_save_path = './pose_refine_example/pose_refine_translate.pth'
torch.save(model.state_dict(), model_save_path)

Epoch 1, Batch 0, Loss: 0.4129
Epoch 2, Batch 0, Loss: 2.0261
Epoch 3, Batch 0, Loss: 1.8677
Epoch 4, Batch 0, Loss: 1.4348
Epoch 5, Batch 0, Loss: 0.3224
Epoch 6, Batch 0, Loss: 1.0744
Epoch 7, Batch 0, Loss: 0.5163
Epoch 8, Batch 0, Loss: 0.9598
Epoch 9, Batch 0, Loss: 0.8229
Epoch 10, Batch 0, Loss: 3.5334
Epoch 11, Batch 0, Loss: 0.7338
Epoch 12, Batch 0, Loss: 0.4635
Epoch 13, Batch 0, Loss: 0.4346
Epoch 14, Batch 0, Loss: 0.6101
Epoch 15, Batch 0, Loss: 1.1971
Epoch 16, Batch 0, Loss: 0.5585
Epoch 17, Batch 0, Loss: 0.3604
Epoch 18, Batch 0, Loss: 0.5580
Epoch 19, Batch 0, Loss: 0.6496
Epoch 20, Batch 0, Loss: 0.3987
Epoch 21, Batch 0, Loss: 0.5975
Epoch 22, Batch 0, Loss: 0.6741
Epoch 23, Batch 0, Loss: 0.3838
Epoch 24, Batch 0, Loss: 1.0974
Epoch 25, Batch 0, Loss: 0.9105
Epoch 26, Batch 0, Loss: 0.6273
Epoch 27, Batch 0, Loss: 0.7929
Epoch 28, Batch 0, Loss: 0.8985
Epoch 29, Batch 0, Loss: 0.7687
Epoch 30, Batch 0, Loss: 0.7529
Epoch 31, Batch 0, Loss: 0.7260
Epoch 32, Batch 0

KeyboardInterrupt: 