SEE https://discuss.pytorch.org/t/differentiable-torch-histc/25865/2
AS A POSSIBILITY

In [1]:
import torch
import torch.nn as nn
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class AutoEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    #encoding
    self.c1 = nn.Conv2d(3,50,3,2)
    self.c2 = nn.Conv2d(50,50,3,2,2)
    self.c3 = nn.Conv2d(50,50,3,1)

    self.dc1 = nn.ConvTranspose2d(50,50,3,2)
    self.dc2 = nn.ConvTranspose2d(50,50,4,2)
    self.dc3 = nn.ConvTranspose2d(50,3,3,1,padding=1)

  def forward(self,x):
    x = self.c1(x)
    x = nn.functional.relu(x)
    x = self.c2(x)
    x = nn.functional.relu(x)
    x = self.c3(x)
    x = nn.functional.relu(x)

    x = self.dc1(x)
    x = nn.functional.relu(x)
    x = self.dc2(x)
    x = nn.functional.relu(x)
    x = self.dc3(x)
    x = nn.functional.sigmoid(x)

    return x

In [4]:
model = AutoEncoder().to("cuda:0")

summary(model, (3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 50, 127, 127]           1,400
            Conv2d-2           [-1, 50, 65, 65]          22,550
            Conv2d-3           [-1, 50, 63, 63]          22,550
   ConvTranspose2d-4         [-1, 50, 127, 127]          22,550
   ConvTranspose2d-5         [-1, 50, 256, 256]          40,050
   ConvTranspose2d-6          [-1, 3, 256, 256]           1,353
Total params: 110,453
Trainable params: 110,453
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 41.93
Params size (MB): 0.42
Estimated Total Size (MB): 43.10
----------------------------------------------------------------




HistNET

In [7]:
class HistNet(nn.Module):
  def __init__(self,img_size):
    super().__init__()
    #takes in just the hue channel, normalized to [0,256[ range.
    self.in_size = img_size * img_size
    #                                                                   3 milions de paramètres pour rien -- changer en conv pour la réduction
    self.c1 = nn.Linear(self.in_size,540)
    self.c2 = nn.Linear(540,256)
    
  def forward(self,x):
    x = torch.reshape(x,(-1,self.in_size))
    x = self.c1(x)
    x = nn.functional.relu(x)
    x = self.c2(x)
    #                                                                     essayer avec Softmax
    x = nn.Sigmoid()(x)
    return x

In [8]:
model = HistNet(256).to("cuda:0")

summary(model, (1,256,256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 540]      35,389,980
            Linear-2                  [-1, 256]         138,496
Total params: 35,528,476
Trainable params: 35,528,476
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.25
Forward/backward pass size (MB): 0.01
Params size (MB): 135.53
Estimated Total Size (MB): 135.79
----------------------------------------------------------------


In [None]:
def make_histogram(img_h):
  hist = torch.histc(img_h,bins=256)/torch.numel(img_h)
  return hist

#image is a (3,_,_) tensor
def plot_histogram_comparison(hist_model,image):
  histImg = np.array(make_histogram(image))
  histPred = np.array(hist_model(image).detach())
  histPred = histPred.reshape((np.size(histPred),))

  plt.figure(figsize=(10,6))
  plt.suptitle("image histogram and generated histogram",fontsize=25)
  plt.subplot(1,2,1)
  plt.title("true histogram")
  plt.bar(list(range(256)), histImg, align='center')
  plt.subplot(1,2,2)
  plt.title("predicted histogram")
  plt.bar(list(range(256)), histPred, align='center')

In [None]:
def noise_image():
  return torch.Tensor(np.random.rand(256*256*3).reshape((3,256,256)))*256.

In [None]:
def create_AE_model():
  model = AutoEncoder()
  loss = nn.MSELoss()
  opt = torch.optim.Adam(model.parameters(),lr=0.02)
  return model,loss,opt

def create_Hist_model(img_size):
  model = HistNet(img_size)
  loss = nn.MSELoss()
  opt = torch.optim.Adam(model.parameters(),lr=2)
  return model,loss,opt

In [None]:

def torch_gaussian(mean,variance):
  hist = (torch.randn(360)*variance)+mean
  for i in range(len(hist)):
    if hist[i]<0:
      hist[i]=hist[i]%360
  return hist

In [None]:
plot_histogram_comparison(model,noise_image()[2,:,:])

In [None]:
a = np.array(noise_image())
print("shape",a.shape)
print("size",a.size)

In [None]:
def train_histmodel(model,batch_size,nb_epochs,mean,variance):
  nb_images=2000
  minibatches = [torch.stack([noise_image()[2,:,:] for _ in range(batch_size)]) for l in range(0,nb_images,batch_size)]
  
  for i in range(nb_epochs):
    print("epoch",i)
    for batch in minibatches:
      pred = model(batch)
      target = torch.stack([torch.histc(img,bins=256)/torch.numel(img) for img in batch])
      

      l = loss(target,pred)

      opt.zero_grad()
      l.backward()
      #optimizer step
      opt.step()
    print("loss : ",l)


In [None]:
model,loss,opt = create_Hist_model(256)

train_histmodel(model,32,50,50,2)