In [None]:
!ls /kaggle/input

In [None]:
# Step 1: Clone the repository
!git clone https://github.com/rishikksh20/ViViT-pytorch.git 
# Step 2: Change directory to the cloned repo
%cd ViViT-pytorch

# Step 4: Run scripts or modify files
!python module.py vivit.py  # Example script to run


## setup libraries & device

In [None]:
import os, sys
import natsort # For number sorting
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import natsort

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch import einsum

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from module import Attention, PreNorm, FeedForward
import cv2
from google.colab.patches import cv2_imshow # colab에서 cv2.imshow 사용 불가
sys.path

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
print(device)
print(torch.cuda.is_available())  # Should print True if GPU is enabled


In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
print(torch.cuda.device_count())  # Nombre de GPUs
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")


###  Pretraitement de dataloader pour surveillance camera

In [None]:
class DashcamDataset(Dataset):
    def __init__(self, root_dir, sequence_length=4, transform=None, overlap=True):
        self.root_dir = root_dir
        self.sequence_length = sequence_length
        self.transform = transform
        self.overlap = overlap
        self.sequences = []

        # Parcourir chaque vidéo séparément
        for video_folder in sorted(os.listdir(root_dir)):
            video_path = os.path.join(root_dir, video_folder)
            if os.path.isdir(video_path):
                frames = sorted(
                    [f for f in os.listdir(video_path) if f.startswith('frame_')],
                    key=lambda x: int(x.split('_')[1].split('.')[0]))  # Trie par numéro (001 -> 1)
                frames = [os.path.join(video_path, f) for f in frames]

                # Générer les séquences pour cette vidéo
                if overlap:
                    for i in range(len(frames) - sequence_length):
                        self.sequences.append(frames[i:i + sequence_length + 1])
                else:
                    for i in range(0, len(frames) - sequence_length, sequence_length + 1):
                        self.sequences.append(frames[i:i + sequence_length + 1])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        frame_paths = self.sequences[idx]
        images = []
        for img_path in frame_paths:
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (256, 256))
            if self.transform:
                img = self.transform(img)
            images.append(img)

        input_frames = torch.stack(images[:self.sequence_length], dim=0)
        target_frame = images[self.sequence_length]
        target_name = os.path.basename(frame_paths[-1])
        return input_frames, target_frame, target_name

In [None]:
# Définition des transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalisation [-1, 1]
])

# Création du Dataset et du DataLoader avec la nouvelle classe améliorée
train_dataset = DashcamDataset(
    root_dir="/kaggle/input/surveillance-camera-fightnofight/Surveillance Camera Fight Dataset/Train",
    sequence_length=4,
    transform=transform,
    overlap=True
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=False,  # Garder False pour l'entraînement pour maintenir l'ordre temporel
    num_workers=4,
    pin_memory=True,
)

test_dataset = DashcamDataset(
    root_dir="/kaggle/input/surveillance-camera-fightnofight/Surveillance Camera Fight Dataset/Test",
    sequence_length=4,
    transform=transform,
    overlap=True  # Paramètre valide pour DashcamDataset
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,  # Désactivé pour maintenir l'ordre temporel
    num_workers=4,
    pin_memory=True,
)



import matplotlib.pyplot as plt
import numpy as np

def display_sequence_with_target(input_frames, target_frame, sequence_idx, target_name):
    """
    Affiche les frames d'entrée et la frame cible avec des informations détaillées.
    """
    # Convertir et préparer les images
    input_frames = input_frames.cpu().numpy().transpose(0, 2, 3, 1)
    target_frame = target_frame.cpu().numpy().transpose(1, 2, 0)
    input_frames = (input_frames + 1) / 2
    target_frame = (target_frame + 1) / 2

    # Extraire les numéros de frame
    target_num = target_name.split('_')[1].split('.')[0]
    input_nums = [str(int(target_num) - len(input_frames) + i) for i in range(len(input_frames))]
    
    # Créer la figure
    plt.figure(figsize=(15, 5))
    plt.suptitle(
        f"Sequence {sequence_idx + 1}\n"
        f"Input frames ({len(input_frames)}): {', '.join(input_nums)}\n"
        f"Target frame: {target_num}",
        fontsize=12, y=1.05
    )
    
    # Afficher les frames d'entrée
    for i in range(len(input_frames)):
        plt.subplot(1, len(input_frames) + 1, i + 1)
        plt.imshow(input_frames[i])
        plt.title(f"Frame {input_nums[i]}")
        plt.axis('off')
    
    # Afficher la cible
    plt.subplot(1, len(input_frames) + 1, len(input_frames) + 1)
    plt.imshow(target_frame)
    plt.title(f"Target: {target_num}")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Paramètre configurable - CHANGEZ ICI LE NOMBRE DE BATCHS À AFFICHER
NUM_BATCHES_TO_DISPLAY = 3  # ← Modifiez ce nombre selon vos besoins

# Itération sur le DataLoader
for batch_idx, (input_frames, target_frames, target_names) in enumerate(train_dataloader):
    if batch_idx >= NUM_BATCHES_TO_DISPLAY:
        break
        
    print(f"\n{'='*40}")
    print(f"=== BATCH {batch_idx + 1}/{NUM_BATCHES_TO_DISPLAY} ===")
    print(f"{'='*40}")
    print(f"Nombre total de séquences dans ce batch: {input_frames.shape[0]}")
    
    for seq_idx in range(input_frames.shape[0]):
        current_input = input_frames[seq_idx]
        current_target = target_frames[seq_idx]
        current_target_name = target_names[seq_idx]
        
        # Affichage console
        target_num = current_target_name.split('_')[1].split('.')[0].zfill(3)
        input_nums = [str(int(target_num) - len(current_input) + i).zfill(3) for i in range(len(current_input))]
        
        print(f"\nSéquence {seq_idx + 1}:")
        print(f"• Frames d'entrée ({len(input_nums)}): {', '.join(input_nums)}")
        print(f"• Frame cible: {target_num}")
        
        # Affichage graphique
        display_sequence_with_target(current_input, current_target, seq_idx, current_target_name)

### Pretraitement de Dataloader pour Avune

In [None]:
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class AvenueDataset(Dataset):
    def __init__(self, root_dir, sequence_length=4, transform=None, overlap=True):
        self.root_dir = root_dir
        self.sequence_length = sequence_length
        self.transform = transform
        self.overlap = overlap
        self.sequences = []
        
        # Parcours des dossiers de séquences (02, 03, etc.)
        for seq_folder in sorted(os.listdir(root_dir)):
            seq_path = os.path.join(root_dir, seq_folder)
            if os.path.isdir(seq_path):
                # Liste des images triées par numéro
                frames = sorted([f for f in os.listdir(seq_path) if f.endswith('.jpg')],
                               key=lambda x: int(x.split('.')[0]))
                frames = [os.path.join(seq_path, f) for f in frames]
                
                # Création des séquences
                if overlap:
                    for i in range(len(frames) - sequence_length):
                        self.sequences.append(frames[i:i + sequence_length + 1])
                else:
                    for i in range(0, len(frames) - sequence_length, sequence_length + 1):
                        self.sequences.append(frames[i:i + sequence_length + 1])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        frame_paths = self.sequences[idx]
        images = []
        
        for img_path in frame_paths:
            img = cv2.imread(img_path)
            if img is None:
                raise ValueError(f"Impossible de charger l'image: {img_path}")
                
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (256, 256))
            
            if self.transform:
                img = self.transform(img)
                
            images.append(img)
        
        input_frames = torch.stack(images[:self.sequence_length], dim=0)
        target_frame = images[self.sequence_length]
        target_name = os.path.basename(frame_paths[-1])
        
        return input_frames, target_frame, target_name


In [None]:

# Transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Chemins des données - A ADAPTER selon votre structure exacte
train_root = "/kaggle/input/avunue-rachid/avenue/training"
test_root = "/kaggle/input/avunue-rachid/avenue/testing"

# Datasets
train_dataset = AvenueDataset(
    root_dir=train_root,
    sequence_length=4,
    transform=transform,
    overlap=True
)

test_dataset = AvenueDataset(
    root_dir=test_root,
    sequence_length=4,
    transform=transform,
    overlap=True
)

# DataLoaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

def display_sequence(input_frames, target_frame, sequence_idx, target_name):
    """
    Affiche une séquence de frames et la frame cible
    """
    # Conversion des tenseurs en numpy arrays
    input_frames = input_frames.cpu().numpy()
    target_frame = target_frame.cpu().numpy()
    
    # Réorganisation des dimensions et dénormalisation
    input_frames = np.transpose(input_frames, (0, 2, 3, 1))
    target_frame = np.transpose(target_frame, (1, 2, 0))
    
    input_frames = (input_frames + 1) / 2
    target_frame = (target_frame + 1) / 2
    
    # Numéros des frames
    target_num = target_name.split('.')[0]
    input_nums = [str(int(target_num) - len(input_frames) + i).zfill(4) 
                 for i in range(len(input_frames))]
    
    # Création de la figure
    plt.figure(figsize=(15, 5))
    plt.suptitle(f"Séquence {sequence_idx + 1} - Target: {target_num}", fontsize=14)
    
    # Affichage des frames d'entrée
    for i, frame in enumerate(input_frames):
        plt.subplot(1, len(input_frames) + 1, i + 1)
        plt.imshow(frame)
        plt.title(f"Frame {input_nums[i]}")
        plt.axis('off')
    
    # Affichage de la cible
    plt.subplot(1, len(input_frames) + 1, len(input_frames) + 1)
    plt.imshow(target_frame)
    plt.title(f"Target {target_num}")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualisation des premières séquences
NUM_BATCHES_TO_DISPLAY = 2

for batch_idx, (inputs, targets, names) in enumerate(train_dataloader):
    if batch_idx >= NUM_BATCHES_TO_DISPLAY:
        break
    
    print(f"\n=== Batch {batch_idx + 1} ===")
    print(f"Input shape: {inputs.shape}")
    print(f"Target shape: {targets.shape}")
    print(f"Target names: {names}")
    
    for seq_idx in range(inputs.shape[0]):
        display_sequence(
            inputs[seq_idx], 
            targets[seq_idx], 
            seq_idx, 
            names[seq_idx]
        )

## Definition de model

In [None]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)


In [None]:
class ViViT(nn.Module):
    def __init__(self, image_size, patch_size, num_frames, dim=512, depth=4, heads=3, pool='cls', in_channels=512,
                 dim_head=64, dropout=0., emb_dropout=0., scale_dim=4, depth_spatial=3, depth_temporal=1):
        super().__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.num_frames = num_frames
        self.in_channels = in_channels
        self.dim = dim
        self.depth_spatial = depth_spatial
        self.depth_temporal = depth_temporal

        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_dim = in_channels * patch_size ** 2  # chaque patch aplati

        # 🔄 PATCH EMBEDDING
        
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b t c (h p1) (w p2) -> b t (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(self.patch_dim, self.dim)  # (b, t, Np, D)
        )

        # 🔄 TEMPORAL: (b, Np, t+1, d)
        self.temporal_pos_embedding = nn.Parameter(torch.randn(1, self.num_patches, self.num_frames + 1, self.dim))
        self.temporal_token = nn.Parameter(torch.randn(1, 1, 1, self.dim))  # (1, 1, 1, d)
        self.temporal_transformer = Transformer(self.dim, self.depth_temporal, heads, dim_head, dim * scale_dim, dropout)

        # 🔄 SPATIAL
        self.spatial_pos_embedding = nn.Parameter(torch.randn(1, self.num_patches, self.dim))
        self.space_transformer = Transformer(self.dim, self.depth_spatial, heads, dim_head, dim * scale_dim, dropout)

        self.dropout = nn.Dropout(emb_dropout)
        self.pool = pool
 
    def forward(self, x):
        # 🔸 Shape d'entrée : (b, t, c, h, w) = (4, 4, 512, 32, 32)
        # print('input of vivit :', x.shape)
        # print('before patches embedding : ')
        x = self.to_patch_embedding(x)  # (b, t, Np, d)
        # print('after patch embed : ', x.shape)

        b, t, n, d = x.shape  # t=4, n=256, d=512

        # 🔹 Transpose pour regrouper par patch : → (b, n, t, d)
        x = x.permute(0, 2, 1, 3)
        # print('after  regroupe 4 frames par path : ',x.shape)

        # 🔹 Ajouter un prediction token par patch group : shape → (b, n, t+1, d)
        pred_token = self.temporal_token.expand(b, n, 1, d)
        x = torch.cat((pred_token, x), dim=2)

        # 🔹 Ajouter position embedding temporel
        x = x + self.temporal_pos_embedding[:, :, :x.shape[2], :]
        # print('afeter adding position embedding : ',x.shape)
        x = self.dropout(x)
        # print ('after dropout')

        # print('temporal transformer input : ', x.shape)  # (b, n, t+1, d)

        # 🔹 Appliquer transformer temporel par groupe (fusionner batch et Np)
        x = rearrange(x, 'b n t d -> (b n) t d')
        x = self.temporal_transformer(x)
        x = rearrange(x, '(b n) t d -> b n t d', b=b)
        # print('after temporal transformer : ',x.shape)
        # 🔹 Récupérer le token de prédiction (position 0)
        x = x[:, :, 0, :]  # (b, n, d)
        # print('after temporal transformer : ',x.shape)
        # print('start embedding spatial')
        # 🔹 Ajouter embedding spatial
        x = x + self.spatial_pos_embedding
        # print('spatial transformer input : ', x.shape)

        # 🔹 Transformer spatial
        x = self.space_transformer(x)
        # print('after spatial transformer : ',x.shape)
        # print(self.image_size)

        # 🔹 Reshape final en feature map (b, c, h, w)
        # x = rearrange(x, 'b (h w) c -> b c h w', h=self.image_size // self.patch_size)
        x = rearrange(x, 'b (h w) c -> b c h w', h=self.image_size//2,w=self.image_size//2)

        # print ('after reshaping : ',x.shape)


        return x


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

# class FlowNet(nn.Module):
#     """Module simplifié pour calculer le flux optique entre deux frames"""
#     def __init__(self):
#         super(FlowNet, self).__init__()
        
#         # Architecture simple pour estimer le flux optique
#         self.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3)
#         self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)
#         self.conv3 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2)
#         self.conv4 = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
        
#     def forward(self, frame1, frame2):
#         # Concaténer les deux frames le long de la dimension des canaux
#         x = torch.cat([frame1, frame2], dim=1)
        
#         # Passer à travers les couches
#         x = F.relu(self.conv1(x))
#         x = F.relu(self.conv2(x))
#         x = F.relu(self.conv3(x))
#         flow = self.conv4(x)
        
#         # Redimensionner le flux à la taille originale
#         flow = F.interpolate(flow, scale_factor=4, mode='bilinear', align_corners=True)
        
#         return flow

class ImprovedFlowNet(nn.Module):
    """Module amélioré pour calculer le flux optique"""
    def __init__(self):
        super(ImprovedFlowNet, self).__init__()
        
        # Première partie - extraction de features
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3),
            nn.LeakyReLU(0.1),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.1),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
        )
        
        # Deuxième partie - estimation du flux
        self.flow_estimator = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(64, 2, kernel_size=3, stride=1, padding=1),
        )
        
        # Module de raffinement
        self.refinement = nn.Sequential(
            nn.Conv2d(2, 16, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(16, 2, kernel_size=3, stride=1, padding=1),
        )
        
    def forward(self, frame1, frame2):
        # Concaténation des frames
        x = torch.cat([frame1, frame2], dim=1)
        
        # Extraction des features
        features = self.feature_extractor(x)
        
        # Estimation du flux
        flow = self.flow_estimator(features)
        
        # Raffinement du flux
        flow = self.refinement(flow)
        
        # Upsampling pour correspondre à la taille d'entrée
        flow = F.interpolate(flow, scale_factor=4, mode='bilinear', align_corners=True)
        
        return flow

In [None]:
class TransAnomaly(nn.Module):
    def __init__(self, batch_size=4, num_frames=4):
        super(TransAnomaly, self).__init__()
        self.batch_size = batch_size
        self.num_frames = num_frames
        self.channels_1 = 64
        self.channels_2 = 128
        self.channels_3 = 256
        self.channels_4 = 512
        
        # Encoder
        self.contracting_11 = self.conv_block(in_channels=3, out_channels=self.channels_1)
        self.contracting_12 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_21 = self.conv_block(in_channels=self.channels_1, out_channels=self.channels_2)
        self.contracting_22 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_31 = self.conv_block(in_channels=self.channels_2, out_channels=self.channels_3)
        self.contracting_32 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_41 = self.conv_block(in_channels=self.channels_3, out_channels=self.channels_4)

        # Residual connections
        self.residual_14 = nn.Conv2d(in_channels=self.channels_1*self.num_frames, out_channels=self.channels_1, kernel_size=3, stride=1, padding=1)
        self.residual_23 = nn.Conv2d(in_channels=self.channels_2*self.num_frames, out_channels=self.channels_2, kernel_size=3, stride=1, padding=1)
        self.residual_32 = nn.Conv2d(in_channels=self.channels_3*self.num_frames, out_channels=self.channels_3, kernel_size=3, stride=1, padding=1)
        self.residual_41 = nn.Conv2d(in_channels=self.channels_4*self.num_frames, out_channels=self.channels_4, kernel_size=3, stride=1, padding=1)

        # ViViT layer
        self.middle = ViViT(image_size=32, patch_size=2, num_frames=self.num_frames, in_channels=512)

        # Decoder
        self.expansive_11 = nn.ConvTranspose2d(in_channels=self.channels_4, out_channels=self.channels_4, kernel_size=3, stride=2, padding=1, output_padding=1) 
        self.expansive_12 = self.conv_block(in_channels=self.channels_4*2, out_channels=self.channels_4)
        self.expansive_21 = nn.ConvTranspose2d(in_channels=self.channels_4, out_channels=self.channels_3, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.expansive_22 = self.conv_block(in_channels=self.channels_3*2, out_channels=self.channels_3)
        self.expansive_31 = nn.ConvTranspose2d(in_channels=self.channels_3, out_channels=self.channels_2, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.expansive_32 = self.conv_block(in_channels=self.channels_2*2, out_channels=self.channels_2)
        self.expansive_41 = nn.ConvTranspose2d(in_channels=self.channels_2, out_channels=self.channels_1, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.expansive_42 = self.conv_block(in_channels=self.channels_1*2, out_channels=self.channels_1)
        self.output = nn.Conv2d(in_channels=self.channels_1, out_channels=3, kernel_size=3, stride=1, padding=1)
        
        # Module de flux optique
        # self.flownet = FlowNet()
        self.flownet = ImprovedFlowNet()

        
        
    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=out_channels),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=out_channels)
        )
        return block
    
    def compute_optical_flow_loss(self, pred_frame, target_frame, prev_frame):
        """
        Calcule la perte de flux optique entre:
        1. Le flux entre la prédiction et la frame précédente
        2. Le flux entre la cible et la frame précédente
        """
        # Calculer les deux flux optiques
        flow_pred = self.flownet(prev_frame, pred_frame)
        flow_target = self.flownet(prev_frame, target_frame)
        
        # Calculer la différence L1 entre les deux flux
        loss = F.l1_loss(flow_pred, flow_target)
        
        return loss
    
    def forward(self, frames):
        # frames.shape = (batch, num_frames, c, h, w)
        batch_size = frames.shape[0]
        
        ########## Encoding #######
        tmp_frames = rearrange(frames, 'b t c h w -> (b t) c h w')
        
        contracting_11_out = self.contracting_11(tmp_frames)
        contracting_12_out = self.contracting_12(contracting_11_out)
        contracting_21_out = self.contracting_21(contracting_12_out)
        contracting_22_out = self.contracting_22(contracting_21_out)
        contracting_31_out = self.contracting_31(contracting_22_out)
        contracting_32_out = self.contracting_32(contracting_31_out)
        contracting_41_out = self.contracting_41(contracting_32_out)
        
        ####### ViViT layer ########
        vivit_input = rearrange(contracting_41_out, '(b t) c h w -> b t c h w', b=batch_size)
        middle_out = self.middle(vivit_input)
        
        ######### Residual connections #####
        residual_14_out = rearrange(contracting_11_out, '(b t) c h w -> b (t c) h w', b=batch_size)
        residual_14_out = self.residual_14(residual_14_out)

        residual_23_out = rearrange(contracting_21_out, '(b t) c h w -> b (t c) h w', b=batch_size)
        residual_23_out = self.residual_23(residual_23_out)

        residual_32_out = rearrange(contracting_31_out, '(b t) c h w -> b (t c) h w', b=batch_size)
        residual_32_out = self.residual_32(residual_32_out)

        residual_41_out = rearrange(contracting_41_out, '(b t) c h w -> b (t c) h w', b=batch_size)
        residual_41_out = self.residual_41(residual_41_out)
        
        ####### Decoding ##########
        expansive_11_out = self.expansive_11(middle_out)
        expansive_12_out = self.expansive_12(torch.cat((expansive_11_out, residual_41_out), dim=1))
        expansive_21_out = self.expansive_21(expansive_12_out)
        expansive_22_out = self.expansive_22(torch.cat((expansive_21_out, residual_32_out), dim=1))
        expansive_31_out = self.expansive_31(expansive_22_out)
        expansive_32_out = self.expansive_32(torch.cat((expansive_31_out, residual_23_out), dim=1))
        expansive_41_out = self.expansive_41(expansive_32_out)
        expansive_42_out = self.expansive_42(torch.cat((expansive_41_out, residual_14_out), dim=1))
        output_out = self.output(expansive_42_out)
        
        return output_out

## les fonction de loss

In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class GradientLoss(nn.Module):
#     """Perte de gradient pour améliorer la netteté des images prédites"""
#     def __init__(self):
#         super(GradientLoss, self).__init__()
        
#     def forward(self, pred_frame, target_frame):
#         """
#         Calcule la différence entre les gradients des images prédites et cibles
#         Args:
#             pred_frame: tensor [B, C, H, W] - image prédite
#             target_frame: tensor [B, C, H, W] - image cible
#         Returns:
#             loss: tensor scalaire - perte de gradient
#         """
#         # Calcul des gradients horizontaux et verticaux pour l'image prédite
#         pred_dx = pred_frame[:, :, :, :-1] - pred_frame[:, :, :, 1:]  # Gradient horizontal
#         pred_dy = pred_frame[:, :, :-1, :] - pred_frame[:, :, 1:, :]  # Gradient vertical
        
#         # Calcul des gradients horizontaux et verticaux pour l'image cible
#         target_dx = target_frame[:, :, :, :-1] - target_frame[:, :, :, 1:]
#         target_dy = target_frame[:, :, :-1, :] - target_frame[:, :, 1:, :]
        
#         # Calcul des pertes L1 pour les gradients
#         loss_x = F.l1_loss(pred_dx, target_dx)
#         loss_y = F.l1_loss(pred_dy, target_dy)
        
#         return loss_x + loss_y

# class OpticalFlowLoss(nn.Module):
#     """Perte de flux optique pour la cohérence temporelle"""
#     def __init__(self):
#         super(OpticalFlowLoss, self).__init__()
        
#     def forward(self, pred_frame, target_frame, prev_frame, flownet):
#         """
#         Calcule la perte de flux optique entre:
#         1. Le flux entre la prédiction et la frame précédente
#         2. Le flux entre la cible et la frame précédente
        
#         Args:
#             pred_frame: tensor [B, C, H, W] - image prédite
#             target_frame: tensor [B, C, H, W] - image cible
#             prev_frame: tensor [B, C, H, W] - frame précédente
#             flownet: nn.Module - réseau pour calculer le flux optique
#         Returns:
#             loss: tensor scalaire - perte de flux optique
#         """
#         # Calculer les deux flux optiques
#         flow_pred = flownet(prev_frame, pred_frame)
#         flow_target = flownet(prev_frame, target_frame)
        
#         # Calculer la différence L1 entre les deux flux
#         loss = F.l1_loss(flow_pred, flow_target)
        
#         return loss

# class CombinedLoss(nn.Module):
#     """Combinaison des différentes pertes avec pondération"""
#     def __init__(self, identity_weight=1.0, gradient_weight=1.0, flow_weight=0.05):
#         """
#         Args:
#             identity_weight: float - poids pour la perte d'identité (MSE)
#             gradient_weight: float - poids pour la perte de gradient
#             flow_weight: float - poids pour la perte de flux optique
#         """
#         super(CombinedLoss, self).__init__()
#         self.identity_loss = nn.MSELoss()
#         self.gradient_loss = GradientLoss()
#         self.optical_flow_loss = OpticalFlowLoss()
#         self.identity_weight = identity_weight
#         self.gradient_weight = gradient_weight
#         self.flow_weight = flow_weight
        
#     def forward(self, pred_frame, target_frame, prev_frame, flownet):
#         """
#         Calcule la perte combinée
        
#         Args:
#             pred_frame: tensor [B, C, H, W] - image prédite
#             target_frame: tensor [B, C, H, W] - image cible
#             prev_frame: tensor [B, C, H, W] - frame précédente
#             flownet: nn.Module - réseau pour calculer le flux optique
#         Returns:
#             loss: tensor scalaire - perte combinée pondérée
#         """
#         # Perte d'identité (MSE)
#         identity = self.identity_loss(pred_frame, target_frame) * self.identity_weight
        
#         # Perte de gradient
#         gradient = self.gradient_loss(pred_frame, target_frame) * self.gradient_weight
        
#         # Perte de flux optique
#         flow = self.optical_flow_loss(pred_frame, target_frame, prev_frame, flownet) * self.flow_weight
        
#         return identity + gradient + flow

import torch
import torch.nn as nn
import torch.nn.functional as F

class GradientLoss(nn.Module):
    """Perte de gradient améliorée avec gestion des bords"""
    def __init__(self, epsilon=1e-6):
        super(GradientLoss, self).__init__()
        self.epsilon = epsilon
        
        # Initialisation des noyaux de convolution une seule fois
        self.register_buffer('kernel_x', torch.tensor([[[[1.0, -1.0]]]]))
        self.register_buffer('kernel_y', torch.tensor([[[[1.0], [-1.0]]]]))
        
    def forward(self, pred_frame, target_frame):
        """
        Version corrigée avec :
        - Conversion automatique des types
        - Gestion des bords optimisée
        """
        # Vérification des dimensions
        if pred_frame.dim() != 4 or target_frame.dim() != 4:
            raise ValueError("Les inputs doivent être de dimension [B, C, H, W]")
            
        # Calcul des gradients
        pred_dx = F.conv2d(pred_frame, 
                          self.kernel_x.repeat(pred_frame.size(1), 1, 1, 1),
                          padding=0, 
                          groups=pred_frame.size(1))
        
        pred_dy = F.conv2d(pred_frame, 
                          self.kernel_y.repeat(pred_frame.size(1), 1, 1, 1),
                          padding=0, 
                          groups=pred_frame.size(1))
        
        target_dx = F.conv2d(target_frame, 
                            self.kernel_x.repeat(target_frame.size(1), 1, 1, 1),
                            padding=0, 
                            groups=target_frame.size(1))
        
        target_dy = F.conv2d(target_frame, 
                            self.kernel_y.repeat(target_frame.size(1), 1, 1, 1),
                            padding=0, 
                            groups=target_frame.size(1))
        
        # Calcul des pertes L1
        loss_x = F.l1_loss(pred_dx, target_dx)
        loss_y = F.l1_loss(pred_dy, target_dy)
        
        return (loss_x + loss_y) / 2

class OpticalFlowLoss(nn.Module):
    """Perte de flux optique simplifiée mais efficace"""
    def __init__(self):
        super(OpticalFlowLoss, self).__init__()
        
    def forward(self, pred_frame, target_frame, prev_frame, flownet):
        # Calcul des flux
        flow_pred = flownet(prev_frame, pred_frame)
        flow_target = flownet(prev_frame, target_frame)
        
        # Perte L1 simple
        return F.l1_loss(flow_pred, flow_target)

class CombinedLoss(nn.Module):
    """Version simplifiée mais robuste de CombinedLoss"""
    def __init__(self, identity_weight=1.0, gradient_weight=1.0, flow_weight=0.05):
        super(CombinedLoss, self).__init__()
        self.identity_loss = nn.MSELoss()
        self.gradient_loss = GradientLoss()
        self.optical_flow_loss = OpticalFlowLoss()
        
        # Poids fixes (version simplifiée)
        self.identity_weight = identity_weight
        self.gradient_weight = gradient_weight
        self.flow_weight = flow_weight
        
    def forward(self, pred_frame, target_frame, prev_frame, flownet):
        # Calcul des pertes
        identity = self.identity_loss(pred_frame, target_frame) * self.identity_weight
        gradient = self.gradient_loss(pred_frame, target_frame) * self.gradient_weight
        flow = self.optical_flow_loss(pred_frame, target_frame, prev_frame, flownet) * self.flow_weight
        
        return identity + gradient + flow

### Entrainement

In [None]:
# Fonction d'accuracy et d'entraînement
def pixel_accuracy(preds, targets):
    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).float()
    targets = (targets > 0.5).float()
    correct = (preds == targets).float()
    return correct.sum() / correct.numel()

def train(model, dataloader, optimizer, epoch, log_interval=200):
    model.train()
    total_loss, total_acc = 0.0, 0.0
    
    for batch_idx, (input_sequences, target_frames, _) in enumerate(dataloader):
        input_sequences = input_sequences.to(device)
        target_frames = target_frames.to(device)
        
        # La dernière frame d'entrée est la frame précédente
        prev_frames = input_sequences[:, -1, :, :, :]
        
        optimizer.zero_grad()
        predicted_frames = model(input_sequences)
        
        # Utilisez model.flownet au lieu de flownet
        loss = criterion(predicted_frames, target_frames, prev_frames, model.flownet)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_acc += pixel_accuracy(predicted_frames, target_frames).item()
        
        if batch_idx % log_interval == 0:
            avg_loss = total_loss / (batch_idx + 1)
            avg_acc = total_acc / (batch_idx + 1)
            print(f"Epoch {epoch} | Batch {batch_idx}/{len(dataloader)} | "
                  f"Loss: {avg_loss:.4f} | Acc: {avg_acc:.4f}")

    avg_loss = total_loss / len(dataloader)
    avg_acc = total_acc / len(dataloader)
    print(f"\nEpoch {epoch} Summary: Avg Loss: {avg_loss:.4f} | Avg Acc: {avg_acc:.4f}\n")
    return avg_loss, avg_acc

In [None]:
import torch
import matplotlib.pyplot as plt


# Initialisation du modèle
model = TransAnomaly().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
criterion = CombinedLoss(identity_weight=1.0, gradient_weight=1.0, flow_weight=0.05).to(device)

# Boucle d'entraînement
losses, accuracies = [], []
EPOCHS = 6

for epoch in range(1, EPOCHS + 1):
    epoch_loss, epoch_acc = train(model, train_dataloader, optimizer, epoch)
    losses.append(epoch_loss)
    accuracies.append(epoch_acc)
    print(f"[EPOCH: {epoch}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

# ➕ Plotting
plt.figure(figsize=(12,5))
plt.subplot(1, 2, 1)
plt.plot(losses, label='Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss par Epoch")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(accuracies, label='Accuracy', color='green')
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy par Epoch")
plt.legend()

plt.tight_layout()
plt.show()


# evaluation avec ROC courb

In [None]:
import numpy as np
from sklearn.metrics import roc_curve, auc, confusion_matrix
import matplotlib.pyplot as plt
import torch
import seaborn as sns  # Pour une meilleure visualisation de la matrice de confusion

def calculate_psnr(img1, img2, max_pixel=1.0):
    """Calcule le PSNR entre deux images (supposées normalisées entre -1 et 1)."""
    img1 = (img1 + 1) / 2  # Conversion [-1, 1] -> [0, 1]
    img2 = (img2 + 1) / 2
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 10 * torch.log10(max_pixel ** 2 / mse).item()

def evaluate_model(model, dataloader, device, window_size=64, stride=32):
    model.eval()
    psnrs = []
    
    with torch.no_grad():
        for input_frames, target_frames, _ in dataloader:
            input_frames = input_frames.to(device)
            target_frames = target_frames.to(device)
            predicted_frames = model(input_frames)
            for i in range(predicted_frames.shape[0]):
                psnr = calculate_psnr(predicted_frames[i], target_frames[i])
                psnrs.append(psnr)
    
    psnrs = np.array(psnrs)
    scores = (psnrs - psnrs.min()) / (psnrs.max() - psnrs.min())
    fake_labels = np.zeros_like(scores)
    fake_labels[np.argsort(scores)[:len(scores)//10]] = 1  # 10% lowest PSNR = anomalies
    
    fpr, tpr, thresholds = roc_curve(fake_labels, 1 - scores)  # Inverser car PSNR bas = anomalie
    roc_auc = auc(fpr, tpr)
    optimal_idx = np.argmax(tpr - fpr)
    optimal_threshold = thresholds[optimal_idx]
    predictions = (1 - scores) > optimal_threshold
    tn, fp, fn, tp = confusion_matrix(fake_labels, predictions).ravel()
    
    return roc_auc, (tp, fp, tn, fn), scores, fpr, tpr, confusion_matrix(fake_labels, predictions)  # Ajout de la matrice de confusion complète

# Utilisation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Évaluation et calcul de l'AUC
roc_auc, (tp, fp, tn, fn), scores, fpr, tpr, conf_matrix = evaluate_model(model, test_dataloader, device)

# Affichage des résultats
print(f"[AUC] = {roc_auc:.4f}")
print(f"Matrice de confusion: TP={tp}, FP={fp}, TN={tn}, FN={fn}")
print(f"Precision = {tp / (tp + fp):.4f}, Recall = {tp / (tp + fn):.4f}")

# Tracé de la courbe ROC
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc="lower right")

# Tracé de la matrice de confusion
plt.subplot(1, 2, 2)
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Normal', 'Anomaly'], 
            yticklabels=['Normal', 'Anomaly'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')

plt.tight_layout()
plt.show()

## Evaluation pour un seul video

In [None]:
## Extraire les frames d'un seul video

In [None]:
import cv2
import os

# Chemin de la vidéo à traiter
video_path = "/kaggle/input/anomalous-action-detection-dataset/Anomalous Action Detection Dataset( Ano-AAD)/abnormal class/Fighting/46.mp4"
output_folder="/kaggle/working/Anomalous Action Detection Dataset/NormaleAction1/abnormale video"

# S'assurer que le dossier de sortie existe
os.makedirs(output_folder, exist_ok=True)

# Extraire le nom de la vidéo sans extension
video_name = os.path.splitext(os.path.basename(video_path))[0]
video_output_folder = os.path.join(output_folder, video_name)
os.makedirs(video_output_folder, exist_ok=True)

# Ouvrir la vidéo
cap = cv2.VideoCapture(video_path)
    
# Vérifier si la vidéo s'ouvre correctement
if not cap.isOpened():
    print(f"❌ Erreur de chargement : {video_path}")
else:
    print(f"🎥 Traitement de {video_path}")

    # Extraire et sauvegarder chaque frame
    frame_idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break  # Fin de la vidéo

        # Sauvegarder l'image en JPG
        frame_filename = os.path.join(video_output_folder, f"frame_{frame_idx:04d}.jpg")
        cv2.imwrite(frame_filename, frame)
        frame_idx += 1

    cap.release()
    print(f"✔ {frame_idx} Frames sauvegardées dans : {video_output_folder}")
    print("✅ Extraction terminée !")

## pour cadrer le nombre de frames  utiliser pour chaque video

In [None]:
import os
import shutil

# Dossier source de la vidéo (images/frames)
input_video_folder = "/kaggle/working/Anomalous Action Detection Dataset/NormaleAction1/abnormale video/46"  # <-- Remplace par le nom du dossier de ta vidéo
output_folder = "/kaggle/working/Anomalous Action Detection Dataset/NormaleAction1/abnormale videok"


# Nombre maximum de frames à copier
MAX_FRAMES = 450

# S'assurer que le dossier de sortie existe
os.makedirs(output_folder, exist_ok=True)

# Lister et trier les frames
frames = sorted(os.listdir(input_video_folder))

frames_to_copy = frames[:MAX_FRAMES]

# Copier les frames
for frame in frames_to_copy:
    src = os.path.join(input_video_folder, frame)
    dst = os.path.join(output_folder, frame)
    shutil.copy(src, dst)

print(f"✅ {len(frames_to_copy)} frames copiées depuis {input_video_folder} vers {output_folder}")


## DataLoader pour un seule video :

In [None]:
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class SingleVideoDataset(Dataset):
    def __init__(self, frames_dir, sequence_length=4, transform=None, overlap=True):
        self.frames_dir = frames_dir
        self.sequence_length = sequence_length
        self.transform = transform
        self.overlap = overlap
        self.sequences = []

        # Charger et trier les frames
        self.frames = sorted(
            [f for f in os.listdir(frames_dir) if f.startswith('frame_')],
            key=lambda x: int(x.split('_')[1].split('.')[0]))
        
        # Calculer le nombre maximal de séquences complètes
        max_sequences = len(self.frames) - sequence_length
        if max_sequences <= 0:
            raise ValueError("Pas assez de frames pour former une séquence !")
        
        # Générer les séquences avec troncature finale
        if overlap:
            self.sequences = [
                [os.path.join(frames_dir, self.frames[j]) 
                 for j in range(i, i + sequence_length + 1)]
                for i in range(max_sequences)
            ]
        else:
            step = sequence_length + 1
            self.sequences = [
                [os.path.join(frames_dir, self.frames[j]) 
                 for j in range(i, i + sequence_length + 1)]
                for i in range(0, max_sequences, step)
            ]
        
        # Tronquer pour avoir un nombre de séquences divisible par le batch_size
        self.total_sequences = len(self.sequences)

    def __len__(self):
        return self.total_sequences

    def __getitem__(self, idx):
        frame_paths = self.sequences[idx]
        images = []
        
        for path in frame_paths:
            img = cv2.imread(path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (256, 256))
            
            if self.transform:
                img = self.transform(img)
                
            images.append(img)

        input_frames = torch.stack(images[:-1], dim=0)  # Les N premières frames
        target_frame = images[-1]  # La dernière frame
        return input_frames, target_frame

In [None]:
# Définir les transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Création du Dataset pour une seule vidéo
video_frames_dir = "/kaggle/working/Anomalous Action Detection Dataset/NormaleAction1/abnormale videok" 
dataset = SingleVideoDataset(
    frames_dir=video_frames_dir,
    sequence_length=4,
    transform=transform,
    overlap=True
)

# Création du DataLoader
signal_dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=False,  # Garder l'ordre temporel
    num_workers=2,
    pin_memory=True,
    drop_last=True  # <-- Ajoutez cette ligne

)

# # Exemple d'utilisation
# for inputs, targets in signal_dataloader:
#     print(f"Batch input shape: {inputs.shape}")  # [4, 4, 3, 256, 256]
#     print(f"Batch target shape: {targets.shape}")  # [4, 3, 256, 256]

# evaluation d'un seul video long ou plusieurs video moyen ( le graphe de score d'anomalie au cours de temps)

In [None]:
import numpy as np
import torch
from skimage.metrics import peak_signal_noise_ratio as psnr
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc

def calculate_psnr(pred, target, max_val=1.0):
    """Calcule le PSNR entre une frame prédite et la frame réelle."""
    pred_np = pred.detach().cpu().numpy().transpose(1, 2, 0)  # (H, W, C)
    target_np = target.detach().cpu().numpy().transpose(1, 2, 0)
    return psnr(target_np, pred_np, data_range=max_val)

def sliding_window_psnr(pred, target, window_size=64, stride=32, max_val=1.0):
    """Calcule le PSNR sur les patches avec les plus grandes erreurs (sliding window)."""
    h, w = pred.shape[-2], pred.shape[-1]
    mse_patches = []
    
    for i in range(0, h - window_size + 1, stride):
        for j in range(0, w - window_size + 1, stride):
            patch_pred = pred[..., i:i+window_size, j:j+window_size]
            patch_target = target[..., i:i+window_size, j:j+window_size]
            mse = torch.mean((patch_pred - patch_target) ** 2).item()
            mse_patches.append(mse)
    
    if not mse_patches:
        return 0.0
    
    # Prend les p patches avec les plus grandes MSE (p = moitié)
    p = len(mse_patches) // 2
    top_mse = sorted(mse_patches, reverse=True)[:p]
    avg_mse = np.mean(top_mse)
    return 10 * np.log10(max_val ** 2 / avg_mse) if avg_mse > 0 else 100.0

In [None]:
# without opticale flow

# def evaluate_anomaly_unlabeled(model, dataloader, device, use_sliding_window=True, window_size=64, stride=32):
#     model.eval()
#     psnr_scores = []
#     frame_predictions = []

#     with torch.no_grad():
#         for batch_idx, (input_sequences, target_frames) in enumerate(dataloader):  # Ajoutez _ pour ignorer le 3ème élément
#             input_sequences = input_sequences.to(device)
#             target_frames = target_frames.to(device)
#             predicted_frames = model(input_sequences)
            
#             for pred, target in zip(predicted_frames, target_frames):
#                 if use_sliding_window:
#                     psnr_score = sliding_window_psnr(pred, target, window_size, stride)
#                 else:
#                     psnr_score = calculate_psnr(pred, target)
#                 psnr_scores.append(psnr_score)
#                 frame_predictions.append((pred.cpu(), target.cpu()))

#     min_psnr, max_psnr = min(psnr_scores), max(psnr_scores)
#     regularity_scores = [(psnr - min_psnr) / (max_psnr - min_psnr + 1e-6) for psnr in psnr_scores]
    
#     return regularity_scores, frame_predictions

#with opticale flow

def evaluate_anomaly_unlabeled(model, dataloader, device, use_sliding_window=True, window_size=64, stride=32):
    model.eval()
    psnr_scores = []
    frame_predictions = []

    with torch.no_grad():
        for batch_idx, (input_sequences, target_frames) in enumerate(dataloader):
            input_sequences = input_sequences.to(device)
            target_frames = target_frames.to(device)
            
            # La dernière frame d'entrée est la frame précédente
            prev_frames = input_sequences[:, -1, :, :, :]
            
            predicted_frames = model(input_sequences)
            
            # Calcul du flux optique entre la frame précédente et la prédiction
            flow_pred = model.flownet(prev_frames, predicted_frames)
            
            for pred, target, flow in zip(predicted_frames, target_frames, flow_pred):
                if use_sliding_window:
                    psnr_score = sliding_window_psnr(pred, target, window_size, stride)
                else:
                    psnr_score = calculate_psnr(pred, target)
                psnr_scores.append(psnr_score)
                frame_predictions.append((pred.cpu(), target.cpu(), flow.cpu()))

    min_psnr, max_psnr = min(psnr_scores), max(psnr_scores)
    regularity_scores = [(psnr - min_psnr) / (max_psnr - min_psnr + 1e-6) for psnr in psnr_scores]
    
    return regularity_scores, frame_predictions

In [None]:
def plot_anomalies(scores, threshold, num_frames_to_plot=5):
    anomalous_indices = np.where(np.array(scores) < threshold)[0]
    plt.figure(figsize=(15, 5))
    plt.plot(scores, label='Scores de régularité')
    plt.axhline(y=threshold, color='r', linestyle='--', label='Seuil d\'anomalie')
    plt.scatter(anomalous_indices[:num_frames_to_plot], 
                [scores[i] for i in anomalous_indices[:num_frames_to_plot]], 
                color='red', label='Frames anormales')
    plt.legend()
    plt.xlabel("Frame")
    plt.ylabel("Score (0=anomalie, 1=normal)")
    plt.title("Détection d'anomalies")
    plt.show()



In [None]:
## apparence d'anomalie a l'aide de optical flow

In [None]:

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

# 1. Scores
scores, frames = evaluate_anomaly_unlabeled(model, signal_dataloader, device)

# 2. Seuil
threshold = np.percentile(scores, 5)

# 3. Visualisation
plot_anomalies(scores, threshold)

# 4. Indices anormaux
anomalous_indices = np.where(np.array(scores) < threshold)[0]


# 5. Affichage des 10 premières anomalies
for idx in anomalous_indices[:30]:
    pred, target, _ = frames[idx]

    def prepare_image(tensor):
        img = tensor.numpy()
        if len(img.shape) == 3 and img.shape[0] == 3:
            img = img.transpose(1, 2, 0)  # (C, H, W) → (H, W, C)
        img = (img - img.min()) / (img.max() - img.min() + 1e-6)  # Normalisation [0, 1]
        return np.clip(img, 0, 1)

    # --- Calcul du masque d'erreur ---
    diff = torch.abs(pred - target)
    error_map = diff.mean(dim=0)  # Moyenne sur les canaux RGB
    norm_error = (error_map - error_map.min()) / (error_map.max() - error_map.min() + 1e-6)
    blurred_mask = cv2.GaussianBlur(norm_error.numpy(), (31, 31), 10)  # Réduction du flou pour plus de détails

    # --- Création de la heatmap améliorée ---
    heatmap = cv2.applyColorMap((blurred_mask * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)  # Colormap plus contrastée
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

    # Renforcement des couleurs (HSV)
    hsv = cv2.cvtColor(heatmap, cv2.COLOR_RGB2HSV)
    hsv[..., 1] = hsv[..., 1] * 1.8  # Saturation augmentée
    hsv[..., 2] = np.clip(hsv[..., 2] * 1.3, 0, 1)  # Luminosité augmentée
    heatmap = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)

    # --- Transparence adaptative ---
    alpha_mask = np.clip(blurred_mask * 3.5, 0, 1)  # Réduction de la transparence
    base = np.ones_like(heatmap) * 0.9  # Fond légèrement gris pour mieux contraster

    # Fusion finale
    transparent_heatmap = heatmap * alpha_mask[..., np.newaxis] + base * (1 - alpha_mask[..., np.newaxis])

    # --- Affichage ---
    plt.figure(figsize=(18, 6), facecolor='white')

    # Image réelle
    plt.subplot(1, 3, 1)
    plt.imshow(prepare_image(target))
    plt.title("Image Réelle", fontsize=12, pad=10)
    plt.axis('off')

    # Image prédite
    plt.subplot(1, 3, 2)
    plt.imshow(prepare_image(pred))
    plt.title(f"Prédiction (score: {scores[idx]:.2f})", fontsize=12, pad=10)
    plt.axis('off')

    # Heatmap des anomalies
    plt.subplot(1, 3, 3)
    plt.imshow(transparent_heatmap)
    plt.title("Flux Optique - Anomalies", fontsize=12, pad=10)
    plt.axis('off')

    plt.tight_layout()
    plt.show()

# evaluation avec bounding box    

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage.metrics import peak_signal_noise_ratio as psnr

def sliding_window_anomaly_detection(pred, target, window_size=64, stride=32, threshold=0.1):
    """Détecte les patches anormaux et retourne leurs coordonnées"""
    h, w = pred.shape[-2], pred.shape[-1]
    anomaly_boxes = []
    
    for i in range(0, h - window_size + 1, stride):
        for j in range(0, w - window_size + 1, stride):
            patch_pred = pred[..., i:i+window_size, j:j+window_size]
            patch_target = target[..., i:i+window_size, j:j+window_size]
            
            # Calcul de l'erreur MSE pour le patch
            mse = torch.mean((patch_pred - patch_target) ** 2).item()
            
            # Si l'erreur dépasse le seuil, on considère le patch comme anormal
            if mse > threshold:
                anomaly_boxes.append((j, i, j+window_size, i+window_size, mse))
    
    return anomaly_boxes

# def evaluate_and_visualize_anomalies(model, dataloader, device, window_size=64, stride=32, anomaly_threshold=0.1, display_count=5):
#     model.eval()
    
#     with torch.no_grad():
#         for batch_idx, (input_sequences, target_frames) in enumerate(dataloader):
#             input_sequences = input_sequences.to(device)
#             target_frames = target_frames.to(device)
#             predicted_frames = model(input_sequences)
            
#             for idx, (pred, target) in enumerate(zip(predicted_frames, target_frames)):
#                 # Convertir les tenseurs en images numpy
#                 pred_np = pred.detach().cpu().numpy().transpose(1, 2, 0)
#                 target_np = target.detach().cpu().numpy().transpose(1, 2, 0)
                
#                 # Normalisation des images
#                 pred_np = (pred_np - pred_np.min()) / (pred_np.max() - pred_np.min() + 1e-6)
#                 target_np = (target_np - target_np.min()) / (target_np.max() - target_np.min() + 1e-6)
                
#                 # Détection des anomalies par sliding window
#                 anomaly_boxes = sliding_window_anomaly_detection(pred, target, window_size, stride, anomaly_threshold)
                
#                 # Visualisation
#                 if len(anomaly_boxes) > 0 and idx < display_count:
#                     plt.figure(figsize=(15, 6))
                    
#                     # Afficher l'image originale avec les bounding boxes
#                     plt.subplot(1, 2, 1)
#                     plt.imshow(target_np)
#                     plt.title("Image Originale avec Anomalies Détectées")
                    
#                     # Dessiner les bounding boxes autour des anomalies
#                     for box in anomaly_boxes:
#                         x1, y1, x2, y2, mse = box
#                         rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, 
#                                                linewidth=2, edgecolor='r', facecolor='none')
#                         plt.gca().add_patch(rect)
#                         plt.text(x1, y1, f"{mse:.2f}", color='white', 
#                                 bbox=dict(facecolor='red', alpha=0.5))
                    
#                     # Afficher l'image prédite pour comparaison
#                     plt.subplot(1, 2, 2)
#                     plt.imshow(pred_np)
#                     plt.title("Image Prédite")
                    
#                     plt.tight_layout()
#                     plt.show()

# # Paramètres
# window_size = 64  # Taille des patches
# stride = 32       # Pas du sliding window
# anomaly_threshold = 0.1  # Seuil d'anomalie (à ajuster selon votre cas)

# # Exécution
# evaluate_and_visualize_anomalies(model, signal_dataloader, device, 
#                                 window_size=window_size, 
#                                 stride=stride, 
#                                 anomaly_threshold=anomaly_threshold)


def evaluate_and_visualize_anomalies(model, dataloader, device, window_size=64, stride=32, 
                                   anomaly_threshold=0.1, max_display=20):  # Ajout du paramètre max_display
    model.eval()
    displayed_count = 0  # Compteur pour les frames affichées
    
    with torch.no_grad():
        for batch_idx, (input_sequences, target_frames) in enumerate(dataloader):
            if displayed_count >= max_display:  # Arrêter après 20 frames
                break
                
            input_sequences = input_sequences.to(device)
            target_frames = target_frames.to(device)
            predicted_frames = model(input_sequences)
            
            for idx, (pred, target) in enumerate(zip(predicted_frames, target_frames)):
                if displayed_count >= max_display:  # Vérification à chaque frame
                    break
                    
                # Convertir les tenseurs en images numpy
                pred_np = pred.detach().cpu().numpy().transpose(1, 2, 0)
                target_np = target.detach().cpu().numpy().transpose(1, 2, 0)
                
                # Normalisation des images
                pred_np = (pred_np - pred_np.min()) / (pred_np.max() - pred_np.min() + 1e-6)
                target_np = (target_np - target_np.min()) / (target_np.max() - target_np.min() + 1e-6)
                
                # Détection des anomalies par sliding window
                anomaly_boxes = sliding_window_anomaly_detection(pred, target, window_size, stride, anomaly_threshold)
                
                # Visualisation seulement si anomalies détectées
                if len(anomaly_boxes) > 0:
                    displayed_count += 1  # Incrémenter le compteur
                    
                    plt.figure(figsize=(15, 6))
                    
                    # Afficher l'image originale avec les bounding boxes
                    plt.subplot(1, 2, 1)
                    plt.imshow(target_np)
                    plt.title(f"Frame Anormale {displayed_count}/{max_display}")
                    
                    # Dessiner les bounding boxes autour des anomalies
                    for box in anomaly_boxes:
                        x1, y1, x2, y2, mse = box
                        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, 
                                               linewidth=2, edgecolor='r', facecolor='none')
                        plt.gca().add_patch(rect)
                        plt.text(x1, y1, f"{mse:.2f}", color='white', 
                                bbox=dict(facecolor='red', alpha=0.5))
                    
                    # Afficher l'image prédite pour comparaison
                    plt.subplot(1, 2, 2)
                    plt.imshow(pred_np)
                    plt.title("Prédiction Correspondante")
                    
                    plt.tight_layout()
                    plt.show()

# Paramètres (inchangés)
window_size = 64
stride = 32
anomaly_threshold = 0.1

# Exécution avec affichage des 20 premières frames anormales
evaluate_and_visualize_anomalies(model, signal_dataloader, device, 
                                window_size=window_size, 
                                stride=stride, 
                                anomaly_threshold=anomaly_threshold,
                                max_display=20)  # Nouveau paramètre