# Video partitionning using mobilenetV2 CNN and Bi-LSTM

## imports

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from torch.cuda.amp import autocast, GradScaler
import os
import cv2
import json
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import random
from datetime import datetime
import time

## constants

In [8]:
FRAME_SIZE = 160  # Resize frames to 160x160
FPS = 4  # Sample 8 frames per second
SEQUENCE_LENGTH = 120  # 30 seconds * 4 fps
BATCH_SIZE = 16  # Will be adjusted based on memory
NUM_CLASSES = 3  # intro, regular, outro
FEATURE_DIM = 512  # Reduced feature dimension
HIDDEN_DIM = 512  # LSTM hidden dimension
ACCUMULATION_STEPS = 4  # Effective batch size of 4

## dataset

In [9]:
class AnimeDataset(Dataset):
    def __init__(self, video_dir, json_dir, clip_length=30, mode='train', train_split=0.7, val_split=0.15):
        """
        Dataset for loading anime episodes and their annotations.
        
        Args:
            video_dir: Directory containing video files
            json_dir: Directory containing annotation JSON files
            clip_length: Length of clips in seconds
            mode: 'train', 'val', or 'test'
            train_split: Proportion of data for training
            val_split: Proportion of data for validation
        """
        self.video_dir = Path(video_dir)
        self.json_dir = Path(json_dir)
        self.clip_length = clip_length
        self.frame_interval = 1.0 / FPS

         # Debug print
        print(f"Initializing dataset with video_dir: {video_dir}, json_dir: {json_dir}")
        
        # Get all video files
        self.video_files = sorted(list(self.video_dir.glob('*.mp4')))
        print(f"Found {len(self.video_files)} video files")
        
        # Split data
        n_videos = len(self.video_files)
        train_idx = int(n_videos * train_split)
        val_idx = int(n_videos * (train_split + val_split))
        print(f"Total videos: {n_videos}, Train idx: {train_idx}, Val idx: {val_idx}")
        
        if mode == 'train':
            self.video_files = self.video_files[:train_idx]
        elif mode == 'val':
            self.video_files = self.video_files[train_idx:val_idx]
        else:  # test
            self.video_files = self.video_files[val_idx:]
            
        # Load annotations with debug prints
        self.annotations = {}
        for video_file in self.video_files:
            json_file = self.json_dir / f"{video_file.stem}.json"
            print(f"Loading annotation: {json_file}")
            if not json_file.exists():
                print(f"Warning: Missing annotation file for {video_file.stem}")
                continue
            with open(json_file, 'r') as f:
                self.annotations[video_file.stem] = json.load(f)
        
        print(f"Loaded {len(self.annotations)} annotations")

        # Print first annotation as example
        if self.annotations:
            first_key = next(iter(self.annotations))
            print(f"Sample annotation: {self.annotations[first_key]}")
                
        # Create clips with balanced sampling
        self.clips = self._create_clips()
        print(f"Created {len(self.clips)} clips")
       
         # Updated transforms with augmentation
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((FRAME_SIZE, FRAME_SIZE)),
            transforms.RandomHorizontalFlip(p=0.3),  # New
            transforms.ColorJitter(brightness=0.2, contrast=0.2),  # New
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ]) if mode == 'train' else transforms.Compose([  # Different transform for val/test
            transforms.ToTensor(),
            transforms.Resize((FRAME_SIZE, FRAME_SIZE)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])
        
    def _create_clips(self):
        """
        Create clips with balanced sampling between intro/regular/outro content.
        Returns list of (video_path, start_time) tuples.
        """
        clips = []
        intro_clips = []
        regular_clips = []
        outro_clips = []
        
        for video_file in self.video_files:
            if video_file.stem not in self.annotations:
                print(f"Skipping {video_file.stem} - no annotation")
                continue


            ann = self.annotations[video_file.stem]



            try:
                video_length = self._get_video_length(video_file)
                print(f"Video {video_file.stem} length: {video_length}")

                # Add debug prints for intro/outro times
                print(f"Intro: {ann.get('intro_start', 'None')} to {ann.get('intro_end', 'None')}")
                print(f"Outro: {ann.get('outro_start', 'None')} to {ann.get('outro_end', 'None')}")
            
                # Add intro clips
                if 'intro_start' in ann and 'intro_end' in ann:
                    intro_start = ann['intro_start']
                    intro_end = ann['intro_end']
                    # Sample clips that contain intro
                    possible_starts = np.arange(
                        max(0, intro_start - self.clip_length + 30),
                        min(intro_end, video_length - self.clip_length),
                        30  # Sample every 30 seconds
                    )
                    for start in possible_starts:
                        intro_clips.append((video_file, start))
                
                # Add outro clips
                if 'outro_start' in ann and 'outro_end' in ann:
                    outro_start = ann['outro_start']
                    outro_end = ann['outro_end']
                    possible_starts = np.arange(
                        max(0, outro_start - self.clip_length + 30),
                        min(outro_end, video_length - self.clip_length),
                        30
                    )
                    for start in possible_starts:
                        outro_clips.append((video_file, start))
                
                # Add regular clips
                regular_regions = self._get_regular_regions(ann, video_length)
                for start, end in regular_regions:
                    possible_starts = np.arange(
                        start,
                        end - self.clip_length,
                        self.clip_length
                    )
                    for start in possible_starts:
                        regular_clips.append((video_file, start))
            
                # Balance dataset by upsampling intro and outro clips
                n_regular = len(regular_clips)
                intro_clips = intro_clips * (n_regular // len(intro_clips) + 1)
                outro_clips = outro_clips * (n_regular // len(outro_clips) + 1)
                
                # Combine and shuffle
                clips = intro_clips[:n_regular//3] + outro_clips[:n_regular//3] + regular_clips[:n_regular//3]
                random.shuffle(clips)


            except Exception as e:
                print(f"Error processing {video_file.stem}: {str(e)}")
                continue

        
         # Print clip counts
        print(f"Found clips - Intro: {len(intro_clips)}, Regular: {len(regular_clips)}, Outro: {len(outro_clips)}")

        return clips
    
    def _get_video_length(self, video_path):
        """Get video length in seconds with debug info."""
        cap = cv2.VideoCapture(str(video_path))
        if not cap.isOpened():
            print(f"Failed to open video: {video_path}")
            return 0
            
        fps = cap.get(cv2.CAP_PROP_FPS)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        cap.release()
        
        if fps == 0:
            print(f"Warning: Zero FPS for video {video_path}")
            return 0
            
        length = frame_count / fps
        print(f"Video {video_path}: {frame_count} frames at {fps} FPS = {length} seconds")
        return length
    
    def _get_regular_regions(self, ann, video_length):
        """Get time regions that are neither intro nor outro."""
        regions = []
        intro_end = ann.get('intro_end', 0)
        outro_start = ann.get('outro_start', video_length)
        
        if intro_end < outro_start:
            regions.append((intro_end, outro_start))
            
        return regions
    
    def _get_frame_labels(self, video_file, start_time, n_frames):
        """Generate frame-level labels."""
        ann = self.annotations[video_file.stem]
        times = np.linspace(start_time, start_time + self.clip_length, n_frames)
        labels = np.zeros((n_frames, NUM_CLASSES))
        
        for i, t in enumerate(times):
            if ann['intro_start'] <= t <= ann['intro_end']:
                labels[i, 0] = 1  # intro
            elif ann.get('outro_start', float('inf')) <= t <= ann.get('outro_end', float('inf')):
                labels[i, 2] = 1  # outro
            else:
                labels[i, 1] = 1  # regular
                
        return labels
    
    def _sample_frames(self, video_path, start_time):
        """Sample frames from video at specified FPS."""
        frames = []
        cap = cv2.VideoCapture(str(video_path))
        
        # Set starting position
        cap.set(cv2.CAP_PROP_POS_MSEC, start_time * 1000)
        
        # Sample frames
        for _ in range(SEQUENCE_LENGTH):
            ret, frame = cap.read()
            if not ret:
                break
                
            # Convert BGR to RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = self.transform(frame)
            frames.append(frame)
            
            # Skip frames to maintain desired FPS
            for _ in range(int(cap.get(cv2.CAP_PROP_FPS) / FPS) - 1):
                cap.read()
                
        cap.release()
        
        # Pad sequence if needed
        while len(frames) < SEQUENCE_LENGTH:
            frames.append(torch.zeros_like(frames[0]))
            
        return torch.stack(frames)
    
    def __len__(self):
        return len(self.clips)
    
    def __getitem__(self, idx):
        video_file, start_time = self.clips[idx]
        frames = self._sample_frames(video_file, start_time)
        labels = self._get_frame_labels(video_file, start_time, SEQUENCE_LENGTH)
        return frames, torch.FloatTensor(labels)

In [10]:
def verify_labels(dataset):
    """Verify label distribution and transitions in the dataset."""
    transitions = 0
    total_frames = 0
    label_counts = [0, 0, 0]
    
    for i in range(min(100, len(dataset))):  # Check first 100 items
        _, labels = dataset[i]
        total_frames += labels.shape[0]
        
        # Count label distributions
        for j in range(3):
            label_counts[j] += (labels[:, j] == 1).sum().item()
            
        # Count transitions
        for t in range(1, labels.shape[0]):
            if not torch.equal(labels[t], labels[t-1]):
                transitions += 1
    
    print(f"\nLabel Distribution Analysis:")
    print(f"Total frames checked: {total_frames}")
    print(f"Label counts: Intro: {label_counts[0]}, Regular: {label_counts[1]}, Outro: {label_counts[2]}")
    print(f"Label percentages: Intro: {label_counts[0]/total_frames*100:.1f}%, "
          f"Regular: {label_counts[1]/total_frames*100:.1f}%, "
          f"Outro: {label_counts[2]/total_frames*100:.1f}%")
    print(f"Number of transitions: {transitions}")
    print(f"Transitions per sequence: {transitions/100:.2f}")

## model defenition

In [11]:

# class AnimeClassifier(nn.Module):
#     def __init__(self):
#         super().__init__()
        
#         # Feature extraction using MobileNetV2 (still efficient but full version)
#         mobilenet = models.mobilenet_v2(pretrained=True)
        
#         # Adjust batch norm momentum for better training
#         for m in mobilenet.modules():
#             if isinstance(m, nn.BatchNorm2d):
#                 m.momentum = 0.1
#                 m.eps = 1e-3
#         self.features = nn.Sequential(*list(mobilenet.children())[:-1])
        
#         # Feature dimension reduction
#         self.feature_reduction = nn.Sequential(
#             nn.AdaptiveAvgPool2d(1),
#             nn.Flatten(),
#             nn.Linear(1280, FEATURE_DIM),
#             nn.BatchNorm1d(FEATURE_DIM),
#             nn.ReLU(),
#             nn.Dropout(0.3)
#         )
        
#         # Bidirectional LSTM layers
#         self.lstm1 = nn.LSTM(FEATURE_DIM, HIDDEN_DIM, bidirectional=True, 
#                             batch_first=True, num_layers=2, dropout=0.3)
#         self.lstm2 = nn.LSTM(HIDDEN_DIM * 2, HIDDEN_DIM, bidirectional=True, 
#                             batch_first=True, num_layers=2, dropout=0.3)
        
#         # Residual connection
#         self.residual_projection = nn.Sequential(
#             nn.Linear(FEATURE_DIM, HIDDEN_DIM * 2),
#             nn.ReLU()
#         )
        
#         # Classification head (outputs logits for BCE with logits)
#         self.classifier = nn.Sequential(
#             nn.Linear(HIDDEN_DIM * 2, HIDDEN_DIM),
#             nn.ReLU(),
#             nn.Dropout(0.3),
#             nn.Linear(HIDDEN_DIM, NUM_CLASSES)
#         )
#         self._initialize_weights()

#     def _initialize_weights(self):
#         # Initialize LSTM weights
#         for lstm in [self.lstm1, self.lstm2]:
#             for name, param in lstm.named_parameters():
#                 if 'weight_ih' in name:
#                     nn.init.xavier_uniform_(param.data)
#                 elif 'weight_hh' in name:
#                     nn.init.orthogonal_(param.data)
#                 elif 'bias' in name:
#                     param.data.fill_(0)
        
#         # Initialize linear layers
#         for m in self.feature_reduction.modules():
#             if isinstance(m, nn.Linear):
#                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
#                 if m.bias is not None:
#                     nn.init.constant_(m.bias, 0)
        
#         for m in self.residual_projection.modules():
#             if isinstance(m, nn.Linear):
#                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
#                 if m.bias is not None:
#                     nn.init.constant_(m.bias, 0)
        
#         for m in self.classifier.modules():
#             if isinstance(m, nn.Linear):
#                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
#                 if m.bias is not None:
#                     nn.init.constant_(m.bias, 0)
        
#     def forward(self, x):
#         batch_size, seq_len, c, h, w = x.size()
        
#         # Process each frame through CNN
#         x = x.view(batch_size * seq_len, c, h, w)
#         features = self.features(x)
#         features = self.feature_reduction(features)
#         features = features.view(batch_size, seq_len, -1)
        
#         # Save for residual connection
#         residual = self.residual_projection(features)
        
#         # Process through LSTM layers
#         lstm_out1, _ = self.lstm1(features)
#         lstm_out2, _ = self.lstm2(lstm_out1)
        
#         # Add residual connection
#         lstm_out = lstm_out2 + residual
        
#         # Classification
#         return self.classifier(lstm_out)  # Returns logits
class AnimeClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Use MobileNetV2 but freeze early layers
        mobilenet = models.mobilenet_v2(pretrained=True)
        # Freeze first 10 layers
        for i, param in enumerate(mobilenet.features[:5].parameters()):
            param.requires_grad = False
            
        self.features = nn.Sequential(*list(mobilenet.children())[:-1])
        
        # Wider feature reduction
        self.feature_reduction = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(1280, FEATURE_DIM),
            nn.LayerNorm(FEATURE_DIM),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(FEATURE_DIM, FEATURE_DIM),
            nn.LayerNorm(FEATURE_DIM),
            nn.GELU(),
            nn.Dropout(0.2)
        )
        
        # Single LSTM layer with higher capacity
        self.lstm = nn.LSTM(FEATURE_DIM, HIDDEN_DIM, 
                           bidirectional=True, 
                           batch_first=True,
                           num_layers=2,  # Single layer but higher dimension
                           dropout=0.2)
        
        # # Simplified classification head
        # self.classifier = nn.Sequential(
        #     nn.Linear(HIDDEN_DIM * 2, HIDDEN_DIM),
        #     nn.GELU(),
        #     nn.Dropout(0.2),
        #     nn.Linear(HIDDEN_DIM, NUM_CLASSES)
        # )
        # Wider classification head
        self.classifier = nn.Sequential(
            nn.Linear(HIDDEN_DIM * 2, HIDDEN_DIM * 2),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(HIDDEN_DIM * 2, HIDDEN_DIM),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(HIDDEN_DIM, NUM_CLASSES)
        )
        
    def forward(self, x):
        batch_size, seq_len, c, h, w = x.size()
        
        # Process each frame through CNN
        x = x.view(batch_size * seq_len, c, h, w)
        features = self.features(x)
        features = self.feature_reduction(features)
        features = features.view(batch_size, seq_len, -1)
        
        # Process through LSTM layer
        lstm_out, _ = self.lstm(features)
        
        # Classification
        return self.classifier(lstm_out)

## Loss function + training epoch + validation

In [12]:
def temporal_consistency_loss(predictions, alpha=0.5):
    """
    Calculate temporal consistency loss to penalize rapid changes in predictions.
    """
    temp_loss = torch.mean(torch.abs(predictions[:, 1:] - predictions[:, :-1]))
    return alpha * temp_loss

In [13]:

def train_epoch(model, train_loader, criterion, optimizer, scaler, device):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc='Training')
    
    for frames, labels in progress_bar:
        frames, labels = frames.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        with autocast():
            logits = model(frames)
            loss = criterion(torch.sigmoid(logits), labels)  # Apply sigmoid here
            temp_loss = temporal_consistency_loss(torch.sigmoid(logits))
            loss = loss + temp_loss
        
        # scaler.scale(loss).backward()
        scaler.scale(loss).backward()
        # Add gradient clipping
        if scaler.is_enabled():  # Check if we're using mixed precision
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
    
    return total_loss / len(train_loader)

In [14]:
def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for frames, labels in tqdm(val_loader, desc='Validation'):
            frames, labels = frames.to(device), labels.to(device)
            
            predictions = model(frames)
            bce_loss = criterion(predictions, labels)
            temp_loss = temporal_consistency_loss(predictions)
            loss = bce_loss + temp_loss
            
            total_loss += loss.item()
            all_preds.append(predictions.cpu())
            all_labels.append(labels.cpu())
    
    all_preds = torch.cat(all_preds, 0)
    all_labels = torch.cat(all_labels, 0)
    
    # Calculate metrics
    accuracy = ((all_preds.argmax(dim=-1) == all_labels.argmax(dim=-1)).float().mean()).item()
    
    return total_loss / len(val_loader), accuracy

## training function

In [15]:


def train_model(model, train_loader, val_loader, num_epochs=15, patience=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Use BCEWithLogitsLoss instead of BCELoss for better numeric stability
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=2e-3, weight_decay=0.01, betas=(0.9, 0.999))
    scaler = GradScaler()
    
   # Better learning rate scheduling with warmup
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=2e-2,  # Peak learning rate
        epochs=num_epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.2,  # 10% warmup
        div_factor=10,  # Initial lr = max_lr/10
        final_div_factor=100,  # Final lr = initial_lr/100
        anneal_strategy='cos'
    )
    
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_accuracy': [],
        'epoch_times': [],
        'learning_rates': []
    }
    
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    best_model_path = f'best_model_{current_time}.pth'
    
    
    try:
        for epoch in range(num_epochs):
            start_time = time.time()
            
            train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device)
            val_loss, val_accuracy = validate(model, val_loader, criterion, device)
            
            scheduler.step()
            current_lr = optimizer.param_groups[0]['lr']
            # Save periodic checkpoint (every 5 epochs)
            if (epoch + 1) % 5 == 0:
                checkpoint_path = f'checkpoint_epoch_{epoch+1}_{current_time}.pth'
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_loss': val_loss,
                    'history': history
                }, checkpoint_path)

            


            if val_loss < best_val_loss:
                best_val_loss = val_loss
                epochs_without_improvement = 0
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_loss': val_loss,
                }, best_model_path)
            else:
                epochs_without_improvement += 1


             # Record history
            epoch_time = time.time() - start_time
            history['train_loss'].append(train_loss)
            history['val_loss'].append(val_loss)
            history['val_accuracy'].append(val_accuracy)
            history['epoch_times'].append(epoch_time)
            history['learning_rates'].append(current_lr)
            
            # Print epoch results
            print(f'Epoch {epoch+1}/{num_epochs}:')
            print(f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, '
                  f'Val Accuracy: {val_accuracy:.4f}, LR: {current_lr:.6f}')
            print(f'Epoch time: {epoch_time:.2f}s')
            
                
            if epochs_without_improvement >= patience:
                print(f'\nEarly stopping after {epoch + 1} epochs')
                break
                
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
        
    finally:
        # Always try to save final state
        try:
            final_path = f'final_model_{current_time}.pth'
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'history': history
            }, final_path)
            print(f"\nFinal model saved to {final_path}")
        except Exception as e:
            print(f"Error saving final model: {str(e)}")
    
    return history

## visualisation

In [16]:
def plot_training_history(history):
    """
    Plot training metrics over time.
    Shows loss curves and validation accuracy.
    """
    plt.figure(figsize=(15, 5))
    
    # Plot losses
    plt.subplot(1, 3, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Loss over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot accuracy
    plt.subplot(1, 3, 2)
    plt.plot(history['val_accuracy'], label='Validation Accuracy')
    plt.title('Validation Accuracy over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    # Plot epoch times
    plt.subplot(1, 3, 3)
    plt.plot(history['epoch_times'], label='Epoch Time')
    plt.title('Training Time per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Time (seconds)')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

## inference

In [17]:
def predict_video(model, video_path, json_path=None, window_size=60, overlap=30):
    """
    Predict intro/outro segments for an entire video using overlapping windows.
    Returns frame-level predictions and optionally compares with ground truth.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = 'cpu'
    model.eval()
    
    # Video properties
    cap = cv2.VideoCapture(str(video_path))
    video_fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    duration = total_frames / video_fps
    cap.release()
    
    # Calculate number of windows with overlap
    stride = window_size - overlap
    n_windows = int(np.ceil((duration - window_size) / stride)) + 1
    
    # Initialize arrays for predictions
    all_predictions = np.zeros((int(duration * FPS), NUM_CLASSES))
    counts = np.zeros(int(duration * FPS))
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((FRAME_SIZE, FRAME_SIZE)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Process each window
    with torch.no_grad():
        for i in tqdm(range(n_windows), desc='Processing video'):
            start_time = i * stride
            
            # Sample frames for this window
            cap = cv2.VideoCapture(str(video_path))
            cap.set(cv2.CAP_PROP_POS_MSEC, start_time * 1000)
            
            frames = []
            for _ in range(SEQUENCE_LENGTH):
                ret, frame = cap.read()
                if not ret:
                    break
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = transform(frame)
                frames.append(frame)
            cap.release()
            
            # If we don't have enough frames, pad with zeros
            while len(frames) < SEQUENCE_LENGTH:
                frames.append(torch.zeros_like(frames[0]))
            
            # Make prediction
            frames_tensor = torch.stack(frames).unsqueeze(0).to(device)
            predictions = model(frames_tensor).cpu().numpy()[0]
            
            # Add predictions to the appropriate time slots
            start_idx = int(start_time * FPS)
            end_idx = start_idx + predictions.shape[0]
            all_predictions[start_idx:end_idx] += predictions
            counts[start_idx:end_idx] += 1
    
    # Average overlapping predictions
    mask = counts > 0
    all_predictions[mask] /= counts[mask, np.newaxis]
    
    # Load ground truth if available
    if json_path:
        with open(json_path, 'r') as f:
            gt = json.load(f)
            
        gt_labels = np.zeros((int(duration * FPS), NUM_CLASSES))
        times = np.arange(0, duration, 1/FPS)
        
        for i, t in enumerate(times):
            if gt['intro_start'] <= t <= gt['intro_end']:
                gt_labels[i, 0] = 1
            elif gt.get('outro_start', float('inf')) <= t <= gt.get('outro_end', float('inf')):
                gt_labels[i, 2] = 1
            else:
                gt_labels[i, 1] = 1
    else:
        gt_labels = None
    
    return all_predictions, gt_labels, times

## Visualizes model predictions against ground truth

In [18]:
def plot_predictions(predictions, ground_truth, times):
    """
    Visualize predictions against ground truth if available.
    """
    plt.figure(figsize=(15, 5))
    
    # Plot predictions
    plt.plot(times, predictions[:, 0], 'b-', label='Intro Pred', alpha=0.7)
    plt.plot(times, predictions[:, 2], 'r-', label='Outro Pred', alpha=0.7)
    
    if ground_truth is not None:
        # Plot ground truth
        plt.plot(times, ground_truth[:, 0], 'b--', label='Intro GT')
        plt.plot(times, ground_truth[:, 2], 'r--', label='Outro GT')
    
    plt.xlabel('Time (seconds)')
    plt.ylabel('Probability')
    plt.title('Intro/Outro Predictions')
    plt.legend()
    plt.grid(True)
    plt.show()

# MAIN

In [19]:


# Create datasets
train_dataset = AnimeDataset('/teamspace/studios/this_studio/100anime', '/teamspace/studios/this_studio/100 anime', mode='train')
val_dataset = AnimeDataset('/teamspace/studios/this_studio/100anime', '/teamspace/studios/this_studio/100 anime', mode='val')

    # Test dataset first
try:
    test_dataset = AnimeDataset('D:/CS/ML/100anime', 'D:/CS/ML/animev2/data/100 anime', mode='train')

    print(f"Dataset size: {len(test_dataset)}")
    # Test single item loading
    frames, labels = test_dataset[0]
    print(f"Sample shapes - Frames: {frames.shape}, Labels: {labels.shape}")
    
    verify_labels(test_dataset)
except Exception as e:
    print(f"Dataset error: {str(e)}")

    # Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                            num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=4, pin_memory=True)

 # Add this debug code before training
for batch_idx, (frames, labels) in enumerate(train_loader):
    if batch_idx == 0:  # Check first batch
        print(f"Frames shape: {frames.shape}")
        print(f"Labels shape: {labels.shape}")
        print(f"Labels min/max: {labels.min()}, {labels.max()}")
        print(f"Labels distribution:\n{labels.sum(dim=0)}")  # Sum for each class
        break                         
    
    # Create and train model
model = AnimeClassifier()
# history = train_model(model, train_loader, val_loader)
    
    # Plot training history
plot_training_history(history)
    
    # Example inference on a single video
# video_path = 'path/to/test_video.mp4'
# json_path = 'path/to/test_video.json'
# predictions, ground_truth, times = predict_video(model, video_path, json_path)
# plot_predictions(predictions, ground_truth, times)

Initializing dataset with video_dir: /teamspace/studios/this_studio/100anime, json_dir: /teamspace/studios/this_studio/100 anime
Found 103 video files
Total videos: 103, Train idx: 72, Val idx: 87
Loading annotation: /teamspace/studios/this_studio/100 anime/1.json
Loading annotation: /teamspace/studios/this_studio/100 anime/10.json
Loading annotation: /teamspace/studios/this_studio/100 anime/100.json
Loading annotation: /teamspace/studios/this_studio/100 anime/101.json
Loading annotation: /teamspace/studios/this_studio/100 anime/102.json
Loading annotation: /teamspace/studios/this_studio/100 anime/103.json
Loading annotation: /teamspace/studios/this_studio/100 anime/11.json
Loading annotation: /teamspace/studios/this_studio/100 anime/12.json
Loading annotation: /teamspace/studios/this_studio/100 anime/13.json
Loading annotation: /teamspace/studios/this_studio/100 anime/14.json
Loading annotation: /teamspace/studios/this_studio/100 anime/15.json
Loading annotation: /teamspace/studios/th



NameError: name 'history' is not defined