-
Notifications
You must be signed in to change notification settings - Fork 1
/
timesformer.py
executable file
·348 lines (263 loc) · 11 KB
/
timesformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
#https://github.com/lucidrains/TimeSformer-pytorch
from math import log, pi
import torch
from torch import nn, einsum
import torch.nn.functional as F
#pip install einops
from einops import rearrange, repeat
#from timesformer_pytorch.rotary import apply_rot_emb, AxialRotaryEmbedding, RotaryEmbedding
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d j -> ... (d j)')
def apply_rot_emb(q, k, rot_emb):
sin, cos = rot_emb
rot_dim = sin.shape[-1]
(q, q_pass), (k, k_pass) = map(lambda t: (t[..., :rot_dim], t[..., rot_dim:]), (q, k))
q, k = map(lambda t: t * cos + rotate_every_two(t) * sin, (q, k))
q, k = map(lambda t: torch.cat(t, dim = -1), ((q, q_pass), (k, k_pass)))
return q, k
class AxialRotaryEmbedding(nn.Module):
def __init__(self, dim, max_freq = 10):
super().__init__()
self.dim = dim
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
self.register_buffer('scales', scales)
def forward(self, h, w, device):
scales = rearrange(self.scales, '... -> () ...')
scales = scales.to(device)
h_seq = torch.linspace(-1., 1., steps = h, device = device)
h_seq = h_seq.unsqueeze(-1)
w_seq = torch.linspace(-1., 1., steps = w, device = device)
w_seq = w_seq.unsqueeze(-1)
h_seq = h_seq * scales * pi
w_seq = w_seq * scales * pi
x_sinu = repeat(h_seq, 'i d -> i j d', j = w)
y_sinu = repeat(w_seq, 'j d -> i j d', i = h)
sin = torch.cat((x_sinu.sin(), y_sinu.sin()), dim = -1)
cos = torch.cat((x_sinu.cos(), y_sinu.cos()), dim = -1)
sin, cos = map(lambda t: rearrange(t, 'i j d -> (i j) d'), (sin, cos))
sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
return sin, cos
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freqs = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freqs', inv_freqs)
def forward(self, n, device):
seq = torch.arange(n, device = device)
freqs = einsum('i, j -> i j', seq, self.inv_freqs)
freqs = torch.cat((freqs, freqs), dim = -1)
freqs = rearrange(freqs, 'n d -> () n d')
return freqs.sin(), freqs.cos()
# helpers
def exists(val):
return val is not None
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, *args, **kwargs):
x = self.norm(x)
return self.fn(x, *args, **kwargs)
# time token shift
def shift(t, amt):
if amt is 0:
return t
return F.pad(t, (0, 0, 0, 0, amt, -amt))
class PreTokenShift(nn.Module):
def __init__(self, frames, fn):
super().__init__()
self.frames = frames
self.fn = fn
def forward(self, x, *args, **kwargs):
f, dim = self.frames, x.shape[-1]
cls_x, x = x[:, :1], x[:, 1:]
x = rearrange(x, 'b (f n) d -> b f n d', f = f)
# shift along time frame before and after
dim_chunk = (dim // 3)
chunks = x.split(dim_chunk, dim = -1)
chunks_to_shift, rest = chunks[:3], chunks[3:]
shifted_chunks = tuple(map(lambda args: shift(*args), zip(chunks_to_shift, (-1, 0, 1))))
x = torch.cat((*shifted_chunks, *rest), dim = -1)
x = rearrange(x, 'b f n d -> b (f n) d')
x = torch.cat((cls_x, x), dim = 1)
return self.fn(x, *args, **kwargs)
# feedforward
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return self.net(x)
# attention
def attn(q, k, v, mask = None):
sim = einsum('b i d, b j d -> b i j', q, k)
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(~mask, max_neg_value)
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
return out
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, einops_from, einops_to, mask = None, cls_mask = None, rot_emb = None, **einops_dims):
h = self.heads
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
q = q * self.scale
# splice out classification token at index 1
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k, v))
# let classification token attend to key / values of all patches across time and space
cls_out = attn(cls_q, k, v, mask = cls_mask)
# rearrange across time or space
q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_))
# add rotary embeddings, if applicable
if exists(rot_emb):
q_, k_ = apply_rot_emb(q_, k_, rot_emb)
# expand cls token keys and values across time or space and concat
r = q_.shape[0] // cls_k.shape[0]
cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v))
k_ = torch.cat((cls_k, k_), dim = 1)
v_ = torch.cat((cls_v, v_), dim = 1)
# attention
out = attn(q_, k_, v_, mask = mask)
# merge back time or space
out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
# concat back the cls token
out = torch.cat((cls_out, out), dim = 1)
# merge back the heads
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
# combine heads out
return self.to_out(out)
# main classes
class TimeSformer(nn.Module):
def __init__(
self,
*,
dim,
num_frames,
num_classes,
image_size = 224,
patch_size = 16,
channels = 3,
depth = 12,
heads = 8,
dim_head = 64,
attn_dropout = 0.,
ff_dropout = 0.,
rotary_emb = True,
shift_tokens = False
):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
num_positions = num_frames * num_patches
patch_dim = channels * patch_size ** 2
self.heads = heads
self.patch_size = patch_size
self.to_patch_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, dim))
self.use_rotary_emb = rotary_emb
if rotary_emb:
self.frame_rot_emb = RotaryEmbedding(dim_head)
self.image_rot_emb = AxialRotaryEmbedding(dim_head)
else:
self.pos_emb = nn.Embedding(num_positions + 1, dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
ff = FeedForward(dim, dropout = ff_dropout)
time_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
spatial_attn = Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
if shift_tokens:
time_attn, spatial_attn, ff = map(lambda t: PreTokenShift(num_frames, t), (time_attn, spatial_attn, ff))
time_attn, spatial_attn, ff = map(lambda t: PreNorm(dim, t), (time_attn, spatial_attn, ff))
self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff]))
self.to_out = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, video, mask = None):
b, f, _, h, w, *_, device, p = *video.shape, video.device, self.patch_size
assert h % p == 0 and w % p == 0, f'height {h} and width {w} of video must be divisible by the patch size {p}'
# calculate num patches in height and width dimension, and number of total patches (n)
hp, wp = (h // p), (w // p)
n = hp * wp
# video to patch embeddings
video = rearrange(video, 'b f c (h p1) (w p2) -> b (f h w) (p1 p2 c)', p1 = p, p2 = p)
tokens = self.to_patch_embedding(video)
# add cls token
cls_token = repeat(self.cls_token, 'n d -> b n d', b = b)
x = torch.cat((cls_token, tokens), dim = 1)
# positional embedding
frame_pos_emb = None
image_pos_emb = None
if not self.use_rotary_emb:
x += self.pos_emb(torch.arange(x.shape[1], device = device))
else:
frame_pos_emb = self.frame_rot_emb(f, device = device)
image_pos_emb = self.image_rot_emb(hp, wp, device = device)
# calculate masking for uneven number of frames
frame_mask = None
cls_attn_mask = None
if exists(mask):
mask_with_cls = F.pad(mask, (1, 0), value = True)
frame_mask = repeat(mask_with_cls, 'b f -> (b h n) () f', n = n, h = self.heads)
cls_attn_mask = repeat(mask, 'b f -> (b h) () (f n)', n = n, h = self.heads)
cls_attn_mask = F.pad(cls_attn_mask, (1, 0), value = True)
# time and space attention
for (time_attn, spatial_attn, ff) in self.layers:
x = time_attn(x, 'b (f n) d', '(b n) f d', n = n, mask = frame_mask, cls_mask = cls_attn_mask, rot_emb = frame_pos_emb) + x
x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f, cls_mask = cls_attn_mask, rot_emb = image_pos_emb) + x
x = ff(x) + x
#print("x shape ",x.shape)
cls_token = x[:, 0]
#print("cls_token shape ",cls_token.shape)
return self.to_out(cls_token)
if __name__ == '__main__':
print("testing timesformer ")
model = TimeSformer(
dim = 512,
image_size = 224,
patch_size = 16,
num_frames = 8,
num_classes = 10,
depth = 12,
heads = 8,
dim_head = 64,
attn_dropout = 0.1,
ff_dropout = 0.1
)
video = torch.randn(2, 8, 3, 224, 224) # (batch x frames x channels x height x width)
mask = torch.ones(2, 8).bool() # (batch x frame) - use a mask if there are variable length videos in the same batch
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')
pred = model(video, mask = mask) # (2, 10)
print("got random pred ",pred, pred.shape)