In [12]:
import torch
import json
import os
import numpy as np

In [3]:
!unzip train_data.zip

Archive:  train_data.zip
   creating: train_data/
   creating: train_data/-220020068_456255399/
  inflating: train_data/-220020068_456255399/video_tensor.pt  
  inflating: train_data/-220020068_456255399/audio_tensor.pt  
  inflating: train_data/labels.json  


In [4]:
train_path = '/content/train_data'

In [18]:
from torch.utils.data import Dataset


class IntroDataset(Dataset):

    def __init__(self, data_path):
        super().__init__()

        self.data_path = data_path
        _, self.video_names, self.labels_name = next(os.walk(data_path))

        with open(f'{train_path}/{self.labels_name[0]}') as f:
            self.labels = json.load(f)

    def __len__(self):
        return len(self.video_names)

    def __getitem__(self, index):
        video_frames = torch.load(f"{self.data_path}/{self.video_names[index]}/video_tensor.pt").to(torch.float32)
        audio_frames = torch.load(f"{self.data_path}/{self.video_names[index]}/audio_tensor.pt").to(torch.float32)

        label = torch.tensor(self.labels[self.video_names[index]], dtype=int)

        return {'video': video_frames,
                'audio': audio_frames,
                'label': label}


In [19]:
intro_dataset = IntroDataset(train_path)

In [24]:
import torch.nn as nn

class AudioBackbone(nn.Module):
    def __init__(self, filters_num):
        super().__init__()

        self.block1 = nn.Sequential(
            nn.Conv1d(filters_num, 32, kernel_size=5, padding='same'),
            nn.BatchNorm1d(32),
            nn.MaxPool1d(4)
        )

        self.block2 = nn.Sequential(
            nn.Conv1d(32, 64, kernel_size=5, padding='same'),
            nn.BatchNorm1d(64),
            nn.MaxUnpool1d(5)
        )

        self.block3 = nn.Sequential(
            nn.Conv1d(64, 128, kernel_size=5, padding='same'),
            nn.BatchNorm1d(128),
            nn.MaxUnpool1d(5)
        )

    def forward(self, audio):
        audio = self.block1(audio)
        audio = self.block2(audio)
        audio = self.block3(audio)

        return audio

In [25]:
import torchvision

class VideoBackbone(nn.Module):
    def __init__(self):
        super().__init__()

        self.resnet = torchvision.models.resnet50()
        self.resnet.avgpool = nn.Identity()
        self.resnet.fc = nn.Identity()

    def forward(self, video):
        return self.resnet(video)

In [26]:
intro_dataset[0]['audio'].size(), intro_dataset[0]['video'].size()

(torch.Size([240, 128, 1600]), torch.Size([240, 3, 224, 224]))

In [27]:
a = torch.zeros((3, 4, 5, 6))
a.reshape((12, -1)).size()

torch.Size([12, 30])

In [None]:
class IntroDetecter(nn.Module):

    def __init__(self, mel_filters_num):
        super().__init__()
        self.mel_filters_num = mel_filters_num


        self.audio_backbone = AudioBackbone(mel_filters_num)
        self.video_backbone = VideoBackbone()

        #self.

    def forward(self, video, audio):
        batch_size, T, img_channels, h, w = video.size()
        time_points_num = audio.size(2)

        audio = audio.reshape((batch_size * T, self.mel_filters_num, time_points_num))
        video = video.reshape((batch_size * T, img_channels, h, w))

        audio_features = self.audio_backbone(audio)
        video_features = self.video_backbone(video)

        audio = audio.reshape((batch_size, T, 128))
        video = video.reshape((batch_size, T, 512))

        features = torch.cat((audio, video), dim=2)

        # дальше должен быть b-lstm-crf на этих фичах
        # https://docs.pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html