In [None]:
import torch 
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import json
import cv2
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from typing import Tuple, List, Dict
import random
from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision.models as models


from torch.cuda.amp import autocast, GradScaler
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
import wandb
from datetime import datetime
import os

In [None]:
# Training Configuration
CONFIG = {
    # Data Parameters
    'FRAME_SIZE': 224,          # Decrease if GPU memory is issue, increase if underfitting
    'FPS': 8,                   # Frames per second to sample
    'SEQUENCE_LENGTH': 240,     # 30 seconds * 8 fps
    'NUM_CLASSES': 3,           # intro, regular, outro
    
    # Model Architecture
    'HIDDEN_DIM': 256,          # GRU hidden dimension - increase if underfitting
    'NUM_LAYERS': 2,            # GRU layers - increase for more complexity
    'DROPOUT': 0.3,             # Increase if overfitting, decrease if underfitting
    
    # Training Parameters
    'BATCH_SIZE': 2,            # Decrease if GPU OOM, increase if GPU underutilized
    'ACCUMULATION_STEPS': 4,    # Effective batch size = BATCH_SIZE * ACCUMULATION_STEPS
    'NUM_EPOCHS': 30,           # Maximum epochs to train
    'PATIENCE': 5,              # Early stopping patience
    'LEARNING_RATE': 1e-4,      # Decrease if training unstable
    'WEIGHT_DECAY': 1e-2,       # L2 regularization - increase if overfitting
    'NUM_WORKERS': 4,           # Data loading workers - set to CPU count
    
    # Scheduler Parameters
    'WARMUP_EPOCHS': 2,         # Linear warmup epochs
    'MIN_LR': 1e-6,            # Minimum learning rate
    
    # Validation
    'VAL_FREQ': 1,             # Validate every N epochs
    'SAVE_FREQ': 5,            # Save checkpoint every N epochs
}

In [None]:
class AnimeFrameDataset(Dataset):
    def __init__(self, video_dir: str, json_dir: str, 
                 clip_length: int = 30, fps: int = 8, 
                 split: str = 'train', seed: int = 42):
        """Same docstring as before"""
        self.video_dir = Path(video_dir)
        self.json_dir = Path(json_dir)
        self.clip_length = clip_length
        self.fps = fps
        self.frames_per_clip = clip_length * fps
        
        random.seed(seed)
        
        # Get sorted list of all video files
        self.video_files = sorted(list(self.video_dir.glob('*.mp4')))
        
        # Split data
        random.shuffle(self.video_files)
        n_videos = len(self.video_files)
        train_idx = int(0.8 * n_videos)
        val_idx = int(0.9 * n_videos)
        
        if split == 'train':
            self.video_files = self.video_files[:train_idx]
        elif split == 'val':
            self.video_files = self.video_files[train_idx:val_idx]
        else:  # test
            self.video_files = self.video_files[val_idx:]
            
        # Load all annotations
        self.annotations = {}
        for video_file in self.video_files:
            json_file = self.json_dir / f"{video_file.stem}.json"
            if not json_file.exists():
                print(f"Warning: Missing JSON for {video_file.stem}")
                continue
            with open(json_file, 'r') as f:
                self.annotations[video_file.stem] = json.load(f)
        
        # Setup transforms first so _create_balanced_clips can use it
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])
        
        # Create balanced clips with debug info
        print(f"Creating clips for {len(self.video_files)} videos...")
        self.clips = self._create_balanced_clips()
        
    def _create_balanced_clips(self) -> List[Tuple[Path, float]]:
        """Creates balanced list of (video_path, start_time) clips"""
        intro_clips = []
        outro_clips = []
        regular_clips = []
        transition_clips = []
        
        for video_file in tqdm(self.video_files, desc="Creating clips"):
            if video_file.stem not in self.annotations:
                continue
                
            ann = self.annotations[video_file.stem]
            try:
                duration = self._get_video_duration(video_file)
            except:
                print(f"Warning: Could not read duration for {video_file}")
                continue

            # Debug print
            print(f"\nProcessing {video_file.stem}:")
            print(f"Duration: {duration:.2f}s")
            print(f"Annotations: {ann}")
                
            # Sample intro clips
            if 'intro_start' in ann and 'intro_end' in ann:
                start = ann['intro_start']
                end = ann['intro_end']
                if start > self.clip_length/2:
                    transition_clips.append((video_file, max(0, start - self.clip_length/2)))
                
                # Sample from middle of intro with error handling
                try:
                    intro_times = np.arange(start + 5, end - self.clip_length, self.clip_length/2)
                    intro_clips.extend((video_file, t) for t in intro_times)
                except ValueError as e:
                    print(f"Warning: Error sampling intro clips: {e}")
                
                if end + self.clip_length/2 < duration:
                    transition_clips.append((video_file, end - self.clip_length/2))
            
            # Sample outro clips
            if 'outro_start' in ann and 'outro_end' in ann:
                start = ann['outro_start']
                end = min(ann['outro_end'], duration)  # Ensure we don't exceed duration
                if start > self.clip_length/2:
                    transition_clips.append((video_file, max(0, start - self.clip_length/2)))
                
                try:
                    outro_times = np.arange(start + 5, end - self.clip_length, self.clip_length/2)
                    outro_clips.extend((video_file, t) for t in outro_times)
                except ValueError as e:
                    print(f"Warning: Error sampling outro clips: {e}")
            
            # Sample regular clips
            regular_start = ann.get('intro_end', 0)
            regular_end = ann.get('outro_start', duration)
            if regular_end > regular_start:
                try:
                    regular_times = np.arange(regular_start + 5, regular_end - self.clip_length, 
                                           self.clip_length * 2)
                    regular_clips.extend((video_file, t) for t in regular_times)
                except ValueError as e:
                    print(f"Warning: Error sampling regular clips: {e}")

        # Print statistics before balancing
        print("\nBefore balancing:")
        print(f"Intro clips: {len(intro_clips)}")
        print(f"Outro clips: {len(outro_clips)}")
        print(f"Regular clips: {len(regular_clips)}")
        print(f"Transition clips: {len(transition_clips)}")
        
        # Balance dataset with minimum available clips
        min_clips = min(
            max(len(intro_clips) // 2, 1),
            max(len(outro_clips) // 2, 1),
            max(len(transition_clips), 1)
        )
        
        balanced_clips = (
            random.sample(intro_clips, min(len(intro_clips), min_clips * 2)) +
            random.sample(outro_clips, min(len(outro_clips), min_clips * 2)) +
            random.sample(transition_clips, min(len(transition_clips), min_clips)) +
            random.sample(regular_clips, min(len(regular_clips), min_clips * 4))
        )
        
        # Print final statistics
        print("\nAfter balancing:")
        print(f"Total clips: {len(balanced_clips)}")
        
        random.shuffle(balanced_clips)
        return balanced_clips
    
    def _get_video_duration(self, video_path: Path) -> float:
        """Get video duration in seconds"""
        cap = cv2.VideoCapture(str(video_path))
        fps = cap.get(cv2.CAP_PROP_FPS)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        cap.release()
        return frame_count / fps
    
    def _load_frames(self, video_path: Path, start_time: float) -> torch.Tensor:
        """Load sequence of frames from video"""
        frames = []
        cap = cv2.VideoCapture(str(video_path))
        
        # Set starting position
        cap.set(cv2.CAP_PROP_POS_MSEC, start_time * 1000)
        
        for _ in range(self.frames_per_clip):
            ret, frame = cap.read()
            if not ret:
                break
                
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            frame = self.transform(frame)
            frames.append(frame)
            
            # Skip frames to maintain desired FPS
            for _ in range(int(cap.get(cv2.CAP_PROP_FPS) / self.fps) - 1):
                cap.read()
                
        cap.release()
        
        # Pad sequence if needed
        while len(frames) < self.frames_per_clip:
            frames.append(torch.zeros_like(frames[0]))
            
        return torch.stack(frames)
    
    def _get_labels(self, video_file: Path, start_time: float) -> torch.Tensor:
        """Generate frame-level labels"""
        ann = self.annotations[video_file.stem]
        times = np.linspace(start_time, 
                          start_time + self.clip_length, 
                          self.frames_per_clip)
        labels = torch.zeros((self.frames_per_clip, 3))
        
        for i, t in enumerate(times):
            if ann.get('intro_start', float('inf')) <= t <= ann.get('intro_end', -float('inf')):
                labels[i, 0] = 1  # intro
            elif ann.get('outro_start', float('inf')) <= t <= ann.get('outro_end', -float('inf')):
                labels[i, 2] = 1  # outro
            else:
                labels[i, 1] = 1  # regular content
                
        return labels
    
    def __len__(self) -> int:
        return len(self.clips)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        video_file, start_time = self.clips[idx]
        frames = self._load_frames(video_file, start_time)
        labels = self._get_labels(video_file, start_time)
        return frames, labels




In [None]:
def visualize_clip(dataset: AnimeFrameDataset, idx: int = 1021) -> None:
    """Visualize frames from a random or specified clip with labels"""
    if idx is None:
        idx = random.randint(0, len(dataset) - 1)
    
    frames, labels = dataset[idx]
    video_file, start_time = dataset.clips[idx]
    
    # Select 3 evenly spaced frames
    n_frames = frames.shape[0]
    frame_indices = [n_frames//4, n_frames//2, 3*n_frames//4]
    
    plt.figure(figsize=(15, 5))
    for i, frame_idx in enumerate(frame_indices):
        plt.subplot(1, 3, i+1)
        
        # Convert tensor to image
        frame = frames[frame_idx].permute(1, 2, 0)
        frame = frame * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])
        frame = frame.clip(0, 1).numpy()
        
        plt.imshow(frame)
        frame_time = start_time + frame_idx/dataset.fps
        label_idx = labels[frame_idx].argmax().item()
        label_name = ['Intro', 'Regular', 'Outro'][label_idx]
        
        plt.title(f'Time: {frame_time:.2f}s\nLabel: {label_name}')
        plt.axis('off')
    
    plt.suptitle(f'Video: {video_file.stem}')
    plt.tight_layout()
    plt.show()

In [None]:
def analyze_dataset(dataset: AnimeFrameDataset) -> Dict:
    """Analyze dataset statistics"""
    label_counts = torch.zeros(3)
    clip_types = {'intro': 0, 'outro': 0, 'regular': 0, 'transition': 0}
    
    for frames, labels in tqdm(dataset, desc="Analyzing dataset"):
        label_indices = labels.argmax(dim=1)
        unique_labels = torch.unique(label_indices)
        
        # Count frame labels
        for i in range(3):
            label_counts[i] += (label_indices == i).sum().item()
            
        # Categorize clip type
        if len(unique_labels) > 1:
            clip_types['transition'] += 1
        elif 0 in unique_labels:
            clip_types['intro'] += 1
        elif 2 in unique_labels:
            clip_types['outro'] += 1
        else:
            clip_types['regular'] += 1
    
    return {
        'total_clips': len(dataset),
        'total_frames': label_counts.sum().item(),
        'label_distribution': {
            'intro': label_counts[0].item(),
            'regular': label_counts[1].item(),
            'outro': label_counts[2].item()
        },
        'clip_types': clip_types
    }

In [None]:
def test_dataset(video_dir, json_dir):
    """Test dataset creation with specified paths"""
    print(f"Testing dataset creation...")
    print(f"Video directory: {video_dir}")
    print(f"JSON directory: {json_dir}")
    
    try:
        dataset = AnimeFrameDataset(
            video_dir=video_dir,
            json_dir=json_dir,
            split='train'
        )
        
        print(f"\nDataset created successfully!")
        print(f"Total clips: {len(dataset)}")
        
        if len(dataset) > 0:
            frames, labels = dataset[0]
            print(f"Sample shapes:")
            print(f"- Frames: {frames.shape}")
            print(f"- Labels: {labels.shape}")
            
            # Test visualization
            visualize_clip(dataset)
            
        return dataset
        
    except Exception as e:
        print(f"Error creating dataset: {str(e)}")
        raise

In [None]:
dataset = test_dataset(
    video_dir='/teamspace/studios/this_studio/100anime',
    json_dir='/teamspace/studios/this_studio/100 anime'
)

In [None]:
class AnimeSceneClassifier(nn.Module):
    def __init__(self, hidden_dim=256, num_layers=2, dropout=0.3):
        super().__init__()
        
        # MobileNetV2 feature extractor
        mobilenet = models.mobilenet_v2(pretrained=True)
        self.features = nn.Sequential(*list(mobilenet.children())[:-1])
        
        # Feature dimension reduction (from 1280 to hidden_dim)
        self.reduce_dim = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(1280, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Bidirectional GRU
        self.gru = nn.GRU(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 3)  # 3 classes: intro, regular, outro
        )

    def forward(self, x):
        batch_size, seq_len, c, h, w = x.size()
        
        # Extract features
        x = x.view(batch_size * seq_len, c, h, w)
        x = self.features(x)
        x = self.reduce_dim(x)
        
        # Reshape for GRU
        x = x.view(batch_size, seq_len, -1)
        
        # Apply GRU
        x, _ = self.gru(x)
        
        # Classify each timestep
        return self.classifier(x)

In [None]:
def test_model():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = AnimeSceneClassifier().to(device)
    
    # Test with random input
    x = torch.randn(2, 240, 3, 224, 224).to(device)  # batch_size=2, seq_len=240
    y = model(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {y.shape}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")
    
    return model

In [None]:
# model = test_model()

In [None]:
def train_model(model, train_loader, val_loader, config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    optimizer = AdamW(model.parameters(), lr=config['LEARNING_RATE'], weight_decay=config['WEIGHT_DECAY'])
    criterion = nn.BCEWithLogitsLoss()
    scaler = GradScaler()
    
    # OneCycle scheduler
    scheduler = OneCycleLR(
        optimizer,
        max_lr=config['LEARNING_RATE'],
        epochs=config['NUM_EPOCHS'],
        steps_per_epoch=len(train_loader),
        pct_start=config['WARMUP_EPOCHS'] / config['NUM_EPOCHS'],
        final_div_factor=config['LEARNING_RATE'] / config['MIN_LR']
    )
    
    # Initialize wandb
    run = wandb.init(project="anime-scene-classifier", config=config)
    
    # Training state
    best_val_loss = float('inf')
    patience_counter = 0
    
    # Save directory
    save_dir = f"checkpoints_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    os.makedirs(save_dir, exist_ok=True)
    
    try:
        for epoch in range(config['NUM_EPOCHS']):
            # Training
            model.train()
            train_loss = 0
            optimizer.zero_grad()
            
            for batch_idx, (frames, labels) in enumerate(tqdm(train_loader)):
                frames, labels = frames.to(device), labels.to(device)
                
                # Forward pass with mixed precision
                with autocast():
                    outputs = model(frames)
                    loss = criterion(outputs, labels) / config['ACCUMULATION_STEPS']
                
                # Backward pass
                scaler.scale(loss).backward()
                
                if (batch_idx + 1) % config['ACCUMULATION_STEPS'] == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    scheduler.step()
                
                train_loss += loss.item() * config['ACCUMULATION_STEPS']
                
                # Log batch metrics
                if batch_idx % 10 == 0:
                    wandb.log({
                        'batch_loss': loss.item() * config['ACCUMULATION_STEPS'],
                        'learning_rate': scheduler.get_last_lr()[0]
                    })
            
            train_loss /= len(train_loader)
            
            # Validation
            if epoch % config['VAL_FREQ'] == 0:
                val_loss, val_acc = validate(model, val_loader, criterion, device)
                
                wandb.log({
                    'epoch': epoch,
                    'train_loss': train_loss,
                    'val_loss': val_loss,
                    'val_accuracy': val_acc
                })
                
                print(f'Epoch {epoch}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}')
                
                # Early stopping
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience_counter = 0
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'val_loss': val_loss,
                    }, f'{save_dir}/best_model.pth')
                else:
                    patience_counter += 1
                
                if patience_counter >= config['PATIENCE']:
                    print(f'Early stopping after {epoch} epochs')
                    break
            
            # Regular checkpoints
            if epoch % config['SAVE_FREQ'] == 0:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                }, f'{save_dir}/checkpoint_epoch_{epoch}.pth')
    
    except KeyboardInterrupt:
        print('Training interrupted')
    
    finally:
        run.finish()
        torch.save(model.state_dict(), f'{save_dir}/final_model.pth')
    
    return model

In [None]:
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for frames, labels in val_loader:
            frames, labels = frames.to(device), labels.to(device)
            outputs = model(frames)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            pred = torch.argmax(outputs, dim=2)
            true = torch.argmax(labels, dim=2)
            correct += (pred == true).sum().item()
            total += true.numel()
    
    return val_loss / len(val_loader), correct / total

In [None]:
model = AnimeSceneClassifier()
    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['BATCH_SIZE'],
        shuffle=True,
        num_workers=CONFIG['NUM_WORKERS'],
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=CONFIG['BATCH_SIZE'],
        shuffle=False,
        num_workers=CONFIG['NUM_WORKERS'],
        pin_memory=True
    )
    
trained_model = train_model(model, train_loader, val_loader, CONFIG)