In [None]:
import numpy as np
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch.nn.init as init
import torch.utils.data as data
import torch.utils.data.dataset as dataset
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.utils as v_utils
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import cv2
import math
from collections import OrderedDict
import copy
import time
from model.utils import DataLoader,VideoDataLoader
from model.base_model import *
from sklearn.metrics import roc_auc_score
from utils import *
import random
from tqdm import tqdm
import argparse
import warnings
import numpy as np
from collections import OrderedDict
import os
import glob
import cv2
import torch.utils.data as data
import random
import pickle
warnings.filterwarnings("ignore") 

def np_load_frame(filename, resize_height, resize_width):
    """
    Load image path and convert it to numpy.ndarray. Notes that the color channels are BGR and the color space
    is normalized from [0, 255] to [-1, 1].

    :param filename: the full path of image
    :param resize_height: resized height
    :param resize_width: resized width
    :return: numpy.ndarray
    """
    image_decoded = cv2.imread(filename)
    image_resized = cv2.resize(image_decoded, (resize_width, resize_height))
    image_resized = image_resized.astype(dtype=np.float32)
    image_resized = (image_resized / 127.5) - 1.0
    return image_resized


class VideoDataLoader(data.Dataset):
    def __init__(self, video_folder, dataset_type, transform, resize_height, resize_width, time_step=4, segs=32, num_pred=1, batch_size=1):
        self.dir = video_folder
        self.dataset_type = dataset_type
        self.transform = transform
        self.videos = OrderedDict()
        self.video_names = []
        self._resize_height = resize_height
        self._resize_width = resize_width
        self._time_step = time_step
        self._num_pred = num_pred
        self.setup()
        self.num_segs = segs
        self.batch_size = batch_size
        
    def setup(self):
        train_folder = self.dir
        file_name = './data/frame_'+self.dataset_type+'.pickle'

        if os.path.exists(file_name):
            file = open(file_name,'rb')
            self.videos = pickle.load(file)
            for name in self.videos:
                self.video_names.append(name)
        else:
            videos = glob.glob(os.path.join(train_folder, '*'))
            
            for video in sorted(videos):
                video_name = video.split('/')[-1]
                self.video_names.append(video_name)
                self.videos[video_name] = {}
                self.videos[video_name]['path'] = video
                self.videos[video_name]['frame'] = glob.glob(os.path.join(video, '*.jpg'))
                self.videos[video_name]['frame'].sort()
                self.videos[video_name]['length'] = len(self.videos[video_name]['frame'])
            
            
    def get_all_samples(self):
        frames = {}
        videos = glob.glob(os.path.join(self.dir, '*'))
        num = 0
        # videos = [videos[0]]
        for video in sorted(videos):
            video_name = video.split('/')[-1]
            frames[video_name] = []
            for i in range(len(self.videos[video_name]['frame'])-self._time_step):
                frames[video_name].append(self.videos[video_name]['frame'][i])
                num += 1
                           
        return frames, num
            
    
    def __getitem__(self, index):
        
        video_name = self.video_names[index]
        length = self.videos[video_name]['length']-self._time_step
        seg_ind = random.sample(range(0, self.num_segs), self.batch_size)
        frame_ind = random.sample(range(0, length//self.num_segs), 1)

        batch = []
        for j in range(self.batch_size):
            frame_name = seg_ind[j]*(length//self.num_segs)+frame_ind[0]
        
            for i in range(self._time_step+self._num_pred):
                image = np_load_frame(self.videos[video_name]['frame'][frame_name+i], self._resize_height, self._resize_width)
                if self.transform is not None:
                    batch.append(self.transform(image))
        return np.concatenate(batch, axis=0)
    
    def __len__(self):
        return len(self.video_names)
    
video_folder = 'data/shanghai/training/frames'
train_dataset = VideoDataLoader(train_folder, dataset_type, transforms.Compose([
             transforms.ToTensor(),           
             ]), resize_height=256, resize_width=256, time_step=4, segs=32, batch_size=1)

train_size = len(train_dataset)
train_batch = data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2, drop_last=True)
print(train_dataset)