In [2]:
import os
import shutil
import torch
import h5py
import numpy as np
from torch import nn, optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision import models
from torchvision.utils import save_image
from torchvision.datasets import ImageFolder

use_cuda = torch.cuda.is_available()                    # gpu可用
device = torch.device('cuda' if use_cuda else 'cpu')    # 优先使用gpu

  from ._conv import register_converters as _register_converters


In [3]:
# 定义h5数据读取类，将不同模型的h5特征进行拼接
class h5Dataset(Dataset):
    def __init__(self, file_list):
        """
        file_list：需要拼接的h5文件路径列表
        """
        h5_data = torch.FloatTensor()
        for file in file_list:
            with h5py.File(file, 'r') as temp_file:
                # 读取图像和标签
                temp_data = temp_file['data'].value
                # 转换为torch数据
                temp_data = torch.from_numpy(temp_data)
                if h5_data.shape == torch.Size([0]):                 # 首次读取数据
                    self.label = torch.from_numpy(temp_file['label'].value)  # 保存标签
                    
                h5_data = torch.cat([h5_data, temp_data], 1)        # 在列上拼接                 
        self.data  = h5_data                                        # 保存特征数据
#         print(self.data.shape, self.label.shape)
        
    def __len__(self):                                # 返回一个python常数
        return self.label.shape[0]                    # torch.Size([22500])，使用[0]提取数
    
    def __getitem__(self, index):                     # 按索引读取数据
        assert index < len(self), 'index range error'
        data = self.data[index]
        label = self.label[index]
        return (data, label)

In [4]:
# 生成训练集的dataset  x:[22500,1024] y:[22500]
train_h5_list = [os.path.join('./feature_h5', 'resnet18_train_feature.hd5f'),
                 os.path.join('./feature_h5', 'squeeze_train_feature.hd5f')]
train_h5 = h5Dataset(train_h5_list)
print('Train numbers:', train_h5.__len__())
sample_data, sample_label = train_h5.__getitem__(0)
print('Sample:',sample_data, sample_label)
print(sample_data.shape, sample_label.shape)

Train numbers: 22500
Sample: tensor([0.6055, 0.1336, 0.0773,  ..., 0.0970, 2.7742, 0.3208]) tensor(0)
torch.Size([1024]) torch.Size([])


In [5]:
# 生成验证集的dataset   x:[2500,1024] y:[2500]
val_h5_list = [os.path.join('./feature_h5', 'resnet18_val_feature.hd5f'),
                 os.path.join('./feature_h5', 'squeeze_val_feature.hd5f')]
val_h5 = h5Dataset(val_h5_list)
print('Val numbers:', val_h5.__len__())

Val numbers: 2500


In [18]:
train_batch = 10  # 与开头的批次不同，这里的用于训练分类器
torch.multiprocessing.freeze_support()
# train_h5_loader = DataLoader(train_h5, batch_size=train_batch, shuffle=True) # , num_workers=4
val_h5_loader = DataLoader(val_h5, batch_size=train_batch, shuffle=False)

In [19]:
it = val_h5_loader.__iter__()

In [None]:
for data in it:
    x, y = data
    print(x.shape, y.shape)

In [9]:
data_tf = transforms.Compose([transforms.ToTensor()])
#                              transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])    
train_dataset = datasets.MNIST(root='../../basic/data/MNIST', train=True,transform=data_tf, download=False)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=4)#, pin_memory=True)


In [10]:
val_h5_loader = DataLoader(val_h5, batch_size=train_batch, shuffle=False, num_workers=4)

In [11]:
for i in train_loader:
    x,y = i
    print(x[0].min())
    break

tensor(0.)


In [12]:
class h5Dataset2(Dataset):

    def __init__(self, h5py_list):
        label_file = h5py.File(h5py_list[0], 'r')
        self.label = torch.from_numpy(label_file['label'].value)
        self.nSamples = self.label.size(0)
        temp_dataset = torch.FloatTensor()
        for file in h5py_list:
            h5_file = h5py.File(file, 'r')
            dataset = torch.from_numpy(h5_file['data'].value)
            temp_dataset = torch.cat((temp_dataset, dataset), 1)

        self.dataset = temp_dataset

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        assert index < len(self), 'index range error'
        data = self.dataset[index]
        label = self.label[index]
        return (data, label)

In [13]:
val_h5_list = [os.path.join('./feature_h5', 'resnet18_val_feature.hd5f')]#,
#                  os.path.join('./feature_h5', 'squeeze_val_feature.hd5f')]
val_h5 = h5Dataset2(val_h5_list)
print('Val numbers:', val_h5.__len__())

Val numbers: 2500


In [14]:
val_h5_loader = DataLoader(val_h5, batch_size=1, shuffle=False)

In [15]:
it = val_h5_loader.__iter__()

In [16]:
for i in it:
    print(i[0].shape)
    

torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1

torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1

torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1

torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1

torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1

torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1

In [17]:
# x, y = next(it)