-
Notifications
You must be signed in to change notification settings - Fork 1
/
percimnist.py
52 lines (35 loc) · 1.76 KB
/
percimnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Jul 8 17:52:50 2023
@author: subhadramokashe
PERCEIVER FOR MNIST
"""
import torch
import math
from torch import nn
import torch.nn.functional as F
from torchmetrics.functional import accuracy
import pytorch_lightning as pl
class PositionalImageEmbedding(nn.Module):
# Flatten the image and concatenate positional encoding,
def fourier_features(self, shape, bands):
# This first "shape" refers to the shape of the input data, not the output of this function
dims = len(shape)
# Every tensor we make has shape: (bands, dimension, x, y, etc...)
# Pos is computed for the second tensor dimension
pos = torch.stack(list(torch.meshgrid(*(torch.linspace(-1.0, 1.0, steps=n) for n in list(shape)))))
pos = pos.unsqueeze(0).expand((bands,) + pos.shape)
# Band frequencies are computed for the first
# tensor-dimension (aptly named "bands") with
# respect to the index in that dimension
band_frequencies = ((torch.logspace(math.log(1.0), math.log(shape[0] / 2), steps=bands, base=math.e)).view((bands,) + tuple(1 for _ in pos.shape[1:])).expand(pos.shape))
# For every single value in the tensor, let's compute:
# freq[band] * pi * pos[d]
# We can easily do that because every tensor is the
# same shape, and repeated in the dimensions where
# it's not relevant (e.g. "bands" dimension for the "pos" tensor)
result = (band_frequencies * math.pi * pos).view((dims * bands,) + shape)
# Use both sin & cos for each band, and then add raw position as well
result = torch.cat([torch.sin(result),torch.cos(result),],dim=0,)
return result