# Project: Implementation of MONet

**PSYCH250: High-level Vision: From Neurons to Deep Neural Networks**


_Stanford University. Winter, 2021._

---

**Team Members:** Benjamin Midler, Gongqi Li

---
This colab re-implements DeepMind's Multi-Object Network (MONet) from the paper [1] for unsupervised scene decomposition in a simplified manner. 

### References
[1] Christopher P Burgess, Loic Matthey, NicholasWatters, Rishabh Kabra, Irina Higgins, Matt Botvinick, and Alexander Lerchner. Monet: Unsupervised scene decomposition and representation. arXiv:1901.11390, 2019.

### Environment Setup


In [9]:
import sys
from collections import namedtuple

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.distributions as dists
from torch.utils.data import Dataset
import numpy as np
from numpy.random import random_integers
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
! npm install -g localtunnel
get_ipython().system_raw('python3 -m pip install visdom')
get_ipython().system_raw('python3 -m visdom.server -port 6006 >> visdomlog.txt 2>&1 &')
get_ipython().system_raw('lt --port 6006 >> url.txt 2>&1 &')
import time
time.sleep(5)
! cat url.txt
import visdom
time.sleep(5)
vis = visdom.Visdom(port='6006')
print(vis)
time.sleep(3)
vis.text('MONet Visualization')
! cat visdomlog.txt

[K[?25h/tools/node/bin/lt -> /tools/node/lib/node_modules/localtunnel/bin/lt.js
+ localtunnel@2.0.1
added 22 packages from 22 contributors in 1.475s
your url is: https://strong-sheep-96.loca.lt


Setting up a new session...


<visdom.Visdom object at 0x7fdc820e4d10>
  ioloop.install()  # Needs to happen before any tornado imports!
INFO:root:Application Started
INFO:tornado.access:200 POST /env/main (127.0.0.1) 0.60ms
INFO:tornado.access:101 GET /vis_socket (127.0.0.1) 0.42ms
INFO:root:Opened visdom socket from ip: 127.0.0.1
INFO:tornado.access:200 POST /events (127.0.0.1) 0.60ms


### Model

In [3]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )


class UNet(nn.Module):
    def __init__(self, num_blocks, in_channels, out_channels, channel_base=64):
        super().__init__()
        self.num_blocks = num_blocks
        self.down_convs = nn.ModuleList()
        cur_in_channels = in_channels
        for i in range(num_blocks):
            self.down_convs.append(double_conv(cur_in_channels,
                                               channel_base * 2**i))
            cur_in_channels = channel_base * 2**i

        self.tconvs = nn.ModuleList()
        for i in range(num_blocks-1, 0, -1):
            self.tconvs.append(nn.ConvTranspose2d(channel_base * 2**i,
                                                  channel_base * 2**(i-1),
                                                  2, stride=2))

        self.up_convs = nn.ModuleList()
        for i in range(num_blocks-2, -1, -1):
            self.up_convs.append(double_conv(channel_base * 2**(i+1), channel_base * 2**i))

        self.final_conv = nn.Conv2d(channel_base, out_channels, 1)

    def forward(self, x):
        intermediates = []
        cur = x
        for down_conv in self.down_convs[:-1]:
            cur = down_conv(cur)
            intermediates.append(cur)
            cur = nn.MaxPool2d(2)(cur)

        cur = self.down_convs[-1](cur)

        for i in range(self.num_blocks-1):
            cur = self.tconvs[i](cur)
            cur = torch.cat((cur, intermediates[-i -1]), 1)
            cur = self.up_convs[i](cur)

        return self.final_conv(cur)


class AttentionNet(nn.Module):
    def __init__(self, conf):
        super().__init__()
        self.conf = conf
        self.unet = UNet(num_blocks=conf.num_blocks,
                         in_channels=4,
                         out_channels=2,
                         channel_base=conf.channel_base)

    def forward(self, x, scope):
        inp = torch.cat((x, scope), 1)
        logits = self.unet(inp)
        alpha = torch.softmax(logits, 1)
        # output channel 0 represents alpha_k,
        # channel 1 represents (1 - alpha_k).
        mask = scope * alpha[:, 0:1]
        new_scope = scope * alpha[:, 1:2]
        return mask, new_scope

class EncoderNet(nn.Module):
    def __init__(self, width, height):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(4, 32, 3, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, stride=2),
            nn.ReLU(inplace=True)
        )

        for i in range(4):
            width = (width - 1) // 2
            height = (height - 1) // 2

        self.mlp = nn.Sequential(
            nn.Linear(64 * width * height, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 32)
        )

    def forward(self, x):
        x = self.convs(x)
        x = x.view(x.shape[0], -1)
        x = self.mlp(x)
        return x

class DecoderNet(nn.Module):
    def __init__(self, height, width):
        super().__init__()
        self.height = height
        self.width = width
        self.convs = nn.Sequential(
            nn.Conv2d(18, 32, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 4, 1),
        )
        ys = torch.linspace(-1, 1, self.height + 8)
        xs = torch.linspace(-1, 1, self.width + 8)
        ys, xs = torch.meshgrid(ys, xs)
        coord_map = torch.stack((ys, xs)).unsqueeze(0)
        self.register_buffer('coord_map_const', coord_map)

    def forward(self, z):
        z_tiled = z.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, self.height + 8, self.width + 8)
        coord_map = self.coord_map_const.repeat(z.shape[0], 1, 1, 1)
        inp = torch.cat((z_tiled, coord_map), 1)
        result = self.convs(inp)
        return result


class Monet(nn.Module):
    def __init__(self, conf, height, width):
        super().__init__()
        self.conf = conf
        self.attention = AttentionNet(conf)
        self.encoder = EncoderNet(height, width)
        self.decoder = DecoderNet(height, width)
        self.beta = 0.5
        self.gamma = 0.25

    def forward(self, x):

        scope = torch.ones_like(x[:, 0:1])
        masks = []
        for i in range(self.conf.num_slots-1):
            mask, scope = self.attention(x, scope)
            masks.append(mask)
        masks.append(scope)
        loss = torch.zeros_like(x[:, 0, 0, 0])
        mask_preds = []
        full_reconstruction = torch.zeros_like(x)
        p_xs = torch.zeros_like(loss)
        kl_zs = torch.zeros_like(loss)
        for i, mask in enumerate(masks):
            z, kl_z = self.__encoder_step(x, mask)
            sigma = self.conf.bg_sigma if i == 0 else self.conf.fg_sigma
            p_x, x_recon, mask_pred = self.__decoder_step(x, z, mask, sigma)
            mask_preds.append(mask_pred)
            loss += -p_x + self.beta * kl_z
            p_xs += -p_x
            kl_zs += kl_z
            full_reconstruction += mask * x_recon
        
        masks = torch.cat(masks, 1)
        tr_masks = torch.transpose(masks, 1, 3)
        q_masks = dists.Categorical(probs=tr_masks)
        q_masks_recon = dists.Categorical(logits=torch.stack(mask_preds, 3))
        kl_masks = dists.kl_divergence(q_masks, q_masks_recon)
        kl_masks = torch.sum(kl_masks, [1, 2])
        loss_dict = {'px': p_xs.mean().item(),
              'kl_z': kl_zs.mean().item() * self.beta,
              'kl_masks': kl_masks.mean().item() * self.gamma}
        loss += self.gamma * kl_masks
        return {'loss': loss,
                'masks': masks,
                'reconstructions': full_reconstruction,
                'loss_dict': loss_dict}


    def __encoder_step(self, x, mask):
        encoder_input = torch.cat((x, mask), 1)
        q_params = self.encoder(encoder_input)
        means = torch.sigmoid(q_params[:, :16]) * 6 - 3
        sigmas = torch.sigmoid(q_params[:, 16:]) * 3
        dist = dists.Normal(means, sigmas)
        dist_0 = dists.Normal(0., sigmas)
        z = means + dist_0.sample()
        q_z = dist.log_prob(z)
        kl_z = dists.kl_divergence(dist, dists.Normal(0., 1.))
        kl_z = torch.sum(kl_z, 1)
        return z, kl_z

    def __decoder_step(self, x, z, mask, sigma):
        decoder_output = self.decoder(z)
        x_recon = torch.sigmoid(decoder_output[:, :3])
        mask_pred = decoder_output[:, 3]
        dist = dists.Normal(x_recon, sigma)
        p_x = dist.log_prob(x)
        p_x *= mask
        p_x = torch.sum(p_x, [1, 2, 3])
        return p_x, x_recon, mask_pred


### Datasets

In [4]:

def make_sprites(n=10000, height=64, width=64):
    images = np.zeros((n, height, width, 3))
    counts = np.zeros((n,))
    print('Generating sprite dataset...')
    for i in range(n):
        num_sprites = random_integers(0, 3)
        # num_sprites = 3
        counts[i] = num_sprites
        for j in range(num_sprites):
            pos_y = random_integers(0, height - 12)
            pos_x = random_integers(0, width - 12)

            scale = random_integers(12, min(16, height-pos_y, width-pos_x))

            cat = random_integers(0, 2)
            # cat = j
            sprite = np.zeros((height, width, 3))


            if cat == 0:  # draw circle
                center_x = pos_x + scale // 2.0
                center_y = pos_y + scale // 2.0
                for x in range(height):
                    for y in range(width):
                        dist_center_sq = (x - center_x)**2 + (y - center_y)**2
                        if  dist_center_sq < (scale // 2.0)**2:
                            sprite[x][y][cat] = 1.0
            elif cat == 1:  # draw square
                sprite[pos_x:pos_x + scale, pos_y:pos_y + scale, cat] = 1.0
            else:  # draw diamond
                center_x = pos_x + scale // 2.0
                center_y = pos_y + scale // 2.0
                for x in range(height):
                    for y in range(width):
                        if abs(x - center_x) + abs(y - center_y) < (scale // 2.0):
                            sprite[x][y][cat] = 1.0
            images[i] += sprite
        if i % 100 == 0:
            print("Making Sprites: {}/{}".format(i, n))
    images = np.clip(images, 0.0, 1.0)

    return {'x_train': images[:4 * n // 5],
            'count_train': counts[:4 * n // 5],
            'x_test': images[4 * n // 5:],
            'count_test': counts[4 * n // 5:]}


class Sprites(Dataset):
    def __init__(self, directory, n=10000, canvas_size=64,
                 train=True, transform=None):
        np_file = 'sprites_{}_{}.npz'.format(n, canvas_size)

        try: 
          data = np.load(np_file)
        except:
          gen_data = make_sprites(n, canvas_size, canvas_size)
          np.savez(np_file, **gen_data)
          data = np.load(np_file)


        self.transform = transform
        self.images = data['x_train'] if train else data['x_test']
        self.counts = data['count_train'] if train else data['count_test']

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        img = self.images[idx]
        if self.transform is not None:
            img = self.transform(img)
        return img, self.counts[idx]


### Visualization

In [8]:
def visualize_masks(imgs, masks, recons):
    # print('recons min/max', recons[:, 0].min().item(), recons[:, 0].max().item())
    # print('recons1 min/max', recons[:, 1].min().item(), recons[:, 1].max().item())
    # print('recons2 min/max', recons[:, 2].min().item(), recons[:, 2].max().item())
    recons = np.clip(recons, 0., 1.)
    colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0), (0, 255, 255), (255, 0, 255), (255, 255, 0)]
    colors.extend([(c[0]//2, c[1]//2, c[2]//2) for c in colors])
    colors.extend([(c[0]//4, c[1]//4, c[2]//4) for c in colors])

    image_list = []
    image_list.append(imgs)
    for slot in range(masks.shape[1]):
        seg_maps = np.zeros_like(imgs)
        mask = masks[:, slot, :, :]
        for i in range(imgs.shape[0]):
            for y in range(imgs.shape[2]):
                for x in range(imgs.shape[3]):
                    seg_maps[i, :, y, x] = [mask[i, y, x], mask[i, y, x], mask[i, y, x]]
        image_list.append(seg_maps)
    image_list.append(recons)
    vis.images(np.concatenate(image_list, 0), nrow=imgs.shape[0])

    masks = np.argmax(masks, 1)
    seg_maps = np.zeros_like(imgs)
    for i in range(imgs.shape[0]):
        for y in range(imgs.shape[2]):
            for x in range(imgs.shape[3]):
                seg_maps[i, :, y, x] = colors[masks[i, y, x]]
    seg_maps /= 255.0
    vis.images(np.concatenate((imgs, seg_maps, recons), 0), nrow=imgs.shape[0])

### Training


In [6]:
def numpify(tensor):
    return tensor.cpu().detach().numpy()

def run_training(monet, conf, trainloader, testloader):
    for w in monet.parameters():
        std_init = 0.01
        nn.init.normal_(w, mean=0., std=std_init)
    print('Initialized parameters')

    optimizer = optim.RMSprop(monet.parameters(), lr=1e-4)

    total_loss = []
    total_px = []
    total_kl_z = []
    total_kl_masks = []

    for epoch in range(conf.num_epochs):
        running_loss = 0.0
        running_px = 0
        running_kl_z = 0
        running_kl_masks = 0
        for i, data in enumerate(trainloader, 0):
            images, counts = data
            images = images.cuda()
            optimizer.zero_grad()
            output = monet(images)
            loss = torch.mean(output['loss'])
            loss.backward()
            optimizer.step()

            loss_dict = output['loss_dict']
            running_px += loss_dict['px']
            running_kl_z += loss_dict['kl_z']
            running_kl_masks += loss_dict['kl_masks']

            running_loss += loss.item()

            if i % conf.vis_every == conf.vis_every - 1:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / conf.vis_every), 
                      ' px: ', running_px / conf.vis_every, 
                      ' kl_z: ', running_kl_z / conf.vis_every,
                      ' kl_masks', running_kl_masks / conf.vis_every)
                total_loss.append(running_loss / conf.vis_every)
                total_px.append(running_px / conf.vis_every)
                total_kl_z.append(running_kl_z / conf.vis_every)
                total_kl_masks.append(running_kl_masks / conf.vis_every)
                running_loss = 0.0
                running_px = 0
                running_kl_z = 0
                running_kl_masks = 0

                visualize_masks(numpify(images[:8]),
                        numpify(output['masks'][:8]),
                        numpify(output['reconstructions'][:8]))
        
        # test_loss = test(testloader, monet)
        # print('Epoch: ', epoch + 1, ' Total loss on test set: ', test_loss)
        # torch.save(monet.state_dict(), conf.checkpoint_file)

    print('training done')
    return (total_loss, total_px, total_kl_z, total_kl_masks)

def test(loader, model):
    model.eval()
    test_loss = 0.0
    for i, data in enumerate(loader):
        with torch.no_grad():
            images, counts = data
            output = model(images)
            if i % 100 == 99:
                visualize_masks(numpify(images[:6]),
                                numpify(output['masks'][:6]),
                                numpify(output['reconstructions'][:6]))
    loss = torch.mean(output['loss'])
    test_loss += loss.item()
    
    return test_loss


def sprite_experiment():
    conf = sprite_config
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Lambda(lambda x: x.float()),
                                    ])
    trainset = Sprites(conf.data_dir, train=True, transform=transform)
    testset = Sprites(conf.data_dir, train=False, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=conf.batch_size,
                                              shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=conf.batch_size,
                                             shuffle=False, num_workers=2)
    monet = Monet(conf, 64, 64).cuda()
    if conf.parallel:
        monet = nn.DataParallel(monet)
    return run_training(monet, conf, trainloader, testloader)

### Config

In [5]:
config_options = [
    'vis_every',  # Visualize progress every X iterations
    'batch_size',
    'num_epochs',
    'load_parameters',  # Load parameters from checkpoint
    'checkpoint_file',  # File for loading/storing checkpoints
    'data_dir',  # Directory for the training data
    'parallel',  # Train using nn.DataParallel
    'num_slots',  # Number of slots k,
    'num_blocks',  # Number of blochs in attention U-Net 
    'channel_base',  # Number of channels used for the first U-Net conv layer
    'bg_sigma',  # Sigma of the decoder distributions for the first slot
    'fg_sigma',  # Sigma of the decoder distributions for all other slots
]

MonetConfig = namedtuple('MonetConfig', config_options)

sprite_config = MonetConfig(vis_every=10,
                            batch_size=64,
                            num_epochs=100,
                            load_parameters=True,
                            checkpoint_file='sprites.ckpt',
                            data_dir='./data/',
                            parallel=True,
                            num_slots=4,
                            num_blocks=5,
                            channel_base=64,
                            bg_sigma=0.09,
                            fg_sigma=0.11,
                           )


### Sprite Experiment

In [7]:
loss = sprite_experiment()

Generating sprite dataset...
Making Sprites: 0/10000


  import sys
  # This is added back by InteractiveShellApp.init_path()
  if sys.path[0] == '':
  
  app.launch_new_instance()
  
  
  
  


Making Sprites: 100/10000
Making Sprites: 200/10000
Making Sprites: 300/10000
Making Sprites: 400/10000
Making Sprites: 500/10000
Making Sprites: 600/10000
Making Sprites: 700/10000
Making Sprites: 800/10000
Making Sprites: 900/10000
Making Sprites: 1000/10000
Making Sprites: 1100/10000
Making Sprites: 1200/10000
Making Sprites: 1300/10000
Making Sprites: 1400/10000
Making Sprites: 1500/10000
Making Sprites: 1600/10000
Making Sprites: 1700/10000
Making Sprites: 1800/10000
Making Sprites: 1900/10000
Making Sprites: 2000/10000
Making Sprites: 2100/10000
Making Sprites: 2200/10000
Making Sprites: 2300/10000
Making Sprites: 2400/10000
Making Sprites: 2500/10000
Making Sprites: 2600/10000
Making Sprites: 2700/10000
Making Sprites: 2800/10000
Making Sprites: 2900/10000
Making Sprites: 3000/10000
Making Sprites: 3100/10000
Making Sprites: 3200/10000
Making Sprites: 3300/10000
Making Sprites: 3400/10000
Making Sprites: 3500/10000
Making Sprites: 3600/10000
Making Sprites: 3700/10000
Making Spr