Skip to content

Commit 5ef2144

Browse files
committed
update Sim-PE experiments
1 parent 6c2b0e4 commit 5ef2144

File tree

8 files changed

+94
-2
lines changed

8 files changed

+94
-2
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
pe.data.image.imagenet module
2+
=============================
3+
4+
.. automodule:: pe.data.image.imagenet
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:

doc/source/api/pe.data.image.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ Submodules
1818
pe.data.image.cifar10
1919
pe.data.image.digiface1m
2020
pe.data.image.image
21+
pe.data.image.imagenet
2122
pe.data.image.mnist

example/image/simulator/mnist_text_render.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
font_size_variation_degrees=[6, 5, 4, 3, 2],
4848
rotation_degree_variation_degrees=[11, 9, 7, 5, 3],
4949
stroke_width_variation_degrees=[1, 1, 1, 0, 0],
50+
text_variation_degrees=0.0,
5051
)
5152
fld_inception_embedding = FLDInception()
5253
histogram = NearestNeighbors(

pe/api/image/draw_text_api.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
self,
3333
font_root_path,
3434
font_variation_degrees,
35+
text_variation_degrees,
3536
font_size_variation_degrees,
3637
rotation_degree_variation_degrees,
3738
stroke_width_variation_degrees,
@@ -52,6 +53,10 @@ def __init__(
5253
is provided, the same variation degree will be used for all iterations. The value means the probability of
5354
changing the font to a random font.
5455
:type font_variation_degrees: float or list[float]
56+
:param text_variation_degrees: The variation degrees for text utilized at each PE iteration. If a single value
57+
is provided, the same variation degree will be used for all iterations. The value means the probability of
58+
changing the text to a random text.
59+
:type text_variation_degrees: float or list[float]
5560
:param font_size_variation_degrees: The variation degrees for font size utilized at each PE iteration. If a
5661
single value is provided, the same variation degree will be used for all iterations. The value means
5762
the maximum possible variation in font size.
@@ -89,6 +94,7 @@ def __init__(
8994
super().__init__()
9095
self._font_root_path = font_root_path
9196
self._font_variation_degrees = _to_constant_list_if_needed(font_variation_degrees)
97+
self._text_variation_degrees = _to_constant_list_if_needed(text_variation_degrees)
9298
self._font_size_variation_degrees = _to_constant_list_if_needed(font_size_variation_degrees)
9399
self._rotation_degree_variation_degrees = _to_constant_list_if_needed(rotation_degree_variation_degrees)
94100
self._stroke_width_variation_degrees = _to_constant_list_if_needed(stroke_width_variation_degrees)
@@ -223,8 +229,10 @@ def _get_variation_image(
223229
rotation_degree,
224230
font_size_variation_degree,
225231
font_variation_degree,
232+
text_variation_degree,
226233
stroke_width_variation_degree,
227234
rotation_degree_variation_degree,
235+
label_name,
228236
):
229237
"""Get a variation image and its parameters.
230238
@@ -242,16 +250,23 @@ def _get_variation_image(
242250
:type font_size_variation_degree: int
243251
:param font_variation_degree: The degree of variation in font
244252
:type font_variation_degree: float
253+
:param text_variation_degree: The degree of variation in text
254+
:type text_variation_degree: float
245255
:param stroke_width_variation_degree: The degree of variation in stroke width
246256
:type stroke_width_variation_degree: int
247257
:param rotation_degree_variation_degree: The degree of variation in rotation degree
248258
:type rotation_degree_variation_degree: int
259+
:param label_name: The label name
260+
:type label_name: str
249261
:return: The image of the avatar and its parameters
250262
:rtype: tuple[np.ndarray, dict]
251263
"""
252264
do_font_variation = random.random() < font_variation_degree
253265
if do_font_variation:
254266
font_file = random.choice(self._font_files)
267+
do_text_variation = random.random() < text_variation_degree
268+
if do_text_variation:
269+
text = random.choice(self._text_list[label_name])
255270

256271
font_size += random.randint(-font_size_variation_degree, font_size_variation_degree)
257272
font_size = max(min(font_size, max(self._font_size_list)), min(self._font_size_list))
@@ -289,9 +304,11 @@ def variation_api(self, syn_data):
289304
execution_logger.info(f"VARIATION API: creating variations for {len(syn_data.data_frame)} samples")
290305
original_params = list(syn_data.data_frame[TEXT_PARAMS_COLUMN_NAME].values)
291306
original_images = np.stack(syn_data.data_frame[IMAGE_DATA_COLUMN_NAME].values)
307+
original_label_ids = syn_data.data_frame[LABEL_ID_COLUMN_NAME].values
292308
iteration = getattr(syn_data.metadata, "iteration", -1)
293309
font_variation_degree = self._font_variation_degrees[iteration + 1]
294310
font_size_variation_degree = self._font_size_variation_degrees[iteration + 1]
311+
text_variation_degree = self._text_variation_degrees[iteration + 1]
295312
rotation_variation_degree = self._rotation_degree_variation_degrees[iteration + 1]
296313
stroke_width_variation_degree = self._stroke_width_variation_degrees[iteration + 1]
297314

@@ -307,9 +324,11 @@ def variation_api(self, syn_data):
307324
original_param = original_params[i]
308325
image, param = self._get_variation_image(
309326
font_size_variation_degree=font_size_variation_degree,
327+
text_variation_degree=text_variation_degree,
310328
font_variation_degree=font_variation_degree,
311329
rotation_degree_variation_degree=rotation_variation_degree,
312330
stroke_width_variation_degree=stroke_width_variation_degree,
331+
label_name=syn_data.metadata.label_info[int(original_label_ids[i])].name,
313332
**original_param,
314333
)
315334
if image is not None:

pe/api/image/nearest_image_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def __init__(
5858
self._nearest_neighbor_backend = nearest_neighbor_backend
5959
self._variation_degrees = _to_constant_list_if_needed(variation_degrees)
6060
self._max_variation_degree = (
61-
self._variation_degrees[0] if isinstance(variation_degrees, ConstantList) else max(self._variation_degrees)
61+
self._variation_degrees[0]
62+
if isinstance(self._variation_degrees, ConstantList)
63+
else max(self._variation_degrees)
6264
)
6365

6466
if nearest_neighbor_backend.lower() == "faiss":

pe/data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .data import Data
2-
from .image import load_image_folder, Cifar10, Camelyon17, Cat, MNIST, CelebA, DigiFace1M
2+
from .image import load_image_folder, Cifar10, Camelyon17, Cat, MNIST, CelebA, DigiFace1M, ImageNet
33
from .text import TextCSV, Yelp, PubMed, OpenReview
44

55
__all__ = [
@@ -11,6 +11,7 @@
1111
"MNIST",
1212
"CelebA",
1313
"DigiFace1M",
14+
"ImageNet",
1415
"TextCSV",
1516
"Yelp",
1617
"PubMed",

pe/data/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from .mnist import MNIST
66
from .celeba import CelebA
77
from .digiface1m import DigiFace1M
8+
from .imagenet import ImageNet

pe/data/image/imagenet.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import pandas as pd
2+
import torchvision.datasets
3+
import torchvision.transforms as T
4+
from tqdm import tqdm
5+
import torch
6+
7+
from pe.data import Data
8+
from pe.constant.data import LABEL_ID_COLUMN_NAME
9+
from pe.constant.data import IMAGE_DATA_COLUMN_NAME
10+
11+
12+
class ImageNet(Data):
13+
"""The ImageNet dataset."""
14+
15+
def __init__(self, root_dir, conditional=False, split="train", res=32, batch_size=1000, num_workers=10):
16+
"""Constructor.
17+
18+
:param root_dir: The root directory of the dataset.
19+
:param conditional: Whether to use conditional ImageNet. Defaults to False
20+
:type conditional: bool, optional
21+
:param split: The split of the dataset, defaults to "train"
22+
:type split: str, optional
23+
:param res: The resolution of the images, defaults to 32
24+
:type res: int, optional
25+
:param batch_size: The batch size to load the images, defaults to 1000
26+
:type batch_size: int, optional
27+
:param num_workers: The number of workers to load the images, defaults to 10
28+
:type num_workers: int, optional
29+
"""
30+
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.Resize(res), T.PILToTensor()])
31+
dataset = torchvision.datasets.ImageNet(
32+
root=root_dir,
33+
split=split,
34+
transform=transform,
35+
)
36+
data_loader = torch.utils.data.DataLoader(
37+
dataset,
38+
batch_size=batch_size,
39+
shuffle=False,
40+
num_workers=4,
41+
drop_last=False,
42+
)
43+
44+
images = []
45+
for batch in tqdm(data_loader, desc="Loading ImageNet", unit="batch"):
46+
images.append(batch[0])
47+
images = torch.cat(images, dim=0)
48+
images = images.permute(0, 2, 3, 1).numpy()
49+
50+
data_frame = pd.DataFrame(
51+
{
52+
IMAGE_DATA_COLUMN_NAME: list(images),
53+
LABEL_ID_COLUMN_NAME: dataset.targets if conditional else [0] * len(images),
54+
}
55+
)
56+
if conditional:
57+
metadata = {"label_info": [{"name": n} for n in map(str, dataset.classes)]}
58+
else:
59+
metadata = {"label_info": [{"name": "none"}]}
60+
super().__init__(data_frame=data_frame, metadata=metadata)

0 commit comments

Comments
 (0)