In [13]:
import cv2
import torch
import torchvision.transforms as T
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import precision_recall_fscore_support
from tqdm import tqdm
from CNN_search.extract_segments import get_segments

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cnn = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
cnn = torch.nn.Sequential(*(list(cnn.children())[:-1])).to(device).eval()

Using cache found in C:\Users\User/.cache\torch\hub\pytorch_vision_main


In [15]:
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [16]:
def extract_video_features(video_path, t_start, t_end, fps=1):
    cap = cv2.VideoCapture(video_path)
    current_fps = cap.get(cv2.CAP_PROP_FPS)
    frame_indices = [int((t_start + i) * current_fps) for i in np.arange(0, t_end - t_start,
                                                                         1 / fps)]  # Список индексов кадров, которые соответствуют времени внутри сегмента
    features = []
    for idx in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)

        ret, frame = cap.read()
        if not ret:
            print('Не удалось считать кадр')
            continue

        img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img)
        x = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            feat = cnn(x).flatten().cpu().numpy()

        features.append(feat)

    cap.release()
    if features:
        return np.mean(features, axis=0)  # Итоговый эмбеддинг[512]
    else:
        return np.zeros(512)

In [17]:
class IntroDataset(Dataset):
    def __init__(self, samlpes):
        self.samlpes = samlpes

    def __len__(self):
        return len(self.samlpes)

    def __getitem__(self, idx):
        item = self.samlpes[idx]
        features = extract_video_features(item['video'], item['t_start'], item['t_end'])
        return torch.from_numpy(features).float(), item['label']

In [18]:
class Classifier(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=128):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return self.model(x).squeeze(-1)


In [19]:
samples = get_segments('../data/data_train_short/data_train_short/labels.json')

In [20]:
dataset = IntroDataset(samples)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)

model = Classifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

for epoch in range(30):
    model.train()
    for feats, labels in tqdm(train_loader):
        feats, labels = feats.to(device), labels.float().to(device)
        logits = model(feats)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    all_preds, all_true = [], []
    with torch.no_grad():
        for feats, labels in val_loader:
            feats, labels = feats.to(device), labels.float().to(device)
            logits = model(feats)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).cpu().numpy()
            all_preds.extend(preds)
            all_true.extend(labels.cpu().numpy())
    p, r, f, _ = precision_recall_fscore_support(all_true, all_preds, average='binary')
    print(f"Epoch {epoch}: Precision={p:.3f}, Recall={r:.3f}, F1={f:.3f}")

  0%|          | 0/278 [00:03<?, ?it/s]


KeyboardInterrupt: 