# Imports

In [None]:
import cv2
import torch
import torch.nn as nn
import torch.optim
import matplotlib.pyplot as plt
import numpy as np
from torch.distributions.bernoulli import Bernoulli


# Generator

In [None]:
class PixelGenerator(nn.Module):
    def __init__(self,h,w):
        super().__init__()
        self.logits = nn.Parameter(torch.zeros((h,w)),requires_grad = True)
    
    def forward(self):
        return Bernoulli(logits=self.logits)
    
    def get_image(self):
        bernoulli = self()
        half_img = bernoulli.sample()
        full_img = torch.cat([half_img,half_img.flip(1)],1)
        return half_img, full_img
    
    def log_prob(self,img_batch):
        bernoulli = self()
        return bernoulli.log_prob(img_batch)


# Training Loop

In [None]:
def get_user_reward():
    while True:
        key = cv2.waitKey(100)
        if key < 0:
            continue
        if key >= ord("0") and  key <= ord("9"):
            reward = key - ord("0")
            if reward == 0:
                reward = 10
            return float(reward)
        else:
            print("Use keys 1-10(0)")


In [None]:
test_img = torch.tensor(
[
    [0,0,0,0,1,1,0,0,0,0.0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,0,0,1,1,0,0,0,0],
    [0,0,0,0,1,1,0,0,0,0],
    [0,0,0,0,1,1,0,0,0,0],
    [0,0,0,1,1,1,1,0,0,0],
    [0,0,1,1,1,1,1,1,0,0],
    [0,1,1,1,1,1,1,1,1,0],
    [0,1,1,1,1,1,1,1,1,0],
]) 
def get_test_reward(full_img):
    cv2.waitKey(10)
    same = full_img == test_img
    count = float(same.sum())
    reward = (count / 10.0)
    return reward



In [None]:
name = "cat"
cv2.namedWindow(name,0)
cv2.namedWindow("logits",0)

batch_size = 32

buffer = [] # [(half_img, full_img, reward)]
generator = PixelGenerator(10,5)
optimiser = torch.optim.SGD(generator.parameters(),lr=1.0)

while True:
    #get the half and full image from the generator
    #The full image is just the half image flipped and concatenated
    half_img,full_img = generator.get_image()
    
    #show image to the user for feedback
    cv2.imshow(name,full_img.numpy())
    
    #show the generator probabilities
    cv2.imshow("logits",generator().probs.data.numpy())
    
    #get the reward from the user
    reward = get_user_reward()
    
#     reward = get_test_reward(full_img)
    
    #record the experience
    buffer.append((half_img,full_img,reward))

    #once we have collected enough for a batch lets update
    if len(buffer) % batch_size == 0:

        #get the latest batch from the buffer
        batch_list = buffer[-batch_size:]
        
        #rezip so that all the tuples get split into separate lists
        half_img, full_img, reward = list(zip(*batch_list))
        
        #Stack the images into a batch
        img_batch= torch.stack(half_img)
        
        #turn the rewards into a tensor and reshape so its broadcastable with the lob probs
        reward_batch = torch.tensor(reward).reshape(-1,1,1)
        
        #mean shift the reward batch
        reward_batch -= reward_batch.mean()

        #calculate the log prob of picking each pixel
        log_prob_batch = generator.log_prob(img_batch)
        
        #loss is negative log scaled by reward all summed up
        loss = (-log_prob_batch * reward_batch).sum()
        
        #step the optimizer
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        
        
    
cv2.destroyAllWindows()

In [None]:
cv2.namedWindow("frog",0)
for _,full_img,_ in frog:
    cv2.imshow("frog",full_img.numpy())
    cv2.waitKey(100)
    


In [None]:
cv2.destroyAllWindows()

# Criterion

In [None]:
logits = torch.tensor([[1.0,2.0],[3.0,0.0]])
m = Bernoulli(logits=logits)
s = m.sample() 
torch.exp(m.log_prob(s))


In [None]:
a = [(1,2),(2,4),(5,6)]
list(zip(*a))