In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import os
from tqdm import tqdm
import timm  # for ViT

In [None]:
BATCH_NUM = 1
BATCH_SIZE = 32
NUM_EPOCHS = 20
LEARNING_RATE = 1e-4
IMG_SIZE = 224  # Resizing the images to 224x224 for ViT input
DATASET_DIR = f"dataset/batch{BATCH_NUM}/"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class PoseDataset(Dataset):
    def __init__(self, image_dir, label_csv, transform=None):
        self.image_dir = image_dir
        self.labels = pd.read_csv(label_csv)
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        row = self.labels.iloc[idx]
        img_path = os.path.join(self.image_dir, row['image_name'])
        image = Image.open(img_path).convert('RGB')
        label = torch.tensor(row[1:].values.astype('float32'))  # x, y, z, pitch, roll, yaw
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),  # Resizing the image to 224x224
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)  # Normalizing the image
])

In [None]:
def get_dataloader(split):
    image_dir = os.path.join(DATASET_DIR, split, 'images')
    label_csv = os.path.join(DATASET_DIR, split, 'labels.csv')
    dataset = PoseDataset(image_dir, label_csv, transform)
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=(split=='train'))

train_loader = get_dataloader('train')
val_loader = get_dataloader('val')

In [None]:
class ViT6DP(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.backbone.head = nn.Sequential(
            nn.Linear(self.backbone.head.in_features, 512),
            nn.ReLU(),
            nn.Linear(512, 6)  # Output: 3 for position (x, y, z), 3 for orientation (pitch, roll, yaw)
        )

    def forward(self, x):
        return self.backbone(x)

In [None]:
model = ViT6DP().to(DEVICE)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
def train():
    for epoch in range(NUM_EPOCHS):
        model.train()
        running_loss = 0.0
        for images, labels in tqdm(train_loader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {running_loss/len(train_loader):.4f}")

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        print(f"           Val Loss: {val_loss/len(val_loader):.4f}")

In [None]:
train()

In [None]:
torch.save(model.state_dict(), f"ViT6DP_batch{BATCH_NUM}.pth")

TESTING

In [None]:
test_loader = get_dataloader('test')

In [None]:
def test_model(model_path):
    model = ViT6DP().to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()

    total_loss = 0.0
    preds = []
    gts = []

    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            preds.extend(outputs.cpu().numpy())
            gts.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(test_loader)
    print(f"Test MSE Loss: {avg_loss:.4f}")

    return preds, gts

In [None]:
predictions, ground_truths = test_model(f"ViT6DP_batch{BATCH_NUM}.pth")

In [None]:
val_loader = get_dataloader('val')

In [None]:
def validate_model(model_path):
    model = ViT6DP().to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()

    total_loss = 0.0
    preds = []
    gts = []

    with torch.no_grad():
        for images, labels in tqdm(val_loader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            preds.extend(outputs.cpu().numpy())
            gts.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(val_loader)
    print(f"Validation MSE Loss: {avg_loss:.4f}")

    return preds, gts


In [None]:
val_predictions, val_ground_truths = validate_model(f"ViT6DP_batch{BATCH_NUM}.pth")