In [83]:
# https://github.com/wenwei202/pytorch-examples/blob/autogrow/cifar10/get_mean_std.py

from collections import defaultdict
from PIL import Image
from torch.utils.data import Dataset, DataLoader

import json
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torchvision.transforms as transforms

In [84]:
class TestDataset(Dataset):
    
    def __init__(self, data_dir = './data/'):
        self.data_dir = data_dir
        
        self.path = self.data_dir + "unlabeled"
        self.video_paths = [os.path.join(self.path, v) for v in os.listdir(self.path) if os.path.isdir(os.path.join(self.path, v))]
        self.path = self.data_dir + "train"
        self.video_paths = self.video_paths + [os.path.join(self.path, v) for v in os.listdir(self.path) if os.path.isdir(os.path.join(self.path, v))]
        self.video_paths.sort()
        
        # Had issues with these files locally, seem corrupted
#         self.video_paths.remove('./data/unlabeled/video_14879')
#         self.video_paths.remove('./data/unlabeled/video_3110')
        self.video_paths.remove('./data/unlabeled/video_3768')
        self.video_paths.remove('./data/unlabeled/video_3776')
#         self.video_paths.remove('./data/unlabeled/video_6751')
#         self.video_paths.remove('./data/unlabeled/video_6814')
        
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])
    
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, index):
        video_path = self.video_paths[index]

        images = []
        for idx in np.arange(22):
            img_path = os.path.join(video_path, f"image_{idx}.png")
            img = self.transform(Image.open(img_path))
            images.append(img)
        image_tensor = torch.stack(images, dim = 0)
        
        return image_tensor, video_path

In [85]:
trainset = TestDataset()

trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, num_workers=2)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
h, w = 0, 0
for batch_idx, (inputs, vpath) in enumerate(trainloader):
    
    if (batch_idx+1) % 100 == 0:
        print(f"Completed mean for {(batch_idx+1)} videos!")
    
    inputs = inputs.to(device)
    inputs = inputs.squeeze()
    if batch_idx == 0:
        h, w = inputs.size(2), inputs.size(3)
        print(inputs.min(), inputs.max())
        chsum = inputs.sum(dim=(0, 2, 3), keepdim=True)
    else:
        chsum += inputs.sum(dim=(0, 2, 3), keepdim=True)
mean = chsum/(len(trainset) * 22)/h/w
print('mean: %s' % mean.view(-1))

chsum = None
for batch_idx, (inputs, vpath) in enumerate(trainloader):
    
    if (batch_idx+1) % 100 == 0:
        print(f"Completed std for {(batch_idx+1)} videos!")
        
    inputs = inputs.to(device)
    inputs = inputs.squeeze()
    if batch_idx == 0:
        chsum = (inputs - mean).pow(2).sum(dim=(0, 2, 3), keepdim=True)
    else:
        chsum += (inputs - mean).pow(2).sum(dim=(0, 2, 3), keepdim=True)
std = torch.sqrt(chsum/(len(trainset) * 22 * h * w - 1))
print('std: %s' % std.view(-1))

print('Done!')

tensor(0.) tensor(1.)
Completed mean for 100 videos!
Completed mean for 200 videos!
Completed mean for 300 videos!
Completed mean for 400 videos!
Completed mean for 500 videos!
Completed mean for 600 videos!
Completed mean for 700 videos!
Completed mean for 800 videos!
Completed mean for 900 videos!
Completed mean for 1000 videos!
Completed mean for 1100 videos!
Completed mean for 1200 videos!
Completed mean for 1300 videos!
Completed mean for 1400 videos!
Completed mean for 1500 videos!
Completed mean for 1600 videos!
Completed mean for 1700 videos!
Completed mean for 1800 videos!
Completed mean for 1900 videos!
Completed mean for 2000 videos!
Completed mean for 2100 videos!
Completed mean for 2200 videos!
Completed mean for 2300 videos!
Completed mean for 2400 videos!
Completed mean for 2500 videos!
Completed mean for 2600 videos!
Completed mean for 2700 videos!
Completed mean for 2800 videos!
Completed mean for 2900 videos!
Completed mean for 3000 videos!
Completed mean for 3100 vid

Completed std for 11900 videos!
Completed std for 12000 videos!
Completed std for 12100 videos!
Completed std for 12200 videos!
Completed std for 12300 videos!
Completed std for 12400 videos!
Completed std for 12500 videos!
Completed std for 12600 videos!
Completed std for 12700 videos!
Completed std for 12800 videos!
Completed std for 12900 videos!
Completed std for 13000 videos!
Completed std for 13100 videos!
Completed std for 13200 videos!
Completed std for 13300 videos!
Completed std for 13400 videos!
Completed std for 13500 videos!
Completed std for 13600 videos!
Completed std for 13700 videos!
Completed std for 13800 videos!
Completed std for 13900 videos!
std: tensor([0.0571, 0.0567, 0.0614])
Done!
