# Contact-based Object Classification

In [48]:
import pandas as pd
import numpy as np
from pathlib import Path
from typing import List, Tuple, Dict
import torch
from torch.utils.data import Dataset, DataLoader
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os
from torchsummary import summary
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import wandb
import torch.nn.functional as F

In [49]:
# Categories for classification
CATEGORIES = {
    'Blueball': 0,
    'Box': 1,
    'Pencilcase': 2,
    'Pinkball': 3,
    'StuffedAnimal': 4,
    'Tennis': 5,
    'Waterbottle': 6,
}

In [50]:
# Configuration Dictionary
config = {
    'batch_size': 32,
    'lr': 0.001,
    'epochs': 20,
    'data_dir': "/Users/benlee/Documents/college/CMU/Spring 2025/IDL/Project/IDL_code/IDL_Data",
    'checkpoint_dir': "/Users/benlee/Documents/college/CMU/Spring 2025/IDL/Project/IDL_code/checkpoint",
}

In [51]:
class ContactWindowDataset(Dataset):
    def __init__(self, data_dir: str, labels: Dict[str, int] = None, window_size: int = 50, step_size: int = 10):
        """
        Args:
            data_dir (str): Directory containing the .txt files
            labels (Dict[str, int]): Dictionary mapping categories to class labels
            window_size (int): Size of the sliding window (smaller = more samples)
            step_size (int): Step size for the sliding window (smaller = more samples)
        """
        self.data_dir = Path(data_dir)
        self.file_paths = list(self.data_dir.glob("*.txt"))
        self.labels = labels or {}
        self.window_size = window_size
        self.step_size = step_size
        
        # Lists to store all windows and their labels
        self.features_list = []
        self.labels_list = []
        self.file_indices = []  # To keep track of which file each window came from
        
        # Process all files
        print("Processing files and extracting windows:")
        for file_idx, file_path in enumerate(self.file_paths):
            # Get category from filename
            category = re.sub(r"\d+", "", file_path.stem)
            label = self.labels.get(category, -1)
            
            # Load and process the file
            df = self._parse_file(file_path)
            
            # Create windows
            windows_from_file = 0
            for start_idx in range(0, len(df) - window_size + 1, step_size):
                window = df.iloc[start_idx:start_idx + window_size]
                
                # Extract features from window
                features = self._extract_features(window)
                
                self.features_list.append(features)
                self.labels_list.append(label)
                self.file_indices.append(file_idx)
                windows_from_file += 1
            
            print(f"  {file_path.name}: {windows_from_file} windows generated from {len(df)} datapoints")
        
        # Convert lists to tensors for efficiency
        self.features = torch.FloatTensor(self.features_list)
        self.labels = torch.LongTensor(self.labels_list)
        self.file_indices = torch.LongTensor(self.file_indices)
        
        # Print summary statistics
        self._print_dataset_stats()
        
    def _parse_file(self, file_path: Path) -> pd.DataFrame:
        """Parse a single data file"""
        df = pd.read_csv(file_path, header=None, skiprows=1)
        columns = [
            'timestamp_pc', 'timestamp_micro',
            'x', 'y', 'angle_1', 'angle_2',
            'contact_1_left', 'contact_1_right',
            'contact_2_left', 'contact_2_right'
        ]
        df = pd.DataFrame(df.values, columns=columns)
        return df
    
    def _extract_features(self, window: pd.DataFrame) -> np.ndarray:
        """Extract features from a window of data"""
        # Basic statistical features
        features = np.array([
            window['contact_1_left'].mean(),
            window['contact_1_right'].mean(),
            window['contact_2_left'].mean(),
            window['contact_2_right'].mean(),
            window['x'].max() - window['x'].min(),
            window['y'].max() - window['y'].min(),
            window['angle_1'].std(),
            window['angle_2'].std(),
            # Additional features for more information
            window['contact_1_left'].std(),
            window['contact_1_right'].std(),
            window['contact_2_left'].std(),
            window['contact_2_right'].std(),
            window['angle_1'].mean(),
            window['angle_2'].mean()
        ])
        return features
    
    def _print_dataset_stats(self):
        """Print statistics about the dataset"""
        total_windows = len(self.features)
        unique_files = len(torch.unique(self.file_indices))
        
        # Count samples per category
        category_counts = {}
        for label in self.labels_list:
            category_name = list(CATEGORIES.keys())[list(CATEGORIES.values()).index(label)]
            category_counts[category_name] = category_counts.get(category_name, 0) + 1
        
        print("\nDataset Statistics:")
        print(f"Total number of windows: {total_windows}")
        print(f"Total number of files: {unique_files}")
        print(f"Average windows per file: {total_windows / unique_files:.2f}")
        print("\nSamples per category:")
        for category, count in category_counts.items():
            print(f"  {category}: {count} windows")
        
    def __len__(self) -> int:
        return len(self.features)
        
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        return self.features[idx], self.labels[idx]

In [52]:
def split_dataset_by_files_robust(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_state=42):
    """
    Split the dataset by files while trying to ensure all classes are represented in each split.
    This function handles edge cases where some classes have very few examples.
    
    Parameters:
    -----------
    dataset : ContactWindowDataset
        The dataset to split
    train_ratio : float
        Ratio of data for training
    val_ratio : float
        Ratio of data for validation
    test_ratio : float
        Ratio of data for testing
    random_state : int
        Random seed for reproducibility
    
    Returns:
    --------
    tuple
        (train_indices, val_indices, test_indices)
    """
    import numpy as np
    import random
    
    # Set random seed for reproducibility
    random.seed(random_state)
    np.random.seed(random_state)
    
    # Verify ratios sum to 1
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-10, "Ratios must sum to 1"
    
    # Get unique files and their labels
    unique_file_indices = torch.unique(dataset.file_indices).numpy()
    
    # Create mapping from file index to most common label in that file
    file_to_label = {}
    for file_idx in unique_file_indices:
        # Get all windows from this file
        mask = dataset.file_indices == file_idx
        window_indices = torch.where(mask)[0]
        
        # Get the labels for these windows
        labels = dataset.labels[window_indices]
        
        # Use the most common label for this file
        unique_labels, counts = torch.unique(labels, return_counts=True)
        most_common_label = unique_labels[torch.argmax(counts)].item()
        file_to_label[file_idx] = most_common_label
    
    # Group files by label
    label_to_files = {}
    for file_idx, label in file_to_label.items():
        if label not in label_to_files:
            label_to_files[label] = []
        label_to_files[label].append(file_idx)
    
    # Print how many files per label
    print("\nFiles per label:")
    for label, files in label_to_files.items():
        class_name = list(CATEGORIES.keys())[list(CATEGORIES.values()).index(label)]
        print(f"  {class_name}: {len(files)} files")
    
    # Manual split to ensure all classes are in all splits
    train_files = []
    val_files = []
    test_files = []
    
    for label, files in label_to_files.items():
        # Shuffle files for this label
        random.shuffle(files)
        
        # Calculate how many files should go to each split
        num_files = len(files)
        
        if num_files >= 3:
            # If we have enough files, use the ratios
            num_train = max(1, int(num_files * train_ratio))
            num_val = max(1, int(num_files * val_ratio))
            num_test = max(1, num_files - num_train - num_val)
            
            # If we don't have enough files for all splits, prioritize train
            if num_train + num_val + num_test > num_files:
                # If we have 2 files, put one in train and one in test
                if num_files == 2:
                    num_train, num_val, num_test = 1, 0, 1
                else:
                    num_train = max(1, num_files - 2)
                    num_val = 1
                    num_test = 1
            
            # Split files
            train_files.extend(files[:num_train])
            val_files.extend(files[num_train:num_train+num_val])
            test_files.extend(files[num_train+num_val:])
        
        elif num_files == 2:
            # For classes with only 2 files, put one in train and one in val
            train_files.append(files[0])
            val_files.append(files[1])
            # We'll handle test set separately
        
        elif num_files == 1:
            # For classes with only 1 file, duplicate windows
            # Put the file in the train set
            train_files.append(files[0])
            print(f"Warning: Class {label} has only one file, putting it in train set")
    
    # Now map back to window indices
    train_indices = []
    val_indices = []
    test_indices = []
    
    for idx in range(len(dataset)):
        file_idx = dataset.file_indices[idx].item()
        if file_idx in train_files:
            train_indices.append(idx)
        elif file_idx in val_files:
            val_indices.append(idx)
        elif file_idx in test_files:
            test_indices.append(idx)
    
    # Special handling for classes with only one or two files
    # For these classes, create synthetic test examples
    for label, files in label_to_files.items():
        if len(files) <= 2:
            class_indices = []
            for idx in range(len(dataset)):
                if dataset.labels[idx].item() == label:
                    class_indices.append(idx)
            
            if len(files) == 2 and len(test_indices) == 0:
                # If we have 2 files but none in test, move some val windows to test
                class_val_indices = [idx for idx in class_indices if idx in val_indices]
                if class_val_indices:
                    # Move half of val windows to test
                    half_point = len(class_val_indices) // 2
                    for idx in class_val_indices[half_point:]:
                        val_indices.remove(idx)
                        test_indices.append(idx)
            
            elif len(files) == 1:
                # If we have only 1 file, duplicate some train windows to val and test
                class_train_indices = [idx for idx in class_indices if idx in train_indices]
                
                # Select random indices to duplicate (without removing from train)
                random.shuffle(class_train_indices)
                third_point = max(1, len(class_train_indices) // 3)
                
                # Add to val and test
                val_indices.extend(class_train_indices[:third_point])
                test_indices.extend(class_train_indices[third_point:2*third_point])
    
    # Print class distribution in each split
    print("\nClass distribution:")
    for split_name, indices in [
        ("Train", train_indices), 
        ("Validation", val_indices), 
        ("Test", test_indices)
    ]:
        class_counts = {}
        for idx in indices:
            label = dataset.labels[idx].item()
            class_name = list(CATEGORIES.keys())[list(CATEGORIES.values()).index(label)]
            class_counts[class_name] = class_counts.get(class_name, 0) + 1
        
        print(f"\n{split_name} set:")
        for class_name, count in sorted(class_counts.items()):
            print(f"  {class_name}: {count} windows")
    
    return train_indices, val_indices, test_indices

In [53]:
# Create the dataset with smaller window and step size to maximize samples
window_size = 100   # Smaller window size generates more samples
step_size = 10     # Smaller step size creates more overlapping windows
dataset = ContactWindowDataset(
    data_dir=config["data_dir"], 
    labels=CATEGORIES, 
    window_size=window_size, 
    step_size=step_size
)

Processing files and extracting windows:
  Pencilcase4.txt: 482 windows generated from 4916 datapoints
  Pinkball3.txt: 467 windows generated from 4760 datapoints
  Pinkball2.txt: 460 windows generated from 4695 datapoints
  Pencilcase5.txt: 531 windows generated from 5404 datapoints
  Pinkball1.txt: 481 windows generated from 4901 datapoints
  Pencilcase2.txt: 615 windows generated from 6249 datapoints
  Pinkball5.txt: 473 windows generated from 4828 datapoints
  Pinkball4.txt: 456 windows generated from 4658 datapoints
  Pencilcase3.txt: 602 windows generated from 6119 datapoints
  Pencilcase1.txt: 504 windows generated from 5135 datapoints
  Pinkball6.txt: 458 windows generated from 4679 datapoints
  Box4.txt: 483 windows generated from 4926 datapoints
  Box5.txt: 530 windows generated from 5399 datapoints
  Box1.txt: 509 windows generated from 5187 datapoints
  Box2.txt: 467 windows generated from 4766 datapoints
  Box3.txt: 484 windows generated from 4939 datapoints
  Blueball6.tx

In [54]:
# Split the dataset by files to prevent data leakage
train_indices, val_indices, test_indices = split_dataset_by_files_robust(dataset)  # Use _robust instead of _stratified

# Create subset datasets
from torch.utils.data import Subset
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'])
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'])

# Print split information
print("\nDataset Splits:")
print(f"Training set size: {len(train_dataset)} windows")
print(f"Validation set size: {len(val_dataset)} windows")
print(f"Test set size: {len(test_dataset)} windows")

# Example of accessing a batch
for features, labels in train_loader:
    print(f"\nSample batch:")
    print(f"Batch features shape: {features.shape}")
    print(f"Batch labels shape: {labels.shape}")
    break


Files per label:
  Pencilcase: 5 files
  Pinkball: 6 files
  Box: 5 files
  Blueball: 6 files
  Waterbottle: 5 files
  StuffedAnimal: 5 files
  Tennis: 5 files

Class distribution:

Train set:
  Blueball: 1846 windows
  Box: 1506 windows
  Pencilcase: 1748 windows
  Pinkball: 1877 windows
  StuffedAnimal: 1481 windows
  Tennis: 1459 windows
  Waterbottle: 1439 windows

Validation set:
  Blueball: 469 windows
  Box: 483 windows
  Pencilcase: 504 windows
  Pinkball: 458 windows
  StuffedAnimal: 491 windows
  Tennis: 631 windows
  Waterbottle: 538 windows

Test set:
  Blueball: 470 windows
  Box: 484 windows
  Pencilcase: 482 windows
  Pinkball: 460 windows
  StuffedAnimal: 352 windows
  Tennis: 524 windows
  Waterbottle: 467 windows

Dataset Splits:
Training set size: 11356 windows
Validation set size: 3574 windows
Test set size: 3239 windows

Sample batch:
Batch features shape: torch.Size([32, 14])
Batch labels shape: torch.Size([32])


In [None]:
class Temporal1DBlock(nn.Module):
    """Residual block for 1D time series data"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
        super(Temporal1DBlock, self).__init__()
        # Calculate padding to maintain temporal dimension
        padding = dilation * (kernel_size - 1) // 2
        
        # First convolutional layer
        self.conv1 = nn.Conv1d(in_channels, out_channels, 
                              kernel_size=kernel_size, 
                              stride=stride, 
                              padding=padding,
                              dilation=dilation)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        # Second convolutional layer
        self.conv2 = nn.Conv1d(out_channels, out_channels, 
                              kernel_size=kernel_size, 
                              padding=padding,
                              dilation=dilation)
        self.bn2 = nn.BatchNorm1d(out_channels)
        
        # Shortcut connection
        self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride)
        self.shortcut_bn = nn.BatchNorm1d(out_channels)
        
    def forward(self, x):
        identity = self.shortcut(x)
        identity = self.shortcut_bn(identity)
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += identity
        out = self.relu(out)
        
        return out

class TimeSeriesClassifier(nn.Module):
    def __init__(self, input_features, num_classes, window_size=50):
        """
        Time Series Classifier using ResNet architecture
        
        Args:
            input_features: Number of features in each time step
            num_classes: Number of output classes
            window_size: Size of the time window
        """
        super(TimeSeriesClassifier, self).__init__()
        
        self.input_features = input_features
        self.window_size = window_size
        
        # Feature enrichment module - expand features
        self.feature_embedding = nn.Sequential(
            nn.Linear(input_features, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True)
        )
        
        # Initial convolution layer
        self.conv1 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=7, stride=1, padding=3),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        )
        
        # Residual blocks with increasing dilation for larger receptive field
        self.res1 = Temporal1DBlock(128, 256, dilation=1)
        self.res2 = Temporal1DBlock(256, 512, dilation=2)  # Increased dilation
        self.res3 = Temporal1DBlock(512, 1024, dilation=4)  # Further increased dilation
        
        # Global pooling
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        
        # Fully connected classifier
        self.dropout = nn.Dropout(0.4)
        self.fc1 = nn.Linear(1024, 1024)
        self.bn = nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        # Input handling: [batch_size, features, time_steps] or [batch_size, features]
        if len(x.shape) == 2:
            # If input is [batch_size, features], reshape to [batch_size, features, 1]
            x = x.unsqueeze(-1)
            
            # For contact data, we assume features are measurements at a point in time
            # Here we need to enrich the feature representation
            batch_size = x.shape[0]
            x = x.transpose(1, 2)  # [batch_size, 1, features]
            x = x.view(batch_size, -1)  # [batch_size, features]
            
            # Apply feature embedding
            x = self.feature_embedding(x)  # [batch_size, 64]
            x = x.unsqueeze(-1)  # [batch_size, 64, 1]
            
            # Repeat the features across the window to simulate a time series
            x = x.repeat(1, 1, self.window_size)  # [batch_size, 64, window_size]
        
        # Normal case when x is already [batch_size, features, time_steps]
        elif len(x.shape) == 3:
            # Apply feature embedding to each time step
            batch_size, features, time_steps = x.shape
            x = x.transpose(1, 2)  # [batch_size, time_steps, features]
            x = x.reshape(-1, features)  # [batch_size*time_steps, features]
            x = self.feature_embedding(x)  # [batch_size*time_steps, 64]
            x = x.view(batch_size, time_steps, 64)  # [batch_size, time_steps, 64]
            x = x.transpose(1, 2)  # [batch_size, 64, time_steps]
        
        # Convolutional feature extraction
        x = self.conv1(x)
        
        # Residual blocks
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        
        # Global pooling
        x = self.global_pool(x).squeeze(-1)  # [batch_size, 1024]
        
        # Fully connected layers
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.bn(x)
        feats = x  # Store features
        out = self.fc2(x)
        
        # Return both features and output
        return {"feats": feats, "out": out}

# Initialize the model (usage example)
input_features = 14  # Number of features in your dataset
window_size = 50  # Size of time window
num_classes = len(CATEGORIES)
model = TimeSeriesClassifier(input_features, num_classes, window_size)

# Try CPU first for compatibility with BatchNorm1d 
device = torch.device('cpu')  # Change to 'mps' after checking summary
model.to(device)
summary(model, (input_features,))

# After checking summary, move to MPS if available
try:
    if torch.backends.mps.is_available():
        model = model.to('mps')
        device = torch.device('mps')
        print("Model moved to MPS device")
    print(f"Using device: {device}")
except Exception as e:
    print(f"Error moving to MPS: {e}")
    print(f"Using device: {device}")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                   [-1, 64]             960
       BatchNorm1d-2                   [-1, 64]             128
              ReLU-3                   [-1, 64]               0
            Conv1d-4              [-1, 128, 50]          57,472
       BatchNorm1d-5              [-1, 128, 50]             256
              ReLU-6              [-1, 128, 50]               0
         MaxPool1d-7              [-1, 128, 25]               0
            Conv1d-8              [-1, 256, 25]          33,024
       BatchNorm1d-9              [-1, 256, 25]             512
           Conv1d-10              [-1, 256, 25]          98,560
      BatchNorm1d-11              [-1, 256, 25]             512
             ReLU-12              [-1, 256, 25]               0
           Conv1d-13              [-1, 256, 25]         196,864
      BatchNorm1d-14              [-1, 

In [82]:
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [83]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    maxk = min(max(topk), output.size()[1])
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]

In [84]:
def train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    # Metric meters
    loss_m = AverageMeter()
    acc_m = AverageMeter()
    
    # Progress Bar
    batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train')
    
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()  # Zero gradients
        
        # Get the data
        x, y = data
        x, y = x.to(device), y.to(device)
        
        # Forward pass
        outputs = model(x)
        loss = criterion(outputs['out'], y)
        
        # Backward and optimize
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
        optimizer.step()
        
        # Calculate accuracy
        acc = accuracy(outputs['out'], y)[0].item()
        
        # Update meters
        loss_m.update(loss.item())
        acc_m.update(acc)
        
        # Update progress bar
        batch_bar.set_postfix(
            loss="{:.04f}".format(float(loss_m.avg)),
            acc="{:.04f}%".format(float(acc_m.avg)),
            lr="{:.06f}".format(float(optimizer.param_groups[0]['lr']))
        )
        batch_bar.update()
        
        # Memory management
        del x, y, outputs, loss
        if hasattr(torch.mps, 'empty_cache'):
            torch.mps.empty_cache()
    
    batch_bar.close()
    return acc_m.avg, loss_m.avg

In [85]:
@torch.no_grad()
def validate_model(model, val_loader, criterion, class_names, device):
    model.eval()
    # Metric meters
    loss_m = AverageMeter()
    acc_m = AverageMeter()
    
    # Progress Bar
    batch_bar = tqdm(total=len(val_loader), dynamic_ncols=True, position=0, leave=False, desc='Val')
    
    all_preds = []
    all_targets = []
    
    for i, data in enumerate(val_loader):
        x, y = data
        x, y = x.to(device), y.to(device)
        
        # Forward pass
        outputs = model(x)
        loss = criterion(outputs['out'], y)
        
        # Calculate accuracy
        acc = accuracy(outputs['out'], y)[0].item()
        
        # Store predictions and targets for confusion matrix
        _, predicted = torch.max(outputs['out'], 1)
        all_preds.extend(predicted.cpu().numpy())
        all_targets.extend(y.cpu().numpy())
        
        # Update meters
        loss_m.update(loss.item())
        acc_m.update(acc)
        
        # Update progress bar
        batch_bar.set_postfix(
            loss="{:.04f}".format(float(loss_m.avg)),
            acc="{:.04f}%".format(float(acc_m.avg))
        )
        batch_bar.update()
        
        # Memory management
        del x, y, outputs, loss
        if hasattr(torch.mps, 'empty_cache'):
            torch.mps.empty_cache()
    
    batch_bar.close()
    
    # Calculate per-class accuracy
    if class_names:
        print("\nPer-class Validation Accuracy:")
        per_class_acc = {}
        for i, class_name in enumerate(class_names):
            class_mask = (np.array(all_targets) == i)
            if np.sum(class_mask) > 0:
                class_correct = np.sum((np.array(all_preds)[class_mask] == i))
                class_total = np.sum(class_mask)
                accuracy = 100 * class_correct / class_total
                print(f"  {class_name}: {accuracy:.4f}% ({class_correct}/{class_total})")
                per_class_acc[f"val_acc_{class_name}"] = accuracy
    
    return acc_m.avg, loss_m.avg

In [86]:
def train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train')
    total_loss = 0
    correct = 0
    total = 0
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        x, y = data
        x, y = x.to(device), y.to(device)
        
        # Forward pass
        outputs = model(x)
        
        # Handle dictionary output format
        if isinstance(outputs, dict) and 'out' in outputs:
            logits = outputs['out']  # Extract the classification logits
        else:
            logits = outputs
            
        # Calculate loss with the extracted logits
        loss = criterion(logits, y)
        
        # Backward and optimize
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
        optimizer.step()
        
        # Calculate accuracy
        total_loss += loss.item()
        _, predicted = torch.max(logits, 1)  # Use the extracted logits
        total += y.size(0)
        correct += (predicted == y).sum().item()
        
        batch_bar.set_postfix(
            loss="{:.04f}".format(float(total_loss / (i + 1))),
            acc="{:.04f}".format(float(correct / total)),
            lr="{:.06f}".format(float(optimizer.param_groups[0]['lr']))
        )
        batch_bar.update()
        
        # Memory management
        del x, y, outputs, loss
        if hasattr(torch.mps, 'empty_cache'):
            torch.mps.empty_cache()
    
    batch_bar.close()
    train_loss = total_loss / len(train_loader)
    train_acc = correct / total * 100
    return train_loss, train_acc

In [90]:
def validate_model(model, val_loader, criterion, class_names, device):
    model.eval()
    batch_bar = tqdm(total=len(val_loader), dynamic_ncols=True, position=0, leave=False, desc='Val')
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    
    for i, data in enumerate(val_loader):
        x, y = data
        x, y = x.to(device), y.to(device)
        
        with torch.no_grad():
            outputs = model(x)
            
            # Handle dictionary output format
            if isinstance(outputs, dict) and 'out' in outputs:
                logits = outputs['out']  # Extract the classification logits
            else:
                logits = outputs
                
            loss = criterion(logits, y)
            
            # Calculate accuracy
            _, predicted = torch.max(logits, 1)  # Use the extracted logits
            total += y.size(0)
            correct += (predicted == y).sum().item()
            
            # Store predictions and targets for confusion matrix
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(y.cpu().numpy())
            
            total_loss += loss.item()
            
            batch_bar.set_postfix(
                loss="{:.04f}".format(float(total_loss / (i + 1))),
                acc="{:.04f}".format(float(correct / total))
            )
            batch_bar.update()
            
            # Memory management
            del x, y, outputs, loss
            if hasattr(torch.mps, 'empty_cache'):
                torch.mps.empty_cache()
    
    batch_bar.close()
    val_loss = total_loss / len(val_loader)
    val_acc = correct / total * 100
    return val_loss, val_acc

In [91]:
def save_model(model, optimizer, scheduler, metrics, epoch, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'metrics': metrics
    }, path)

In [92]:
# Define CrossEntropyLoss as the criterion
criterion = nn.CrossEntropyLoss()

# Initialize optimizer with AdamW
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['lr'],
    weight_decay=1e-5
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5,
    min_lr=1e-6,
    verbose=True
)



In [65]:
wandb.login(key="78d5988d9f05a421bc74d044c3cd9afc3b918020") # API Key is in your wandb account, under settings (wandb.ai/settings)



True

In [69]:
# Initialize wandb
run = wandb.init(
    name = "06run", ## Wandb creates random run names if you skip this field
    reinit = False, ### Allows reinitalizing runs when you re-run this cell
    #id = "", ### Insert specific run id here if you want to resume a previous run
    # resume = "must" ### You need this to resume previous runs, but comment out reinit = True when using this
    project = "object_classification", ### Project should be created in your wandb account
    config = config ### Wandb Config for your run
)

In [93]:
# Create checkpoint directory if it doesn't exist
os.makedirs(config['checkpoint_dir'], exist_ok=True)

# Initialize best metrics tracking
best_val_loss = float('inf')
best_val_acc = 0
class_names = list(CATEGORIES.keys())

# Training loop
for epoch in range(config['epochs']):
    print(f"\nEpoch {epoch + 1}/{config['epochs']}")
    
    # Training phase
    train_loss, train_acc = train_model(model, train_loader, criterion, optimizer, device)
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%")
    
    # Validation phase
    val_loss, val_acc = validate_model(model, val_loader, criterion, class_names, device)
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")
    
    # Update learning rate scheduler
    scheduler.step(val_loss)
    curr_lr = optimizer.param_groups[0]['lr']
    
    # Save best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_val_acc = val_acc
        
        # Save best model
        best_model_path = os.path.join(config['checkpoint_dir'], 'best_model.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc,
        }, best_model_path)
        wandb.save(best_model_path)  # Save the model to WandB
        print(f"Saved best model with validation loss: {best_val_loss:.4f} and accuracy: {best_val_acc:.2f}%")
    
    # Save the model for every epoch
    last_model_path = os.path.join(config['checkpoint_dir'], f'model_epoch_{epoch+1}.pth')
    torch.save(model.state_dict(), last_model_path)
    wandb.save(last_model_path)  # Save the model to WandB
    print(f"Saved model for epoch {epoch+1}")
    
    # Logging metrics to WandB
    wandb.log({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc,
        'learning_rate': curr_lr
    }, step=epoch)
    
    print(f"End of Epoch {epoch+1}/{config['epochs']}")

# Final message
print(f"\nTraining complete! Best validation accuracy: {best_val_acc:.2f}%")


Epoch 1/20


                                                                                              

Train Loss: 2.0493, Train Accuracy: 20.16%


                                                                                

Validation Loss: 2.4744, Validation Accuracy: 11.89%
Saved best model with validation loss: 2.4744 and accuracy: 11.89%
Saved model for epoch 1
End of Epoch 1/20

Epoch 2/20


                                                                                              

Train Loss: 1.9608, Train Accuracy: 20.57%


                                                                                

Validation Loss: 2.5482, Validation Accuracy: 13.54%
Saved model for epoch 2
End of Epoch 2/20

Epoch 3/20


                                                                                              

Train Loss: 1.9379, Train Accuracy: 23.36%


                                                                                

Validation Loss: 2.0957, Validation Accuracy: 15.70%
Saved best model with validation loss: 2.0957 and accuracy: 15.70%
Saved model for epoch 3
End of Epoch 3/20

Epoch 4/20


                                                                                              

Train Loss: 1.8857, Train Accuracy: 25.94%


                                                                                

Validation Loss: 2.1225, Validation Accuracy: 19.70%
Saved model for epoch 4
End of Epoch 4/20

Epoch 5/20


                                                                                              

Train Loss: 1.8327, Train Accuracy: 26.66%


                                                                                

Validation Loss: 2.5162, Validation Accuracy: 18.49%
Saved model for epoch 5
End of Epoch 5/20

Epoch 6/20


                                                                                              

Train Loss: 1.7841, Train Accuracy: 27.91%


                                                                                

Validation Loss: 2.2188, Validation Accuracy: 20.03%
Saved model for epoch 6
End of Epoch 6/20

Epoch 7/20


                                                                                              

Train Loss: 1.7663, Train Accuracy: 29.73%


                                                                                

Validation Loss: 2.8273, Validation Accuracy: 21.77%
Saved model for epoch 7
End of Epoch 7/20

Epoch 8/20


                                                                                              

Train Loss: 1.7370, Train Accuracy: 30.64%


                                                                                

Validation Loss: 2.8580, Validation Accuracy: 24.29%
Saved model for epoch 8
End of Epoch 8/20

Epoch 9/20


                                                                                              

Train Loss: 1.7109, Train Accuracy: 31.81%


                                                                                

Validation Loss: 2.2679, Validation Accuracy: 21.77%
Saved model for epoch 9
End of Epoch 9/20

Epoch 10/20


                                                                                              

Train Loss: 1.6022, Train Accuracy: 35.07%


                                                                                

Validation Loss: 2.2971, Validation Accuracy: 19.73%
Saved model for epoch 10
End of Epoch 10/20

Epoch 11/20


                                                                                              

Train Loss: 1.5815, Train Accuracy: 35.29%


                                                                                

Validation Loss: 2.2729, Validation Accuracy: 20.15%
Saved model for epoch 11
End of Epoch 11/20

Epoch 12/20


                                                                                              

Train Loss: 1.5595, Train Accuracy: 36.03%


                                                                                

Validation Loss: 2.3722, Validation Accuracy: 20.03%
Saved model for epoch 12
End of Epoch 12/20

Epoch 13/20


                                                                                              

Train Loss: 1.5504, Train Accuracy: 36.23%


                                                                                

Validation Loss: 2.6182, Validation Accuracy: 19.39%
Saved model for epoch 13
End of Epoch 13/20

Epoch 14/20


                                                                                              

Train Loss: 1.5322, Train Accuracy: 36.43%


                                                                                

Validation Loss: 2.3061, Validation Accuracy: 22.16%
Saved model for epoch 14
End of Epoch 14/20

Epoch 15/20


                                                                                              

Train Loss: 1.5440, Train Accuracy: 36.46%


                                                                                

Validation Loss: 2.4020, Validation Accuracy: 17.60%
Saved model for epoch 15
End of Epoch 15/20

Epoch 16/20


                                                                                              

Train Loss: 1.4812, Train Accuracy: 38.68%


                                                                                

Validation Loss: 2.3702, Validation Accuracy: 21.04%
Saved model for epoch 16
End of Epoch 16/20

Epoch 17/20


                                                                                              

Train Loss: 1.4518, Train Accuracy: 39.81%


                                                                                

Validation Loss: 2.4648, Validation Accuracy: 21.15%
Saved model for epoch 17
End of Epoch 17/20

Epoch 18/20


                                                                                              

Train Loss: 1.4504, Train Accuracy: 39.92%


                                                                                

Validation Loss: 2.4121, Validation Accuracy: 20.51%
Saved model for epoch 18
End of Epoch 18/20

Epoch 19/20


                                                                                              

Train Loss: 1.4291, Train Accuracy: 40.52%


                                                                                

Validation Loss: 2.4300, Validation Accuracy: 20.57%
Saved model for epoch 19
End of Epoch 19/20

Epoch 20/20


                                                                                              

Train Loss: 1.4332, Train Accuracy: 40.66%


                                                                                

Validation Loss: 2.5615, Validation Accuracy: 19.75%
Saved model for epoch 20
End of Epoch 20/20

Training complete! Best validation accuracy: 15.70%


In [77]:
def test_model(model, test_loader, criterion, class_names, device, checkpoint_dir=None):
    """
    Evaluate the model on the test dataset and generate detailed performance metrics
    """
    # Set model to evaluation mode
    model.eval()
    # Initialize metrics
    test_loss = 0.0
    correct = 0
    total = 0
    # Store all predictions and ground truth for analysis
    all_preds = []
    all_targets = []
    all_probs = [] # Store probabilities for confidence analysis
    # Per-class statistics
    class_correct = {class_name: 0 for class_name in class_names}
    class_total = {class_name: 0 for class_name in class_names}
    # Create progress bar
    test_bar = tqdm(test_loader, desc="Testing", unit="batch", ncols=100)
    
    with torch.no_grad():
        for data in test_bar:
            # Get inputs and labels
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            
            # Handle different output formats (dictionary vs. tensor)
            if isinstance(outputs, dict) and 'out' in outputs:
                outputs_for_loss = outputs['out']
            else:
                outputs_for_loss = outputs
                
            loss = criterion(outputs_for_loss, targets)
            
            # Calculate loss and accuracy
            test_loss += loss.item() * inputs.size(0)
            
            # Get predictions and probabilities
            if isinstance(outputs, dict) and 'out' in outputs:
                probs = torch.nn.functional.softmax(outputs['out'], dim=1)
                _, predicted = torch.max(outputs['out'], 1)
            else:
                probs = torch.nn.functional.softmax(outputs, dim=1)
                _, predicted = torch.max(outputs, 1)
            
            # Update total counts
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
            
            # Update per-class counts
            for i in range(targets.size(0)):
                label = targets[i].item()
                pred = predicted[i].item()
                class_name = class_names[label]
                class_total[class_name] += 1
                if pred == label:
                    class_correct[class_name] += 1
            
            # Store predictions and targets for later analysis
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            
            # Update progress bar
            test_bar.set_postfix({
                'loss': f"{test_loss/total:.4f}",
                'acc': f"{100.0*correct/total:.2f}%"
            })
    
    # Calculate overall metrics
    test_loss /= len(test_loader.dataset)
    test_acc = correct / total
    
    # Calculate per-class accuracy
    class_accuracy = {name: class_correct[name]/class_total[name] if class_total[name] > 0 else 0
                    for name in class_names}
    
    # Print results
    print("\n" + "="*50)
    print("TEST RESULTS")
    print("="*50)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.4f} ({correct}/{total})")
    print("\nPer-Class Accuracy:")
    for class_name in class_names:
        print(f" {class_name}: {class_accuracy[class_name]:.4f} ({class_correct[class_name]}/{class_total[class_name]})")

    # Return comprehensive metrics dictionary
    return {
        'test_loss': test_loss,
        'test_accuracy': test_acc,
        'class_accuracy': class_accuracy,
        'predictions': all_preds,
        'targets': all_targets,
        'probabilities': all_probs
    }

In [94]:
# # Load the best model (optional - if you saved a checkpoint)
# best_model_path = f"{config['checkpoint_dir']}/best_model.pth"
# if os.path.exists(best_model_path):
#     checkpoint = torch.load(best_model_path)
#     model.load_state_dict(checkpoint['model_state_dict'])
#     print(f"Loaded best model from epoch {checkpoint.get('epoch', 'unknown')}")

# Evaluate on test set
test_results = test_model(
    model=model,
    test_loader=test_loader,
    criterion=criterion,
    class_names=class_names,
    device=device,
    checkpoint_dir=config['checkpoint_dir']
)

Testing: 100%|████████████████████████| 102/102 [00:01<00:00, 55.47batch/s, loss=2.4158, acc=23.16%]


TEST RESULTS
Test Loss: 2.4158
Test Accuracy: 0.2316 (750/3239)

Per-Class Accuracy:
 Blueball: 0.1745 (82/470)
 Box: 0.2521 (122/484)
 Pencilcase: 0.2427 (117/482)
 Pinkball: 0.3478 (160/460)
 StuffedAnimal: 0.0170 (6/352)
 Tennis: 0.3588 (188/524)
 Waterbottle: 0.1606 (75/467)





In [95]:
run.finish()

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
learning_rate,██████▄▄▄▄▄▄▂▂▂▂▂▂▁▂
train_acc,▁▂▃▃▄▄▅▅▆▆▆▆▇▇█████▇
train_loss,█▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
val_acc,▃▁▃▃▄▇▆▃▅▆▅▆▇▆▅█▇█▇▅
val_loss,▁▃▂▂▄▃▄▅▃▄▄▃▄▅▅█▆▆▇▆

0,1
epoch,20.0
learning_rate,0.00025
train_acc,40.65692
train_loss,1.43315
val_acc,19.75378
val_loss,2.56151
