In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader
from siren import SirenModel
from PIL import Image

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [3]:
img = Image.open("lenna.png")

In [4]:
img.size

(512, 512)

In [5]:
transformations = transforms.Compose([transforms.Resize(256), transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])

In [6]:
y = transformations(img).reshape(3, -1).transpose(0, 1)
y.shape

torch.Size([65536, 3])

In [7]:
g0, g1 = torch.meshgrid([torch.arange(-1, 1, step = 2 / 256), torch.arange(-1, 1, step = 2 / 256)])
x = torch.cat([g0.flatten().unsqueeze(1), g1.flatten().unsqueeze(1)], dim = 1)
x = x.float()
x.shape

torch.Size([65536, 2])

In [8]:
dataset = TensorDataset(x, y)

In [9]:
dataloader = DataLoader(dataset, batch_size = 4096, shuffle = True, pin_memory = True)

In [10]:
model = SirenModel(layer_dims = [2, 256, 128, 64, 32, 3]).to(device)

In [11]:
model.train()

SirenModel(
  (layers): ModuleList(
    (0): SirenLayer()
    (1): SirenLayer()
    (2): SirenLayer()
    (3): SirenLayer()
    (4): SirenLayer()
  )
)

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0005)

In [13]:
criterion = torch.nn.MSELoss()

In [14]:
for epoch in range(300):
    losses = []
    for batch in dataloader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)
        predictions = model(inputs)
        loss = criterion(predictions, targets)
        losses.append(loss.reshape(-1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    avg_loss = torch.mean(torch.cat(losses)).item()
    print("{} -> {:.4f}".format(epoch, avg_loss))

0 -> 0.4662
1 -> 0.2400
2 -> 0.1356
3 -> 0.0894
4 -> 0.0660
5 -> 0.0520
6 -> 0.0427
7 -> 0.0360
8 -> 0.0311
9 -> 0.0272
10 -> 0.0242
11 -> 0.0217
12 -> 0.0197
13 -> 0.0180
14 -> 0.0166
15 -> 0.0153
16 -> 0.0143
17 -> 0.0134
18 -> 0.0126
19 -> 0.0119
20 -> 0.0112
21 -> 0.0107
22 -> 0.0102
23 -> 0.0097
24 -> 0.0093
25 -> 0.0089
26 -> 0.0086
27 -> 0.0083
28 -> 0.0080
29 -> 0.0077
30 -> 0.0075
31 -> 0.0073
32 -> 0.0070
33 -> 0.0069
34 -> 0.0067
35 -> 0.0065
36 -> 0.0063
37 -> 0.0062
38 -> 0.0060
39 -> 0.0059
40 -> 0.0057
41 -> 0.0056
42 -> 0.0055
43 -> 0.0054
44 -> 0.0053
45 -> 0.0052
46 -> 0.0051
47 -> 0.0050
48 -> 0.0049
49 -> 0.0048
50 -> 0.0047
51 -> 0.0047
52 -> 0.0046
53 -> 0.0045
54 -> 0.0044
55 -> 0.0044
56 -> 0.0043
57 -> 0.0042
58 -> 0.0042
59 -> 0.0041
60 -> 0.0041
61 -> 0.0040
62 -> 0.0040
63 -> 0.0039
64 -> 0.0039
65 -> 0.0038
66 -> 0.0038
67 -> 0.0037
68 -> 0.0037
69 -> 0.0036
70 -> 0.0036
71 -> 0.0036
72 -> 0.0035
73 -> 0.0035
74 -> 0.0034
75 -> 0.0034
76 -> 0.0034
77 -> 0.0

In [15]:
with torch.no_grad():
    test_y = model(x.to(device)).cpu()
    test_y = test_y.transpose(0, 1).reshape(3, 256, 256)
    test_y.shape

torch.Size([3, 256, 256])

In [16]:
test_img = transforms.ToPILImage(mode = 'RGB')(test_y * 0.5 + 0.5)

In [17]:
test_img.show()

In [18]:
# torch.save({'model' : model}, 'siren.cpt')