<a href="https://colab.research.google.com/github/MohdDilshad-nitk/pytorch/blob/main/skeleton_data_and_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
import pandas as pd
import os
from pathlib import Path
from torch.utils.data import Dataset, DataLoader



In [None]:
!unzip /content/Data.zip

Archive:  /content/Data.zip
replace Data/Person001/1.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: Data/Person001/1.txt    
  inflating: Data/Person001/2.txt    
  inflating: Data/Person001/3.txt    
  inflating: Data/Person001/4.txt    
  inflating: Data/Person001/5.txt    
  inflating: Data/Person002/1.txt    
  inflating: Data/Person002/2.txt    
  inflating: Data/Person002/3.txt    
  inflating: Data/Person002/4.txt    
  inflating: Data/Person003/1.txt    
  inflating: Data/Person003/2.txt    
  inflating: Data/Person003/3.txt    
  inflating: Data/Person003/4.txt    
  inflating: Data/Person003/5.txt    
  inflating: Data/Person003/6.txt    
  inflating: Data/Person004/1.txt    
  inflating: Data/Person004/2.txt    
  inflating: Data/Person004/3.txt    
  inflating: Data/Person004/4.txt    
  inflating: Data/Person004/5.txt    
  inflating: Data/Person005/1.txt    
  inflating: Data/Person005/2.txt    
  inflating: Data/Person005/3.txt    
  inflating: Data/Person005/

In [None]:
class SkeletonDataProcessor:
    def __init__(self, data_dir):
        """
        Initialize the skeleton data processor

        Args:
            data_dir (str): Path to the root directory containing person folders
        """
        self.data_dir = Path(data_dir)
        self.person_dirs = sorted([d for d in self.data_dir.iterdir() if d.is_dir()])
        self.num_joints = 20

    def read_skeleton_file(self, file_path):
        """
        Read and parse a single skeleton file

        Args:
            file_path (Path): Path to the skeleton file

        Returns:
            np.ndarray: Array of shape [time_steps, num_joints, 3]
        """
        with open(file_path, 'r') as file:
            lines = file.read().strip().split('\n')

        # Count frames
        num_frames = len(lines) // self.num_joints

        # Initialize array for this sequence
        sequence = np.zeros((num_frames, self.num_joints, 3))

        # Process each line
        for i, line in enumerate(lines):
            parts = line.split(';')

            frame_idx = i // self.num_joints
            joint_idx = i % self.num_joints

            # Store x, y, z coordinates
            sequence[frame_idx, joint_idx, 0] = float(parts[1])  # x
            sequence[frame_idx, joint_idx, 1] = float(parts[2])  # y
            sequence[frame_idx, joint_idx, 2] = float(parts[3])  # z

        return sequence

    def normalize_sequence(self, sequence):
        """
        Normalize a sequence using hip center as origin

        Args:
            sequence (np.ndarray): Array of shape [time_steps, num_joints, 3]

        Returns:
            np.ndarray: Normalized sequence
        """
        # Find hip center index (assuming it's consistent across all sequences)
        hip_center_idx = 8  # Update this based on your joint order

        # Normalize each frame
        normalized_sequence = sequence.copy()
        for frame_idx in range(sequence.shape[0]):
            # Get hip center coordinates for current frame
            hip_center = sequence[frame_idx, hip_center_idx]

            # Translate all joints relative to hip center
            normalized_sequence[frame_idx] -= hip_center

        # Scale to [0, 1] range
        min_vals = normalized_sequence.min(axis=(0, 1), keepdims=True)
        max_vals = normalized_sequence.max(axis=(0, 1), keepdims=True)
        normalized_sequence = (normalized_sequence - min_vals) / (max_vals - min_vals + 1e-7)

        return normalized_sequence

    def process_all_data(self):
        """
        Process all skeleton data in the dataset

        Returns:
            list: List of tuples (person_id, sequence_id, normalized_sequence)
        """
        all_sequences = []

        for person_dir in self.person_dirs:
            person_id = person_dir.name

            # Get all txt files for this person
            skeleton_files = sorted(person_dir.glob('*.txt'))

            for file_path in skeleton_files:
                sequence_id = file_path.stem  # Get filename without extension

                # Read and normalize sequence
                sequence = self.read_skeleton_file(file_path)
                normalized_sequence = self.normalize_sequence(sequence)

                # Store with metadata
                all_sequences.append((person_id, sequence_id, normalized_sequence))

        return all_sequences



In [None]:
class SkeletonDataset(Dataset):
    def __init__(self, sequences, max_len=None):
        """
        Create a PyTorch dataset from processed sequences

        Args:
            sequences (list): List of tuples (person_id, sequence_id, normalized_sequence)
            max_len (int, optional): Maximum sequence length for padding
        """
        self.sequences = sequences

        # Find max sequence length if not provided
        if max_len is None:
            max_len = max(seq[2].shape[0] for seq in sequences)
        self.max_len = max_len

        # Create person and sequence ID mappings
        self.person_ids = sorted(list(set(seq[0] for seq in sequences)))
        self.person_to_idx = {pid: i for i, pid in enumerate(self.person_ids)}

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        person_id, sequence_id, sequence = self.sequences[idx]

        # Convert to tensor
        sequence_tensor = torch.FloatTensor(sequence)

        # Pad sequence if necessary
        if sequence_tensor.size(0) < self.max_len:
            padding = torch.zeros(self.max_len - sequence_tensor.size(0),
                                sequence_tensor.size(1),
                                sequence_tensor.size(2))
            sequence_tensor = torch.cat([sequence_tensor, padding], dim=0)

        # Create attention mask (1 for real data, 0 for padding)
        attention_mask = torch.ones(self.max_len)
        attention_mask[sequence.shape[0]:] = 0

        return {
            'person_id': self.person_to_idx[person_id],
            'sequence_id': sequence_id,
            'sequence': sequence_tensor,
            'attention_mask': attention_mask
        }



In [None]:
# def create_data_loaders(data_dir, batch_size=32, train_split=0.8, val_split=0.1):
#     """
#     Create data loaders for train, validation, and test sets

#     Args:
#         data_dir (str): Path to data directory
#         batch_size (int): Batch size for data loaders
#         train_split (float): Proportion of data for training
#         val_split (float): Proportion of data for validation

#     Returns:
#         tuple: Train, validation, and test data loaders
#     """
#     # Process all data
#     processor = SkeletonDataProcessor(data_dir)
#     all_sequences = processor.process_all_data()

#     # Shuffle sequences
#     np.random.shuffle(all_sequences)

#     # Split data
#     n_sequences = len(all_sequences)
#     n_train = int(n_sequences * train_split)
#     n_val = int(n_sequences * val_split)

#     train_sequences = all_sequences[:n_train]
#     val_sequences = all_sequences[n_train:n_train + n_val]
#     test_sequences = all_sequences[n_train + n_val:]

#     # Create datasets
#     train_dataset = SkeletonDataset(train_sequences)
#     val_dataset = SkeletonDataset(val_sequences, max_len=train_dataset.max_len)
#     test_dataset = SkeletonDataset(test_sequences, max_len=train_dataset.max_len)

#     # Create data loaders
#     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#     val_loader = DataLoader(val_dataset, batch_size=batch_size)
#     test_loader = DataLoader(test_dataset, batch_size=batch_size)

#     return train_loader, val_loader, test_loader

def _get_distribution_stats(sequences):
    """Helper function to calculate distribution statistics"""
    person_counts = {}
    for person_id, _, _ in sequences:
        person_counts[person_id] = person_counts.get(person_id, 0) + 1

    if not person_counts:
        return (0, 0, 0)

    counts = list(person_counts.values())
    return (min(counts), sum(counts)/len(counts), max(counts))

def create_data_loaders(data_dir, batch_size=32, train_split=0.8, val_split=0.1):
    """
    Create data loaders for train, validation, and test sets with stratified splitting
    ensuring each person has representation in the training set.

    Args:
        data_dir (str): Path to data directory
        batch_size (int): Batch size for data loaders
        train_split (float): Proportion of data for training
        val_split (float): Proportion of data for validation

    Returns:
        tuple: Train, validation, and test data loaders
    """
    # Process all data
    processor = SkeletonDataProcessor(data_dir)
    all_sequences = processor.process_all_data()

    # Group sequences by person_id
    person_sequences = {}
    for person_id, seq_id, normalized_seq in all_sequences:
        if person_id not in person_sequences:
            person_sequences[person_id] = []
        person_sequences[person_id].append((person_id, seq_id, normalized_seq))

    train_sequences = []
    val_sequences = []
    test_sequences = []

    # For each person, split their sequences
    for person_id, sequences in person_sequences.items():
        n_sequences = len(sequences)

        # Ensure at least one sequence in training
        n_train = max(1, int(n_sequences * train_split))
        n_val = int(n_sequences * val_split)

        # Shuffle sequences for this person
        np.random.shuffle(sequences)

        # Split sequences for this person
        train_sequences.extend(sequences[:n_train])
        val_sequences.extend(sequences[n_train:n_train + n_val])
        test_sequences.extend(sequences[n_train + n_val:])

    # Shuffle the final sets
    np.random.shuffle(train_sequences)
    np.random.shuffle(val_sequences)
    np.random.shuffle(test_sequences)

    # Create datasets
    train_dataset = SkeletonDataset(train_sequences)
    val_dataset = SkeletonDataset(val_sequences, max_len=train_dataset.max_len)
    test_dataset = SkeletonDataset(test_sequences, max_len=train_dataset.max_len)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # Print distribution statistics
    print("\nData distribution statistics:")
    print(f"Total number of people: {len(person_sequences)}")
    print(f"Training samples per person (min/avg/max): {_get_distribution_stats(train_sequences)}")
    print(f"Total samples - Train: {len(train_sequences)}, Val: {len(val_sequences)}, Test: {len(test_sequences)}")

    return train_loader, val_loader, test_loader



In [None]:
# Example usage
# if __name__ == "__main__":
# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Directory containing the dataset
data_dir = "/content/Data"
sb = ''
try:
    # Create data loaders
    train_loader, val_loader, test_loader = create_data_loaders(data_dir)

    print(f"Number of training batches: {len(train_loader)}")
    print(f"Number of validation batches: {len(val_loader)}")
    print(f"Number of test batches: {len(test_loader)}")

    # Get a sample batch
    sample_batch = next(iter(train_loader))
    # print(sample_batch)
    print("\nSample batch contents:")
    for key, value in sample_batch.items():
        if torch.is_tensor(value):
            print(f"{key} shape: {value.shape}")
        else:
            print(f"{key}: {value}")

    # Save dataset statistics
    stats = {
        'num_training_sequences': len(train_loader.dataset),
        'num_validation_sequences': len(val_loader.dataset),
        'num_test_sequences': len(test_loader.dataset),
        'max_sequence_length': train_loader.dataset.max_len,
        'num_joints': 20,
        'num_persons': len(train_loader.dataset.person_ids)
    }

    pd.DataFrame([stats]).to_csv('dataset_statistics.csv', index=False)

except Exception as e:
    print(f"Error processing dataset: {str(e)}")



Data distribution statistics:
Total number of people: 164
Training samples per person (min/avg/max): (2, 3.9695121951219514, 4)
Total samples - Train: 651, Val: 0, Test: 171
Number of training batches: 21
Number of validation batches: 0
Number of test batches: 6

Sample batch contents:
person_id shape: torch.Size([32])
sequence_id: ['4', '4', '2', '4', '4', '2', '2', '4', '1', '4', '4', '2', '5', '5', '2', '4', '4', '2', '1', '4', '2', '3', '1', '3', '5', '5', '2', '5', '2', '2', '3', '1']
sequence shape: torch.Size([32, 1711, 20, 3])
attention_mask shape: torch.Size([32, 1711])


In [None]:
# # prompt: print a datapoint

# # Access a single data point from the training dataset
# data_point = train_dataset[0]

# # Print the contents of the data point
# print(data_point)


In [None]:
# ... (previous imports remain the same)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import numpy as np
from typing import Dict, Tuple
# import wandb  # for logging
import random
import os
from datetime import datetime

class TemporalAugmenter:
    def __init__(
        self,
        crop_ratio_range=(0.8, 0.9),
        mask_ratio_range=(0.1, 0.2),
        min_sequence_length=16
    ):
        """
        Initialize temporal augmentation parameters

        Args:
            crop_ratio_range (tuple): Range for random crop ratio
            mask_ratio_range (tuple): Range for random masking ratio
            min_sequence_length (int): Minimum sequence length after cropping
        """
        self.crop_ratio_range = crop_ratio_range
        self.mask_ratio_range = mask_ratio_range
        self.min_sequence_length = min_sequence_length

    def random_temporal_crop(
        self,
        sequence: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply random temporal cropping to sequence"""
        seq_length = sequence.size(1)

        # Determine crop size
        min_ratio, max_ratio = self.crop_ratio_range
        crop_ratio = random.uniform(min_ratio, max_ratio)
        crop_size = max(int(seq_length * crop_ratio), self.min_sequence_length)

        # Random start point
        max_start = seq_length - crop_size
        start_idx = random.randint(0, max_start)
        end_idx = start_idx + crop_size

        # Apply crop
        cropped_sequence = sequence[:, start_idx:end_idx]

        if attention_mask is not None:
            cropped_mask = attention_mask[:, start_idx:end_idx]
            return cropped_sequence, cropped_mask

        return cropped_sequence, None

    def random_temporal_mask(
        self,
        sequence: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply random temporal masking to sequence"""
        seq_length = sequence.size(1)

        # Determine number of segments to mask
        min_ratio, max_ratio = self.mask_ratio_range
        mask_ratio = random.uniform(min_ratio, max_ratio)
        num_masks = max(1, int(seq_length * mask_ratio))

        # Create copy of sequence for masking
        masked_sequence = sequence.clone()
        if attention_mask is not None:
            new_attention_mask = attention_mask.clone()

        # Apply random masks
        for _ in range(num_masks):
            # Random mask length between 1 and 5 frames
            mask_length = random.randint(1, min(5, seq_length // 10))
            start_idx = random.randint(0, seq_length - mask_length)
            end_idx = start_idx + mask_length

            # Apply mask (set to zeros)
            masked_sequence[:, start_idx:end_idx] = 0

            if attention_mask is not None:
                new_attention_mask[:, start_idx:end_idx] = 0

        if attention_mask is not None:
            return masked_sequence, new_attention_mask

        return masked_sequence, None

# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.optim import Adam
# import numpy as np
# from typing import Dict, Tuple
# import wandb  # for logging

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        print(f"Input shape: {x.shape}")
        print(f"Positional Encoding shape: {self.pe[:, :x.size(1)].shape}")
        return x + self.pe[:, :x.size(1)]

# class SkeletonEmbedding(nn.Module):
#     def __init__(self, d_model: int):
#         super().__init__()
#         # self.joint_embedding = nn.Linear(3, d_model)  # 3 for x,y,z coordinates
#         self.joint_embedding = nn.Linear(60, d_model)  # 20*3 for x,y,z coordinates of 20 joints

#         # Add a linear layer to reduce the concatenated joints to d_model
#         # self.final_projection = nn.Linear(d_model * 20, d_model)

#     # def forward(self, x: torch.Tensor) -> torch.Tensor:
#     #     # x shape: [batch_size, seq_len, num_joints, 3]
#     #     print(f"skeleton_emb shape before join: {x.shape}")
#     #     batch_size, seq_len, num_joints, _ = x.shape

#     #     # Embed each joint
#     #     x = x.view(batch_size * seq_len, num_joints, 3)
#     #     print(f"skeleton_emb shape after view: {x.shape}")
#     #     x = self.joint_embedding(x)  # [batch_size * seq_len, num_joints, d_model]
#     #     print(f"skeleton_emb shape after embedding: {x.shape}")

#     #     # Reshape back to original shape
#     #     x = x.view(batch_size, seq_len, num_joints, -1)

#     #     # Combine joint embeddings
#     #     x = x.view(batch_size, seq_len, num_joints * x.size(-1))
#     #     print(f"skeleton_emb shape after join: {x.shape}")

#     #     # Project down to d_model dimension
#     #     # x = self.final_projection(x)  # [batch_size, seq_len, d_model]
#     #     return x

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         # x shape: [batch_size, seq_len, num_joints, 3]
#         print(f"skeleton_emb shape before join: {x.shape}")
#         print(x[0][1])
#         batch_size, seq_len, num_joints, coords = x.shape

#         # Combine all joints first
#         x = x.reshape(batch_size, seq_len, num_joints * coords)  # [batch_size, seq_len, 60]
#         print(f"skeleton_emb shape after combining joints: {x.shape}")
#         print(x[0][1])

#         # Project to d_model
#         x = self.joint_embedding(x)  # [batch_size, seq_len, d_model]
#         print(f"skeleton_emb shape after embedding: {x.shape}")
#         print(x[0][1])
#         return x




class SkeletonEmbedding(nn.Module):
    """New approach: combine joints first, then embed"""
    def __init__(self, d_model: int):
        super().__init__()
        print("d_model: ",d_model)
        self.joint_embedding = nn.Linear(60, d_model)  # 20 joints * 3 coordinates = 60

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: [batch_size, seq_len, num_joints, 3]
        batch_size, seq_len, num_joints, coords = x.shape

        # Combine all joints first
        x = x.reshape(batch_size, seq_len, num_joints * coords)  # [batch_size, seq_len, 60]

        # Project to d_model
        x = self.joint_embedding(x)  # [batch_size, seq_len, d_model]
        return x

# class SkeletonTransformerTrainer:
#     def __init__(
#         self,
#         model: SkeletonTransformer,
#         train_loader: torch.utils.data.DataLoader,
#         val_loader: torch.utils.data.DataLoader,
#         learning_rate: float = 1e-4,
#         weight_decay: float = 1e-4,
#         use_wandb: bool = True
#     ):
#         self.model = model
#         self.train_loader = train_loader
#         self.val_loader = val_loader
#         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#         self.model = self.model.to(self.device)
#         self.optimizer = Adam(
#             model.parameters(),
#             lr=learning_rate,
#             weight_decay=weight_decay
#         )

#         self.contrastive_loss = NTXentLoss()
#         self.classification_loss = nn.CrossEntropyLoss()
#         self.use_wandb = use_wandb

#     def train_epoch(self) -> Dict[str, float]:
#         self.model.train()
#         total_cont_loss = 0
#         total_cls_loss = 0
#         correct = 0
#         total = 0

#         for batch in self.train_loader:
#             # Get batch data
#             sequence = batch['sequence'].to(self.device)
#             attention_mask = batch['attention_mask'].to(self.device)
#             person_id = batch['person_id'].to(self.device)

#             # Create augmented views
#             aug1 = self.augment_sequence(sequence)
#             aug2 = self.augment_sequence(sequence)

#             # Forward pass
#             proj1, logits1 = self.model(aug1, attention_mask)
#             proj2, logits2 = self.model(aug2, attention_mask)

#             # Calculate losses
#             cont_loss = self.contrastive_loss(proj1, proj2)
#             cls_loss = self.classification_loss(logits1, person_id)

#             # Combined loss
#             loss = cont_loss + cls_loss

#             # Backward pass
#             self.optimizer.zero_grad()
#             loss.backward()
#             self.optimizer.step()

#             # Statistics
#             total_cont_loss += cont_loss.item()
#             total_cls_loss += cls_loss.item()

#             pred = logits1.argmax(dim=1)
#             correct += (pred == person_id).sum().item()
#             total += person_id.size(0)

#         return {
#             'train_cont_loss': total_cont_loss / len(self.train_loader),
#             'train_cls_loss': total_cls_loss / len(self.train_loader),
#             'train_accuracy': correct / total
#         }

#     @torch.no_grad()
#     def validate(self) -> Dict[str, float]:
#         self.model.eval()
#         total_loss = 0
#         correct = 0
#         total = 0

#         for batch in self.val_loader:
#             sequence = batch['sequence'].to(self.device)
#             attention_mask = batch['attention_mask'].to(self.device)
#             person_id = batch['person_id'].to(self.device)

#             _, logits = self.model(sequence, attention_mask)
#             loss = self.classification_loss(logits, person_id)

#             total_loss += loss.item()
#             pred = logits.argmax(dim=1)
#             correct += (pred == person_id).sum().item()
#             total += person_id.size(0)

#         return {
#             'val_loss': total_loss / len(self.val_loader),
#             'val_accuracy': correct / total
#         }

#     def augment_sequence(self, sequence: torch.Tensor) -> torch.Tensor:
#         """Apply random augmentations to skeleton sequence"""
#         # Example augmentations:
#         # 1. Random rotation
#         # 2. Random noise
#         # 3. Random temporal crop/mask
#         return sequence  # TODO: Implement actual augmentations

#     def train(self, num_epochs: int):
#         if self.use_wandb:
#             wandb.init(project='skeleton-transformer')

#         for epoch in range(num_epochs):
#             train_metrics = self.train_epoch()
#             val_metrics = self.validate()

#             metrics = {**train_metrics, **val_metrics}

#             if self.use_wandb:
#                 wandb.log(metrics)

#             print(f"Epoch {epoch+1}/{num_epochs}")
#             for k, v in metrics.items():
#                 print(f"{k}: {v:.4f}")
#             print()

# Example usage
# if __name__ == "__main__":
#     # Assuming you have your data loaders from the previous script
#     data_dir = "data"
#     train_loader, val_loader, test_loader = create_data_loaders(data_dir)

#     # Get number of persons (classes) from the dataset
#     num_classes = len(train_loader.dataset.person_ids)

#     # Create model
#     model = SkeletonTransformer(
#         num_joints=20,
#         d_model=256,
#         nhead=8,
#         num_encoder_layers=6,
#         num_classes=num_classes
#     )

#     # Create trainer
#     trainer = SkeletonTransformerTrainer(
#         model=model,
#         train_loader=train_loader,
#         val_loader=val_loader
#     )

#     # Train model
#     trainer.train(num_epochs=100)













In [None]:
class SkeletonTransformer(nn.Module):
    def __init__(
        self,
        num_joints: int,
        d_model: int = 256,
        nhead: int = 8,
        num_encoder_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        num_classes: int = None
    ):
        super().__init__()

        self.d_model = d_model
        # self.embedding = SkeletonEmbedding(d_model // num_joints)
        self.embedding = SkeletonEmbedding(d_model)
        # print(self.embedding.shape)
        self.pos_encoder = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )

        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_encoder_layers
        )

        # Projection head for contrastive learning
        self.projection = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 128)  # 128-dimensional projection
        )

        # Classification head
        if num_classes is not None:
            self.classifier = nn.Linear(d_model, num_classes)
        else:
            self.classifier = None

    def encode(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> torch.Tensor:
        # x shape: [batch_size, seq_len, num_joints, 3]
        x = self.embedding(x)  # [batch_size, seq_len, d_model]
        # print("embeddings shape: ", x.shape)
        x = self.pos_encoder(x)

        if attention_mask is not None:
            # Convert boolean mask to float attention mask
            attention_mask = attention_mask.float()
            attention_mask = attention_mask.masked_fill(
                attention_mask == 0,
                float('-inf')
            )

        encoded = self.transformer_encoder(x, src_key_padding_mask=attention_mask)
        # Use [CLS] token (first token) as sequence representation
        sequence_repr = encoded[:, 0]

        return sequence_repr

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        sequence_repr = self.encode(x, attention_mask)
        projection = self.projection(sequence_repr)

        if self.classifier is not None:
            logits = self.classifier(sequence_repr)
        else:
            logits = None

        return projection, logits


In [None]:
class NTXentLoss(nn.Module):
    def __init__(self, temperature: float = 0.07):
        super().__init__()
        self.temperature = temperature

    def forward(
        self,
        z_i: torch.Tensor,
        z_j: torch.Tensor
    ) -> torch.Tensor:
        batch_size = z_i.size(0)
        z_i = F.normalize(z_i, dim=1)
        z_j = F.normalize(z_j, dim=1)

        representations = torch.cat([z_i, z_j], dim=0)
        similarity_matrix = F.cosine_similarity(
            representations.unsqueeze(1),
            representations.unsqueeze(0),
            dim=2
        )

        sim_ij = torch.diag(similarity_matrix, batch_size)
        sim_ji = torch.diag(similarity_matrix, -batch_size)
        positives = torch.cat([sim_ij, sim_ji], dim=0)

        nominator = torch.exp(positives / self.temperature)
        denominator = torch.sum(
            torch.exp(similarity_matrix / self.temperature),
            dim=1
        )

        all_losses = -torch.log(nominator / denominator)
        loss = torch.sum(all_losses) / (2 * batch_size)
        return loss


In [None]:
class SkeletonTransformerTrainer:
    def __init__(
        self,
        model: SkeletonTransformer,
        train_loader: torch.utils.data.DataLoader,
        val_loader: torch.utils.data.DataLoader,
        learning_rate: float = 1e-4,
        weight_decay: float = 1e-4,
        # use_wandb: bool = False,
        save_dir: str = 'models'
    ):

        # ... (previous initialization code remains the same)
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.model = self.model.to(self.device)
        self.optimizer = Adam(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay
        )

        self.contrastive_loss = NTXentLoss()
        self.classification_loss = nn.CrossEntropyLoss()

        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)

        self.augmenter = TemporalAugmenter()

        # Initialize best metrics for model saving
        self.best_val_accuracy = 0.0
        self.best_epoch = 0

    def augment_sequence(
        self,
        sequence: torch.Tensor,
        attention_mask: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply temporal augmentations to sequence"""
        # First apply random crop
        sequence, attention_mask = self.augmenter.random_temporal_crop(
            sequence, attention_mask
        )

        # Then apply random masking
        sequence, attention_mask = self.augmenter.random_temporal_mask(
            sequence, attention_mask
        )

        return sequence, attention_mask

    def save_checkpoint(
        self,
        epoch: int,
        metrics: Dict[str, float],
        is_best: bool = False
    ):
        """Save model checkpoint"""
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'metrics': metrics,
            'best_val_accuracy': self.best_val_accuracy,
            'best_epoch': self.best_epoch
        }

        # Save regular checkpoint
        checkpoint_path = os.path.join(
            self.save_dir,
            f'checkpoint_epoch_{epoch}_{timestamp}.pt'
        )
        torch.save(checkpoint, checkpoint_path)

        # Save best model if applicable
        if is_best:
            best_path = os.path.join(self.save_dir, 'best_model.pt')
            torch.save(checkpoint, best_path)

        # Optionally save model architecture config
        if epoch == 0:
            config = {
                'd_model': self.model.d_model,
                'nhead': self.model.transformer_encoder.layers[0].nhead,
                'num_encoder_layers': len(self.model.transformer_encoder.layers),
                'dim_feedforward': self.model.transformer_encoder.layers[0].linear1.out_features,
            }
            config_path = os.path.join(self.save_dir, 'model_config.pt')
            torch.save(config, config_path)

    def load_checkpoint(self, checkpoint_path: str):
        """Load model checkpoint"""
        checkpoint = torch.load(checkpoint_path)

        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        self.best_val_accuracy = checkpoint['best_val_accuracy']
        self.best_epoch = checkpoint['best_epoch']

        return checkpoint['epoch']

    def train_epoch(self) -> Dict[str, float]:
        self.model.train()
        total_cont_loss = 0
        total_cls_loss = 0
        correct = 0
        total = 0

        for batch in self.train_loader:
            # Get batch data
            sequence = batch['sequence'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            person_id = batch['person_id'].to(self.device)

            # Create augmented views
            aug1, mask1 = self.augment_sequence(sequence, attention_mask)
            aug2, mask2 = self.augment_sequence(sequence, attention_mask)

            # Forward pass
            proj1, logits1 = self.model(aug1, mask1)
            proj2, logits2 = self.model(aug2, mask2)

            # Calculate losses
            cont_loss = self.contrastive_loss(proj1, proj2)
            cls_loss = self.classification_loss(logits1, person_id)

            # Combined loss
            loss = cont_loss + cls_loss

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Statistics
            total_cont_loss += cont_loss.item()
            total_cls_loss += cls_loss.item()

            pred = logits1.argmax(dim=1)
            correct += (pred == person_id).sum().item()
            total += person_id.size(0)

        return {
            'train_cont_loss': total_cont_loss / len(self.train_loader),
            'train_cls_loss': total_cls_loss / len(self.train_loader),
            'train_accuracy': correct / total
        }

    def train(self, num_epochs: int, resume_path: str = None):
        # if self.use_wandb:
        #     wandb.init(project='skeleton-transformer')

        start_epoch = 0
        if resume_path is not None:
            start_epoch = self.load_checkpoint(resume_path)
            print(f"Resumed training from epoch {start_epoch}")

        for epoch in range(start_epoch, num_epochs):
            train_metrics = self.train_epoch()
            val_metrics = self.validate()

            metrics = {**train_metrics, **val_metrics}

            if epoch % 10 == 0:
                # print(f"Epoch {epoch+1}/{num_epochs}")
                for k, v in metrics.items():
                    print(f"{k}: {v:.4f}")

            # Check if this is the best model
            is_best = False
            if val_metrics['val_accuracy'] > self.best_val_accuracy:
                self.best_val_accuracy = val_metrics['val_accuracy']
                self.best_epoch = epoch
                is_best = True

            # Save checkpoint
            self.save_checkpoint(epoch, metrics, is_best)

            # if self.use_wandb:
            #     wandb.log(metrics)

            print(f"Epoch {epoch+1}/{num_epochs}")
            for k, v in metrics.items():
                print(f"{k}: {v:.4f}")
            if is_best:
                print("New best model!")
            print()


In [None]:
# Example usage
# if __name__ == "__main__":
# Create model and trainer
model = SkeletonTransformer(
    num_joints=20,
    d_model=256,
    nhead=8,
    num_encoder_layers=6,
    num_classes=164
)

trainer = SkeletonTransformerTrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    save_dir='skeleton_transformer_models'
)



d_model:  256


In [None]:
# Train model
trainer.train(
    num_epochs=100,
    resume_path=None  # Set to checkpoint path to resume training
)

Input shape: torch.Size([32, 1482, 256])
Positional Encoding shape: torch.Size([1, 1482, 256])
