In [1]:
# Local utilities
from util import *
environment_check()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CUDA is available
Tensor on GPU: tensor([1., 2., 3.], device='cuda:0')

PyTorch3D is using CUDA


In [24]:
import torch
import json
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torch.optim import Adam
import logging

# Set up logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

class PoseRefinementDataset(torch.utils.data.Dataset):
    def __init__(self, data_json_filepath):
        with open(data_json_filepath, 'r') as f:
            self.data_json = json.load(f)
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        logging.info("Dataset loaded with {} samples".format(len(self.data_json)))

    def __len__(self):
        return len(self.data_json)

    def __getitem__(self, idx):
        entry = self.data_json[idx]
        image = Image.open(entry['silhouette_path']).convert('L').convert('RGB')
        image = self.transform(image)
        rt_matrix = torch.tensor(entry['RT'], dtype=torch.float32)
        if rt_matrix.numel() == 16:
            rt_matrix = rt_matrix.view(4, 4)[:-1]  # Assuming the last row is [0, 0, 0, 1] and can be discarded
        else:
            rt_matrix = rt_matrix.view(3, 4)  # Normal case
        logging.debug(f"Image and RT matrix loaded for index {idx}")
        return image, rt_matrix

class ViTImageToRTNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
        self.vit.heads = nn.Identity()  # Remove the classifier head
        self.rotation_regressor = nn.Linear(768, 9)  # For 3x3 rotation matrix
        self.translation_regressor = nn.Linear(768, 3)  # For 3x1 translation vector
        logging.info("ViT model initialized with removed classifier head and regressors added")

    def forward(self, x):
        features = self.vit(x)
        cls_token = features
        rotation = self.rotation_regressor(cls_token).view(-1, 3, 3)
        translation = self.translation_regressor(cls_token).view(-1, 3, 1)
        rt_matrix = torch.cat((rotation, translation), dim=2)
        logging.debug("Model forward pass completed")
        return rt_matrix

def train_model(dataset, epochs=999, batch_size=32, learning_rate=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    model = ViTImageToRTNetwork().to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for images, rt_matrices in loader:
            images, rt_matrices = images.to(device), rt_matrices.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, rt_matrices)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        logging.info(f"Epoch {epoch + 1}: Avg Loss = {total_loss / len(loader)}")

# Example usage
dataset = PoseRefinementDataset("./pose_refine_dataset/dataset_info.json")
train_model(dataset, epochs=999, batch_size=4)


2024-04-13 02:13:19,049 - INFO - Dataset loaded with 12 samples
2024-04-13 02:13:19,713 - INFO - ViT model initialized with removed classifier head and regressors added
2024-04-13 02:13:20,427 - INFO - Epoch 1: Avg Loss = 1.4270548621813457
2024-04-13 02:13:20,668 - INFO - Epoch 2: Avg Loss = 0.8098077178001404
2024-04-13 02:13:20,902 - INFO - Epoch 3: Avg Loss = 0.45564375321070355
2024-04-13 02:13:21,146 - INFO - Epoch 4: Avg Loss = 0.43488171696662903
2024-04-13 02:13:21,385 - INFO - Epoch 5: Avg Loss = 0.35984830061594647
2024-04-13 02:13:21,621 - INFO - Epoch 6: Avg Loss = 0.38802435000737506
2024-04-13 02:13:21,861 - INFO - Epoch 7: Avg Loss = 0.3601399064064026
2024-04-13 02:13:22,098 - INFO - Epoch 8: Avg Loss = 0.37292688091595966
2024-04-13 02:13:22,331 - INFO - Epoch 9: Avg Loss = 0.339798907438914
2024-04-13 02:13:22,566 - INFO - Epoch 10: Avg Loss = 0.3555435339609782
2024-04-13 02:13:22,801 - INFO - Epoch 11: Avg Loss = 0.34649600585301715
2024-04-13 02:13:23,043 - INFO -

KeyboardInterrupt: 