In [1]:
import os
import sys
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset
import av
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import VivitForVideoClassification, VivitImageProcessor, get_scheduler
from torch.optim import AdamW
from tqdm import tqdm
import torchvision.transforms as T
import matplotlib.pyplot as plt
from torch.cuda.amp import GradScaler
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch_lr_finder import LRFinder, TrainDataLoaderIter

# Create a custom iterator that handles dictionary batches
class CustomTrainDataLoaderIter(TrainDataLoaderIter):
    def inputs_labels_from_batch(self, batch_data):
        if batch_data is None:
            # If batch is None, return empty tensors
            return torch.tensor([]), torch.tensor([])
        # Extract pixel_values and labels from the dictionary batch
        inputs = batch_data["pixel_values"]
        labels = batch_data["labels"]
        return inputs, labels

# Load and split the data
df = pd.read_csv('E:/SRC-Bhuvaneswari/processed files/video/ftest/test_data.csv')
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])

# Save split datasets
train_df.to_csv('train_data.csv', index=False)
test_df.to_csv('test_data.csv', index=False)
TRAIN_CSV_PATH = 'train_data.csv'
TEST_CSV_PATH = 'test_data.csv'
SAVE_DIR = 'F:/SRC_Bhuvaneswari/typpo/Crimenet/VisTra/Checkpoints/v2.0'
os.makedirs(SAVE_DIR, exist_ok=True)

# Define label mapping
LABEL_MAP = {'Normal': 0, 'Explosion': 1, 'Fighting': 2, 'Car Accident': 3, 'Shooting': 4, 'Riot': 5}
INV_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
NUM_CLASSES = len(LABEL_MAP)

# Define maximum clips per label - adjusted for full frame usage
MAX_TRAIN_SAMPLES = 110
MAX_TEST_SAMPLES = 30

# Define maximum frames per label for 80-20 ratio
# For a 80-20 ratio, if we assume each label should have equal representation
MAX_FRAME = 40000  # 80% of frames per label
MAX_TEST_FRAME = 10000  # 20% of frames per label

# Improved hyperparameters
LEARNING_RATE = 8.538808556803076e-05  # Will be updated by learning rate finder
TRAIN_BATCH_SIZE = 3
EVAL_BATCH_SIZE = 3
SEED = 24
GRADIENT_ACCUMULATION_STEPS = 8
NUM_EPOCHS = 20
CLIP_LEN = 32
FRAME_SAMPLE_RATE = 1
TARGET_FRAMES = 128

# Resume training from a checkpoint
# Set to None to start training a new model from scratch
# Set to path of saved model to continue training from that point
# Can be any model checkpoint - doesn't need to follow a specific naming format
RESUME_FROM_CHECKPOINT = 'F:/SRC_Bhuvaneswari/typpo/Crimenet/VisTra/Checkpoints/v2.0/best_model_acc.pt'

# Early stopping parameters
EARLY_STOPPING_PATIENCE = 3
early_stopping_counter = 0

# Set random seed for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [2]:
def load_data_from_csv(csv_path, video_path_column='rgb_video_path'):
    """Load and clean dataframe from CSV."""
    df = pd.read_csv(csv_path)
    df = df.dropna(subset=[video_path_column, 'label'])
    return df

df_train = load_data_from_csv(TRAIN_CSV_PATH)
df_test = load_data_from_csv(TEST_CSV_PATH)

def balance_dataset(df, max_samples):
    """Balance dataset to have equal number of samples per class."""
    return df.groupby('label', group_keys=False).apply(
        lambda x: x.sample(min(len(x), max_samples), random_state=SEED)
    )

df_train = balance_dataset(df_train, MAX_TRAIN_SAMPLES)
df_test = balance_dataset(df_test, MAX_TEST_SAMPLES)

# Function to balance test and train sets based on frame count
def balance_test_and_train_sets(df_train, df_test):
    """Balance test set to MAX_TEST_FRAME per label and move excess to training."""
    # Group test videos by label
    test_videos_by_label = {label: [] for label in LABEL_MAP.values()}
    test_frames_by_label = {label: 0 for label in LABEL_MAP.values()}
    
    # Count frames per label in test set
    for _, row in df_test.iterrows():
        video_path = row['rgb_video_path']
        label = int(row['label'])
        
        try:
            container = av.open(video_path)
            frame_count = container.streams.video[0].frames
            container.close()
            
            test_videos_by_label[label].append((video_path, frame_count))
            test_frames_by_label[label] += frame_count
        except Exception as e:
            print(f"Error processing test video {video_path}: {e}")
    
    # Move excess videos from test to train
    videos_to_move = []
    
    for label, videos in test_videos_by_label.items():
        if test_frames_by_label[label] > MAX_TEST_FRAME:
            # Sort videos by frame count (optional - depends on your strategy)
            videos.sort(key=lambda x: x[1])
            
            current_frames = 0
            keep_idx = 0
            
            # Find how many videos to keep in test set
            for i, (_, frame_count) in enumerate(videos):
                if current_frames + frame_count <= MAX_TEST_FRAME:
                    current_frames += frame_count
                    keep_idx = i + 1
                else:
                    break
            
            # Identify videos to move to training
            videos_to_move.extend([path for path, _ in videos[keep_idx:]])
            
            print(f"Label {INV_LABEL_MAP[label]}: Moving {len(videos) - keep_idx} videos " 
                  f"({test_frames_by_label[label] - current_frames} frames) from test to train")
            
            # Update frame count
            test_frames_by_label[label] = current_frames
    
    # Create masks for videos to keep in test and move to train
    keep_mask = ~df_test['rgb_video_path'].isin(videos_to_move)
    
    # Extract videos to move
    df_to_move = df_test[~keep_mask].copy()
    
    # Update test dataframe
    df_test_balanced = df_test[keep_mask].copy()
    
    # Add moved videos to training
    df_train_updated = pd.concat([df_train, df_to_move], ignore_index=True)
    
    print("\nUpdated frame distribution:")
    for label, frames in test_frames_by_label.items():
        print(f"Label {INV_LABEL_MAP[label]}: {frames} frames in test set")
    
    return df_train_updated, df_test_balanced

# Apply the balancing function
df_train, df_test = balance_test_and_train_sets(df_train, df_test)

# Print dataset statistics
print("\nBalanced number of videos per label in the training set:")
train_label_counts = df_train['label'].value_counts().sort_index().to_dict()
for label_id, count in train_label_counts.items():
    print(f"{INV_LABEL_MAP[label_id]} ({label_id}): {count} videos")

print("\nBalanced number of videos per label in the validation set:")
val_label_counts = df_test['label'].value_counts().sort_index().to_dict()
for label_id, count in val_label_counts.items():
    print(f"{INV_LABEL_MAP[label_id]} ({label_id}): {count} videos")

  return df.groupby('label', group_keys=False).apply(
  return df.groupby('label', group_keys=False).apply(


Label Fighting: Moving 4 videos (5760 frames) from test to train
Label Riot: Moving 13 videos (18300 frames) from test to train

Updated frame distribution:
Label Normal: 9384 frames in test set
Label Explosion: 8604 frames in test set
Label Fighting: 9841 frames in test set
Label Car Accident: 3861 frames in test set
Label Shooting: 4422 frames in test set
Label Riot: 8920 frames in test set

Balanced number of videos per label in the training set:
Normal (0): 110 videos
Explosion (1): 106 videos
Fighting (2): 114 videos
Car Accident (3): 110 videos
Shooting (4): 110 videos
Riot (5): 123 videos

Balanced number of videos per label in the validation set:
Normal (0): 30 videos
Explosion (1): 27 videos
Fighting (2): 26 videos
Car Accident (3): 30 videos
Shooting (4): 30 videos
Riot (5): 17 videos


In [3]:
# Analyze frame distribution
def analyze_frame_distribution():
    """Analyze frame distribution across labels to understand dataset imbalance."""
    label_frame_counts = {}
    label_video_counts = {}
    problematic_videos = []
    
    for _, row in df_train.iterrows():
        video_path = row['rgb_video_path']
        label = row['label']
        
        try:
            container = av.open(video_path)
            total_frames = container.streams.video[0].frames
            container.close()
            
            if label not in label_frame_counts:
                label_frame_counts[label] = 0
                label_video_counts[label] = 0
                
            label_frame_counts[label] += total_frames
            label_video_counts[label] += 1
            
        except Exception as e:
            problematic_videos.append((video_path, str(e)))
            print(f"Error processing video {video_path}: {e}")
    
    # Print results
    print("Total frames per label in training set:")
    for label, count in label_frame_counts.items():
        label_name = INV_LABEL_MAP[label] if label in INV_LABEL_MAP else label
        avg_frames = count / label_video_counts[label] if label_video_counts[label] > 0 else 0
        print(f"{label_name}: {count} frames across {label_video_counts[label]} videos (avg: {avg_frames:.2f} frames/video)")
    
    results_df = pd.DataFrame({
        'Label': [INV_LABEL_MAP[label] if label in INV_LABEL_MAP else label for label in label_frame_counts.keys()],
        'Total Frames': label_frame_counts.values(),
        'Video Count': [label_video_counts[label] for label in label_frame_counts.keys()],
        'Avg Frames/Video': [label_frame_counts[label]/label_video_counts[label] 
                             if label_video_counts[label] > 0 else 0 
                             for label in label_frame_counts.keys()]
    })
    
    print("\nDataFrame of training results:")
    print(results_df)
    
    # Now analyze test set
    label_frame_counts_test = {}
    label_video_counts_test = {}
    
    for _, row in df_test.iterrows():
        video_path = row['rgb_video_path']
        label = row['label']
        
        try:
            container = av.open(video_path)
            total_frames = container.streams.video[0].frames
            container.close()
            
            if label not in label_frame_counts_test:
                label_frame_counts_test[label] = 0
                label_video_counts_test[label] = 0
                
            label_frame_counts_test[label] += total_frames
            label_video_counts_test[label] += 1
            
        except Exception as e:
            print(f"Error processing test video {video_path}: {e}")
    
    print("\nTotal frames per label in test set:")
    for label, count in label_frame_counts_test.items():
        label_name = INV_LABEL_MAP[label] if label in INV_LABEL_MAP else label
        avg_frames = count / label_video_counts_test[label] if label_video_counts_test[label] > 0 else 0
        print(f"{label_name}: {count} frames across {label_video_counts_test[label]} videos (avg: {avg_frames:.2f} frames/video)")
    
    results_df_test = pd.DataFrame({
        'Label': [INV_LABEL_MAP[label] if label in INV_LABEL_MAP else label for label in label_frame_counts_test.keys()],
        'Total Frames': label_frame_counts_test.values(),
        'Video Count': [label_video_counts_test[label] for label in label_frame_counts_test.keys()],
        'Avg Frames/Video': [label_frame_counts_test[label]/label_video_counts_test[label] 
                             if label_video_counts_test[label] > 0 else 0 
                             for label in label_frame_counts_test.keys()]
    })
    
    print("\nDataFrame of test results:")
    print(results_df_test)
    
    # Calculate total frames and ratio
    total_train_frames = sum(label_frame_counts.values())
    total_test_frames = sum(label_frame_counts_test.values())
    total_frames = total_train_frames + total_test_frames
    
    print(f"\nTotal frames - Train: {total_train_frames}, Test: {total_test_frames}")
    print(f"Train-Test ratio: {total_train_frames/total_frames:.2f}:{total_test_frames/total_frames:.2f}")
    
    if problematic_videos:
        print(f"\nFound {len(problematic_videos)} problematic videos. Check logs for details.")
        with open("problematic_videos.log", "w") as f:
            for path, error in problematic_videos:
                f.write(f"{path}: {error}\n")
    
    return label_frame_counts, label_video_counts

# Run frame distribution analysis
label_frame_counts, label_video_counts = analyze_frame_distribution()

Total frames per label in training set:
Normal: 28988 frames across 110 videos (avg: 263.53 frames/video)
Explosion: 29586 frames across 106 videos (avg: 279.11 frames/video)
Fighting: 57256 frames across 114 videos (avg: 502.25 frames/video)
Car Accident: 15132 frames across 110 videos (avg: 137.56 frames/video)
Shooting: 22851 frames across 110 videos (avg: 207.74 frames/video)
Riot: 126027 frames across 123 videos (avg: 1024.61 frames/video)

DataFrame of training results:
          Label  Total Frames  Video Count  Avg Frames/Video
0        Normal         28988          110        263.527273
1     Explosion         29586          106        279.113208
2      Fighting         57256          114        502.245614
3  Car Accident         15132          110        137.563636
4      Shooting         22851          110        207.736364
5          Riot        126027          123       1024.609756

Total frames per label in test set:
Normal: 9384 frames across 30 videos (avg: 312.80 frame

In [4]:
# Define class-specific augmentation to target problematic classes
def get_class_specific_augment(label, is_training=True, strong_augment=False):
    """Apply different augmentation strategies based on class and confusion patterns."""
    if not is_training:
        return T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor()
        ])
        
    # Stronger augmentation for underrepresented or problematic classes
    if strong_augment:
        # Base strong augmentation
        transform = T.Compose([
            T.RandomResizedCrop(224, scale=(0.7, 1.0)),
            T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
            T.RandomRotation(30),
            T.RandomHorizontalFlip(p=0.7),
            T.RandomVerticalFlip(p=0.3),
            T.RandomGrayscale(p=0.2),
            T.ToTensor()
        ])
    else:
        # Addressing Normal/Car Accident confusion
        if label in [0, 3]:  # Normal or Car Accident
            transform = T.Compose([
                T.RandomResizedCrop(224, scale=(0.65, 0.95)),  # More aggressive cropping
                T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
                T.RandomRotation(45),  # More rotation
                T.RandomHorizontalFlip(p=0.8),
                T.RandomVerticalFlip(p=0.4),
                T.RandomGrayscale(p=0.3),
                T.ToTensor()
            ])
        # Addressing Fighting/Normal confusion
        elif label == 2:  # Fighting
            transform = T.Compose([
                T.RandomResizedCrop(224, scale=(0.7, 1.0)),
                T.ColorJitter(brightness=0.4, contrast=0.6, saturation=0.4),
                T.RandomRotation(35),
                T.RandomHorizontalFlip(p=0.9),  # Higher flip probability
                T.RandomPerspective(distortion_scale=0.3, p=0.5),  # Add perspective transform
                T.ToTensor()
            ])
        # Addressing Shooting class performance decline
        elif label == 4:  # Shooting
            transform = T.Compose([
                T.RandomResizedCrop(224, scale=(0.8, 1.0)),
                T.ColorJitter(brightness=0.3, contrast=0.5, saturation=0.3),
                T.RandomRotation(20),
                T.RandomHorizontalFlip(p=0.6),
                T.RandomGrayscale(p=0.1),  # Less grayscale for shooting class
                T.ToTensor()
            ])
        # Default augmentation
        else:
            transform = T.Compose([
                T.RandomResizedCrop(224, scale=(0.8, 1.0)),
                T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
                T.RandomRotation(15),
                T.RandomHorizontalFlip(p=0.5),
                T.ToTensor()
            ])
            
    return transform

# Enhanced data augmentation with class-specific strategy
def read_video_pyav(container, indices, is_training=True, strong_augment=False, label=None):
    """Read video frames with appropriate augmentation based on class."""
    transform = get_class_specific_augment(label, is_training, strong_augment)

    frames = []
    container.seek(0)
    try:
        for i, frame in enumerate(container.decode(video=0)):
            if i in indices:
                frame = transform(frame.to_image())
                frames.append(frame)
            if len(frames) == len(indices):
                break
    except Exception as e:
        print(f"Error during frame decoding: {e}")
        return None

    if not frames:
        return None
        
    return torch.stack(frames)

def split_into_chunks(total_frames, clip_len, overlap=0):
    """Split video into overlapping chunks."""
    step = clip_len - overlap
    chunks = [(start, start + clip_len) for start in range(0, total_frames, step) if start + clip_len <= total_frames]
    return chunks

In [5]:
class VideoDataset(Dataset):
    def __init__(self, dataframe, processor, clip_len=32, frame_sample_rate=1, 
                 target_frames=128, overlap=0, is_training=True, video_path_column='rgb_video_path',
                 max_frames_per_label=MAX_FRAME):
        self.dataframe = dataframe
        self.processor = processor
        self.clip_len = clip_len
        self.frame_sample_rate = frame_sample_rate
        self.target_frames = target_frames
        self.overlap = overlap
        self.is_training = is_training
        self.video_path_column = video_path_column
        self.max_frames_per_label = max_frames_per_label
        self.problematic_files = []  # Initialize before calling _prepare_data
        self.data = self._prepare_data()
        
    def _prepare_data(self):
        """Prepare dataset by processing videos into clips."""
        prepared_data = []
        short_clips = []
        
        for _, row in self.dataframe.iterrows():
            video_path = row[self.video_path_column]
            label = int(row['label'])
            try:
                container = av.open(video_path)
                total_frames = container.streams.video[0].frames
                container.close()
                
                if total_frames >= self.clip_len * self.frame_sample_rate:
                    chunks = split_into_chunks(total_frames, self.clip_len * self.frame_sample_rate, self.overlap)
                    for start, end in chunks:
                        prepared_data.append({
                            "video_path": video_path, 
                            "label": label, 
                            "start": start, 
                            "end": end,
                            "combined": False
                        })
                else:
                    short_clips.append({
                        "video_path": video_path,
                        "label": label,
                        "frames": total_frames,
                        "combined": True
                    })
            except Exception as e:
                self.problematic_files.append((video_path, str(e)))
                print(f"Error processing video {video_path}: {e}")
        
        # Log problematic files
        if self.problematic_files:
            log_path = "problematic_files_dataset.log"
            with open(log_path, "w") as f:
                for path, error in self.problematic_files:
                    f.write(f"{path}: {error}\n")
            print(f"Logged {len(self.problematic_files)} problematic files to {log_path}")
        
        combined_clips = self._combine_short_clips(short_clips)
        
        # Print combined clips stats before adding them to prepared_data
        print(f"Combined clips processing summary: {len(combined_clips)} clips created from {len(short_clips)} short clips, {len(short_clips) - sum([len(clip['video_paths']) for clip in combined_clips])} clips discarded.")
        
        prepared_data.extend(combined_clips)
        
        # Balance clips by frame count
        if self.is_training:
            prepared_data = self._balance_clips_by_label(prepared_data)
        
        print(f"\nTotal clips: {len(prepared_data)}, including {len(combined_clips)} combined clips")

        return prepared_data
    
    def _balance_clips_by_label(self, clips):
        """Balance clips to ensure each label has max_frames_per_label frames."""
        label_groups = {}
        label_frame_counts = {}
    
        # Calculate frame counts per clip and per label
        for clip in clips:
            label = clip["label"]
            if label not in label_groups:
                label_groups[label] = []
                label_frame_counts[label] = 0
            
            # Calculate actual frame count for this clip
            if not clip.get("combined", False):
                frame_count = (clip["end"] - clip["start"]) // self.frame_sample_rate
            else:
                frame_count = sum(clip.get("frames_per_clip", [0])) // self.frame_sample_rate
            
            clip['frame_count'] = frame_count  # Store frame count directly in clip
            label_groups[label].append(clip)
            label_frame_counts[label] += frame_count
    
        balanced_clips = []
    
        for label, label_clips in label_groups.items():
            current_frames = label_frame_counts[label]
            
            if current_frames > self.max_frames_per_label:
                # Downsample strategy: keep largest clips first
                sorted_clips = sorted(label_clips, key=lambda x: x['frame_count'], reverse=True)
                accumulated = 0
                selected_clips = []
                
                for clip in sorted_clips:
                    if accumulated + clip['frame_count'] <= self.max_frames_per_label:
                        selected_clips.append(clip)
                        accumulated += clip['frame_count']
                    elif accumulated == 0:  # Handle case where single clip exceeds max
                        selected_clips.append(clip)
                        accumulated += clip['frame_count']
                        break
                
                balanced_clips.extend(selected_clips)
                print(f"Label {INV_LABEL_MAP[label]} ({label}): "
                      f"Reduced from {current_frames} to {accumulated} frames")
    
            else:
                # Upsample strategy: augment until reaching max_frames_per_label
                balanced_clips.extend(label_clips)
                accumulated = current_frames
                deficit = self.max_frames_per_label - accumulated
                
                if deficit > 0:
                    original_clips = label_clips.copy()
                    while deficit > 0:
                        for clip in original_clips:
                            if deficit <= 0:
                                break
                                
                            augmented_clip = clip.copy()
                            augmented_clip["augment_strongly"] = True
                            balanced_clips.append(augmented_clip)
                            deficit -= clip['frame_count']
                            accumulated += clip['frame_count']
    
                    print(f"Label {INV_LABEL_MAP[label]} ({label}): "
                          f"Increased from {current_frames} to {accumulated - deficit} frames "
                          f"(added {len(balanced_clips) - len(label_clips)} augmented clips)")
        
        # Verification step to confirm balancing
        final_label_frames = {}
        for clip in balanced_clips:
            label = clip["label"]
            if label not in final_label_frames:
                final_label_frames[label] = 0
            final_label_frames[label] += clip['frame_count']
        
        print("\nFinal frame count per label after balancing:")
        for label, count in final_label_frames.items():
            print(f"Label {INV_LABEL_MAP[label]} ({label}): {count} frames")
    
        return balanced_clips

    
    def _combine_short_clips(self, short_clips):
        """Combine short clips to reach minimum length."""
        label_groups = {}
        for clip in short_clips:
            label = clip["label"]
            if label not in label_groups:
                label_groups[label] = []
            label_groups[label].append(clip)
        
        combined_data = []
        for label, clips in label_groups.items():
            current_clips = []
            current_frames = 0
            
            for clip in clips:
                if current_frames + clip["frames"] <= self.target_frames:
                    current_clips.append(clip)
                    current_frames += clip["frames"]
                    
                    if current_frames >= self.clip_len * self.frame_sample_rate:
                        combined_data.append({
                            "video_paths": [c["video_path"] for c in current_clips],
                            "label": label,
                            "frames_per_clip": [c["frames"] for c in current_clips],
                            "combined": True
                        })
                        current_clips = []
                        current_frames = 0
            
            if current_clips and current_frames >= self.clip_len * self.frame_sample_rate:
                combined_data.append({
                    "video_paths": [c["video_path"] for c in current_clips],
                    "label": label,
                    "frames_per_clip": [c["frames"] for c in current_clips],
                    "combined": True
                })
        
        return combined_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        strong_augment = item.get("augment_strongly", False)
        label = item["label"]
        
        if not item.get("combined", False):
            video_path, start, end = item["video_path"], item["start"], item["end"]
            try:
                container = av.open(video_path)
                indices = list(range(start, end, self.frame_sample_rate))
                if len(indices) != self.clip_len:
                    container.close()
                    return None
                    
                video = read_video_pyav(container, indices, self.is_training, strong_augment, label)
                container.close()
                
                if video is None:
                    return None
                
                # Check if frames are valid
                if torch.isnan(video).any() or torch.isinf(video).any():
                    print(f"Invalid frame values detected in {video_path}")
                    return None
                
                # Add temporal dropout for additional augmentation when training
                if self.is_training and np.random.random() < 0.3:
                    # Randomly drop up to 10% of frames and repeat adjacent frames
                    num_frames_to_drop = max(1, int(0.1 * video.shape[0]))
                    frames_to_drop = np.random.choice(video.shape[0], num_frames_to_drop, replace=False)
                    for idx in frames_to_drop:
                        # Replace with previous or next frame
                        replace_idx = max(0, idx-1) if idx > 0 else min(idx+1, video.shape[0]-1)
                        video[idx] = video[replace_idx]
                
                inputs = self.processor(list(video.permute(0, 2, 3, 1).numpy()), return_tensors="pt")
                return {"pixel_values": inputs["pixel_values"].squeeze(0), "labels": torch.tensor(label, dtype=torch.long)}
            except Exception as e:
                print(f"Error loading video chunk from {video_path}: {e}")
                self.problematic_files.append((video_path, str(e)))
                return None
        
        else:
            video_paths = item["video_paths"]
            frames_per_clip = item["frames_per_clip"]
            
            all_frames = []
            current_frame_count = 0
            
            for i, video_path in enumerate(video_paths):
                try:
                    container = av.open(video_path)
                    frames_needed = min(frames_per_clip[i], self.clip_len * self.frame_sample_rate - current_frame_count)
                    indices = list(range(0, frames_needed))
                    video = read_video_pyav(container, indices, self.is_training, strong_augment, label)
                    container.close()
                    
                    if video is None:
                        continue
                    
                    # Check if frames are valid
                    if torch.isnan(video).any() or torch.isinf(video).any():
                        print(f"Invalid frame values detected in {video_path}")
                        continue
                        
                    all_frames.append(video)
                    current_frame_count += frames_needed
                    
                    if current_frame_count >= self.clip_len * self.frame_sample_rate:
                        break
                except Exception as e:
                    print(f"Error loading combined video from {video_path}: {e}")
                    self.problematic_files.append((video_path, str(e)))
            
            if not all_frames or current_frame_count < self.clip_len * self.frame_sample_rate:
                return None
                
            combined_video = torch.cat(all_frames, dim=0)
            
            if combined_video.shape[0] > self.clip_len:
                combined_video = combined_video[:self.clip_len]
                
            inputs = self.processor(list(combined_video.permute(0, 2, 3, 1).numpy()), return_tensors="pt")
            return {"pixel_values": inputs["pixel_values"].squeeze(0), "labels": torch.tensor(label, dtype=torch.long)}

def collate_fn(batch):
    """Custom collate function that handles None values in batch."""
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    return {"pixel_values": pixel_values, "labels": labels}

In [6]:
# Initialize the processor with do_rescale=None and offset=None as requested
processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2", do_rescale=None, offset=None)
model = VivitForVideoClassification.from_pretrained(
    "google/vivit-b-16x2",
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1
)

# Load model from checkpoint if specified
if RESUME_FROM_CHECKPOINT:
    print(f"Loading model from checkpoint: {RESUME_FROM_CHECKPOINT}")
    model.load_state_dict(torch.load(RESUME_FROM_CHECKPOINT))
    print(f"Successfully loaded model from: {RESUME_FROM_CHECKPOINT}")
    print("Starting additional training from this checkpoint...")

# Initialize datasets and dataloaders with different max frames for train and test
train_dataset = VideoDataset(df_train, processor, CLIP_LEN, FRAME_SAMPLE_RATE, 
                            target_frames=TARGET_FRAMES, overlap=16, is_training=True,
                            max_frames_per_label=MAX_FRAME)
test_dataset = VideoDataset(df_test, processor, CLIP_LEN, FRAME_SAMPLE_RATE, 
                           target_frames=TARGET_FRAMES, overlap=16, is_training=False,
                           max_frames_per_label=MAX_TEST_FRAME)

# Separate parameters with and without weight decay
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {
        'params': [p for n, p in model.named_parameters() 
                  if not any(nd in n for nd in no_decay)],
        'weight_decay': 1e-6,
    },
    {
        'params': [p for n, p in model.named_parameters() 
                  if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0
    }
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Adjust class weights based on confusion matrix insights
def get_adjusted_class_weights(class_counts):
    weights = 1.0 / class_counts
    
    # Increase weight for frequently misclassified classes
    # Fighting (2) and Car Accident (3) need more attention
    weights[2] *= 1.3  # Fighting
    weights[3] *= 1.5  # Car Accident
    weights[4] *= 1.2  # Shooting (declining performance)
    
    return torch.tensor(weights, dtype=torch.float)

train_labels = df_train["label"].values
class_counts = np.bincount(train_labels)
class_weights = get_adjusted_class_weights(class_counts).to(device)

# Define focal loss to focus more on hard examples
class FocalLoss(nn.Module):
    def __init__(self, weight=None, gamma=2.0, reduction="mean", label_smoothing=0.0):
        super(FocalLoss, self).__init__()
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction
        self.label_smoothing = label_smoothing
        self.ce = nn.CrossEntropyLoss(weight=weight, reduction="none", label_smoothing=label_smoothing)
        
    def forward(self, inputs, targets):
        # Extract logits if inputs is a model output object (e.g., ImageClassifierOutput)
        if hasattr(inputs, 'logits'):
            inputs = inputs.logits
            
        ce_loss = self.ce(inputs, targets)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        if self.reduction == "mean":
            return focal_loss.mean()
        elif self.reduction == "sum":
            return focal_loss.sum()
        return focal_loss


# Use focal loss with label smoothing
loss_fn = FocalLoss(weight=class_weights, gamma=2.0, label_smoothing=0.1)

# Create optimizer with initial learning rate (will be updated by LR finder)
optimizer = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE, betas=(0.9, 0.999), eps=1e-8)

# Create the GradScaler for mixed precision training
scaler = GradScaler()

# Create train dataloader for LR finder
train_dataloader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, 
                             shuffle=True, collate_fn=collate_fn, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=EVAL_BATCH_SIZE, 
                            shuffle=False, collate_fn=collate_fn, pin_memory=True)

# Only run the learning rate finder if we're not resuming from a checkpoint
if not RESUME_FROM_CHECKPOINT:
    # Set up mixed precision config for torch_lr_finder
    amp_config = {
        'device_type': 'cuda',
        'dtype': torch.float16,
    }

    # Create and run the LR finder with custom iterator
    print("Running learning rate finder...")
    train_data_iter = CustomTrainDataLoaderIter(train_dataloader)
    
    lr_finder = LRFinder(
        model, 
        optimizer, 
        loss_fn, 
        device=device, 
        amp_backend='torch',
        amp_config=amp_config, 
        grad_scaler=scaler
    )

    # Run range test with gradient accumulation
    lr_finder.range_test(
        train_data_iter,  # Use custom iterator instead of dataloader directly
        end_lr=1e-2, 
        num_iter=100, 
        step_mode="exp", 
        accumulation_steps=GRADIENT_ACCUMULATION_STEPS
    )

    # Fixed code
    fig, suggested_lr = lr_finder.plot(suggest_lr=True)  # This returns both the plot and the suggested learning rate
    print(f"Suggested learning rate: {suggested_lr}")

    # Update the learning rate
    LEARNING_RATE = suggested_lr if suggested_lr is not None else LEARNING_RATE
    print(f"Using learning rate: {LEARNING_RATE}")

    # Reset the model and optimizer
    lr_finder.reset()

    # Recreate optimizer with the found learning rate
    optimizer = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE, betas=(0.9, 0.999), eps=1e-8)
else:
    print(f"Skipping learning rate finder since we're resuming from a checkpoint.")
    print(f"Using learning rate: {LEARNING_RATE}")

# Setup learning rate scheduler
num_training_steps = NUM_EPOCHS * len(train_dataloader)
num_warmup_steps = int(num_training_steps * 0.1)
lr_scheduler = get_scheduler(
    name="cosine",
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

# Initialize training history
history = {
    "train_loss": [],
    "train_accuracy": [],
    "val_loss": [],
    "val_accuracy": []
}

best_val_accuracy = 0
early_stopping_counter = 0

Some weights of VivitForVideoClassification were not initialized from the model checkpoint at google/vivit-b-16x2 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([6, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([6]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading model from checkpoint: F:/SRC_Bhuvaneswari/typpo/Crimenet/VisTra/Checkpoints/v2.0/best_model_acc.pt
Successfully loaded model from: F:/SRC_Bhuvaneswari/typpo/Crimenet/VisTra/Checkpoints/v2.0/best_model_acc.pt
Starting additional training from this checkpoint...
Combined clips processing summary: 19 clips created from 42 short clips, 3 clips discarded.
Label Normal (0): Reduced from 53041 to 39985 frames
Label Explosion (1): Reduced from 54426 to 39994 frames
Label Fighting (2): Reduced from 109427 to 39987 frames
Label Car Accident (3): Increased from 25293 to 40026 frames (added 4205 augmented clips)
Label Shooting (4): Reduced from 40820 to 39988 frames
Label Riot (5): Reduced from 247072 to 40000 frames

Final frame count per label after balancing:
Label Normal (0): 39985 frames
Label Explosion (1): 39994 frames
Label Fighting (2): 39987 frames
Label Car Accident (3): 40013 frames
Label Shooting (4): 39988 frames
Label Riot (5): 40000 frames

Total clips: 7489, including 19 

  scaler = GradScaler()


In [None]:
# Start training from epoch 0 regardless of checkpoint source
for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0
    correct_predictions = 0
    total_predictions = 0

    epoch_progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS}", leave=True)
    for step, batch in enumerate(epoch_progress_bar):
        if batch is None:
            continue

        batch = {k: v.to(device) for k, v in batch.items()}

        with torch.amp.autocast('cuda'):
            outputs = model(pixel_values=batch["pixel_values"])
            raw_loss = loss_fn(outputs.logits, batch["labels"])
            loss_for_backward = raw_loss / GRADIENT_ACCUMULATION_STEPS

        scaler.scale(loss_for_backward).backward()

        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            # Unscale gradients first
            scaler.unscale_(optimizer)
    
            # Add gradient clipping with smaller max_norm for better stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            lr_scheduler.step()

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        correct_predictions += (predictions == batch["labels"]).sum().item()
        total_predictions += batch["labels"].size(0)

        epoch_loss += loss_for_backward.item() * GRADIENT_ACCUMULATION_STEPS
        epoch_progress_bar.set_postfix({
            "Loss": f"{loss_for_backward.item() * GRADIENT_ACCUMULATION_STEPS:.8f}",
            "Accuracy": f"{(correct_predictions / total_predictions) * 100:.2f}%"
        })

    avg_epoch_loss = epoch_loss / len(train_dataloader)
    train_accuracy = correct_predictions / total_predictions
    history["train_loss"].append(avg_epoch_loss)
    history["train_accuracy"].append(train_accuracy)

    # Validation phase
    model.eval()
    val_loss = 0
    val_correct_predictions = 0
    val_total_predictions = 0
    all_predictions = []
    all_labels = []
    
    # Create validation progress bar similar to training
    val_progress_bar = tqdm(test_dataloader, desc="Validation", leave=True)
    with torch.no_grad():
        for batch in val_progress_bar:
            if batch is None:
                continue
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(pixel_values=batch["pixel_values"], interpolate_pos_encoding=True)
            
            # Calculate batch loss
            batch_loss = loss_fn(outputs.logits, batch["labels"]).item()
            val_loss += batch_loss

            predictions = torch.argmax(outputs.logits, dim=-1)
            batch_correct = (predictions == batch["labels"]).sum().item()
            batch_total = batch["labels"].size(0)
            
            val_correct_predictions += batch_correct
            val_total_predictions += batch_total

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(batch["labels"].cpu().numpy())
            
            # Update progress bar with current batch accuracy and loss
            current_accuracy = (val_correct_predictions / val_total_predictions) * 100
            val_progress_bar.set_postfix({
                "Loss": f"{batch_loss:.8f}",
                "Accuracy": f"{current_accuracy:.2f}%"
            })

    avg_val_loss = val_loss / len(test_dataloader)
    val_accuracy = val_correct_predictions / val_total_predictions
    history["val_loss"].append(avg_val_loss)
    history["val_accuracy"].append(val_accuracy)

    # Save best model based on accuracy
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), os.path.join(SAVE_DIR, "best_model_acc.pt"))
        print(f"New best model saved with validation accuracy: {val_accuracy * 100:.2f}%")
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1

    # Save model at each epoch
    torch.save(model.state_dict(), os.path.join(SAVE_DIR, f"vivit_epoch_{epoch + 1}.pt"))

    print(f"Epoch {epoch + 1} completed. Train Loss: {avg_epoch_loss:.8f}, "
          f"Train Accuracy: {train_accuracy * 100:.2f}%, "
          f"Val Loss: {avg_val_loss:.8f}, "
          f"Val Accuracy: {val_accuracy * 100:.2f}%")

    # Create confusion matrix every epoch with larger figure size for better readability
    cm = confusion_matrix(all_labels, all_predictions)
    plt.figure(figsize=(20, 10))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(LABEL_MAP.keys()))
    disp.plot(cmap='Blues')
    plt.title(f"Confusion Matrix - Epoch {epoch + 1}")
    plt.savefig(os.path.join(SAVE_DIR, f'confusion_matrix_epoch_{epoch + 1}.png'))
    plt.close()
    
    # Implement early stopping
    if early_stopping_counter >= EARLY_STOPPING_PATIENCE:
        print(f"Early stopping triggered after {epoch + 1} epochs")
        break

Epoch 1/20:  10%|███▏                           | 261/2497 [20:12<2:58:01,  4.78s/it, Loss=0.00000021, Accuracy=94.25%]

In [None]:
# Plot training and validation metrics
plt.figure(figsize=(20, 10))

plt.subplot(2, 2, 1)
plt.plot(history["train_accuracy"], label="Train Accuracy")
plt.plot(history["val_accuracy"], label="Validation Accuracy")
plt.ylabel("Accuracy")
plt.xlabel("Epoch")
plt.legend(loc="lower right")
plt.title("Training and Validation Accuracy")

In [None]:
plt.subplot(2, 2, 2)
plt.plot(history["train_loss"], label="Train Loss")
plt.plot(history["val_loss"], label="Validation Loss")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.legend(loc="upper right")
plt.title("Training and Validation Loss")

In [None]:
plt.tight_layout()
plt.savefig(os.path.join(SAVE_DIR, 'training_metrics.png'))
plt.close()

In [None]:
print(f"Training completed.")