In [11]:
# !tar -xf ytcelebrity.tar
# !mv *.avi ytcelebrity/

In [12]:
import cv2
import os
import csv
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

## Save frames from video
def save_frames(root, file, output_dir):
    name = file.split(".")[0]
    if not os.path.exists(output_dir + file):
        os.makedirs(output_dir + file)
    video = cv2.VideoCapture(root+file)
    if not video.isOpened():
        raise Exception("Could not open video")
    success,image = video.read()
    count = 0
    while success:
        output_file = output_dir + file + "/{count}.jpg"
        if not cv2.imwrite(output_file, image):
            raise Exception("Could not write image")
        success,image = video.read()
        count += 1
        
## Get frames from video
def get_frames(root, file):
    frames = [] 
    video = cv2.VideoCapture(root+file)
    if not video.isOpened():
        raise Exception("Could not open video")
    success,image = video.read()
    while success:
        frames.append(image)
        success,image = video.read()
    return frames


In [13]:
## Dataset for YTCelebrity dataset
class YTCelebrityDataset(Dataset):
    def __init__(self, dataset_path, csv_path):
        self.root = dataset_path
        self.data = []
        self.label = dict()
        self.file = dict()
        for file_name in os.listdir(dataset_path):
            file = file_name.split(".")[0]
            _, video_id, clip_id, first_name, last_name = file.split("_")
            file = "_".join([first_name, last_name, video_id, clip_id])
            self.data.append(file)
            self.file[file] = file_name

        with open (csv_path) as f:
            reader = csv.reader(f)
            row = next(reader)
            for row in reader:
                name, label = row[0].split(".")[0], row[1:]
                _, first_name, last_name, video_id, clip_id = name.split("_")
                self.label["_".join([first_name, last_name, video_id, clip_id])] = label
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        frames = get_frames(self.root, self.file[self.data[idx]])
        label = self.label[self.data[idx]]
        return frames, label


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        
    def forward(self, x):
        return x
    
def train(model, device, train_loader, optimizer, i):
    model.train()
    print(i)
          
def test(model, device, test_loader):
    model.eval()

In [16]:
root = "ytcelebrity/"
csv_path = "celebrity.csv"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 0.001
epoch = 10

def main():
    dataset = YTCelebrityDataset(root, csv_path)

    generator = torch.Generator().manual_seed(42)
    training, testing = torch.utils.data.random_split(dataset, [0.7, 0.3], generator=generator)

    train_loader = DataLoader(training, batch_size=4, shuffle=True, num_workers=4)
    test_loader = DataLoader(testing, batch_size=4, shuffle=True, num_workers=4)

    # model = Model().to(device)
    # optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    
    # train(model, device, train_loader, optimizer, epoch)

    # test(model, device, test_loader)

main()