/
data_utils.py
121 lines (100 loc) · 4.3 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision
import numpy as np
import copy
from torch.utils.data import Dataset
np.random.seed(6)
def build_dataset(dataset,num_meta):
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
transform_test = transforms.Compose([
transforms.ToTensor(),
normalize
])
if dataset == 'cifar10':
train_dataset = torchvision.datasets.CIFAR10(root='../cifar-10', train=True, download=False, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10('../cifar-10', train=False, transform=transform_test)
img_num_list = [num_meta] * 10
num_classes = 10
if dataset == 'cifar100':
train_dataset = torchvision.datasets.CIFAR100(root='../cifar-100', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR100('../cifar-100', train=False, transform=transform_test)
img_num_list = [num_meta] * 100
num_classes = 100
data_list_val = {}
for j in range(num_classes):
data_list_val[j] = [i for i, label in enumerate(train_dataset.targets) if label == j]
idx_to_meta = []
idx_to_train = []
print(img_num_list)
for cls_idx, img_id_list in data_list_val.items():
np.random.shuffle(img_id_list)
img_num = img_num_list[int(cls_idx)]
idx_to_meta.extend(img_id_list[:img_num])
idx_to_train.extend(img_id_list[img_num:])
train_data = copy.deepcopy(train_dataset)
train_data_meta = copy.deepcopy(train_dataset)
train_data_meta.data = np.delete(train_dataset.data, idx_to_train,axis=0)
train_data_meta.targets = np.delete(train_dataset.targets, idx_to_train, axis=0)
train_data.data = np.delete(train_dataset.data, idx_to_meta, axis=0)
train_data.targets = np.delete(train_dataset.targets, idx_to_meta, axis=0)
return train_data_meta, train_data, test_dataset
def get_img_num_per_cls(dataset, imb_factor=None, num_meta=None):
if dataset == 'cifar10':
img_max = (50000-num_meta)/10
cls_num = 10
if dataset == 'cifar100':
img_max = (50000-num_meta)/100
cls_num = 100
if imb_factor is None:
return [img_max] * cls_num
img_num_per_cls = []
for cls_idx in range(cls_num):
num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
img_num_per_cls.append(int(num))
return img_num_per_cls
class new_dataset(Dataset):
def __init__(self, dataset, train=None):
self.data = dataset.data
self.targets = dataset.targets
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
if train:
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),(4, 4, 4, 4), mode='reflect').squeeze()),
transforms.ToPILImage(),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
else:
self.transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
def __getitem__(self, index):
img, label = self.data[index, ::], self.targets[index]
img = self.transform(img)
label = torch.LongTensor([np.int64(label)])
return img, label, index
def __len__(self):
return len(self.data)