In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
from typing import Tuple, Optional
import os 
import glob
from tqdm import tqdm

In [2]:

class TCAM(nn.Module):
    """
    Temporal Class-Activation Map Network for Arrow of Time detection.
    
    Architecture:
    - Input: Optical flow frames (multiple temporal chunks)
    - Processing: Temporal chunking + Parallel VGG-16 backbones
    - Fusion: Late temporal fusion of conv5 features
    - Classification: 3 conv layers + Global Average Pooling + Logistic Regression
    
    Args:
        num_temporal_chunks (int): Number of temporal chunks (T). Default: 2
        frames_per_chunk (int): Frames per temporal chunk. Default: 10
        num_flow_channels (int): Input channels (optical flow components). Default: 2 (u, v)
        pretrained (bool): Whether to use pretrained VGG16 weights. Default: False
    """
    
    def __init__(
        self,
        num_temporal_chunks: int = 2,
        frames_per_chunk: int = 10,
        num_flow_channels: int = 2,
        pretrained: bool = False
    ):
        super(TCAM, self).__init__()
        
        self.num_temporal_chunks = num_temporal_chunks
        self.frames_per_chunk = frames_per_chunk
        self.num_flow_channels = num_flow_channels
        self.total_frames = num_temporal_chunks * frames_per_chunk
        self.total_input_channels = num_flow_channels * frames_per_chunk
        
        # ==================== Temporal Feature Fusion Stage ====================
        # Modified VGG-16 backbone to accept optical flow input
        # Expand conv1 filters to accept stacked optical flow frames
        
        self.features = self._build_vgg16_backbone(pretrained)
        
        # ==================== Classification Stage ====================
        # Replace FC layers with conv layers for better interpretability
        
        # Three convolutional layers: 3×3×1024 with BatchNorm
        # NEW: 512 * num_temporal_chunks
        self.class_conv1 = nn.Conv2d(
            512 * self.num_temporal_chunks,  # 512*T
            1024,
            kernel_size=3,
            stride=1,
            padding=1
        )

        self.bn1 = nn.BatchNorm2d(1024)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.class_conv2 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(1024)
        self.relu2 = nn.ReLU(inplace=True)
        
        self.class_conv3 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(1024)
        self.relu3 = nn.ReLU(inplace=True)
        
        # Global Average Pooling
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        
        # Binary logistic regression for forward/backward classification
        self.fc = nn.Linear(1024, 1)
        self.sigmoid = nn.Sigmoid()
        
    def _build_vgg16_backbone(self, pretrained: bool = False) -> nn.Sequential:
        """
        Build VGG16 backbone modified for optical flow input.
        
        Args:
            pretrained (bool): Use pretrained ImageNet weights
            
        Returns:
            nn.Sequential: VGG16 feature extraction layers (up to conv5)
        """
        # Standard VGG16 configuration: conv layer depths for each block
        vgg16_config = [
            64, 64, 'M',                    # Block 1
            128, 128, 'M',                  # Block 2
            256, 256, 256, 'M',             # Block 3
            512, 512, 512, 'M',             # Block 4
            512, 512, 512                   # Block 5 (no final pooling)
        ]
        
        layers = []
        in_channels = self.total_input_channels  # Stacked optical flow input
        
        for v in vgg16_config:
            if v == 'M':
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                layers.append(nn.Conv2d(in_channels, v, kernel_size=3, padding=1))
                layers.append(nn.BatchNorm2d(v))
                layers.append(nn.ReLU(inplace=True))
                in_channels = v
        
        # Load pretrained weights if requested (from torchvision VGG16)
        if pretrained:
            try:
                import torchvision.models as models
                vgg16_pretrained = models.vgg16(pretrained=True)
                # Copy weights, skipping first conv layer due to channel mismatch
                for i, layer in enumerate(layers):
                    if isinstance(layer, nn.Conv2d) and i > 0:
                        layer.load_state_dict(
                            vgg16_pretrained.features[i].state_dict(),
                            strict=False
                        )
            except Exception as e:
                print(f"Warning: Could not load pretrained weights: {e}")
                print("Proceeding with random initialization")
        
        return nn.Sequential(*layers)
    
    def _extract_temporal_chunks(self, x: torch.Tensor) -> list:
        """
        Extract temporal chunks from input batch.
        
        Args:
            x (torch.Tensor): Input tensor of shape 
                (batch_size, total_input_channels, height, width)
                where total_input_channels = num_flow_channels * total_frames
                
        Returns:
            list: List of tensors, each of shape (batch_size, chunk_input_channels, height, width)
        """
        batch_size, total_channels, height, width = x.shape
        chunk_input_channels = self.num_flow_channels * self.frames_per_chunk
        
        chunks = []
        for chunk_idx in range(self.num_temporal_chunks):
            start_channel = chunk_idx * chunk_input_channels
            end_channel = start_channel + chunk_input_channels
            chunk = x[:, start_channel:end_channel, :, :]
            chunks.append(chunk)
        
        return chunks
    
    def _temporal_fusion(self, chunk_features: list) -> torch.Tensor:
        """
        Late temporal fusion: concatenate conv5 features from all chunks.
        
        Args:
            chunk_features (list): List of conv5 feature maps from each chunk,
                each of shape (batch_size, 512, H, W)
                
        Returns:
            torch.Tensor: Concatenated features of shape (batch_size, 512*T, H, W)
                where T = num_temporal_chunks
        """
        # Concatenate along channel dimension
        fused_features = torch.cat(chunk_features, dim=1)  # (B, 512*T, H, W)
        return fused_features
    
    def extract_cam_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Extract features for Class Activation Map visualization.
        
        Args:
            x (torch.Tensor): Input optical flow tensor
            
        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                - conv5_features: Feature maps before classification head (B, 512*T, H, W)
                - output: Model output logits (B, 1)
        """
        # Extract temporal chunks
        chunks = self._extract_temporal_chunks(x)
        
        # Process each chunk through VGG16 backbone
        chunk_features = []
        for chunk in chunks:
            # Extract conv5 features by running through backbone
            feat = self.features(chunk)
            chunk_features.append(feat)
        
        # Temporal fusion
        fused_features = self._temporal_fusion(chunk_features)
        
        # Save fused features for CAM computation
        conv5_features = fused_features
        
        # Classification head
        x = self.class_conv1(fused_features)
        x = self.bn1(x)
        x = self.relu1(x)
        
        x = self.class_conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        
        x = self.class_conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        
        # Get features before GAP for CAM
        pre_gap_features = x  # (B, 1024, H, W)
        
        # Global Average Pooling
        x = self.gap(x)  # (B, 1024, 1, 1)
        x = x.view(x.size(0), -1)  # (B, 1024)
        
        # Classification output
        output = self.fc(x)  # (B, 1)
        
        return {
            'conv5_features': conv5_features,
            'pre_gap_features': pre_gap_features,
            'output': output
        }
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of T-CAM model.
        
        Args:
            x (torch.Tensor): Input optical flow tensor of shape
                (batch_size, num_flow_channels * num_temporal_chunks * frames_per_chunk, H, W)
                Example: (B, 2*2*10=40, H, W) for default configuration
                
        Returns:
            torch.Tensor: Output logits of shape (batch_size, 1)
                Values close to 1.0 indicate forward-playing video
                Values close to 0.0 indicate backward-playing video
        """
        # Extract temporal chunks
        chunks = self._extract_temporal_chunks(x)
        
        # Process each chunk through VGG16 backbone
        chunk_features = []
        for chunk in chunks:
            feat = self.features(chunk)
            chunk_features.append(feat)
        
        # Temporal fusion (late fusion)
        fused_features = self._temporal_fusion(chunk_features)  # (B, 512*T, H, W)
        
        # Classification head
        x = self.class_conv1(fused_features)
        x = self.bn1(x)
        x = self.relu1(x)
        
        x = self.class_conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        
        x = self.class_conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        
        # Global Average Pooling
        x = self.gap(x)  # (B, 1024, 1, 1)
        x = x.view(x.size(0), -1)  # (B, 1024)
        
        # Binary classification logits
        output = self.fc(x)  # (B, 1)
        
        return output
    
    def get_classification_probabilities(self, x: torch.Tensor) -> torch.Tensor:
        """
        Get classification probabilities (0-1 range).
        
        Args:
            x (torch.Tensor): Input optical flow tensor
            
        Returns:
            torch.Tensor: Probabilities of shape (batch_size, 1) in range [0, 1]
        """
        logits = self.forward(x)
        probs = torch.sigmoid(logits)
        return probs
    
    def compute_class_activation_map(
        self,
        x: torch.Tensor,
        target_class: int = 1,
        normalize: bool = True
    ) -> torch.Tensor:
        """
        Compute Class Activation Map (CAM) for visualization.
        
        This method visualizes which spatial regions contribute most to the
        forward/backward classification decision.
        
        Args:
            x (torch.Tensor): Input optical flow tensor
            target_class (int): Target class (0 for backward, 1 for forward). Default: 1
            normalize (bool): Normalize CAM to [0, 1]. Default: True
            
        Returns:
            torch.Tensor: Class Activation Map of shape (batch_size, 1, H, W)
                High values indicate regions important for predicting target_class
        """
        # Extract features and get model output
        features_dict = self.extract_cam_features(x)
        pre_gap_features = features_dict['pre_gap_features']  # (B, 1024, H, W)
        
        # Get weights from fc layer for the target class
        fc_weights = self.fc.weight[0]  # (1024,)
        
        # Compute CAM as weighted sum of feature maps
        batch_size, num_channels, height, width = pre_gap_features.shape
        
        # Reshape weights for broadcasting
        weights = fc_weights.view(1, num_channels, 1, 1)
        
        # Compute weighted feature maps
        cam = (pre_gap_features * weights).sum(dim=1, keepdim=True)  # (B, 1, H, W)
        
        # Apply ReLU to focus on positive activations
        cam = F.relu(cam)
        
        # Normalize CAM
        if normalize:
            # Normalize each sample independently
            for i in range(batch_size):
                cam_min = cam[i].min()
                cam_max = cam[i].max()
                if cam_max - cam_min > 0:
                    cam[i] = (cam[i] - cam_min) / (cam_max - cam_min)
                else:
                    cam[i] = torch.zeros_like(cam[i])
        
        return cam

In [3]:


# ==================== Utility Functions ====================

def create_tcam_model(
    num_temporal_chunks: int = 2,
    frames_per_chunk: int = 10,
    num_flow_channels: int = 2,
    pretrained: bool = False,
    device: Optional[torch.device] = None
) -> TCAM:
    """
    Create a T-CAM model instance.
    
    Args:
        num_temporal_chunks (int): Number of temporal chunks. Default: 2
        frames_per_chunk (int): Frames per chunk. Default: 10
        num_flow_channels (int): Input flow channels. Default: 2 (u, v)
        pretrained (bool): Use pretrained weights. Default: False
        device (torch.device): Device to move model to. Default: None
        
    Returns:
        TCAM: Instantiated model
    """
    model = TCAM(
        num_temporal_chunks=num_temporal_chunks,
        frames_per_chunk=frames_per_chunk,
        num_flow_channels=num_flow_channels,
        pretrained=pretrained
    )
    
    if device is not None:
        model = model.to(device)
    
    return model


def count_parameters(model: nn.Module) -> int:
    """Count total number of trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def print_model_summary(model: TCAM) -> None:
    """Print model architecture summary."""
    print("=" * 80)
    print("T-CAM Model Summary")
    print("=" * 80)
    print(model)
    print(f"\nTotal Parameters: {count_parameters(model):,}")
    print(f"Model Size: {count_parameters(model) * 4 / (1024**2):.2f} MB")
    print("=" * 80)


In [4]:


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
model = create_tcam_model(
    num_temporal_chunks=2,
    frames_per_chunk=10,
    num_flow_channels=2,
    pretrained=False,
    device=device
)

# Print model summary
print_model_summary(model)

# Example forward pass
batch_size = 4
height, width = 224, 224
num_flow_channels = 2
total_frames = 20  # 2 chunks * 10 frames

# Create dummy optical flow input
# Shape: (batch_size, num_flow_channels * total_frames, height, width)
x = torch.randn(
    batch_size,
    num_flow_channels * total_frames,
    height,
    width,
    device=device
)

print(f"\nInput shape: {x.shape}")

# Forward pass
output = model(x)
print(f"Output shape: {output.shape}")
print(f"Output logits (raw): {output.squeeze()}")

# Get probabilities
probs = model.get_classification_probabilities(x)
print(f"Output probabilities: {probs.squeeze()}")
print(f"Predictions (1=Forward, 0=Backward): {(probs > 0.5).long().squeeze()}")

# Compute CAM
cam = model.compute_class_activation_map(x, target_class=1)
print(f"\nClass Activation Map shape: {cam.shape}")
print(f"CAM value range: [{cam.min():.4f}, {cam.max():.4f}]")

print("\n✓ Model created and tested successfully!")

Using device: cuda
T-CAM Model Summary
TCAM(
  (features): Sequential(
    (0): Conv2d(20, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, c

In [5]:
# ==================== OPTICAL FLOW EXTRACTION ====================

class OpticalFlowExtractor:
    """
    Extract optical flow using TV-L1 method (as used in the paper).
    Requires OpenCV with contrib modules.
    """
    
    def __init__(self, backend: str = 'opencv'):
        """
        Initialize optical flow extractor.
        
        Args:
            backend (str): 'opencv' for OpenCV DualTVL1, 'farneback' for Farneback
        """
        self.backend = backend
        self._setup_backend()
    
    def _setup_backend(self):
        """Setup the optical flow backend."""
        if self.backend == 'opencv':
            try:
                self.optical_flow = cv2.optflow.createOptFlow_DualTVL1()
            except AttributeError:
                print("Warning: DualTVL1 not available. Using Farneback instead.")
                print("Install opencv-contrib-python: pip install opencv-contrib-python")
                self.backend = 'farneback'
        
        if self.backend == 'farneback':
            self.optical_flow = None  # Use cv2.calcOpticalFlowFarneback
    
    def extract_flow(self, frame1: np.ndarray, frame2: np.ndarray) -> np.ndarray:
        """
        Extract optical flow between two consecutive frames.
        
        Args:
            frame1 (np.ndarray): Previous frame (grayscale or RGB)
            frame2 (np.ndarray): Current frame (grayscale or RGB)
            
        Returns:
            np.ndarray: Optical flow of shape (H, W, 2) containing (u, v) components
        """
        # Convert to grayscale if needed
        if len(frame1.shape) == 3:
            frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY)
        if len(frame2.shape) == 3:
            frame2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY)
        
        if self.backend == 'opencv':
            flow = self.optical_flow.calc(frame1, frame2, None)
        else:  # Farneback
            flow = cv2.calcOpticalFlowFarneback(
                frame1, frame2,
                None,
                pyr_scale=0.5,
                levels=3,
                winsize=15,
                iterations=3,
                n8=False,
                poly_n=5,
                poly_sigma=1.2,
                flags=0
            )
        
        return flow.astype(np.float32)
    
    @staticmethod
    def extract_flow_sequence(
        frames: np.ndarray,
        method: str = 'consecutive'
    ) -> np.ndarray:
        """
        Extract optical flow for a sequence of frames.
        
        Args:
            frames (np.ndarray): Array of frames, shape (T, H, W, 3)
            method (str): 'consecutive' for frame-to-frame flow
            
        Returns:
            np.ndarray: Optical flow sequence, shape (T-1, H, W, 2)
        """
        extractor = OpticalFlowExtractor(backend='opencv')
        flows = []
        
        for i in range(len(frames) - 1):
            flow = extractor.extract_flow(frames[i], frames[i + 1])
            flows.append(flow)
        
        return np.stack(flows, axis=0)
    
    @staticmethod
    def visualize_flow(
        flow: np.ndarray,
        colorize: bool = True
    ) -> np.ndarray:
        """
        Visualize optical flow.
        
        Args:
            flow (np.ndarray): Optical flow (H, W, 2)
            colorize (bool): Use HSV colorization
            
        Returns:
            np.ndarray: Visualization image (H, W, 3)
        """
        if colorize:
            mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
            hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
            hsv[..., 0] = ang * 180 / np.pi / 2
            hsv[..., 1] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
            hsv[..., 2] = 255
            rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
            return rgb
        else:
            # Magnitude visualization
            mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])
            mag = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
            return cv2.applyColorMap(mag, cv2.COLORMAP_JET)


In [6]:
class ReverseFilmFlowDataset(Dataset):
    """
    Uses precomputed TV-L1 flow from ./flow_11/stab_10.
    labels_* file lines:
        "<Movie>/<Clip> 0"      -> original order
        "<Movie>/<Clip>_rev 1"  -> time-reversed
    """

    def __init__(
        self,
        flow_root,
        label_file,
        num_temporal_chunks=2,
        frames_per_chunk=10,
        image_size=(224, 224),
    ):
        self.flow_root = flow_root
        self.num_temporal_chunks = num_temporal_chunks
        self.frames_per_chunk = frames_per_chunk
        self.total_frames = num_temporal_chunks * frames_per_chunk
        self.image_size = image_size

        self.samples = []
        with open(label_file, "r") as f:
            for line in f:
                rel_path, lab = line.strip().split()
                self.samples.append((rel_path, int(lab)))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        rel_path, label = self.samples[idx]
        # handle _rev
        is_rev = rel_path.endswith("_rev")
        if is_rev:
            rel_path = rel_path[:-4]  # strip "_rev"

        movie, clip = rel_path.split("/")
        flow_path = os.path.join(self.flow_root, movie, f"{clip}.npy")

        flows = np.load(flow_path)  # (T_raw, H0, W0, 2)
        if is_rev:
            flows = flows[::-1].copy()  # reverse time

        T_raw, H0, W0, _ = flows.shape

        # ensure exactly total_frames flows
        if T_raw >= self.total_frames:
            flows = flows[:self.total_frames]
        else:
            last = flows[-1]
            while flows.shape[0] < self.total_frames:
                flows = np.concatenate([flows, last[None]], axis=0)

        H, W = self.image_size
        if (H0, W0) != (H, W):
            resized = []
            for f in flows:
                fx = cv2.resize(f[..., 0], (W, H))
                fy = cv2.resize(f[..., 1], (W, H))
                resized.append(np.stack([fx, fy], axis=-1))
            flows = np.stack(resized, axis=0)

        # per-clip normalization
        mu, sigma = flows.mean(), flows.std()
        flows = (flows - mu) / (sigma + 1e-6)

        # (T, H, W, 2) -> (2*T, H, W)
        T, H, W, C = flows.shape
        flows = flows.transpose(0, 3, 1, 2).reshape(T * C, H, W)

        flow_tensor = torch.from_numpy(flows).float()
        label = torch.tensor(label, dtype=torch.long)

        return flow_tensor, label


In [7]:
class TrainingConfig:
    """Configuration for model training."""
    
    def __init__(
        self,
        batch_size: int = 4,
        num_epochs: int = 50,
        learning_rate: float = 0.001,
        weight_decay: float = 0.0005,
        num_workers: int = 0,
        device: str = 'cuda',
        checkpoint_dir: str = './checkpoints',
        log_interval: int = 100
    ):
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.num_workers = num_workers
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.checkpoint_dir = checkpoint_dir
        self.log_interval = log_interval


In [8]:
class Trainer:
    """Trainer class for T-CAM model."""
    
    def __init__(
        self,
        model: TCAM,
        config: TrainingConfig,
        criterion: nn.Module = None,
        optimizer: optim.Optimizer = None
    ):
        """
        Initialize trainer.
        
        Args:
            model (TCAM): T-CAM model instance
            config (TrainingConfig): Training configuration
            criterion (nn.Module): Loss function. Default: BCEWithLogitsLoss
            optimizer (optim.Optimizer): Optimizer. Default: SGD
        """
        self.model = model.to(config.device)
        for p in self.model.features.parameters():
            p.requires_grad = False

        self.config = config
        self.device = config.device
        
        # Loss function: Binary classification with logits
        if criterion is None:
            criterion = nn.BCEWithLogitsLoss()
        self.criterion = criterion
        
        # Optimizer: SGD with momentum as in paper
        if optimizer is None:
            optimizer = optim.SGD(
                model.parameters(),
                lr=config.learning_rate,
                momentum=0.9,
                weight_decay=config.weight_decay
            )
        self.optimizer = optimizer
        
        # Metrics tracking
        self.train_losses = []
        self.val_accuracies = []
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.1)

    
    def train_epoch(self, train_loader: DataLoader) -> float:
        """
        Train for one epoch.
        
        Args:
            train_loader (DataLoader): Training data loader
            
        Returns:
            float: Average training loss
        """
        self.model.train()
        total_loss = 0.0
        
        for batch_idx, (flows, labels) in enumerate(tqdm(train_loader, desc="Train", leave=False)):
            flows = flows.to(self.device)
            labels = labels.to(self.device).float().unsqueeze(1)
            
            # Forward pass
            self.optimizer.zero_grad()
            outputs = self.model(flows)
            loss = self.criterion(outputs, labels)
            
            # Backward pass
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            # Logging
            if (batch_idx + 1) % self.config.log_interval == 0:
                print(f"Batch [{batch_idx + 1}/{len(train_loader)}], "
                      f"Loss: {loss.item():.4f}")
            
        avg_loss = total_loss / len(train_loader)
        self.train_losses.append(avg_loss)
        return avg_loss
    
    def validate(self, val_loader: DataLoader) -> Tuple[float, float]:
        """
        Validate model.
        
        Args:
            val_loader (DataLoader): Validation data loader
            
        Returns:
            Tuple[float, float]: (Accuracy, Loss)
        """
        self.model.eval()
        correct = 0
        total = 0
        total_loss = 0.0
        
        with torch.no_grad():
            for flows, labels in tqdm(val_loader, desc="Val", leave=False):
                flows = flows.to(self.device)
                labels = labels.to(self.device).float().unsqueeze(1)
                
                outputs = self.model(flows)
                loss = self.criterion(outputs, labels)
                total_loss += loss.item()
                
                # Compute accuracy
                predictions = (torch.sigmoid(outputs) > 0.5).long()
                correct += (predictions == labels.long()).sum().item()
                total += labels.size(0)
        
        accuracy = correct / total
        avg_loss = total_loss / len(val_loader)
        self.val_accuracies.append(accuracy)
        
        return accuracy, avg_loss
    
    def train(
        self,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader] = None
    ):
        """
        Full training loop.
        
        Args:
            train_loader (DataLoader): Training data loader
            val_loader (Optional[DataLoader]): Validation data loader
        """
        print(f"Starting training for {self.config.num_epochs} epochs...")
        print(f"Device: {self.device}")
        
        for epoch in range(self.config.num_epochs):
            print(f"\n--- Epoch {epoch + 1}/{self.config.num_epochs} ---")
            
            # Train
            train_loss = self.train_epoch(train_loader)
            print(f"Average Training Loss: {train_loss:.4f}")
            
            # Validate
            if val_loader is not None:
                val_acc, val_loss = self.validate(val_loader)
                print(f"Validation Accuracy: {val_acc:.4f}, Loss: {val_loss:.4f}")
            
            # Save checkpoint
            self._save_checkpoint(epoch)
            self.scheduler.step()

    
    def _save_checkpoint(self, epoch: int):
        """Save model checkpoint."""
        import os
        os.makedirs(self.config.checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(
            self.config.checkpoint_dir,
            f"tcam_epoch_{epoch + 1}.pth"
        )
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'val_accuracies': self.val_accuracies
        }, checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")
    
    def load_checkpoint(self, checkpoint_path: str):
        """Load model from checkpoint."""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.train_losses = checkpoint.get('train_losses', [])
        self.val_accuracies = checkpoint.get('val_accuracies', [])
        print(f"Checkpoint loaded from: {checkpoint_path}")


In [9]:

# ==================== INFERENCE UTILITIES ====================

class Inference:
    """Inference utilities for arrow of time detection."""
    
    def __init__(self, model: TCAM, device: torch.device):
        """
        Initialize inference engine.
        
        Args:
            model (TCAM): Trained model
            device (torch.device): Device to run inference on
        """
        self.model = model.to(device)
        self.model.eval()
        self.device = device
    
    def predict(self, flows: torch.Tensor) -> Tuple[float, str]:
        """
        Make prediction on optical flow.
        
        Args:
            flows (torch.Tensor): Optical flow tensor
            
        Returns:
            Tuple[float, str]: (Probability, Direction) where Direction is 'Forward' or 'Backward'
        """
        with torch.no_grad():
            flows = flows.to(self.device)
            if flows.dim() == 3:
                flows = flows.unsqueeze(0)
            
            probs = self.model.get_classification_probabilities(flows)
            prob = probs.item()
            direction = 'Forward' if prob > 0.5 else 'Backward'
        
        return prob, direction
    
    def predict_with_cam(self, flows: torch.Tensor) -> Tuple[float, str, np.ndarray]:
        """
        Make prediction and compute Class Activation Map.
        
        Args:
            flows (torch.Tensor): Optical flow tensor
            
        Returns:
            Tuple[float, str, np.ndarray]: (Probability, Direction, CAM)
        """
        with torch.no_grad():
            flows = flows.to(self.device)
            if flows.dim() == 3:
                flows = flows.unsqueeze(0)
            
            # Get prediction
            probs = self.model.get_classification_probabilities(flows)
            prob = probs.item()
            direction = 'Forward' if prob > 0.5 else 'Backward'
            
            # Get CAM
            cam = self.model.compute_class_activation_map(flows, target_class=1)
            cam = cam[0, 0].cpu().numpy()  # (H, W)
        
        return prob, direction, cam


In [10]:


print("T-CAM Training Utilities")
print("=" * 80)

# Configuration
config = TrainingConfig(
    batch_size=4,
    num_epochs=30,
    learning_rate=1e-4,
    device='cuda'
)

# Create model
model = create_tcam_model(device=config.device)

# Create trainer
trainer = Trainer(model, config)

print("\nTrainer initialized successfully!")
print(f"Device: {config.device}")
print(f"Batch size: {config.batch_size}")
print(f"Learning rate: {config.learning_rate}")


T-CAM Training Utilities

Trainer initialized successfully!
Device: cuda
Batch size: 4
Learning rate: 0.0001


In [11]:
# flow_root = "./flow_11/stab_10"
# labels = "./labels_bidirectional.txt"
flow_root = os.path.join(os.getcwd(),'flow_11','stab_10')
labels = os.path.join(os.getcwd(),'labels_bidirectional.txt')
full_dataset = ReverseFilmFlowDataset(
    flow_root=flow_root,
    label_file=labels,
    num_temporal_chunks=2,
    frames_per_chunk=10,
    image_size=(224, 224),
)

print("Train/val split...")
from collections import defaultdict
movies = defaultdict(list)
for idx, (rel_path, lab) in enumerate(full_dataset.samples):
    movie = rel_path.split("/")[0].replace("_rev", "")
    movies[movie].append(idx)

movie_names = sorted(movies.keys())
n_train_movies = int(0.8 * len(movie_names))
train_movies = set(movie_names[:n_train_movies])
val_movies = set(movie_names[n_train_movies:])

train_indices = [i for m in train_movies for i in movies[m]]
val_indices   = [i for m in val_movies  for i in movies[m]]

train_ds = torch.utils.data.Subset(full_dataset, train_indices)
val_ds   = torch.utils.data.Subset(full_dataset, val_indices)

# sanity check
train_labels = [full_dataset.samples[i][1] for i in train_ds.indices]
val_labels = [full_dataset.samples[i][1] for i in val_ds.indices]
print("Train labels set:", set(train_labels))
print("Val labels set:", set(val_labels))

train_loader = DataLoader(
    train_ds,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True,
)


trainer.train(train_loader, val_loader)


Train/val split...
Train labels set: {0, 1}
Val labels set: {0, 1}
Starting training for 30 epochs...
Device: cuda

--- Epoch 1/30 ---


                                                      

Average Training Loss: 0.6988


                                                  

Validation Accuracy: 0.5000, Loss: 0.6950
Checkpoint saved: ./checkpoints\tcam_epoch_1.pth

--- Epoch 2/30 ---


                                                      

Average Training Loss: 0.6957


                                                  

Validation Accuracy: 0.5000, Loss: 0.6851
Checkpoint saved: ./checkpoints\tcam_epoch_2.pth

--- Epoch 3/30 ---


                                                      

Average Training Loss: 0.6895


                                                  

Validation Accuracy: 0.5000, Loss: 0.6872
Checkpoint saved: ./checkpoints\tcam_epoch_3.pth

--- Epoch 4/30 ---


                                                      

Average Training Loss: 0.6912


                                                  

Validation Accuracy: 0.5833, Loss: 0.6835
Checkpoint saved: ./checkpoints\tcam_epoch_4.pth

--- Epoch 5/30 ---


                                                      

Average Training Loss: 0.6836


                                                  

Validation Accuracy: 0.4167, Loss: 0.6900
Checkpoint saved: ./checkpoints\tcam_epoch_5.pth

--- Epoch 6/30 ---


                                                      

Average Training Loss: 0.6785


                                                  

Validation Accuracy: 0.4167, Loss: 0.6865
Checkpoint saved: ./checkpoints\tcam_epoch_6.pth

--- Epoch 7/30 ---


                                                      

Average Training Loss: 0.6860


                                                  

Validation Accuracy: 0.5000, Loss: 0.6862
Checkpoint saved: ./checkpoints\tcam_epoch_7.pth

--- Epoch 8/30 ---


                                                      

Average Training Loss: 0.6641


                                                  

Validation Accuracy: 0.5000, Loss: 0.6867
Checkpoint saved: ./checkpoints\tcam_epoch_8.pth

--- Epoch 9/30 ---


                                                      

Average Training Loss: 0.6709


                                                  

Validation Accuracy: 0.5000, Loss: 0.6854
Checkpoint saved: ./checkpoints\tcam_epoch_9.pth

--- Epoch 10/30 ---


                                                      

Average Training Loss: 0.6714


                                                  

Validation Accuracy: 0.4167, Loss: 0.6914
Checkpoint saved: ./checkpoints\tcam_epoch_10.pth

--- Epoch 11/30 ---


                                                      

Average Training Loss: 0.6718


                                                  

Validation Accuracy: 0.5000, Loss: 0.6910
Checkpoint saved: ./checkpoints\tcam_epoch_11.pth

--- Epoch 12/30 ---


                                                      

Average Training Loss: 0.6795


                                                  

Validation Accuracy: 0.5000, Loss: 0.6888
Checkpoint saved: ./checkpoints\tcam_epoch_12.pth

--- Epoch 13/30 ---


                                                      

Average Training Loss: 0.6769


                                                  

Validation Accuracy: 0.5000, Loss: 0.6912
Checkpoint saved: ./checkpoints\tcam_epoch_13.pth

--- Epoch 14/30 ---


                                                      

Average Training Loss: 0.6622


                                                  

Validation Accuracy: 0.5833, Loss: 0.6957
Checkpoint saved: ./checkpoints\tcam_epoch_14.pth

--- Epoch 15/30 ---


                                                      

Average Training Loss: 0.6626


                                                  

Validation Accuracy: 0.5000, Loss: 0.6930
Checkpoint saved: ./checkpoints\tcam_epoch_15.pth

--- Epoch 16/30 ---


                                                      

Average Training Loss: 0.6721


                                                  

Validation Accuracy: 0.5000, Loss: 0.6911
Checkpoint saved: ./checkpoints\tcam_epoch_16.pth

--- Epoch 17/30 ---


                                                      

Average Training Loss: 0.6854


                                                  

Validation Accuracy: 0.5000, Loss: 0.6914
Checkpoint saved: ./checkpoints\tcam_epoch_17.pth

--- Epoch 18/30 ---


                                                      

Average Training Loss: 0.6675


                                                  

Validation Accuracy: 0.5000, Loss: 0.6932
Checkpoint saved: ./checkpoints\tcam_epoch_18.pth

--- Epoch 19/30 ---


                                                      

Average Training Loss: 0.6665


                                                  

Validation Accuracy: 0.5000, Loss: 0.6928
Checkpoint saved: ./checkpoints\tcam_epoch_19.pth

--- Epoch 20/30 ---


                                                      

Average Training Loss: 0.6715


                                                  

Validation Accuracy: 0.5000, Loss: 0.6945
Checkpoint saved: ./checkpoints\tcam_epoch_20.pth

--- Epoch 21/30 ---


                                                      

Average Training Loss: 0.6706


                                                  

Validation Accuracy: 0.5000, Loss: 0.6908
Checkpoint saved: ./checkpoints\tcam_epoch_21.pth

--- Epoch 22/30 ---


                                                      

Average Training Loss: 0.6740


                                                  

Validation Accuracy: 0.5000, Loss: 0.6943
Checkpoint saved: ./checkpoints\tcam_epoch_22.pth

--- Epoch 23/30 ---


                                                      

Average Training Loss: 0.6701


                                                  

Validation Accuracy: 0.5833, Loss: 0.6890
Checkpoint saved: ./checkpoints\tcam_epoch_23.pth

--- Epoch 24/30 ---


                                                      

Average Training Loss: 0.6696


                                                  

Validation Accuracy: 0.5000, Loss: 0.6956
Checkpoint saved: ./checkpoints\tcam_epoch_24.pth

--- Epoch 25/30 ---


                                                      

Average Training Loss: 0.6593


                                                  

Validation Accuracy: 0.5000, Loss: 0.6924
Checkpoint saved: ./checkpoints\tcam_epoch_25.pth

--- Epoch 26/30 ---


                                                      

Average Training Loss: 0.6563


                                                  

Validation Accuracy: 0.5000, Loss: 0.6935
Checkpoint saved: ./checkpoints\tcam_epoch_26.pth

--- Epoch 27/30 ---


                                                      

Average Training Loss: 0.6757


                                                  

Validation Accuracy: 0.5000, Loss: 0.6915
Checkpoint saved: ./checkpoints\tcam_epoch_27.pth

--- Epoch 28/30 ---


                                                      

Average Training Loss: 0.6584


                                                  

Validation Accuracy: 0.5000, Loss: 0.6918
Checkpoint saved: ./checkpoints\tcam_epoch_28.pth

--- Epoch 29/30 ---


                                                      

Average Training Loss: 0.6546


                                                  

Validation Accuracy: 0.5000, Loss: 0.6950
Checkpoint saved: ./checkpoints\tcam_epoch_29.pth

--- Epoch 30/30 ---


                                                      

Average Training Loss: 0.6646


                                                  

Validation Accuracy: 0.5000, Loss: 0.6902
Checkpoint saved: ./checkpoints\tcam_epoch_30.pth
