In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset
import os
import csv
from torchvision import transforms
# import custom_transforms as tr
from PIL import Image, ImageFile
from glob import glob
from sklearn.utils import shuffle
import pandas as pd
import json
import re

In [30]:
class Segment_dataloader:
    def __init__(self, mode="train", length=400, rate=0.8):
        """
            学習するためにデータを整形し,学習に使える形にする。
            整形したデータをself.dataに格納
            self.data = {'image_path': {'image': ,'label': '', 'date': 画像の時間}}

            Parameters
            ----------
            mode : string
                train,val,testのどれか。
            length : integer
                データのながさ
            rate : integer
                trainとvalの割合
        """
        self.endlength= 400//4
        self.data = {}
        if mode=="train":
            self.load_label()
            key = sorted(list(self.data.keys()))[:54800]
            self.key = key[:38350-self.endlength] + key[38370:-self.endlength]
                
        elif mode=="val":
            self.load_label()
            key = sorted(list(self.data.keys()))[54800:]
            self.key = key[:-self.endlength]
        elif mode=="test":
            with open('../operationIwakaInputData15.json') as f:
                self.data = json.load(f)
            key = sorted(list(self.data.keys()))
            new_key = []
            isfirst14=True
            for k in key:
                if self.data[k]["label"]!=14:
                    isfirst14=True
                    new_key+=[k]
                elif isfirst14:
                    isfirst14=False
                    new_key=new_key[:-self.endlength]
            self.key = new_key[:-self.endlength]

        
    def __getitem__(self, index):
        _img , _target = self.format_data(index)
        sample = {'image': _img, 'label': _target}
        
        if self.split == "train":
            return self.transform_tr(sample)
        elif self.split == 'val':
            return self.transform_val(sample)
        elif self.split == "test":
            return self.transform_val(sample)
        
    def load_label(self):
        """
            trainとvalのための辞書型のデータを作る
        """
        with open('./data/0903formated_label.csv') as f:
            reader = csv.reader(f)
            header = next(reader)
            for row in reader:
                time = self.format_time(row[0], '03')
                self.data[time] = {'label': row[1]}
                    
        with open('./data/0904formated_label.csv') as f:
            reader = csv.reader(f)
            header = next(reader)  
            for row in reader:
                time = self.format_time(row[0], '04')
                self.data[time] = {'label': row[1]}
        
    def format_data(self, index):
        """
            htmlの中から記事のタイトルと本文をスクレイプする

            Parameters
            ----------
            index : csvreader
                csvファイル
        """
        startTime = self.key[index]
                
        image = np.array(Image.open('../kataoka/segemention/' + key))
        time = self.get_time(key.split('/')[0], key)
        label_id = self.get_label(key.split('/')[0], time)
        print(image)
        break
        self.data[key] = {'image':image , 'label': label_id, 'date':time}
            
    def format_time(seld, time, day):
        """
            timeをkeyの形に変形する

            Parameters
            ----------
            time : string
                7:5:14
            day : string
                日付(03)
                
            Returns
            -------
            format : string
                画像の時間 ('03072732')
        """
        time_list = time.split(':')
        print(time_list)
        if int(time_list[0]) < 10:
            time_list[0] = '0' + str(time_list[0])
        if int(time_list[1]) < 10:
            time_list[1] = '0' + str(time_list[1])
        if int(time_list[2]) < 10:
            time_list[2] = '0' + str(time_list[2])
            
        return day + time_list[0] + time_list[1] + time_list[2] 
            
    def get_time(self, day, image_path):
        """
            image_pathの画像が何時何分のやつなのかを取得する

            Parameters
            ----------
            day : string
                何日の画像か(20180904)
            image_path : string
                対象となる画像のパス
                
            Returns
            -------
            time : string
                画像の時間 ('072732')
        """
        day = day.split('2018')[1]
        time_pd = pd.read_csv('data/2018' + day +'time.csv')
        target_data = time_pd[time_pd['path'] == image_path]
        if target_data['hour'][0] < 10:
            target_data['hour'] = '0' + str(target_data['hour'][0])
        if target_data['minute'][0] < 10:
            target_data['minute'] = '0' + str(target_data['minute'][0])
        if target_data['second'][0] < 10:
            target_data['second'] = '0' + str(target_data['second'][0])
        time = str(target_data['hour'][0]) + str(target_data['minute'][0]) + str(target_data['second'][0])
        return time
            
    def get_label(self, day, time):
        """
            dateの時間の正解ラベルを取得

            Parameters
            ----------
            day : string
                何日の画像か(20180904)
            time : string
                画像の時間 ('072732')
                
            Returns
            -------
            label_id : numpy.int64
                正解のラベルID
        """
        day = day.split('2018')[1]
        time_list = re.split('(..)',time)[1::2]
        time = str(int(time_list[0])) + ":" + str(int(time_list[1])) + ":" + str(int(time_list[2]))
        label_pd = pd.read_csv('data/' + day +'formated_label.csv')
        label_id = label_pd[label_pd['time'] == str(time)]
        print(type(label_id['label'][0]))
        return label_id['label'][0]
        