In [58]:
import torch
import torch.nn as nn
from torchsummary import summary
import numpy as np

In [71]:
class AutoEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    #encoding
    self.c1 = nn.Conv2d(3,50,3,2)
    self.c2 = nn.Conv2d(50,50,3,2,2)
    self.c3 = nn.Conv2d(50,50,3,1)

    self.dc1 = nn.ConvTranspose2d(50,50,3,2)
    self.dc2 = nn.ConvTranspose2d(50,50,4,2)
    self.dc3 = nn.ConvTranspose2d(50,3,3,1,padding=1)

  def forward(self,x):
    x = self.c1(x)
    x = nn.functional.relu(x)
    x = self.c2(x)
    x = nn.functional.relu(x)
    x = self.c3(x)
    x = nn.functional.relu(x)

    x = self.dc1(x)
    x = nn.functional.relu(x)
    x = self.dc2(x)
    x = nn.functional.relu(x)
    x = self.dc3(x)
    x = nn.functional.sigmoid(x)

    return x

In [72]:
model = AutoEncoder()

summary(model, (3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 50, 127, 127]           1,400
            Conv2d-2           [-1, 50, 65, 65]          22,550
            Conv2d-3           [-1, 50, 63, 63]          22,550
   ConvTranspose2d-4         [-1, 50, 127, 127]          22,550
   ConvTranspose2d-5         [-1, 50, 256, 256]          40,050
   ConvTranspose2d-6          [-1, 3, 256, 256]           1,353
Total params: 110,453
Trainable params: 110,453
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 41.93
Params size (MB): 0.42
Estimated Total Size (MB): 43.10
----------------------------------------------------------------




In [73]:
def noise_image():
  return np.random.randn(256*256*3).reshape((3,256,256))

In [74]:
def create_the_model():
  model = AutoEncoder()
  loss = nn.MSELoss()
  opt = torch.optim.Adam(model.parameters(),lr=0.02)
  return model,loss,opt

In [89]:
def torch_gaussian(mean,variance):
  hist = (torch.randn(360)*variance)+mean
  for i in range(len(hist)):
    if hist[i]<0:
      hist[i]=hist[i]%360
  return hist

In [102]:
def train(model,X,batch_size,nb_epochs,mean,variance):
  minibatches = [X[l:l+batch_size] for l in range(0,len(X),batch_size)]
    
  for i in range(nb_epochs):
    print("epoch",i)
    for batch in minibatches:
      pred = model(batch)
      l = loss(batch,pred)
      print(l)

      #naive version of loss function
      hist = torch.Tensor(size=(361,))
      hist.requires_grad_()
      predReshaped=torch.reshape(pred,(256,256,3))
      #construct histogram
      for line in predReshaped:
        for pix in line:
          m = torch.min(pix)
          M = torch.max(pix)
          if M-m>0:
            H = (pix[0]-pix[1])/(M-m)+4 if M==pix[2] else (pix[2]-pix[0])/(M-m)+2 if M==pix[1] else torch.remainder((pix[1]-pix[2])/(M-m),6)
            H = torch.mul(60, H)
            H = torch.remainder(H,360)
            torch.add(hist[int(H)],1)

      #compute KL divergence
      loss2 = nn.functional.kl_div(hist,torch_gaussian(mean,variance))

      print("loss2 : ",loss2)

      #add it to loss
      l = l + loss2
      opt.zero_grad()
      l.backward()
      #optimizer step
      opt.step()


In [101]:
model,loss,opt = create_the_model()

X = torch.Tensor(np.array([noise_image() for _ in range(200)]))

train(model,X,1,5,50,2)

epoch 0
tensor(1.2602, grad_fn=<MseLossBackward0>)
tensor(-1.4741e+12, grad_fn=<MeanBackward0>)


RuntimeError: ignored