# IMPORTS

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import json
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score



import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
import cv2
from torchvision import transforms
from typing import List, Tuple, Dict, Optional, Union

# DATASET/DATALOADERS

In [3]:
class VideoDataset(Dataset):
    """Dataset for loading video data for camera technique classification."""
    
    def __init__(
        self,
        video_paths: List[str],
        labels: List[int],
        num_frames: int = 64,
        transform=None,
        temporal_sample_method: str = "uniform",
        frame_size: Tuple[int, int] = (224, 224),
    ):
        """
        Args:
            video_paths: List of paths to video files
            labels: List of class labels for each video
            num_frames: Number of frames to extract from each video
            transform: Optional transforms to apply to frames
            temporal_sample_method: Method to sample frames ("uniform" or "random")
            frame_size: Size to resize frames to (height, width)
        """
        self.video_paths = video_paths
        self.labels = labels
        self.num_frames = num_frames
        self.transform = transform
        self.temporal_sample_method = temporal_sample_method
        self.frame_size = frame_size
        
        # Default transform if none provided
        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
    
    def __len__(self):
        return len(self.video_paths)
    
    def _load_video(self, video_path: str) -> np.ndarray:
        """Load video and extract frames."""
        frames = []
        cap = cv2.VideoCapture(video_path)
        
        # Get total frame count
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        if total_frames <= 0:
            raise ValueError(f"Failed to load video: {video_path}")
        
        # Choose frame indices based on sampling method
        if self.temporal_sample_method == "uniform":
            # Uniformly sample frames across the video
            indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)
        elif self.temporal_sample_method == "random":
            # Randomly sample frames
            indices = sorted(random.sample(range(total_frames), min(self.num_frames, total_frames)))
            # If we need more frames than the video has, we'll cycle through
            if len(indices) < self.num_frames:
                extra = np.random.choice(indices, self.num_frames - len(indices))
                indices = np.concatenate([indices, extra])
        else:
            raise ValueError(f"Unsupported temporal sampling method: {self.temporal_sample_method}")
        
        # Extract selected frames
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if not ret:
                # If frame reading fails, create a black frame
                frame = np.zeros((self.frame_size[0], self.frame_size[1], 3), dtype=np.uint8)
            else:
                # Convert from BGR to RGB
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                # Resize frame
                frame = cv2.resize(frame, (self.frame_size[1], self.frame_size[0]))
            
            frames.append(frame)
        
        cap.release()
        return np.array(frames)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        """Get video frames and label for a given index."""
        video_path = self.video_paths[idx]
        label = self.labels[idx]
        
        # Load video frames
        frames = self._load_video(video_path)
        
        # Apply transforms to each frame
        transformed_frames = []
        for frame in frames:
            if self.transform:
                frame = self.transform(frame)
            transformed_frames.append(frame)
        
        # Stack frames along time dimension
        # Output shape: (T, C, H, W)
        video_tensor = torch.stack(transformed_frames)
        
        # Rearrange to expected model input shape: (T, C, H, W)
        return video_tensor, label


def create_dataloader(
    data_dir: str,
    batch_size: int = 8,
    num_frames: int = 64,
    num_workers: int = 4,
    train_ratio: float = 0.8,
    frame_size: Tuple[int, int] = (224, 224),
    video_extensions: Union[List[str], str] = "*.mp4",
) -> Tuple[DataLoader, DataLoader, Dict[int, str]]:
    """
    Create train and validation dataloaders from a directory of videos.
    
    Args:
        data_dir: Directory containing class folders with videos
        batch_size: Batch size for dataloaders
        num_frames: Number of frames to sample from each video
        num_workers: Number of worker processes for dataloaders
        train_ratio: Ratio of data to use for training
        frame_size: Size to resize frames to (height, width)
        video_extensions: File extensions to look for (e.g., "*.mp4" or ["*.mp4", "*.avi"])
        
    Returns:
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        class_to_idx: Dictionary mapping class indices to class names
    """
    data_dir = Path(data_dir)
    
    # Check if data directory exists
    if not data_dir.exists():
        raise FileNotFoundError(f"Data directory not found: {data_dir}")
    
    # Get class folders
    class_dirs = [d for d in data_dir.iterdir() if d.is_dir()]
    
    if len(class_dirs) == 0:
        raise ValueError(f"No class directories found in {data_dir}. Please make sure your data is organized in subdirectories, with each subdirectory representing a class.")
    
    print(f"Found {len(class_dirs)} class directories: {[d.name for d in class_dirs]}")
    
    class_to_idx = {cls.name: i for i, cls in enumerate(class_dirs)}
    
    video_paths = []
    labels = []
    
    # Handle both string and list of extensions
    if isinstance(video_extensions, str):
        video_extensions = [video_extensions]
    
    # Collect video paths and labels
    for class_dir in class_dirs:
        class_idx = class_to_idx[class_dir.name]
        class_videos = []
        
        # Try multiple extensions
        for ext in video_extensions:
            class_videos.extend(list(class_dir.glob(ext)))
        
        if not class_videos:
            print(f"Warning: No videos found in {class_dir} with extensions {video_extensions}")
            continue
            
        print(f"Found {len(class_videos)} videos in class '{class_dir.name}'")
        
        for video_file in class_videos:
            video_paths.append(str(video_file))
            labels.append(class_idx)
    
    if len(video_paths) == 0:
        raise ValueError(f"No video files found in the class directories. "
                         f"Checked extensions: {video_extensions}. "
                         f"Please make sure your videos have the correct extensions.")
    
    print(f"Total videos found: {len(video_paths)}")
    
    # Create train/val split
    indices = list(range(len(video_paths)))
    random.shuffle(indices)
    split = int(train_ratio * len(indices))
    
    if split == 0:
        split = 1  # Ensure at least one sample in training set
    
    train_indices = indices[:split]
    val_indices = indices[split:] if split < len(indices) else [indices[0]]  # Ensure at least one sample in validation
    
    print(f"Training samples: {len(train_indices)}, Validation samples: {len(val_indices)}")
    
    # Create transforms
    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(frame_size[0]),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(frame_size[0] + 32),
        transforms.CenterCrop(frame_size[0]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = VideoDataset(
        [video_paths[i] for i in train_indices],
        [labels[i] for i in train_indices],
        num_frames=num_frames,
        transform=train_transform,
        temporal_sample_method="random",
        frame_size=frame_size
    )
    
    val_dataset = VideoDataset(
        [video_paths[i] for i in val_indices],
        [labels[i] for i in val_indices],
        num_frames=num_frames,
        transform=val_transform,
        temporal_sample_method="uniform",
        frame_size=frame_size
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=min(batch_size, len(train_dataset)),  # Ensure batch size isn't larger than dataset
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True if len(train_dataset) > batch_size else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=min(batch_size, len(val_dataset)),  # Ensure batch size isn't larger than dataset
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False
    )
    
    # Invert class_to_idx for easier interpretation
    idx_to_class = {v: k for k, v in class_to_idx.items()}
    
    return train_loader, val_loader, idx_to_class


# MODEL

In [4]:
class LearnablePositionalEmbedding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=100, spatial=True):
        super(LearnablePositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.spatial = spatial  # Determines if it’s used for spatial or temporal encoding
        
        # Different shapes for spatial vs. temporal embeddings
        if self.spatial:
            self.position_embeddings = nn.Parameter(torch.zeros(1, max_len, 1, d_model))  # (1, 196, 1, D)
        else:
            self.position_embeddings = nn.Parameter(torch.zeros(1, 1, max_len, d_model))  # (1, 1, 16, D)
        
        nn.init.trunc_normal_(self.position_embeddings, std=0.02)

    def forward(self, x):
        """
        x: (B, 196, T, D) for temporal encoding or (B, T, 196, D) for spatial encoding
        """
        x = x + self.position_embeddings[:, :x.size(1), :, :] if self.spatial else x + self.position_embeddings[:, :, :x.size(2), :]
        return self.dropout(x)



class TransformerEncoderBlock(nn.Module):
    """
    Transformer Encoder Block.

    Implements a single Transformer Encoder Block as proposed in "Attention Is All You Need" (Vaswani et al.).

    This block consists of:
    - Multi-Head Self-Attention
    - Feedforward Neural Network (FFN)
    - Residual connections and Layer Normalization

    Args:
        input_dim (int): Dimension of the input embeddings/features.
        num_heads (int): Number of attention heads. Default is 8.
        ff_hidden_dim (int): Dimension of the hidden layer in the feedforward network. Default is 2048.
        dropout (float): Dropout rate applied after attention and feedforward layers. Default is 0.1.

    Example:
        >>> encoder_block = TransformerEncoderBlock(input_dim=512, num_heads=8)
        >>> x = torch.randn(32, 10, 512)  # (batch_size, sequence_length, input_dim)
        >>> output = encoder_block(x)

    Shape:
        - Input: (B, L, C) where B = batch size, L = sequence length, and C = input_dim.
        - Output: (B, L, C)
    """

    def __init__(self, input_dim, num_heads=8, ff_hidden_dim=2048, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()

        # Multi-Head Self-Attention
        self.self_attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=dropout, batch_first=True)

        # Layer Normalization for attention output
        self.norm1 = nn.LayerNorm(input_dim)

        # Feedforward Network (FFN)
        self.ffn = nn.Sequential(
            nn.Linear(input_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_hidden_dim, input_dim),
        )

        # Layer Normalization for FFN output
        self.norm2 = nn.LayerNorm(input_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Forward pass of the Transformer Encoder Block.

        Args:
            x (torch.Tensor): Input tensor of shape (B, L, C).
            mask (torch.Tensor, optional): Attention mask of shape (B, L) or (L, L).

        Returns:
            torch.Tensor: Output tensor of shape (B, L, C).
        """
        # Multi-Head Self-Attention with residual connection
        attn_output, _ = self.self_attention(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_output))

        # Feedforward Network with residual connection
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))

        return x



"""
- INPUT: (B, T, 3, 224, 224)
- PATCH: (B, T, P, D) where P=224*224/(16*16)=196 patches and D=768 embedding dimension
- CLS: Add CLS token to get (B, T, P+1, D) before applying positional encodings
- SPATIAL: Apply spatial positional encoding, then permute to (B, P+1, T, D) for temporal encoding
- TEMPORAL: Apply temporal encoding, then reshape to (B*T, P+1, D) for spatial attention
- After spatial attention, split CLS tokens and reshape patches to (B, P, T, D) for temporal attention
"""

class TimeSFormer(nn.Module):
    def __init__(self, number_of_frames: int, num_classes: int = 10):
        super(TimeSFormer, self).__init__()

        self._INPUT_DIM = 224
        self._INPUT_CHANNEL = 3
        self._PATCH_SIZE = 16
        self._EMBED_DIM = 768
        self._PATCH_NUM = (224 // self._PATCH_SIZE) ** 2  # 196
        self._FRAME_NUM = number_of_frames
        
        
        # Positional encoding. The +1 in `self._PATCH_NUM + 1` and `self._FRAME_NUM + 1` is for CLS Token
        self.spatial_encoding = LearnablePositionalEmbedding(self._EMBED_DIM, 0.15, self._PATCH_NUM + 1, spatial=True)
        self.temporal_encoding = LearnablePositionalEmbedding(self._EMBED_DIM, 0.15, self._FRAME_NUM + 1, spatial=False)

        # Patch embedding layer (convert patches to embeddings)
        self.patch_embedding = nn.Linear(self._PATCH_SIZE * self._PATCH_SIZE * 3, self._EMBED_DIM)

        
        # Attention
        self.spartial_encoder = TransformerEncoderBlock(self._EMBED_DIM, 8, 1024, 0.1)
        self.temporal_encoder = TransformerEncoderBlock(self._EMBED_DIM, 16, 2048, 0.1)
        
        
        #CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self._EMBED_DIM))
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        # Classification head
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self._EMBED_DIM, num_classes)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # X has the shape of (B, T, 3, 224, 224)
        B, T, C, H, W = x.shape

        if T != self._FRAME_NUM:
            raise ValueError(f"The number of frames in input is {T}, mismatch with the model's expected number of frames {self._FRAME_NUM}")

        if  (C != self._INPUT_CHANNEL) or (H != self._INPUT_DIM) or (W != self._INPUT_DIM):
            raise ValueError(f"The input spartial dimension is ({C}, {H}, {W}), mismatch with the model's expected spartial dimension ({self._INPUT_CHANNEL}, {self._INPUT_DIM}, {self._INPUT_DIM})")
        
        # Extract patches using unfold (B, T, C, H, W) -> (B, T, P, 16*16*3)
        x = x.reshape(B * T, C, H, W)  # Flatten batch & time
        x = x.unfold(2, self._PATCH_SIZE, self._PATCH_SIZE).unfold(3, self._PATCH_SIZE, self._PATCH_SIZE)
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous()  # (B*T, 14, 14, 3, 16, 16)
        x = x.reshape(B, T, self._PATCH_NUM, -1)  # (B, T, P, D)

        # Apply patch embedding
        patch_embeds = self.patch_embedding(x)  # (B, T, P, D)
        
        # Reshape to add CLS token
        patch_embeds = patch_embeds.reshape(B*T, self._PATCH_NUM, -1)  # (B*T, P, D)
        
        # Add CLS tokens for each frame in the batch
        cls_tokens = self.cls_token.expand(B*T, -1, -1)  # (B*T, 1, D)
        tokens = torch.cat([cls_tokens, patch_embeds], dim=1)  # (B*T, P+1, D)
        
        # Reshape back to include time dimension
        tokens = tokens.reshape(B, T, self._PATCH_NUM + 1, -1)  # (B, T, P+1, D)
        
        # Add spatial encoding
        tokens = self.spatial_encoding(tokens)  # (B, T, P+1, D)
        
        # Swap and add temporal encoding
        tokens = tokens.permute(0, 2, 1, 3)  # (B, P+1, T, D)
        tokens = self.temporal_encoding(tokens)  # (B, P+1, T, D)
        
        # Restore to (B, T, P+1, D) shape
        tokens = tokens.permute(0, 2, 1, 3)  # (B, T, P+1, D)
        
        # Reshape for spatial attention
        tokens = tokens.reshape(B*T, self._PATCH_NUM + 1, -1)  # (B*T, P+1, D)
        
        # Apply spatial attention
        spartial_atten = self.spartial_encoder(tokens)  # (B*T, P+1, D)
        
        # Split CLS tokens and patch embeddings
        cls_token_after_spartial = spartial_atten[:, 0:1, :]  # (B*T, 1, D)
        patch_embeds = spartial_atten[:, 1:, :]  # (B*T, P, D)
        
        # Reshape for temporal attention
        patch_embeds = patch_embeds.reshape(B, T, self._PATCH_NUM, -1)  # (B, T, P, D)
        patch_embeds = patch_embeds.permute(0, 2, 1, 3)  # (B, P, T, D)
        
        # Apply temporal attention on each patch position
        patch_embeds = patch_embeds.reshape(B*self._PATCH_NUM, T, -1)  # (B*P, T, D)
        patch_embeds = self.temporal_encoder(patch_embeds)  # (B*P, T, D)
        patch_embeds = patch_embeds.reshape(B, self._PATCH_NUM, T, -1)  # (B, P, T, D)
        
        # Process the CLS token for classification
        # We'll use the spatially-attended CLS token (before temporal)
        cls_token_after_spartial = cls_token_after_spartial.reshape(B, T, self._EMBED_DIM)  # (B, T, D)
        
        # Average the CLS token over time dimension
        global_cls = cls_token_after_spartial.mean(dim=1)  # (B, D)
        
        # Apply classification head
        output = self.classifier(self.dropout(global_cls))  # (B, num_classes)
        
        return output

# TRAIN/VALIDATE

In [None]:
# Configuration dictionary
CONFIG = {
    # Data parameters
    "data_dir": "dataset",  # Directory containing video data
    "output_dir": "output",  # Output directory for logs and checkpoints
    "num_frames": 64,  # Number of frames to sample from each video
    "train_ratio": 0.8,  # Ratio of data to use for training
    "video_extensions": ["*.mp4", "*.avi", "*.mov", "*.mkv", "*.MOV"],  # Video file extensions to look for
    
    # Training parameters
    "batch_size": 4,  # Training batch size
    "epochs": 30,  # Number of training epochs
    "learning_rate": 1e-4,  # Initial learning rate
    "min_lr": 1e-6,  # Minimum learning rate
    "weight_decay": 1e-4,  # Weight decay coefficient
    "save_every": 5,  # Save checkpoint every N epochs
    
    # Other parameters
    "seed": 42,  # Random seed
    "num_workers": 0,  # Number of data loading workers
    "no_cuda": False,  # Disable CUDA
}


def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """Train the model for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1} [Train]')
    
    for videos, labels in progress_bar:
        # Move data to device
        videos = videos.to(device)  # Expects shape (B, T, C, H, W)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(videos)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Track metrics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': running_loss / (progress_bar.n + 1),
            'acc': 100. * correct / total
        })
    
    train_loss = running_loss / len(dataloader)
    train_acc = 100. * correct / total
    
    return train_loss, train_acc


def validate(model, dataloader, criterion, device):
    """Evaluate the model on the validation set."""
    model.eval()
    running_loss = 0.0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc='Validation')
        
        for videos, labels in progress_bar:
            # Move data to device
            videos = videos.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(videos)
            loss = criterion(outputs, labels)
            
            # Track metrics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            
            # Save predictions and targets for metric computation
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())
            
            # Update progress bar
            progress_bar.set_postfix({'loss': running_loss / (progress_bar.n + 1)})
    
    # Compute metrics
    val_loss = running_loss / len(dataloader)
    val_acc = accuracy_score(all_targets, all_predictions) * 100
    precision = precision_score(all_targets, all_predictions, average='macro', zero_division=0)
    recall = recall_score(all_targets, all_predictions, average='macro', zero_division=0)
    f1 = f1_score(all_targets, all_predictions, average='macro', zero_division=0)
    
    metrics = {
        'val_loss': val_loss,
        'val_acc': val_acc,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }
    
    return metrics


def save_checkpoint(model, optimizer, scheduler, epoch, metrics, save_dir):
    """Save model checkpoint."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'metrics': metrics
    }
    
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    torch.save(checkpoint, os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pth'))


def main(config):
    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() and not config["no_cuda"] else 'cpu')
    print(f"Using device: {device}")
    
    # Set random seed for reproducibility
    torch.manual_seed(config["seed"])
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config["seed"])
    
    # Create output directories
    output_dir = Path(config["output_dir"])
    checkpoint_dir = output_dir / 'checkpoints'
    log_dir = output_dir / 'logs'
    
    for directory in [output_dir, checkpoint_dir, log_dir]:
        directory.mkdir(parents=True, exist_ok=True)
    
    # Check if data directory exists
    data_dir = Path(config["data_dir"])
    if not data_dir.exists():
        print(f"Error: Data directory '{data_dir}' does not exist.")
        print(f"Current working directory: {Path.cwd()}")
        print(f"Please create the directory or update the CONFIG['data_dir'] value.")
        return
    
    # Initialize TensorBoard writer
    writer = SummaryWriter(log_dir)
    
    try:
        # Create data loaders - add back the video_extensions parameter
        train_loader, val_loader, idx_to_class = create_dataloader(
            config["data_dir"],
            batch_size=config["batch_size"],
            num_frames=config["num_frames"],
            num_workers=config["num_workers"],
            train_ratio=config["train_ratio"],
            frame_size=(224, 224),
            video_extensions=config["video_extensions"]  # Pass the video extensions from CONFIG
        )
        
        # Save class mapping
        with open(output_dir / 'class_mapping.json', 'w') as f:
            json.dump(idx_to_class, f, indent=4)
        
        print(f"Number of training batches: {len(train_loader)}")
        print(f"Number of validation batches: {len(val_loader)}")
        print(f"Class mapping: {idx_to_class}")
        
        # Initialize model
        num_classes = len(idx_to_class)
        model = TimeSFormer(number_of_frames=config["num_frames"], num_classes=num_classes)
        model = model.to(device)
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
        
        # Learning rate scheduler
        scheduler = CosineAnnealingLR(optimizer, T_max=config["epochs"], eta_min=config["min_lr"])
        
        # Track best model
        best_val_acc = 0.0
        
        # Training loop
        for epoch in range(config["epochs"]):
            # Train for one epoch
            train_loss, train_acc = train_one_epoch(
                model, train_loader, criterion, optimizer, device, epoch
            )
            
            # Validate
            val_metrics = validate(model, val_loader, criterion, device)
            
            # Update learning rate
            scheduler.step()
            
            # Log metrics
            writer.add_scalar('Loss/train', train_loss, epoch)
            writer.add_scalar('Accuracy/train', train_acc, epoch)
            writer.add_scalar('Loss/val', val_metrics['val_loss'], epoch)
            writer.add_scalar('Accuracy/val', val_metrics['val_acc'], epoch)
            writer.add_scalar('Precision/val', val_metrics['precision'], epoch)
            writer.add_scalar('Recall/val', val_metrics['recall'], epoch)
            writer.add_scalar('F1/val', val_metrics['f1'], epoch)
            writer.add_scalar('LearningRate', scheduler.get_last_lr()[0], epoch)
            
            # Print metrics
            print(f"Epoch {epoch+1}/{config['epochs']}:")
            print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"  Val Loss: {val_metrics['val_loss']:.4f}, Val Acc: {val_metrics['val_acc']:.2f}%")
            print(f"  Precision: {val_metrics['precision']:.4f}, Recall: {val_metrics['recall']:.4f}, F1: {val_metrics['f1']:.4f}")
            
            # Save checkpoint if it's the best model so far
            if val_metrics['val_acc'] > best_val_acc:
                best_val_acc = val_metrics['val_acc']
                save_checkpoint(
                    model, optimizer, scheduler, epoch, val_metrics, 
                    os.path.join(checkpoint_dir, 'best_model')
                )
                print(f"  New best model saved with val acc: {best_val_acc:.2f}%")
            
            # Regularly save checkpoints
            if (epoch + 1) % config["save_every"] == 0:
                save_checkpoint(
                    model, optimizer, scheduler, epoch, val_metrics, checkpoint_dir
                )
        
        # Save final model
        save_checkpoint(
            model, optimizer, scheduler, config["epochs"] - 1, val_metrics, checkpoint_dir
        )
        
        writer.close()
        print(f"Training completed. Best validation accuracy: {best_val_acc:.2f}%")
    
    except Exception as e:
        print(f"Error during training: {e}")
        import traceback
        traceback.print_exc()
        writer.close()


if __name__ == "__main__":


    main(CONFIG)
