In [1]:
%%time
from tqdm.notebook import tqdm
import torch
from torch import nn, optim
from torch.nn import functional as F
from matplotlib import pyplot as plt
import numpy as np
from torchvision.datasets import ImageFolder
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from torch.utils.data import DataLoader, random_split, Dataset
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from src.arcface import ArcFaceLoss
import graphviz
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy, MulticlassConfusionMatrix
import torchvision
import json
import pandas as pd
from torchmetrics import MetricCollection
import os
import seaborn as sns
import torchvision.transforms.v2 as tf
from src.ds import VideoFrameDataset, ImglistToTensor
from src.video_utils import read_video
from torchvision.models.video.swin_transformer import swin3d_b, Swin3D_B_Weights
from torchvision.models.video.swin_transformer import swin3d_t, Swin3D_T_Weights


from pytorch_lightning.callbacks import LearningRateMonitor
torch.set_float32_matmul_precision("medium")

from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from pytorch_grad_cam import GradCAM

CPU times: user 5.06 s, sys: 1.03 s, total: 6.09 s
Wall time: 5.14 s




In [2]:
class RSL_DS(torch.utils.data.Dataset):
    def __init__(self, annotations_path='data/rsl/annotations.csv', ds_type='train'):
        super().__init__()
        self.df = pd.read_csv(annotations_path, sep='\t')
        self.df = self.df[self.df['text'] != 'no_event']
        self.ds_type = ds_type

        if self.ds_type == 'train':
            self.df = self.df[self.df['train']]
        elif self.ds_type == 'test':
            self.df = self.df[~self.df['train']]
        else:
            raise Exception("Invalid ds type")

        self.classes = list(self.df['text'].unique())
        self.text_to_id = {text: i for i, text in enumerate(self.classes)}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sample = self.df.iloc[idx]
        cls = self.text_to_id[sample['text']]

        vid, _ = read_video(os.path.join(f'data/rsl/{self.ds_type}', sample['attachment_id'] + '.mp4'))
        
        return vid.shape[0]


In [3]:
ds = RSL_DS()

In [4]:
from multiprocessing.pool import ThreadPool as Pool
with Pool() as pool:
    ll = []
    for l in tqdm(pool.imap(lambda i: ds[np.random.randint(len(ds))], range(1000)), total=1000):
        ll.append(l)
    pool.close()
    pool.join()

  0%|          | 0/1000 [00:00<?, ?it/s]

In [17]:
16*20000 * 256*256 * 3 * 4 / 1024/1024/1024

234.375