In [29]:

# VIT6D.ipynb refactored for Minty (Linux Mint)
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import os
from tqdm import tqdm
import timm
import time

In [45]:

# Constants and environment setup
BATCH_NUM = 3
BATCH_SIZE = 32
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4
IMG_SIZE = 224  # Required input size for ViT

# Paths assume Linux-style forward slashes
BASE_DIR = os.path.expanduser("~/SKRIPSI/SCRIPTS")  # Refactor path to be absolute
DATASET_DIR = os.path.join(BASE_DIR, f"dataset/batch{BATCH_NUM}")
MODEL_SAVE_PATH = os.path.join(BASE_DIR, f"ViT6DP_batch{BATCH_NUM}.pth")

# Use CUDA if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [46]:
class PoseDataset(Dataset):
    def __init__(self, image_dir, label_csv, transform=None):
        self.image_dir = image_dir
        self.labels = pd.read_csv(label_csv)
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        row = self.labels.iloc[idx]
        img_path = os.path.join(self.image_dir, row['image_name'])
        image = Image.open(img_path).convert('RGB')
        label = torch.tensor(row[1:].values.astype('float32'))  # x, y, z, pitch, roll, yaw
        if self.transform:
            image = self.transform(image)
        return image, label

In [47]:
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * 3, [0.5] * 3)
])

In [48]:
def get_dataloader(split):
    image_dir = os.path.join(DATASET_DIR, split, 'images')
    label_csv = os.path.join(DATASET_DIR, split, 'labels.csv')
    dataset = PoseDataset(image_dir, label_csv, transform)
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=(split == 'train'))

train_loader = get_dataloader('train')
val_loader = get_dataloader('val')

In [49]:
class ViT6DP(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.backbone.head = nn.Sequential(
            nn.Linear(self.backbone.head.in_features, 512),
            nn.ReLU(),
            nn.Linear(512, 6)
        )

    def forward(self, x):
        return self.backbone(x)

In [50]:
def normalize_pose(pose):
    translation = pose[:, :3]                     # [x, y, z]
    rotation = pose[:, 3:] / 360.0                # Normalize rotation degrees to [0, 1] divide by 1 rotation / 2π / 360deg
    return torch.cat([translation, rotation], dim=1)

In [51]:
def compute_rmse(pred, target):
    pred_trans, pred_rot = pred[:, :3], pred[:, 3:]
    target_trans, target_rot = target[:, :3], target[:, 3:]

    trans_rmse = torch.sqrt(nn.MSELoss()(pred_trans, target_trans))
    rot_rmse = torch.sqrt(nn.MSELoss()(pred_rot / 360.0, target_rot / 360.0))

    return trans_rmse.item(), rot_rmse.item()

In [52]:
def combined_loss(pred, target, alpha=1.0, beta=1.0):
    pred_trans, pred_rot = pred[:, :3], pred[:, 3:]
    target_trans, target_rot = target[:, :3], target[:, 3:]

    # Optionally normalize rotation
    pred_rot = pred_rot / 360.0
    target_rot = target_rot / 360.0

    trans_loss = nn.MSELoss()(pred_trans, target_trans)
    rot_loss = nn.MSELoss()(pred_rot, target_rot)
    
    return alpha * trans_loss + beta * rot_loss


In [53]:
model = ViT6DP().to(DEVICE)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
def train(validate=True):
    now = []
    now.append(time.time())
    for epoch in range(NUM_EPOCHS):
        print("\n\n")
        model.train()
        running_loss = 0.0
        for images, labels in tqdm(train_loader):
            
            images, labels = images.to(DEVICE), normalize_pose(labels.to(DEVICE))
            outputs = normalize_pose(model(images))

            loss = combined_loss(outputs, labels, alpha=1.0, beta=1.0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        
        now.append(time.time())
        avg_train_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {avg_train_loss:.4f}")
        print(f"Time took per epoch : {epoch+1}: {int(now[epoch+1]-now[epoch])}s")
        if validate:
            # Validation
            model.eval()
            val_loss = 0.0
            total_trans_rmse, total_rot_rmse = 0.0, 0.0

            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(DEVICE), normalize_pose(labels.to(DEVICE))
                    outputs = normalize_pose(model(images))
                    
                    loss = combined_loss(outputs, labels)
                    val_loss += loss.item()

                    trans_rmse, rot_rmse = compute_rmse(outputs, labels)
                    total_trans_rmse += trans_rmse
                    total_rot_rmse += rot_rmse

            avg_val_loss = val_loss / len(val_loader)
            print(f"Val Loss: {avg_val_loss:.4f}")
            print(f"RMSE - Translation: {total_trans_rmse / len(val_loader):.4f}, "
                  f"Rotation: {total_rot_rmse / len(val_loader):.4f}")
        else:
            print("Skipping validation for this epoch.")

In [56]:
train(validate=True)

100%|██████████| 30/30 [00:19<00:00,  1.56it/s]


Epoch 1/20, Train Loss: 0.0031
Time took per epoch : 1: 19s

           Val Loss: 0.0027
           RMSE - Translation: 0.0522, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.56it/s]


Epoch 2/20, Train Loss: 0.0028
Time took per epoch : 2: 20s

           Val Loss: 0.0041
           RMSE - Translation: 0.0641, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.56it/s]


Epoch 3/20, Train Loss: 0.0025
Time took per epoch : 3: 20s

           Val Loss: 0.0019
           RMSE - Translation: 0.0436, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.55it/s]


Epoch 4/20, Train Loss: 0.0021
Time took per epoch : 4: 20s

           Val Loss: 0.0021
           RMSE - Translation: 0.0459, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.56it/s]


Epoch 5/20, Train Loss: 0.0017
Time took per epoch : 5: 20s

           Val Loss: 0.0016
           RMSE - Translation: 0.0398, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.55it/s]


Epoch 6/20, Train Loss: 0.0015
Time took per epoch : 6: 20s

           Val Loss: 0.0015
           RMSE - Translation: 0.0391, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.56it/s]


Epoch 7/20, Train Loss: 0.0015
Time took per epoch : 7: 20s

           Val Loss: 0.0014
           RMSE - Translation: 0.0375, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.56it/s]


Epoch 8/20, Train Loss: 0.0013
Time took per epoch : 8: 20s

           Val Loss: 0.0015
           RMSE - Translation: 0.0384, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.55it/s]


Epoch 9/20, Train Loss: 0.0012
Time took per epoch : 9: 20s

           Val Loss: 0.0014
           RMSE - Translation: 0.0371, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.55it/s]


Epoch 10/20, Train Loss: 0.0011
Time took per epoch : 10: 21s

           Val Loss: 0.0017
           RMSE - Translation: 0.0415, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.55it/s]


Epoch 11/20, Train Loss: 0.0012
Time took per epoch : 11: 21s

           Val Loss: 0.0013
           RMSE - Translation: 0.0356, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.54it/s]


Epoch 12/20, Train Loss: 0.0009
Time took per epoch : 12: 21s

           Val Loss: 0.0011
           RMSE - Translation: 0.0329, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.55it/s]


Epoch 13/20, Train Loss: 0.0009
Time took per epoch : 13: 20s

           Val Loss: 0.0017
           RMSE - Translation: 0.0413, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.55it/s]


Epoch 14/20, Train Loss: 0.0012
Time took per epoch : 14: 20s

           Val Loss: 0.0011
           RMSE - Translation: 0.0337, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.55it/s]


Epoch 15/20, Train Loss: 0.0007
Time took per epoch : 15: 21s

           Val Loss: 0.0015
           RMSE - Translation: 0.0381, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.55it/s]


Epoch 16/20, Train Loss: 0.0007
Time took per epoch : 16: 21s

           Val Loss: 0.0012
           RMSE - Translation: 0.0351, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.55it/s]


Epoch 17/20, Train Loss: 0.0006
Time took per epoch : 17: 21s

           Val Loss: 0.0012
           RMSE - Translation: 0.0339, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.55it/s]


Epoch 18/20, Train Loss: 0.0005
Time took per epoch : 18: 21s

           Val Loss: 0.0015
           RMSE - Translation: 0.0381, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.54it/s]


Epoch 19/20, Train Loss: 0.0005
Time took per epoch : 19: 21s

           Val Loss: 0.0013
           RMSE - Translation: 0.0353, Rotation: 0.0000


100%|██████████| 30/30 [00:19<00:00,  1.55it/s]


Epoch 20/20, Train Loss: 0.0004
Time took per epoch : 20: 21s

           Val Loss: 0.0012
           RMSE - Translation: 0.0345, Rotation: 0.0000


In [57]:
torch.save(model.state_dict(), MODEL_SAVE_PATH)

In [58]:
test_loader = get_dataloader('test')

In [59]:
def test_model(model_path):
    model = ViT6DP().to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()

    total_loss = 0.0
    total_trans_rmse, total_rot_rmse = 0.0, 0.0
    preds = []
    gts = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images, labels = images.to(DEVICE), normalize_pose(labels.to(DEVICE))
            outputs = normalize_pose(model(images))

            loss = combined_loss(outputs, labels)
            total_loss += loss.item()

            trans_rmse, rot_rmse = compute_rmse(outputs, labels)
            total_trans_rmse += trans_rmse
            total_rot_rmse += rot_rmse

            preds.extend(outputs.cpu().numpy())
            gts.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(test_loader)
    print(f"Test Loss: {avg_loss:.4f}")
    print(f"Test RMSE - Translation: {total_trans_rmse / len(test_loader):.4f}, "
          f"Rotation: {total_rot_rmse / len(test_loader):.4f}")

    return preds, gts


In [60]:
predictions, ground_truths = test_model(MODEL_SAVE_PATH)

100%|██████████| 9/9 [00:03<00:00,  2.69it/s]


Test Loss: 0.0010
Test RMSE - Translation: 0.0316, Rotation: 0.0000


In [61]:
def validate_model(model_path):
    model = ViT6DP().to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()

    total_loss = 0.0
    total_trans_rmse, total_rot_rmse = 0.0, 0.0
    preds = []
    gts = []

    with torch.no_grad():
        for images, labels in tqdm(val_loader):
            images, labels = images.to(DEVICE), normalize_pose(labels.to(DEVICE))
            outputs = normalize_pose(model(images))

            loss = combined_loss(outputs, labels)
            total_loss += loss.item()

            trans_rmse, rot_rmse = compute_rmse(outputs, labels)
            total_trans_rmse += trans_rmse
            total_rot_rmse += rot_rmse

            preds.extend(outputs.cpu().numpy())
            gts.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(val_loader)
    print(f"Validation Loss: {avg_loss:.4f}")
    print(f"Validation RMSE - Translation: {total_trans_rmse / len(val_loader):.4f}, "
          f"Rotation: {total_rot_rmse / len(val_loader):.4f}")

    return preds, gts


In [62]:
val_predictions, val_ground_truths = validate_model(MODEL_SAVE_PATH)

100%|██████████| 5/5 [00:01<00:00,  3.02it/s]

Validation Loss: 0.0012
Validation RMSE - Translation: 0.0345, Rotation: 0.0000



