# Contact-based Object Classification

In [18]:
import pandas as pd
import numpy as np
from pathlib import Path
from typing import List, Tuple, Dict
import torch
from torch.utils.data import Dataset, DataLoader
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os
from torchsummary import summary
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import wandb

In [2]:
# Categories for classification
CATEGORIES = {
    'Blueball': 0,
    'Box': 1,
    'Pencilcase': 2,
    'Pinkball': 3,
    'StuffedAnimal': 4,
    'Tennis': 5,
    'Waterbottle': 6,
}

In [50]:
# Configuration Dictionary
config = {
    'batch_size': 32,
    'lr': 0.001,
    'epochs': 20,
    'data_dir': "/Users/benlee/Documents/college/CMU/Spring 2025/IDL/Project/IDL_code/IDL_Data",
    'checkpoint_dir': "/Users/benlee/Documents/college/CMU/Spring 2025/IDL/Project/IDL_code/checkpoint",
}

In [51]:
class ContactWindowDataset(Dataset):
    def __init__(self, data_dir: str, labels: Dict[str, int] = None, window_size: int = 50, step_size: int = 10):
        """
        Args:
            data_dir (str): Directory containing the .txt files
            labels (Dict[str, int]): Dictionary mapping categories to class labels
            window_size (int): Size of the sliding window (smaller = more samples)
            step_size (int): Step size for the sliding window (smaller = more samples)
        """
        self.data_dir = Path(data_dir)
        self.file_paths = list(self.data_dir.glob("*.txt"))
        self.labels = labels or {}
        self.window_size = window_size
        self.step_size = step_size
        
        # Lists to store all windows and their labels
        self.features_list = []
        self.labels_list = []
        self.file_indices = []  # To keep track of which file each window came from
        
        # Process all files
        print("Processing files and extracting windows:")
        for file_idx, file_path in enumerate(self.file_paths):
            # Get category from filename
            category = re.sub(r"\d+", "", file_path.stem)
            label = self.labels.get(category, -1)
            
            # Load and process the file
            df = self._parse_file(file_path)
            
            # Create windows
            windows_from_file = 0
            for start_idx in range(0, len(df) - window_size + 1, step_size):
                window = df.iloc[start_idx:start_idx + window_size]
                
                # Extract features from window
                features = self._extract_features(window)
                
                self.features_list.append(features)
                self.labels_list.append(label)
                self.file_indices.append(file_idx)
                windows_from_file += 1
            
            print(f"  {file_path.name}: {windows_from_file} windows generated from {len(df)} datapoints")
        
        # Convert lists to tensors for efficiency
        self.features = torch.FloatTensor(self.features_list)
        self.labels = torch.LongTensor(self.labels_list)
        self.file_indices = torch.LongTensor(self.file_indices)
        
        # Print summary statistics
        self._print_dataset_stats()
        
    def _parse_file(self, file_path: Path) -> pd.DataFrame:
        """Parse a single data file"""
        df = pd.read_csv(file_path, header=None, skiprows=1)
        columns = [
            'timestamp_pc', 'timestamp_micro',
            'x', 'y', 'angle_1', 'angle_2',
            'contact_1_left', 'contact_1_right',
            'contact_2_left', 'contact_2_right'
        ]
        df = pd.DataFrame(df.values, columns=columns)
        return df
    
    def _extract_features(self, window: pd.DataFrame) -> np.ndarray:
        """Extract features from a window of data"""
        # Basic statistical features
        features = np.array([
            window['contact_1_left'].mean(),
            window['contact_1_right'].mean(),
            window['contact_2_left'].mean(),
            window['contact_2_right'].mean(),
            window['x'].max() - window['x'].min(),
            window['y'].max() - window['y'].min(),
            window['angle_1'].std(),
            window['angle_2'].std(),
            # Additional features for more information
            window['contact_1_left'].std(),
            window['contact_1_right'].std(),
            window['contact_2_left'].std(),
            window['contact_2_right'].std(),
            window['angle_1'].mean(),
            window['angle_2'].mean()
        ])
        return features
    
    def _print_dataset_stats(self):
        """Print statistics about the dataset"""
        total_windows = len(self.features)
        unique_files = len(torch.unique(self.file_indices))
        
        # Count samples per category
        category_counts = {}
        for label in self.labels_list:
            category_name = list(CATEGORIES.keys())[list(CATEGORIES.values()).index(label)]
            category_counts[category_name] = category_counts.get(category_name, 0) + 1
        
        print("\nDataset Statistics:")
        print(f"Total number of windows: {total_windows}")
        print(f"Total number of files: {unique_files}")
        print(f"Average windows per file: {total_windows / unique_files:.2f}")
        print("\nSamples per category:")
        for category, count in category_counts.items():
            print(f"  {category}: {count} windows")
        
    def __len__(self) -> int:
        return len(self.features)
        
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        return self.features[idx], self.labels[idx]

In [52]:
# Function to split dataset while preventing data leakage
def split_dataset_by_files(dataset, train_ratio=0.7, val_ratio=0.15, seed=42):
    """
    Split the dataset by files to prevent data leakage.
    Windows from the same file will stay in the same split.
    """
    # Get unique file indices
    unique_files = torch.unique(dataset.file_indices).tolist()
    
    # Shuffle the files
    np.random.seed(seed)
    np.random.shuffle(unique_files)
    
    # Split files into train, val, test
    n_files = len(unique_files)
    train_files = unique_files[:int(train_ratio * n_files)]
    val_files = unique_files[int(train_ratio * n_files):int((train_ratio + val_ratio) * n_files)]
    test_files = unique_files[int((train_ratio + val_ratio) * n_files):]
    
    # Get indices for each split
    train_indices = [i for i, file_idx in enumerate(dataset.file_indices) if file_idx in train_files]
    val_indices = [i for i, file_idx in enumerate(dataset.file_indices) if file_idx in val_files]
    test_indices = [i for i, file_idx in enumerate(dataset.file_indices) if file_idx in test_files]
    
    return train_indices, val_indices, test_indices

In [53]:
# Create the dataset with smaller window and step size to maximize samples
window_size = 100   # Smaller window size generates more samples
step_size = 10     # Smaller step size creates more overlapping windows
dataset = ContactWindowDataset(
    data_dir=config["data_dir"], 
    labels=CATEGORIES, 
    window_size=window_size, 
    step_size=step_size
)

Processing files and extracting windows:
  Pencilcase4.txt: 482 windows generated from 4916 datapoints
  Pinkball3.txt: 467 windows generated from 4760 datapoints
  Pinkball2.txt: 460 windows generated from 4695 datapoints
  Pencilcase5.txt: 531 windows generated from 5404 datapoints
  Pinkball1.txt: 481 windows generated from 4901 datapoints
  Pencilcase2.txt: 615 windows generated from 6249 datapoints
  Pinkball5.txt: 473 windows generated from 4828 datapoints
  Pinkball4.txt: 456 windows generated from 4658 datapoints
  Pencilcase3.txt: 602 windows generated from 6119 datapoints
  Pencilcase1.txt: 504 windows generated from 5135 datapoints
  Pinkball6.txt: 458 windows generated from 4679 datapoints
  Box4.txt: 483 windows generated from 4926 datapoints
  Box5.txt: 530 windows generated from 5399 datapoints
  Box1.txt: 509 windows generated from 5187 datapoints
  Box2.txt: 467 windows generated from 4766 datapoints
  Box3.txt: 484 windows generated from 4939 datapoints
  Blueball6.tx

In [54]:
# Split the dataset by files to prevent data leakage
train_indices, val_indices, test_indices = split_dataset_by_files(dataset)

# Create subset datasets
from torch.utils.data import Subset
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'])
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'])

# Print split information
print("\nDataset Splits:")
print(f"Training set size: {len(train_dataset)} windows")
print(f"Validation set size: {len(val_dataset)} windows")
print(f"Test set size: {len(test_dataset)} windows")

# Example of accessing a batch
for features, labels in train_loader:
    print(f"\nSample batch:")
    print(f"Batch features shape: {features.shape}")
    print(f"Batch labels shape: {labels.shape}")
    break


Dataset Splits:
Training set size: 12423 windows
Validation set size: 2919 windows
Test set size: 2827 windows

Sample batch:
Batch features shape: torch.Size([32, 14])
Batch labels shape: torch.Size([32])


In [55]:
# Define the Neural Network Model for Object Classification
class ContactClassifier(torch.nn.Module):
    def __init__(self, input_size, num_classes):
        super(ContactClassifier, self).__init__()
        
        # Specialized feature extraction pathway for contact data
        self.model = torch.nn.Sequential(
            # Input normalization layer
            torch.nn.BatchNorm1d(input_size),
            
            # Initial feature expansion - smaller width but better feature extraction
            torch.nn.Linear(input_size, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.GELU(),
            torch.nn.Dropout(0.2),  # Reduced dropout to prevent information loss
            
            # Deeper feature processing (keeping original structure but improving flow)
            torch.nn.Linear(256, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.GELU(),
            
            # Refinement layer - not in original but helps with feature separation
            torch.nn.Linear(512, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.GELU(),
            torch.nn.Dropout(0.3),
            
            # Added skip connection internally by splitting flow and rejoining
            torch.nn.Linear(512, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.GELU(),
            
            # Squeeze down to focused features
            torch.nn.Linear(256, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.GELU(),
            
            # Final classification layer
            torch.nn.Linear(128, num_classes)
        )
        
        # Proper weight initialization
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Linear):
                # Kaiming initialization works better with GELU
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)
            elif isinstance(m, torch.nn.BatchNorm1d):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        return self.model(x)

# Initialize the model
input_size = 14  # Number of features in your dataset
num_classes = len(CATEGORIES)  # Number of object categories
model = ContactClassifier(input_size, num_classes)

# Try CPU first for compatibility with BatchNorm1d 
device = torch.device('cpu')  # Change to 'mps' after checking summary
model.to(device)
summary(model, (input_size,))

# After checking summary, move to MPS if available
try:
    if torch.backends.mps.is_available():
        model = model.to('mps')
        device = torch.device('mps')
        print("Model moved to MPS device")
    print(f"Using device: {device}")
except Exception as e:
    print(f"Error moving to MPS: {e}")
    print(f"Using device: {device}")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
       BatchNorm1d-1                   [-1, 14]              28
            Linear-2                  [-1, 256]           3,840
       BatchNorm1d-3                  [-1, 256]             512
              GELU-4                  [-1, 256]               0
           Dropout-5                  [-1, 256]               0
            Linear-6                  [-1, 512]         131,584
       BatchNorm1d-7                  [-1, 512]           1,024
              GELU-8                  [-1, 512]               0
            Linear-9                  [-1, 512]         262,656
      BatchNorm1d-10                  [-1, 512]           1,024
             GELU-11                  [-1, 512]               0
          Dropout-12                  [-1, 512]               0
           Linear-13                  [-1, 256]         131,328
      BatchNorm1d-14                  [

In [56]:
# Define CrossEntropyLoss as the criterion
# Standard loss function for multi-class classification problems
criterion = nn.CrossEntropyLoss()

# Initialize optimizer with AdamW (Adam with weight decay)
# We pass all model parameters and set the learning rate
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=0.0001,
    weight_decay=1e-5  # Adding a small weight decay for regularization
)

# Learning rate scheduler to reduce learning rate when training plateaus
# ReduceLROnPlateau reduces learning rate when validation loss stops improving
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',  # Monitor minimum validation loss
    factor=0.5,  # Reduce learning rate by half when plateau is detected
    patience=5,  # Wait for 5 epochs without improvement before reducing LR
    min_lr=1e-6,  # Don't reduce learning rate below this threshold
    verbose=True  # Print message when learning rate is reduced
)

# Mixed Precision, if you need it
scaler = torch.cuda.amp.GradScaler()

  scaler = torch.cuda.amp.GradScaler()


In [57]:
def train_model(model, train_loader, criterion, optimizer, device):
    model.train()
    batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train')
    total_loss = 0
    correct = 0
    total = 0
    
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        
        x, y = data
        x, y = x.to(device), y.to(device)
        
        # Forward pass
        outputs = model(x)
        loss = criterion(outputs, y)
        
        # Backward and optimize
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
        optimizer.step()
        
        # Calculate accuracy
        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()
        
        batch_bar.set_postfix(
            loss="{:.04f}".format(float(total_loss / (i + 1))),
            acc="{:.04f}".format(float(correct / total)),
            lr="{:.06f}".format(float(optimizer.param_groups[0]['lr']))
        )
        batch_bar.update()
        
        # Memory management
        del x, y, outputs, loss
        if hasattr(torch.mps, 'empty_cache'):
            torch.mps.empty_cache()
    
    batch_bar.close()
    train_loss = total_loss / len(train_loader)
    train_acc = correct / total *100
    
    return train_loss, train_acc

In [58]:
def validate_model(model, val_loader, criterion, class_names, device):
    model.eval()
    batch_bar = tqdm(total=len(val_loader), dynamic_ncols=True, position=0, leave=False, desc='Val')
    total_loss = 0
    correct = 0
    total = 0
    
    all_preds = []
    all_targets = []
    
    for i, data in enumerate(val_loader):
        x, y = data
        x, y = x.to(device), y.to(device)
        
        with torch.no_grad():
            outputs = model(x)
            loss = criterion(outputs, y)
            
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
            
            # Store predictions and targets for confusion matrix
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(y.cpu().numpy())
            
            total_loss += loss.item()
            
            batch_bar.set_postfix(
                loss="{:.04f}".format(float(total_loss / (i + 1))), 
                acc="{:.04f}".format(float(correct / total))
            )
            batch_bar.update()
        
        # Memory management
        del x, y, outputs, loss
        if hasattr(torch.mps, 'empty_cache'):
            torch.mps.empty_cache()
    
    batch_bar.close()
    val_loss = total_loss / len(val_loader)
    val_acc = correct / total *100
    
    return val_loss, val_acc

In [59]:
wandb.login(key="78d5988d9f05a421bc74d044c3cd9afc3b918020") # API Key is in your wandb account, under settings (wandb.ai/settings)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/benlee/.netrc


True

In [60]:
# Initialize wandb
run = wandb.init(
    name = "04run", ## Wandb creates random run names if you skip this field
    reinit = False, ### Allows reinitalizing runs when you re-run this cell
    #id = "", ### Insert specific run id here if you want to resume a previous run
    # resume = "must" ### You need this to resume previous runs, but comment out reinit = True when using this
    project = "object_classification", ### Project should be created in your wandb account
    config = config ### Wandb Config for your run
)

In [61]:
# Create checkpoint directory if it doesn't exist
os.makedirs(config['checkpoint_dir'], exist_ok=True)

# Initialize best metrics tracking
best_val_loss = float('inf')
best_val_acc = 0
class_names = list(CATEGORIES.keys())

# Training loop
for epoch in range(config['epochs']):
    print(f"\nEpoch {epoch + 1}/{config['epochs']}")
    
    # Training loop with progress bar
    model.train()
    epoch_train_loss = 0
    correct = 0
    total = 0
    
    train_bar = tqdm(train_loader, desc="Training", unit="batch", ncols=100)
    for batch_idx, (X_batch, y_batch) in enumerate(train_bar):
        # Move data to device
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Track metrics
        epoch_train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()
        
        # Update progress bar
        train_bar.set_postfix({
            'loss': f"{epoch_train_loss/(batch_idx+1):.4f}",
            'acc': f"{100.*correct/total:.2f}%"
        })
        
        # Free memory
        del X_batch, y_batch, outputs, loss
        if device.type == 'mps' and hasattr(torch.mps, 'empty_cache'):
            torch.mps.empty_cache()
    
    # Calculate final training metrics
    epoch_train_loss /= len(train_loader)
    train_acc = correct / total
    print(f"Train Loss: {epoch_train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
    
    # Validation loop with progress bar
    model.eval()
    epoch_val_loss = 0
    correct = 0
    total = 0
    
    # For per-class accuracy
    class_correct = [0] * len(class_names)
    class_total = [0] * len(class_names)
    
    val_bar = tqdm(val_loader, desc="Validation", unit="batch", ncols=100)
    with torch.no_grad():
        for batch_idx, (X_val_batch, y_val_batch) in enumerate(val_bar):
            # Move data to device
            X_val_batch, y_val_batch = X_val_batch.to(device), y_val_batch.to(device)
            
            # Forward pass
            outputs = model(X_val_batch)
            loss = criterion(outputs, y_val_batch)
            
            # Track metrics
            epoch_val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += y_val_batch.size(0)
            correct += (predicted == y_val_batch).sum().item()
            
            # Per-class accuracy
            for c in range(len(class_names)):
                class_mask = (y_val_batch == c)
                class_correct[c] += (predicted[class_mask] == c).sum().item()
                class_total[c] += class_mask.sum().item()
            
            # Update progress bar
            val_bar.set_postfix({
                'loss': f"{epoch_val_loss/(batch_idx+1):.4f}", 
                'acc': f"{100.*correct/total:.2f}%"
            })
            
            # Free memory
            del X_val_batch, y_val_batch, outputs, loss
            if device.type == 'mps' and hasattr(torch.mps, 'empty_cache'):
                torch.mps.empty_cache()
    
    # Calculate final validation metrics
    epoch_val_loss /= len(val_loader)
    val_acc = correct / total
    print(f"Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")
    
    # Print per-class accuracy every 5 epochs
    if (epoch + 1) % 5 == 0:
        print("\nPer-class Validation Accuracy:")
        per_class_acc = {}
        for i, class_name in enumerate(class_names):
            if class_total[i] > 0:
                accuracy = class_correct[i] / class_total[i]
                print(f"  {class_name}: {accuracy:.4f} ({class_correct[i]}/{class_total[i]})")
                per_class_acc[f"val_acc_{class_name}"] = accuracy
        # Log per-class metrics to WandB
        wandb.log(per_class_acc, step=epoch)
    
    # Update learning rate scheduler
    scheduler.step(epoch_val_loss)
    curr_lr = optimizer.param_groups[0]['lr']
    
    # Save best model based on validation loss
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        best_val_acc = val_acc
        
        # Save best model
        best_model_path = os.path.join(config['checkpoint_dir'], 'best_model.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': epoch_val_loss,
            'val_acc': val_acc,
        }, best_model_path)
        wandb.save(best_model_path)  # Save the model to WandB
        print(f"Saved best model with validation loss: {best_val_loss:.4f} and accuracy: {best_val_acc:.4f}")
    
    # Save the model for every epoch
    last_model_path = os.path.join(config['checkpoint_dir'], f'model_epoch_{epoch+1}.pth')
    torch.save(model.state_dict(), last_model_path)
    wandb.save(last_model_path)  # Save the model to WandB
    print(f"Saved model for epoch {epoch+1}")
    
    # Logging metrics to WandB
    wandb.log({
        'epoch': epoch + 1,
        'train_loss': epoch_train_loss,
        'train_acc': train_acc,
        'val_loss': epoch_val_loss,
        'val_acc': val_acc,
        'learning_rate': curr_lr
    }, step=epoch)
    
    print(f"End of Epoch {epoch+1}/{config['epochs']}")

# Final message
print(f"\nTraining complete! Best validation accuracy: {best_val_acc:.4f}")


Epoch 1/20


Training: 100%|███████████████████████| 389/389 [00:04<00:00, 92.91batch/s, loss=3.8367, acc=17.01%]


Train Loss: 3.8367, Train Accuracy: 0.1701


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 133.52batch/s, loss=2.5456, acc=13.74%]


Validation Loss: 2.5456, Validation Accuracy: 0.1374
Saved best model with validation loss: 2.5456 and accuracy: 0.1374
Saved model for epoch 1
End of Epoch 1/20

Epoch 2/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 119.07batch/s, loss=2.7348, acc=19.10%]


Train Loss: 2.7348, Train Accuracy: 0.1910


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 148.04batch/s, loss=2.2968, acc=13.33%]


Validation Loss: 2.2968, Validation Accuracy: 0.1333
Saved best model with validation loss: 2.2968 and accuracy: 0.1333
Saved model for epoch 2
End of Epoch 2/20

Epoch 3/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 121.11batch/s, loss=2.3092, acc=21.15%]


Train Loss: 2.3092, Train Accuracy: 0.2115


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 147.39batch/s, loss=2.1488, acc=12.50%]


Validation Loss: 2.1488, Validation Accuracy: 0.1250
Saved best model with validation loss: 2.1488 and accuracy: 0.1250
Saved model for epoch 3
End of Epoch 3/20

Epoch 4/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 120.05batch/s, loss=2.1141, acc=21.49%]


Train Loss: 2.1141, Train Accuracy: 0.2149


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 140.92batch/s, loss=2.0605, acc=12.44%]


Validation Loss: 2.0605, Validation Accuracy: 0.1244
Saved best model with validation loss: 2.0605 and accuracy: 0.1244
Saved model for epoch 4
End of Epoch 4/20

Epoch 5/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 122.90batch/s, loss=2.0236, acc=22.14%]


Train Loss: 2.0236, Train Accuracy: 0.2214


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 143.03batch/s, loss=2.0531, acc=15.72%]


Validation Loss: 2.0531, Validation Accuracy: 0.1572

Per-class Validation Accuracy:
  Pencilcase: 0.3503 (186/531)
  Pinkball: 0.2227 (102/458)
  StuffedAnimal: 0.0321 (32/998)
  Tennis: 0.2318 (108/466)
  Waterbottle: 0.0665 (31/466)
Saved best model with validation loss: 2.0531 and accuracy: 0.1572
Saved model for epoch 5
End of Epoch 5/20

Epoch 6/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 122.29batch/s, loss=1.9516, acc=23.32%]


Train Loss: 1.9516, Train Accuracy: 0.2332


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 140.34batch/s, loss=2.0330, acc=16.41%]


Validation Loss: 2.0330, Validation Accuracy: 0.1641
Saved best model with validation loss: 2.0330 and accuracy: 0.1641
Saved model for epoch 6
End of Epoch 6/20

Epoch 7/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 122.11batch/s, loss=1.9089, acc=23.55%]


Train Loss: 1.9089, Train Accuracy: 0.2355


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 155.72batch/s, loss=2.0306, acc=16.82%]


Validation Loss: 2.0306, Validation Accuracy: 0.1682
Saved best model with validation loss: 2.0306 and accuracy: 0.1682
Saved model for epoch 7
End of Epoch 7/20

Epoch 8/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 122.37batch/s, loss=1.8791, acc=24.43%]


Train Loss: 1.8791, Train Accuracy: 0.2443


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 153.70batch/s, loss=2.0249, acc=17.30%]


Validation Loss: 2.0249, Validation Accuracy: 0.1730
Saved best model with validation loss: 2.0249 and accuracy: 0.1730
Saved model for epoch 8
End of Epoch 8/20

Epoch 9/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 124.85batch/s, loss=1.8629, acc=24.79%]


Train Loss: 1.8629, Train Accuracy: 0.2479


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 128.18batch/s, loss=2.0085, acc=17.03%]


Validation Loss: 2.0085, Validation Accuracy: 0.1703
Saved best model with validation loss: 2.0085 and accuracy: 0.1703
Saved model for epoch 9
End of Epoch 9/20

Epoch 10/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 123.23batch/s, loss=1.8435, acc=25.48%]


Train Loss: 1.8435, Train Accuracy: 0.2548


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 152.87batch/s, loss=2.0576, acc=11.92%]


Validation Loss: 2.0576, Validation Accuracy: 0.1192

Per-class Validation Accuracy:
  Pencilcase: 0.3315 (176/531)
  Pinkball: 0.2555 (117/458)
  StuffedAnimal: 0.0080 (8/998)
  Tennis: 0.0944 (44/466)
  Waterbottle: 0.0064 (3/466)
Saved model for epoch 10
End of Epoch 10/20

Epoch 11/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 120.90batch/s, loss=1.8305, acc=25.77%]


Train Loss: 1.8305, Train Accuracy: 0.2577


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 146.72batch/s, loss=2.0471, acc=12.85%]


Validation Loss: 2.0471, Validation Accuracy: 0.1285
Saved model for epoch 11
End of Epoch 11/20

Epoch 12/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 123.22batch/s, loss=1.8268, acc=25.59%]


Train Loss: 1.8268, Train Accuracy: 0.2559


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 145.23batch/s, loss=2.0293, acc=15.04%]


Validation Loss: 2.0293, Validation Accuracy: 0.1504
Saved model for epoch 12
End of Epoch 12/20

Epoch 13/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 120.04batch/s, loss=1.8133, acc=26.68%]


Train Loss: 1.8133, Train Accuracy: 0.2668


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 140.36batch/s, loss=2.0127, acc=16.27%]


Validation Loss: 2.0127, Validation Accuracy: 0.1627
Saved model for epoch 13
End of Epoch 13/20

Epoch 14/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 124.03batch/s, loss=1.8042, acc=26.46%]


Train Loss: 1.8042, Train Accuracy: 0.2646


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 143.46batch/s, loss=2.0377, acc=13.57%]


Validation Loss: 2.0377, Validation Accuracy: 0.1357
Saved model for epoch 14
End of Epoch 14/20

Epoch 15/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 123.41batch/s, loss=1.7945, acc=26.74%]


Train Loss: 1.7945, Train Accuracy: 0.2674


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 139.37batch/s, loss=2.0599, acc=14.18%]


Validation Loss: 2.0599, Validation Accuracy: 0.1418

Per-class Validation Accuracy:
  Pencilcase: 0.4633 (246/531)
  Pinkball: 0.2358 (108/458)
  StuffedAnimal: 0.0220 (22/998)
  Tennis: 0.0687 (32/466)
  Waterbottle: 0.0129 (6/466)
Saved model for epoch 15
End of Epoch 15/20

Epoch 16/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 122.79batch/s, loss=1.7824, acc=28.05%]


Train Loss: 1.7824, Train Accuracy: 0.2805


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 143.03batch/s, loss=2.0497, acc=14.97%]


Validation Loss: 2.0497, Validation Accuracy: 0.1497
Saved model for epoch 16
End of Epoch 16/20

Epoch 17/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 123.44batch/s, loss=1.7813, acc=27.29%]


Train Loss: 1.7813, Train Accuracy: 0.2729


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 142.95batch/s, loss=2.0364, acc=16.31%]


Validation Loss: 2.0364, Validation Accuracy: 0.1631
Saved model for epoch 17
End of Epoch 17/20

Epoch 18/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 124.58batch/s, loss=1.7704, acc=27.87%]


Train Loss: 1.7704, Train Accuracy: 0.2787


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 145.78batch/s, loss=2.0040, acc=18.98%]


Validation Loss: 2.0040, Validation Accuracy: 0.1898
Saved best model with validation loss: 2.0040 and accuracy: 0.1898
Saved model for epoch 18
End of Epoch 18/20

Epoch 19/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 123.35batch/s, loss=1.7730, acc=27.79%]


Train Loss: 1.7730, Train Accuracy: 0.2779


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 146.08batch/s, loss=2.0357, acc=18.53%]


Validation Loss: 2.0357, Validation Accuracy: 0.1853
Saved model for epoch 19
End of Epoch 19/20

Epoch 20/20


Training: 100%|██████████████████████| 389/389 [00:03<00:00, 123.66batch/s, loss=1.7647, acc=28.00%]


Train Loss: 1.7647, Train Accuracy: 0.2800


Validation: 100%|██████████████████████| 92/92 [00:00<00:00, 141.24batch/s, loss=2.0398, acc=15.55%]


Validation Loss: 2.0398, Validation Accuracy: 0.1555

Per-class Validation Accuracy:
  Pencilcase: 0.4482 (238/531)
  Pinkball: 0.2227 (102/458)
  StuffedAnimal: 0.0641 (64/998)
  Tennis: 0.0880 (41/466)
  Waterbottle: 0.0193 (9/466)
Saved model for epoch 20
End of Epoch 20/20

Training complete! Best validation accuracy: 0.1898


In [62]:
def test_model(model, test_loader, criterion, class_names, device, checkpoint_dir=None):
    """
    Evaluate the model on the test dataset and generate detailed performance metrics
    """
    # Set model to evaluation mode
    model.eval()
    
    # Initialize metrics
    test_loss = 0.0
    correct = 0
    total = 0
    
    # Store all predictions and ground truth for analysis
    all_preds = []
    all_targets = []
    all_probs = []  # Store probabilities for confidence analysis
    
    # Per-class statistics
    class_correct = {class_name: 0 for class_name in class_names}
    class_total = {class_name: 0 for class_name in class_names}
    
    # Create progress bar
    test_bar = tqdm(test_loader, desc="Testing", unit="batch", ncols=100)
    
    with torch.no_grad():
        for data in test_bar:
            # Get inputs and labels
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Calculate loss and accuracy
            test_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            probs = torch.nn.functional.softmax(outputs, dim=1)
            
            # Update total counts
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
            
            # Update per-class counts
            for i in range(targets.size(0)):
                label = targets[i].item()
                pred = predicted[i].item()
                class_name = class_names[label]
                class_total[class_name] += 1
                if pred == label:
                    class_correct[class_name] += 1
            
            # Store predictions and targets for later analysis
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            
            # Update progress bar
            test_bar.set_postfix({
                'loss': f"{test_loss/total:.4f}",
                'acc': f"{100.0*correct/total:.2f}%"
            })
    
    # Calculate overall metrics
    test_loss /= len(test_loader)
    test_acc = correct / total
    
    # Calculate per-class accuracy
    class_accuracy = {name: class_correct[name]/class_total[name] if class_total[name] > 0 else 0 
                     for name in class_names}
    
    # Create confusion matrix
    from sklearn.metrics import confusion_matrix, classification_report
    conf_matrix = confusion_matrix(all_targets, all_preds)
    classification_rep = classification_report(all_targets, all_preds, 
                                              target_names=class_names, output_dict=True)
    
    # Print results
    print("\n" + "="*50)
    print("TEST RESULTS")
    print("="*50)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.4f} ({correct}/{total})")
    print("\nPer-Class Accuracy:")
    for class_name in class_names:
        print(f"  {class_name}: {class_accuracy[class_name]:.4f} ({class_correct[class_name]}/{class_total[class_name]})")
    
    print("\nClassification Report:")
    print(classification_report(all_targets, all_preds, target_names=class_names))
    
    # Return comprehensive metrics dictionary
    return {
        'test_loss': test_loss,
        'test_accuracy': test_acc,
        'class_accuracy': class_accuracy,
        'confusion_matrix': conf_matrix,
        'classification_report': classification_rep,
        'predictions': all_preds,
        'targets': all_targets,
        'probabilities': all_probs
    }

In [63]:
# Load the best model (optional - if you saved a checkpoint)
best_model_path = f"{config['checkpoint_dir']}/best_model.pth"
if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint.get('epoch', 'unknown')}")

# Evaluate on test set
test_results = test_model(
    model=model,
    test_loader=test_loader,
    criterion=criterion,
    class_names=class_names,
    device=device,
    checkpoint_dir=config['checkpoint_dir']
)

# You can now access detailed metrics
print(f"\nFinal test accuracy: {test_results['test_accuracy']:.4f}")

  checkpoint = torch.load(best_model_path)


Loaded best model from epoch 17


Testing: 100%|██████████████████████████| 89/89 [00:01<00:00, 84.66batch/s, loss=1.9309, acc=20.91%]



TEST RESULTS
Test Loss: 61.3322
Test Accuracy: 0.2091 (591/2827)

Per-Class Accuracy:
  Blueball: 0.1567 (144/919)
  Box: 0.6916 (323/467)
  Pencilcase: 0.0000 (0/0)
  Pinkball: 0.0768 (35/456)
  StuffedAnimal: 0.0000 (0/0)
  Tennis: 0.1680 (84/500)
  Waterbottle: 0.0103 (5/485)

Classification Report:
               precision    recall  f1-score   support

     Blueball       0.34      0.16      0.21       919
          Box       0.35      0.69      0.46       467
   Pencilcase       0.00      0.00      0.00         0
     Pinkball       0.10      0.08      0.09       456
StuffedAnimal       0.00      0.00      0.00         0
       Tennis       0.27      0.17      0.21       500
  Waterbottle       0.17      0.01      0.02       485

     accuracy                           0.21      2827
    macro avg       0.17      0.16      0.14      2827
 weighted avg       0.26      0.21      0.20      2827


Final test accuracy: 0.2091


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [64]:
run.finish()

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
learning_rate,██████████████▁▁▁▁▁▁
train_acc,▁▂▄▄▄▅▅▆▆▆▇▆▇▇▇█████
train_loss,█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▃▂▂▂▅▅▆▆▆▁▂▄▅▃▃▄▅██▅
val_acc_Pencilcase,▂▁█▇
val_acc_Pinkball,▁█▄▁
val_acc_StuffedAnimal,▄▁▃█
val_acc_Tennis,█▂▁▂
val_acc_Waterbottle,█▁▂▃

0,1
epoch,20.0
learning_rate,5e-05
train_acc,0.28005
train_loss,1.76472
val_acc,0.15553
val_acc_Pencilcase,0.44821
val_acc_Pinkball,0.22271
val_acc_StuffedAnimal,0.06413
val_acc_Tennis,0.08798
val_acc_Waterbottle,0.01931
