### Colorspace Formalization
- Idea: AEGAN uses a color module. This module assumes the output of a convolutional Generator will have 16 channels that are then softmaxed to pick a dominant color each for R, G, B
- Implementation uses a lot of duplicated code, which could be wrapped in a loop BUT I have a better idea
- This feels like it should be possible to do directly with matrix multiplication, but the batch dimension makes things slightly tricky

In [3]:
import torch
import torch.nn as nn
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# some variables:
image_size = 5
batch_size = 4
color_channels = 2
intermediate_channels = 4


In [5]:
# simple example
softmax = nn.Softmax(dim=1)

# input upsampled 
im = softmax(torch.randn((intermediate_channels, image_size, image_size)))

# I think we need to be careful when constructing the palette
palette = torch.rand(intermediate_channels, color_channels)



In [6]:
im.transpose_(0, 2) @ palette

tensor([[[0.3240, 0.2982],
         [0.3424, 0.4572],
         [0.3093, 0.3901],
         [1.0629, 1.2004],
         [0.3918, 0.6281]],

        [[0.2763, 0.3047],
         [0.8138, 1.1029],
         [0.7074, 0.8083],
         [0.4189, 0.4584],
         [0.2139, 0.2998]],

        [[0.3711, 0.4637],
         [0.3720, 0.5581],
         [0.7408, 0.7364],
         [0.5720, 0.7847],
         [0.3744, 0.4311]],

        [[0.4013, 0.4718],
         [0.6303, 0.7126],
         [0.3697, 0.4774],
         [0.4200, 0.4611],
         [0.6091, 0.8513]],

        [[0.4939, 0.7284],
         [0.4190, 0.6857],
         [0.4041, 0.3374],
         [0.3818, 0.3744],
         [0.7315, 0.8481]]])

In [10]:
# how it currently works
im = softmax(torch.randn((batch_size, intermediate_channels, image_size, image_size)))
color_palettes = [torch.rand((batch_size, intermediate_channels)) for _ in range(color_channels)]
upsampler = nn.Upsample(scale_factor=image_size)

outs = []

for color in color_palettes:
    print(f"Initial color size: {color.shape}")
    # then a reshape occurs to match (b, c, im, im)
    color = color.view((-1, intermediate_channels, 1, 1))
    print(f"Resized: {color.shape}")
    # upsamples the reshaped tensor
    color = upsampler(color)
    print(f"Upsampled: {color.shape}")

    # next comes the multiplication of the two matrices (I assume elementwise)
    inter_im = im * color
    print(f"Post Mult Size: {inter_im.shape}")

    # final aggregation step
    out_im = torch.sum(inter_im, dim=1, keepdim=True)
    print(f"Post Agg Size: {out_im.shape}")

    outs.append(out_im)

final_im = torch.cat(outs, dim=1)
print(f"Output Image: {final_im.shape}")

Initial color size: torch.Size([4, 4])
Resized: torch.Size([4, 4, 1, 1])
Upsampled: torch.Size([4, 4, 5, 5])
Post Mult Size: torch.Size([4, 4, 5, 5])
Post Agg Size: torch.Size([4, 1, 5, 5])
Initial color size: torch.Size([4, 4])
Resized: torch.Size([4, 4, 1, 1])
Upsampled: torch.Size([4, 4, 5, 5])
Post Mult Size: torch.Size([4, 4, 5, 5])
Post Agg Size: torch.Size([4, 1, 5, 5])
Output Image: torch.Size([4, 2, 5, 5])


In [22]:
# attempt to recreate this

palettes = torch.rand((batch_size, intermediate_channels, color_channels))
outs = []

for curr_im, color in zip(im, palettes):
    print(f"Image Size: {curr_im.shape}\t Palette Size: {color.shape}")
    out_im = curr_im.transpose_(0, 2) @ color
    out_im.transpose_(0, 2)
    print(f"Final Shape: {out_im.shape}")
    outs.append(out_im.view(-1, color_channels, image_size, image_size))

out_im = torch.cat(outs, dim=0)
out_im.shape

Image Size: torch.Size([4, 5, 5])	 Palette Size: torch.Size([4, 2])
Final Shape: torch.Size([2, 5, 5])
Image Size: torch.Size([4, 5, 5])	 Palette Size: torch.Size([4, 2])
Final Shape: torch.Size([2, 5, 5])
Image Size: torch.Size([4, 5, 5])	 Palette Size: torch.Size([4, 2])
Final Shape: torch.Size([2, 5, 5])
Image Size: torch.Size([4, 5, 5])	 Palette Size: torch.Size([4, 2])
Final Shape: torch.Size([2, 5, 5])


torch.Size([4, 2, 5, 5])

In [72]:
# let's make them classes and compare!

class OldMethod(nn.Module):
    def __init__(self, colors, im_size, nc):
        super(OldMethod, self).__init__()
        self.colors = colors
        self.upsampler = nn.Upsample(scale_factor=im_size)
        self.nc = nc

    def forward(self, input):
        output = []
        for color in self.colors:
            color = color.view((-1, self.nc, 1, 1))
            color = self.upsampler(color)
            out_im = input * color
            out_im = torch.sum(out_im, dim=1, keepdim=True)
            output.append(out_im)
        
        return torch.cat(output, dim=1)

class NewMethod(nn.Module):
    def __init__(self, colors, nc, outc):
        super(NewMethod, self).__init__()
        self.colors = colors.view(-1, 1, nc, outc) #(b, nc, outc) -> (b, 1, nc, outc)

    def forward(self, input):
        logits = input.transpose_(1, 3)
        logits = logits @ self.colors
        logits = logits.transpose_(1, 3)
        return logits


In [74]:
b = 8           # batch size
im_size = 5     # image size
nc = 3          # intermediate palette channels
outc = 2        # total colors to output

color_list = [torch.rand(b, nc) for _ in range(outc)] # a list of (b, nc) tensors, length outc
reshuffled_colors = [color.view(-1, nc, 1) for color in color_list] 
color_tensor = torch.cat(reshuffled_colors, dim=2) # creates a (b, nc, outc) tensor

im = softmax(torch.rand(b, nc, im_size, im_size))

champion = OldMethod(colors=color_list, im_size=im_size, nc=nc)
challenger = NewMethod(color_tensor, nc, outc)

old_out = champion(im)
new_out = challenger(im)

torch.allclose(old_out, new_out)

True

# GREAT SUCCESS !!!