ViT для видеоэмбеддингов и их интерпретации.
CLIP для одновременной работы с видео и текстовыми метками.
LLM для интерпретации визуальных эмбеддингов и определения операций.

In [1]:
import os

from PIL import Image
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score,
)
from tqdm import tqdm
from transformers import (
    AutoProcessor, CLIPModel, CLIPProcessor, GPT2Model, GPT2Tokenizer,
)

In [2]:
VIDEO_DIR = '/root/tatneft/datasets/violations_dataset/cuts1'
LABELS_FILE = '/root/tatneft/datasets/violations_dataset/cuts1_train.txt'
VAL_LABELS_FILE = '/root/tatneft/datasets/violations_dataset/cuts1_val.txt'
FRAME_COUNT = 8
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
def load_labels(labels_file):
    """Reads video paths and labels from a file. Returns list of (video_file, label) tuples.
    
    Args:
        labels_file: Text file with lines formatted 'video_path label'.
    
    Returns:
        List of tuples (str, int) with video paths and integer labels.
    """
    data = []
    with open(labels_file, 'r') as f:
        for line in f:
            video_file, label = line.strip().split()
            data.append((video_file, int(label)))
    return data

In [4]:
class VideoDataset(Dataset):
    """PyTorch dataset for loading videos and corresponding labels."""

    def __init__(self, video_dir, labels, transform=None):
        """
        Args:
            video_dir: Directory containing video files
            labels: List of (video_filename, label) tuples
            transform: Optional transform to apply to video frames
        """
        self.video_dir = video_dir
        self.labels = labels
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        """Returns the total number of videos in the dataset."""
        return len(self.labels)

    def __getitem__(self, idx):
        """
        Returns:
            tuple: (transformed_frames, label) for video at given index
        """
        video_file, label = self.labels[idx]
        video_path = os.path.join(self.video_dir, video_file)
        frames = self._load_video(video_path)
        frames = torch.stack([self.transform(frame) for frame in frames])
        return frames, label

    def _load_video(self, path):
        """
        Loads and pads video frames to FRAME_COUNT.
        Returns:
            list: PIL Image objects of video frames
        """
        cap = cv2.VideoCapture(path)
        frames = []
        count = 0
        while cap.isOpened() and count < FRAME_COUNT:
            ret, frame = cap.read()
            if not ret:
                break
            frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            frames.append(frame)
            count += 1
        cap.release()
        while len(frames) < FRAME_COUNT:
            frames.append(frames[-1])
        return frames

In [5]:
"""Defining ViT precessor and model"""

vit_processor = AutoProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
vit_model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


In [7]:
"""Defining CLIP precessor and model"""

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


In [8]:
class CLIPTextVideoClassifier(nn.Module):
    """Classifier combining CLIP's text and video features for classification.
    
    Uses CLIP's text embeddings from multiple frames and classifies them using
    a linear layer on top of averaged frame embeddings.
    """

    def __init__(self, clip_model, num_classes):
        """
        Args:
            clip_model: Pretrained CLIP model
            num_classes: Number of output classes
        """
        super(CLIPTextVideoClassifier, self).__init__()
        self.clip_model = clip_model
        text_embed_dim = clip_model.config.text_config.hidden_size
        self.fc = nn.Linear(text_embed_dim, num_classes)
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    def forward(self, frames, text_inputs):
        """
        Args:
            frames: Input video frames tensor of shape 
                   (batch_size, num_frames, channels, height, width)
            text_inputs: List of text prompts (one per video in batch)
        
        Returns:
            logits: Classification logits of shape (batch_size, num_classes)
        """
        batch_size, num_frames, channels, height, width = frames.size()
        frames_list = [frames[i, j] for i in range(batch_size) for j in range(num_frames)]
        inputs = self.clip_processor(
            images=frames_list,
            text=text_inputs * num_frames,
            return_tensors="pt",
            padding=True
        ).to(frames.device)
        outputs = self.clip_model(**inputs)
        text_embeds = outputs.text_embeds.view(batch_size, num_frames, -1).mean(dim=1)
        return self.fc(text_embeds)

In [11]:
def train_model(model, train_loader, val_loader, criterion, optimizer, save_dir="checkpoints"):
    """Trains a model with validation and checkpoint saving.
    
    Args:
        model: Model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        criterion: Loss function
        optimizer: Optimization algorithm
        save_dir: Directory to save checkpoints (default: "checkpoints")
    
    Returns:
        None (saves best model weights to disk)
    
    Behavior:
        - Trains for EPOCHS iterations
        - Validates after each epoch
        - Saves best model based on validation F1 score
        - Prints training/validation metrics
    """
    os.makedirs(save_dir, exist_ok=True)
    best_metric = -1  
    best_epoch = 0

    for epoch in range(EPOCHS):
        model.train() 
        epoch_loss = 0.0

        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            text_inputs = [f"Label: {label.item()}" for label in labels]
            outputs = model(inputs, text_inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        val_loss, val_metrics = validate_model(model, val_loader, criterion)
        print(f"Epoch {epoch+1}/{EPOCHS}: Train Loss {epoch_loss/len(train_loader):.4f}, Val Loss {val_loss:.4f}")
        print(f"Validation Metrics: Precision: {val_metrics['precision']:.4f}, Recall: {val_metrics['recall']:.4f}, F1: {val_metrics['f1']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}")

        current_metric = val_metrics['f1'] 
        if current_metric > best_metric:
            best_metric = current_metric
            best_epoch = epoch + 1
            torch.save(model.state_dict(), os.path.join(save_dir, f"best_model_epoch_{epoch+1}.pth"))

In [12]:
def validate_model(model, val_loader, criterion):
    """Evaluates model performance on validation data.
    
    Args:
        model: Model to evaluate
        val_loader: DataLoader for validation data
        criterion: Loss function
    
    Returns:
        tuple: (val_loss, metrics_dict) where metrics_dict contains:
            - precision (weighted average)
            - recall (weighted average)
            - f1 (weighted average)
            - accuracy
    
    Note:
        Uses weighted averaging for multiclass metrics. Sets model to eval mode.
    """
    model.eval()
    val_loss = 0.0
    val_metrics = {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'accuracy': 0.0}
    y_true, y_pred = [], []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            text_inputs = [f"Label: {label.item()}" for label in labels]
            outputs = model(inputs, text_inputs)
            val_loss += criterion(outputs, labels).item()
            _, predicted = outputs.max(1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    accuracy = accuracy_score(y_true, y_pred)

    val_metrics = {'precision': precision, 'recall': recall, 'f1': f1, 'accuracy': accuracy}
    return val_loss / len(val_loader), val_metrics

In [13]:
"""Loading data"""

train_labels = load_labels(LABELS_FILE)
val_labels = load_labels(VAL_LABELS_FILE)

train_dataset = VideoDataset(VIDEO_DIR, train_labels)
val_dataset = VideoDataset(VIDEO_DIR, val_labels)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [14]:
"""Defining learning data"""

num_classes = len(set(label for _, label in train_labels))
model = CLIPTextVideoClassifier(clip_model, num_classes).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [31]:
"""Training model"""

train_model(model, train_loader, val_loader, criterion, optimizer)

Epoch 1/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [01:54<00:00,  4.78s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 1/10: Train Loss 2.4098, Val Loss 2.2957
Validation Metrics: Precision: 0.3140, Recall: 0.4177, F1: 0.2998, Accuracy: 0.4177


Epoch 2/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [02:04<00:00,  5.17s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 2/10: Train Loss 2.4086, Val Loss 2.2965
Validation Metrics: Precision: 0.1001, Recall: 0.3165, F1: 0.1521, Accuracy: 0.3165


Epoch 3/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [02:05<00:00,  5.25s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 3/10: Train Loss 2.4159, Val Loss 2.3039
Validation Metrics: Precision: 0.1001, Recall: 0.3165, F1: 0.1521, Accuracy: 0.3165


Epoch 4/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [02:02<00:00,  5.10s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 4/10: Train Loss 2.4119, Val Loss 2.3062
Validation Metrics: Precision: 0.1001, Recall: 0.3165, F1: 0.1521, Accuracy: 0.3165


Epoch 5/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [02:07<00:00,  5.31s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 5/10: Train Loss 2.4114, Val Loss 2.3038
Validation Metrics: Precision: 0.1001, Recall: 0.3165, F1: 0.1521, Accuracy: 0.3165


Epoch 6/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [02:05<00:00,  5.25s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 6/10: Train Loss 2.4138, Val Loss 2.3025
Validation Metrics: Precision: 0.1001, Recall: 0.3165, F1: 0.1521, Accuracy: 0.3165


Epoch 7/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [02:06<00:00,  5.25s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 7/10: Train Loss 2.4125, Val Loss 2.2967
Validation Metrics: Precision: 0.1001, Recall: 0.3165, F1: 0.1521, Accuracy: 0.3165


Epoch 8/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [02:06<00:00,  5.28s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 8/10: Train Loss 2.4045, Val Loss 2.2912
Validation Metrics: Precision: 0.3155, Recall: 0.5063, F1: 0.3697, Accuracy: 0.5063


Epoch 9/10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [02:07<00:00,  5.32s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 9/10: Train Loss 2.4079, Val Loss 2.2939
Validation Metrics: Precision: 0.3155, Recall: 0.5063, F1: 0.3697, Accuracy: 0.5063


Epoch 10/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [02:06<00:00,  5.28s/it]


Epoch 10/10: Train Loss 2.4053, Val Loss 2.2841
Validation Metrics: Precision: 0.3140, Recall: 0.4177, F1: 0.2998, Accuracy: 0.4177


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [15]:
model.load_state_dict(torch.load('checkpoints/best_model_epoch_8.pth'))

  model.load_state_dict(torch.load('checkpoints/best_model_epoch_8.pth'))


<All keys matched successfully>

In [16]:
"""Best score"""

validate_model(model, val_loader, criterion)

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


(2.291225266456604,
 {'precision': 0.31545107494474584,
  'recall': 0.5063291139240507,
  'f1': 0.3696777905638665,
  'accuracy': 0.5063291139240507})