In [1]:
import os, sys, copy, time, random, argparse, cv2

import imageio
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.nn.functional import normalize

PI = 3.141592653589793

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class DecoderMLP(nn.Module):
    def __init__(self, input_dim=128):
        super(DecoderMLP, self).__init__()
        self.MLP = nn.Sequential(
                nn.Linear(input_dim, input_dim), nn.ReLU(inplace=True),
                nn.Linear(input_dim, input_dim // 2), nn.ReLU(inplace=True),
                nn.Linear(input_dim // 2, 1), nn.Sigmoid())

    def forward(self, x):
        return self.MLP(x)

In [3]:
torch.manual_seed(0)
embedding_dimension = 128
std = 5.0
basis = torch.normal(mean=torch.zeros(1, embedding_dimension // 2), std=std)
colorEmbeddingDecoder = DecoderMLP(input_dim=embedding_dimension)

In [4]:
criterion = nn.MSELoss()  # Mean Squared Error Loss
optimizer = optim.SGD(colorEmbeddingDecoder.parameters(), lr=0.01)  # Stochastic Gradient Descent
batch_size = 4096

In [5]:
num_epochs = 100000

for epoch in range(num_epochs):
    random_data = torch.rand(batch_size, 1)
    mapped_data = (2. * PI * random_data) @ basis
    color_embedding = normalize(torch.cat([torch.sin(mapped_data), torch.cos(mapped_data)], dim=-1), p=2.0, dim=1)

    # Forward pass
    outputs = colorEmbeddingDecoder(color_embedding)
    loss = criterion(outputs, random_data)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

Epoch [100/100000], Loss: 0.0824
Epoch [200/100000], Loss: 0.0855
Epoch [300/100000], Loss: 0.0836
Epoch [400/100000], Loss: 0.0817
Epoch [500/100000], Loss: 0.0812
Epoch [600/100000], Loss: 0.0837
Epoch [700/100000], Loss: 0.0802
Epoch [800/100000], Loss: 0.0822
Epoch [900/100000], Loss: 0.0791
Epoch [1000/100000], Loss: 0.0780
Epoch [1100/100000], Loss: 0.0794
Epoch [1200/100000], Loss: 0.0797
Epoch [1300/100000], Loss: 0.0770
Epoch [1400/100000], Loss: 0.0774
Epoch [1500/100000], Loss: 0.0740
Epoch [1600/100000], Loss: 0.0759
Epoch [1700/100000], Loss: 0.0743
Epoch [1800/100000], Loss: 0.0728
Epoch [1900/100000], Loss: 0.0712
Epoch [2000/100000], Loss: 0.0713
Epoch [2100/100000], Loss: 0.0695
Epoch [2200/100000], Loss: 0.0699
Epoch [2300/100000], Loss: 0.0660
Epoch [2400/100000], Loss: 0.0653
Epoch [2500/100000], Loss: 0.0609
Epoch [2600/100000], Loss: 0.0601
Epoch [2700/100000], Loss: 0.0569
Epoch [2800/100000], Loss: 0.0536
Epoch [2900/100000], Loss: 0.0505
Epoch [3000/100000], Lo

In [6]:
random_data = torch.rand(batch_size, 1)
mapped_data = (2. * PI * random_data) @ basis
color_embedding = normalize(torch.cat([torch.sin(mapped_data), torch.cos(mapped_data)], dim=-1), p=2.0, dim=1)

# Forward pass
outputs = colorEmbeddingDecoder(color_embedding)
loss = criterion(outputs, random_data)

print(random_data)
print(outputs)

tensor([[0.0066],
        [0.5459],
        [0.4202],
        ...,
        [0.7840],
        [0.0829],
        [0.2786]])
tensor([[0.0660],
        [0.5357],
        [0.4347],
        ...,
        [0.8036],
        [0.0893],
        [0.2726]], grad_fn=<SigmoidBackward0>)


In [7]:
checkpoint = {
    'model_state_dict': colorEmbeddingDecoder.state_dict(),
    'basis': basis,
    "embedding_dimension": embedding_dimension
}

torch.save(checkpoint, "color_embedding_decoder.pth")