In [None]:
import torch
import numpy as np
from PIL import Image

im = Image.open('rose_crop.jpg')
im_small = im.resize((128, 128))
im_small

In [None]:
class PositionalEmbedding(torch.nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.freq = torch.exp(- torch.arange(0, embed_dim, 2).float() / embed_dim)

    def forward(self, x):
        ## x.shape B x 2
        x = x[..., None, :] * self.freq[..., None]
        print(x.shape)
        return torch.cat([torch.sin(x), torch.cos(x)], dim=-1).view(*x.shape[:-2], -1)
    
#pe = PositionalEmbedding(16)
#pe(torch.tensor([[1, 0], [0, 1]])).shape

In [None]:
from PIL import Image
class Rose(torch.nn.Module):
    def __init__(self):
        super().__init__()
        #self.enc = torch.nn.Identity() # identity encoding - use raw x,y coordinates
        self.enc = PositionalEmbedding(12)
        self.net = torch.nn.Sequential(
            torch.nn.Linear(24, 256), # 24=2*12
            torch.nn.ReLU(),
            torch.nn.Linear(256, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 3), # 3 channels - RGB
        )
    def forward(self, x):
        x = self.enc(x)
        return self.net(x)
    
rose_tensor = torch.as_tensor(np.array(im_small), dtype=torch.float32) / 255. - 0.5
position = torch.stack(torch.meshgrid(torch.linspace(-1, 1, 128), torch.linspace(-1,1, 128)), dim=-1)

net = Rose()

rose_tensor = rose_tensor.cuda()
position = position.cuda()
net = net.cuda()

optim = torch.optim.Adam(net.parameters(), lr=1e-3)
for it in range(5000):
    optim.zero_grad()
    loss = abs(net(position) - rose_tensor).mean()
    if it % 100 == 0:
        print(float(loss))
    loss.backward()
    optim.step()
    
Image.fromarray(((net(position) + 0.5).clamp(0, 1) * 255).cpu().to(torch.uint8).numpy())