In [None]:
!pip install av
!pip install torch torchvision torchaudio librosa


In [2]:
import os
import torch
import torchaudio
import torchvision.io as io
import librosa
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
from torchvision.models import vit_b_16, vit_b_32
import av

In [None]:
print(av.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device:{device}")


In [4]:
class AudioVideoDataset(Dataset):
    def __init__(self, video_files):
        self.video_files = video_files

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        video_path = self.video_files[idx]
        vframes, aframes, info = io.read_video(video_path, pts_unit='sec')
        aframes = aframes.mean(0)  #Convert to mono if not done already

        # Initialize the transformation
        transform = transforms.Compose([
            transforms.ToPILImage(),  #Convert the tensors to PIL Images to use Resize
            transforms.Resize((224, 224)),  #Resize to 224x224 as expected by the model
            transforms.ToTensor()  #Convert back to tensor
        ])

        # Apply transformation to each frame
        vframes_transformed = []
        for frame in vframes:
            #Ensure the frame is in (C, H, W) format
            # print(frame.shape)
            frame = frame.permute(2, 0, 1)  #Change from (H, W, C) to (C, H, W)
            # print(frame.shape)
            frame = transform(frame)  #Apply the transformation
            # print(frame.shape)
            vframes_transformed.append(frame)
        vframes = torch.stack(vframes_transformed)  #Stack the list of frames back into a tensor

        onset_times = detect_audio_peaks(video_path)
        fps = info['video_fps']
        onset_frames = (onset_times * fps).astype(int)
        peak_labels = np.zeros(len(vframes))
        peak_labels[onset_frames] = 1

        peak_labels = torch.tensor(peak_labels).float()
        return vframes, peak_labels

def detect_audio_peaks(video_path):
    y, sr = torchaudio.load(video_path)
    y = y.mean(0)  #Convert to mono
    onset_env = librosa.onset.onset_strength(y=y.numpy(), sr=sr)
    onset_frames = librosa.onset.onset_detect(onset_envelope=onset_env, sr=sr)
    onset_times = librosa.frames_to_time(onset_frames, sr=sr)
    return onset_times




In [5]:
def list_video_files(directory, extension=".mp4"):
    files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(extension)]
    return files

#Path to video folder
video_directory = "/content/drive/MyDrive/Videos"

#Get the list of video file paths
video_files = list_video_files(video_directory)

#Create the dataset and dataloader
dataset = AudioVideoDataset(video_files)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)



In [6]:
class VideoPeakTransformer(nn.Module):
    def __init__(self, num_frames, num_classes=1):
        super(VideoPeakTransformer, self).__init__()
        #Loading a pre-trained Vision Transformer
        self.vit = vit_b_16(pretrained=True)
        self.vit.heads = nn.Identity()  #Remove the original classification head

        #Time-distributed fully connected layer
        self.time_distributed = nn.Linear(768, num_classes) #Directly used 768

    def forward(self, x):
        batch_size, timesteps, C, H, W = x.shape
        x = x.view(batch_size * timesteps, C, H, W)  #Combine batch and timesteps

        #Pass each frame through the Vision Transformer
        features = self.vit(x)

        #Reshape to get back the timesteps dimension
        features = features.view(batch_size, timesteps, -1)

        #Apply a time-distributed classifier to each timestep
        x = self.time_distributed(features)
        return torch.sigmoid(x).view(batch_size, timesteps, -1)

In [1]:
def train_model(model, dataloader, epochs=5):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    model.train()

    for epoch in range(epochs):
        for i, (vframes, labels) in enumerate(dataloader):
            vframes=vframes.to(device)
            labels = labels.unsqueeze(-1)
            optimizer.zero_grad()
            outputs = model(vframes)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if i % 10 == 0:
                print(f"Epoch {epoch}, Step {i}, Loss: {loss.item()}")


In [None]:
#Adjust num_frames to match the typical number of frames you process at once
model = VideoPeakTransformer(num_frames=30).to(device)
train_model(model, dataloader)
torch.save(model.state_dict(), 'video_peak_transformer.pth')
