In [2]:
# installing dependencies
import torchvision.models as models
import torch.nn as nn
from timm import create_model
import torch
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from PIL import Image
import cv2
from pathlib import Path
import matplotlib.pyplot as plt

In [3]:
# define transformer model's parts being the backbone, transformer encoder, decoder

class ViTBackbone(nn.Module):
    def __init__(self, model_name="vit_base_patch16_224"):
        super().__init__()
        # Use a timm model with global_pool disabled
        self.vit = create_model(model_name, pretrained=True, num_classes=0, global_pool='')
        self.patch_size = 16  # each token corresponds to a 16x16 patch
        self.grid_size = 224 // self.patch_size  # e.g., 14

    def forward(self, x):
        # Use forward_features to obtain the full token sequence
        x = self.vit.forward_features(x)
        # Remove the class token (first token)
        x = x[:, 1:, :]  # now x has shape (B, 196, 768)
        B, N, C = x.shape
        # Reshape tokens into a spatial grid: (B, C, grid_size, grid_size)
        x = x.permute(0, 2, 1).reshape(B, C, self.grid_size, self.grid_size)
        return x


# Transformer Encoder
class Transformer(nn.Module):
    # feature_dim is the embed_dim sent in from ViT
    # we can also change the number of self-attention heads (num_heads)
    # and layers of transformer layers (num_layers)
    def __init__(self, feature_dim, num_heads=8, num_layers=6):
        super().__init__()
        encoder_layers = TransformerEncoderLayer(d_model=feature_dim, nhead=num_heads) # creates a single encoder layer
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers=num_layers)

    def forward(self, x):
        B, C, H, W = x.shape  # expecting (batch, feature_dim, 14, 14) from ViT, where feature_dim will be 768 from given embed_dim
        x = x.flatten(2).permute(2, 0, 1)  # convert (batch, C, H, W) → (seq_len, batch, feature_dim) since that's the order of inputs that TransformerEncoder takes
        x = self.transformer_encoder(x)
        x = x.permute(1, 2, 0).reshape(B, C, H, W)  # convert back to (batch, C, H, W) so that it remains spatially structured
        return x

# Heatmap Decoder (Upsamples to 224x224)
class HeatmapDecoder(nn.Module):
    def __init__(self, feature_dim, num_joints):
        super().__init__()
        self.upsample1 = nn.ConvTranspose2d(feature_dim, 256, kernel_size=4, stride=2, padding=1)  # upsamples from 14x14 to 28x28
        self.upsample2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)  # same idea, doubles each dimension (now 56)
        self.upsample3 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1) # (224)
        self.upsample4 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1) # 224
        self.conv_map = nn.Conv2d(32, num_joints, kernel_size=1)  # 1x1 conv to Output 224x224 heatmaps for each joints (from num_joints)

    def forward(self, x):
        x = nn.ReLU()(self.upsample1(x)) # relu to each upsample
        x = nn.ReLU()(self.upsample2(x))
        x = nn.ReLU()(self.upsample3(x))
        x = nn.ReLU()(self.upsample4(x))
        x = self.conv_map(x)
        return x # heatmaps

# defining the full model with all its parts from above
backbone = ViTBackbone()
transformer = Transformer(feature_dim=768)
decoder = HeatmapDecoder(feature_dim=768, num_joints=13)

class closedPoseTransformer(nn.Module):
    def __init__(self, backbone, transformer, decoder):
        super().__init__()
        self.backbone = backbone
        self.transformer = transformer
        self.decoder = decoder
    def forward(self, x):
        features = self.backbone(x) # (1) extracting features using the ViT
        transformed_features = self.transformer(features) # (2) encoder refines features
        heatmaps = self.decoder(transformed_features) # (3) upsampled into heatmaps
        heatmaps = torch.sigmoid(heatmaps)
        return heatmaps

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [4]:
def generate_heatmaps_batch(keypoints, H=224, W=224, sigma=4):
    """
    Vectorized generation of Gaussian heatmaps for a batch of keypoints.
    keypoints: tensor of shape [B, num_joints, 2] in pixel space
    Returns: tensor of shape [B, num_joints, H, W]
    If a keypoint is (0,0) (assumed missing/invisible), its heatmap will be all zeros.
    """
    B, J, _ = keypoints.shape
    device = keypoints.device
    # Create coordinate grid once.
    x_lin = torch.linspace(0, W - 1, W, device=device)
    y_lin = torch.linspace(0, H - 1, H, device=device)
    y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij")  # shape: [H, W]
    # Expand grid to match batch & joints.
    # Final shapes: [B, J, H, W]
    x_grid = x_grid.unsqueeze(0).unsqueeze(0)  # shape [1,1,H,W]
    y_grid = y_grid.unsqueeze(0).unsqueeze(0)
    
    # Expand keypoints from shape [B, J, 2] to [B, J, 1, 1]
    kp_exp = keypoints.unsqueeze(-1).unsqueeze(-1)
    kp_x = kp_exp[:, :, 0, :, :]  # shape [B, J, 1, 1]
    kp_y = kp_exp[:, :, 1, :, :]
    
    # Compute squared distance from the keypoint location
    dist_sq = (x_grid - kp_x)**2 + (y_grid - kp_y)**2
    # Compute Gaussian heatmaps
    heatmaps = torch.exp(-dist_sq / (2 * sigma**2))
    
    # If a keypoint is (0,0) (assuming missing) then set its heatmap to zero.
    mask_missing = (keypoints.abs().sum(dim=-1) == 0).unsqueeze(-1).unsqueeze(-1)
    heatmaps = heatmaps * (1 - mask_missing.float())
    return heatmaps



# Training and Evaluation Functions
class KeypointDataset(Dataset):
    def __init__(self, images_dir, labels_dir, transform=None, max_images=None):
        self.images_dir = Path(images_dir)
        self.labels_dir = Path(labels_dir)
        self.image_paths = [p for p in self.images_dir.glob("*.jpg") if (self.labels_dir / f"{p.stem}.txt").exists()]
        if max_images is not None:
            self.image_paths = self.image_paths[:max_images]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        img = cv2.imread(str(image_path))
        if img is None:
            raise RuntimeError(f"Failed to load image: {image_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            img = Image.fromarray(img)
            image_tensor = self.transform(img)
        
        label_path = self.labels_dir / (image_path.stem + ".txt")
        with open(label_path, 'r') as f:
            line = f.readline().strip()
        parts = line.split()
        kp_values = parts[1:]
        kp_array = np.array(kp_values, dtype=float).reshape(-1, 3)[:, 0:2]
        kp_array *= 224
        label_tensor = torch.tensor(kp_array, dtype=torch.float32)
        return image_tensor, label_tensor

# Focal Loss
def weighted_mse_loss(pred, target, alpha=55.0, beta=5.0, threshold=0.1):
    """
    Compute a weighted mean squared error loss.
    Pixels with ground truth value greater than 'threshold' are considered foreground and weighted by alpha.
    Background pixels are weighted by beta.
    """
    # Create a weight tensor matching the target shape
    weights = torch.where(target > threshold,
                          torch.tensor(alpha, device=target.device),
                          torch.tensor(beta, device=target.device))
    loss = weights * (pred - target) ** 2
    #loss = (pred-target) ** 2
    return loss.mean()

# Training Function
def train(model, train_loader, val_loader, learning_rate=0.00001, num_epochs=150, device='cuda'):
    torch.manual_seed(420)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
    
    H, W = 224, 224
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_train_pck = 0.0
        train_count = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)  # labels: [B, 13, 2]
            optimizer.zero_grad()
            
            heatmap_outputs = model(images)
            heatmap_outputs = F.interpolate(heatmap_outputs, size=(H, W), mode='bilinear', align_corners=False)
            # Batch generate ground truth heatmaps (shape: [B, 13, H, W])
            gt_heatmaps = generate_heatmaps_batch(labels, H=H, W=W, sigma=4)
            
            loss = weighted_mse_loss(heatmap_outputs, gt_heatmaps)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() 
            
            # Compute training PCK on this batch
            batch_pck = compute_pck(heatmap_outputs, labels)
            bsize = images.size(0)
            running_train_pck += batch_pck * bsize
            train_count += bsize
        
        avg_train_loss = running_loss / len(train_loader)
        avg_train_pck = running_train_pck / train_count
        
        model.eval()
        val_running_loss = 0.0
        running_val_pck = 0.0
        val_count = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                heatmap_outputs = model(images)
                heatmap_outputs = F.interpolate(heatmap_outputs, size=(H, W), mode='bilinear', align_corners=False)
                gt_heatmaps = generate_heatmaps_batch(labels, H=H, W=W, sigma=4)
                loss = weighted_mse_loss(heatmap_outputs, gt_heatmaps)
                val_running_loss += loss.item()
                
                batch_pck = compute_pck(heatmap_outputs, labels)
                bsize = images.size(0)
                running_val_pck += batch_pck * bsize
                val_count += bsize
                
        avg_val_loss = val_running_loss / len(val_loader)
        avg_val_pck = running_val_pck / val_count
        scheduler.step(avg_val_loss)
        
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {avg_train_loss:.4f}, Train PCK: {avg_train_pck*100:.2f}% | "
              f"Val Loss: {avg_val_loss:.4f}, Val PCK: {avg_val_pck*100:.2f}%")
        
        if (epoch + 1) % 10 == 0:
            with torch.no_grad():
                images, labels = next(iter(val_loader))
                images = images.to(device)
                pred_heatmaps = model(images)
                pred_heatmaps = F.interpolate(pred_heatmaps, size=(H, W), mode='bilinear', align_corners=False)
                gt_heatmaps = generate_heatmaps_batch(labels.to(device), H=H, W=W, sigma=4)
                plot_heatmaps(images, gt_heatmaps, pred_heatmaps)
    
    return model

def evaluate(model, data_loader, device='cuda', threshold=20):
    model.eval()
    total_loss = 0.0
    total_pck = 0.0
    count = 0
    H, W = 224, 224
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            outputs = F.interpolate(outputs, size=(H, W), mode='bilinear', align_corners=False)
            gt_heatmaps = generate_heatmaps_batch(labels, H=H, W=W, sigma=4)
            loss = weighted_mse_loss(outputs, gt_heatmaps)
            total_loss += loss.item() * inputs.size(0)
            batch_pck = compute_pck(outputs, labels, threshold)
            total_pck += batch_pck * inputs.size(0)
            count += inputs.size(0)
    avg_loss = total_loss / count
    avg_pck = total_pck / count
    print(f'Validation Loss: {avg_loss:.4f}, PCK: {avg_pck*100:.2f}%')
    return avg_loss, avg_pck

# PCK Evaluation
def compute_pck(outputs, labels, threshold=20):
    """
    Compute PCK only for keypoints whose labels are not [0, 0].
    outputs: predicted heatmaps used with soft_argmax to get coordinates (shape: [B, 13, H, W])
    labels: ground truth keypoints (shape: [B, 13, 2])
    threshold: distance threshold (in pixels)
    """
    coords = hard_argmax(outputs)  # shape: [B, 13, 2]
    batch_size = coords.shape[0]
    labels = labels.view(batch_size, -1, 2)  # [B, 13, 2]
    
    # Create a mask for valid keypoints (i.e., not both coordinates equal to 0)
    valid_mask = ~((labels == 0).all(dim=2))  # shape: [B, 13], True if keypoint is valid
    
    # Compute Euclidean distances per keypoint
    distances = torch.norm(coords - labels, dim=2)  # shape: [B, 13]
    
    # For each keypoint, if distance < threshold, mark it correct (only if valid)
    correct = ((distances < threshold).float() * valid_mask.float())  # shape: [B, 13]
    
    # For each sample, count number of valid keypoints (avoid division by zero)
    valid_counts = valid_mask.sum(dim=1).float()  # shape: [B]
    
    per_sample_pck = []
    for i in range(batch_size):
        if valid_counts[i] > 0:
            pck = correct[i].sum() / valid_counts[i]
        else:
            pck = torch.tensor(0.0, device=correct.device)
        per_sample_pck.append(pck)
    per_sample_pck = torch.stack(per_sample_pck)
    return per_sample_pck.mean().item()

import numpy as np
def hard_argmax(heatmaps):
    B, J, H, W = heatmaps.shape
    # Flatten each heatmap to shape [B, J, H*W]
    flat = heatmaps.view(B, J, -1)
    # Get indices of maximum value
    indices = flat.argmax(dim=-1)  # shape [B, J]
    # Compute x and y coordinates
    x = indices % W
    y = indices // W
    coords = torch.stack((x.float(), y.float()), dim=-1)
    return coords
def plot_heatmaps(images, gt_heatmaps, pred_heatmaps, num_samples=2):
    for i in range(min(num_samples, images.size(0))):
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 3, 1)
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        plt.imshow(img)
        plt.title("Input Image")
        plt.subplot(1, 3, 2)
        plt.imshow(gt_heatmaps[i, 0].cpu().numpy())
        plt.title("Ground Truth Heatmap")
        plt.subplot(1, 3, 3)
        plt.imshow(pred_heatmaps[i, 0].cpu().numpy())
        plt.title("Predicted Heatmap")
        plt.show()

In [6]:
np.random.seed(420)
torch.manual_seed(420)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == 'cuda':
    torch.cuda.manual_seed_all(420)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = KeypointDataset('../datasets/train_subset_single/standardized_images',
                            '../datasets/train_subset_single/labels',
                            transform=transform,
                            max_images=2000)

train_len = int(0.8 * len(dataset))
val_len = int(0.1 * len(dataset))
test_len = len(dataset) - train_len - val_len
train_set, val_set, test_set = random_split(dataset, [train_len, val_len, test_len])

train_loader = DataLoader(train_set, batch_size=30, shuffle=True)
val_loader = DataLoader(val_set, batch_size=30, shuffle=False)
test_loader = DataLoader(test_set, batch_size=30, shuffle=False)
backbone = ViTBackbone()
transformer = Transformer(feature_dim=768)
decoder = HeatmapDecoder(feature_dim=768, num_joints=13)
model = closedPoseTransformer(backbone=backbone,transformer=transformer,decoder=decoder).cuda()
model = model.to(device)
model = train(model, train_loader, val_loader, learning_rate=0.005, num_epochs=10, device=device)
evaluate(model, test_loader, device=device)



KeyboardInterrupt: 