-
Notifications
You must be signed in to change notification settings - Fork 340
/
sample.py
115 lines (95 loc) · 4.64 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
# The MIT License (MIT)
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details
# src/utils/sample.py
import numpy as np
import random
from numpy import linalg
from math import sin,cos,sqrt
from utils.losses import latent_optimise
import torch
import torch.nn.functional as F
from torch.nn import DataParallel
def sample_latents(dist, batch_size, dim, truncated_factor=1, num_classes=None, perturb=None, device=torch.device("cpu"), sampler="default"):
if num_classes:
if sampler == "default":
y_fake = torch.randint(low=0, high=num_classes, size=(batch_size,), dtype=torch.long, device=device)
elif sampler == "class_order_some":
assert batch_size % 8 == 0, "The size of the batches should be a multiple of 8."
num_classes_plot = batch_size//8
indices = np.random.permutation(num_classes)[:num_classes_plot]
elif sampler == "class_order_all":
batch_size = num_classes*8
indices = [c for c in range(num_classes)]
elif isinstance(sampler, int):
y_fake = torch.tensor([sampler]*batch_size, dtype=torch.long).to(device)
else:
raise NotImplementedError
if sampler in ["class_order_some", "class_order_all"]:
y_fake = []
for idx in indices:
y_fake += [idx]*8
y_fake = torch.tensor(y_fake, dtype=torch.long).to(device)
else:
y_fake = None
if isinstance(perturb, float) and perturb > 0.0:
if dist == "gaussian":
latents = torch.randn(batch_size, dim, device=device)/truncated_factor
eps = perturb*torch.randn(batch_size, dim, device=device)
latents_eps = latents + eps
elif dist == "uniform":
latents = torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device)
eps = perturb*torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device)
latents_eps = latents + eps
elif dist == "hyper_sphere":
latents, latents_eps = random_ball(batch_size, dim, perturb=perturb)
latents, latents_eps = torch.FloatTensor(latents).to(device), torch.FloatTensor(latents_eps).to(device)
return latents, y_fake, latents_eps
else:
if dist == "gaussian":
latents = torch.randn(batch_size, dim, device=device)/truncated_factor
elif dist == "uniform":
latents = torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device)
elif dist == "hyper_sphere":
latents = random_ball(batch_size, dim, perturb=perturb).to(device)
return latents, y_fake
def random_ball(batch_size, z_dim, perturb=False):
if perturb:
normal = np.random.normal(size=(z_dim, batch_size))
random_directions = normal/linalg.norm(normal, axis=0)
random_radii = random.random(batch_size) ** (1/z_dim)
zs = 1.0 * (random_directions * random_radii).T
normal_perturb = normal + 0.05*np.random.normal(size=(z_dim, batch_size))
perturb_random_directions = normal_perturb/linalg.norm(normal_perturb, axis=0)
perturb_random_radii = random.random(batch_size) ** (1/z_dim)
zs_perturb = 1.0 * (perturb_random_directions * perturb_random_radii).T
return zs, zs_perturb
else:
normal = np.random.normal(size=(z_dim, batch_size))
random_directions = normal/linalg.norm(normal, axis=0)
random_radii = random.random(batch_size) ** (1/z_dim)
zs = 1.0 * (random_directions * random_radii).T
return zs
# Convenience function to sample an index, not actually a 1-hot
def sample_1hot(batch_size, num_classes, device='cuda'):
return torch.randint(low=0, high=num_classes, size=(batch_size,),
device=device, dtype=torch.int64, requires_grad=False)
def make_mask(labels, n_cls, device):
labels = labels.detach().cpu().numpy()
n_samples = labels.shape[0]
mask_multi = np.zeros([n_cls, n_samples])
for c in range(n_cls):
c_indices = np.where(labels==c)
mask_multi[c, c_indices] =+1
mask_multi = torch.tensor(mask_multi).type(torch.long)
return mask_multi.to(device)
def target_class_sampler(dataset, target_class):
try:
targets = dataset.data.targets
except:
targets = dataset.labels
weights = [True if target == target_class else False for target in targets]
num_samples = sum(weights)
weights = torch.DoubleTensor(weights)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights), replacement=False)
return num_samples, sampler