-
Notifications
You must be signed in to change notification settings - Fork 0
/
rng.py
259 lines (210 loc) · 10.3 KB
/
rng.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import torch
import habana_frameworks.torch.hpu.random as htrandom
from modules import devices, rng_philox, shared, habana
#from habana import HPUGenerator
def randn(seed, shape, generator=None):
"""Generate a tensor with random numbers from a normal distribution using seed."""
manual_seed(seed)
if shared.opts.randn_source == "NV":
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
elif shared.opts.randn_source == "HPU":
return torch.asarray(hpu_rng.randn(shape), device=devices.device)
# return torch.randn(shape, device=devices.device, generator=(generator or hpu_rng))
elif shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
return torch.randn(shape, device=devices.device, generator=generator)
def randn_local(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed."""
local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
if shared.opts.randn_source == "NV":
rng = rng_philox.Generator(seed)
return torch.asarray(rng.randn(shape), device=devices.device)
elif shared.opts.randn_source == "HPU":
local_generator = habana.HPUGenerator().manual_seed(seed)
return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
local_generator = torch.Generator(local_device).manual_seed(int(seed))
return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
def randn_like(x):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized generator."""
if shared.opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
elif shared.opts.randn_source == "HPU":
return torch.randn_like(x, generator=hpu_rng).to(x.device)
elif shared.opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
return torch.randn_like(x)
def randn_without_seed(shape, generator=None):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized generator.
Use either randn() or manual_seed() to initialize the generator."""
if shared.opts.randn_source == "NV":
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
elif shared.opts.randn_source == "HPU":
return torch.asarray(hpu_rng.randn(shape), device=devices.device)
# return torch.randn(shape, device=devices.device, generator=(generator or hpu_rng))
elif shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
return torch.randn(shape, device=devices.device, generator=generator)
def manual_seed(seed):
"""Set up a global random number generator using the specified seed."""
seed = int(seed)
if shared.opts.randn_source == "NV":
global nv_rng
nv_rng = rng_philox.Generator(seed)
elif shared.opts.randn_source == "HPU":
global hpu_rng
hpu_rng = rng_philox.Generator(seed)
# hpu_rng = habana.HPUGenerator().manual_seed(seed)
else:
torch.manual_seed(seed)
def create_generator(seed):
seed = int(seed)
if shared.opts.randn_source == "NV":
return rng_philox.Generator(seed)
elif shared.opts.randn_source == "HPU":
generator = habana.HPUGenerator()
return generator.manual_seed(seed)
else:
device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
generator = torch.Generator(device=device).manual_seed(seed)
return generator
#def randn(seed, shape, generator=None):
# """Generate a tensor with random numbers from a normal distribution using seed.
#
# Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
#
# manual_seed(seed)
#
# if shared.opts.randn_source == "NV":
# return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
#
# if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
# return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
#
# return torch.randn(shape, device=devices.device, generator=generator)
#
#
#def randn_local(seed, shape):
# """Generate a tensor with random numbers from a normal distribution using seed.
#
# Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
#
# if shared.opts.randn_source == "NV":
# rng = rng_philox.Generator(seed)
# return torch.asarray(rng.randn(shape), device=devices.device)
#
# local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
# local_generator = torch.Generator(local_device).manual_seed(int(seed))
# return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
#
#
#def randn_like(x):
# """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
#
# Use either randn() or manual_seed() to initialize the generator."""
#
# if shared.opts.randn_source == "NV":
# return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
#
# if shared.opts.randn_source == "CPU" or x.device.type == 'mps':
# return torch.randn_like(x, device=devices.cpu).to(x.device)
#
# return torch.randn_like(x)
#
#
#def randn_without_seed(shape, generator=None):
# """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
#
# Use either randn() or manual_seed() to initialize the generator."""
#
# if shared.opts.randn_source == "NV":
# return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
#
# if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
# return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
#
# return torch.randn(shape, device=devices.device, generator=generator)
#
#def manual_seed(seed):
# """Set up a global random number generator using the specified seed."""
#
# if shared.opts.randn_source == "NV":
# global nv_rng
# nv_rng = rng_philox.Generator(seed)
# return
#
# torch.manual_seed(seed)
#
#
#def create_generator(seed):
# if shared.opts.randn_source == "NV":
# return rng_philox.Generator(seed)
#
# device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
# generator = torch.Generator(device).manual_seed(int(seed))
# return generator
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
low_norm = low/torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True)
dot = (low_norm*high_norm).sum(1)
if dot.mean() > 0.9995:
return low * val + high * (1 - val)
omega = torch.acos(dot)
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
return res
class ImageRNG:
def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
self.shape = tuple(map(int, shape))
self.seeds = seeds
self.subseeds = subseeds
self.subseed_strength = subseed_strength
self.seed_resize_from_h = seed_resize_from_h
self.seed_resize_from_w = seed_resize_from_w
self.generators = [create_generator(seed) for seed in seeds]
self.is_first = True
def first(self):
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
xs = []
for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
subnoise = None
if self.subseeds is not None and self.subseed_strength != 0:
subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
subnoise = randn(subseed, noise_shape)
if noise_shape != self.shape:
noise = randn(seed, noise_shape)
else:
noise = randn(seed, self.shape, generator=generator)
if subnoise is not None:
noise = slerp(self.subseed_strength, noise, subnoise)
if noise_shape != self.shape:
x = randn(seed, self.shape, generator=generator)
dx = (self.shape[2] - noise_shape[2]) // 2
dy = (self.shape[1] - noise_shape[1]) // 2
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
tx = 0 if dx < 0 else dx
ty = 0 if dy < 0 else dy
dx = max(-dx, 0)
dy = max(-dy, 0)
x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
noise = x
xs.append(noise)
eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
if eta_noise_seed_delta:
self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
return torch.stack(xs).to(shared.device)
def next(self):
if self.is_first:
self.is_first = False
return self.first()
xs = []
for generator in self.generators:
x = randn_without_seed(self.shape, generator=generator)
xs.append(x)
return torch.stack(xs).to(shared.device)
devices.randn = randn
devices.randn_local = randn_local
devices.randn_like = randn_like
devices.randn_without_seed = randn_without_seed
devices.manual_seed = manual_seed