<a href="https://colab.research.google.com/github/NitishaS-812k/comet/blob/master/comet_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Copyright (c) 2019, Adobe Inc. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
# 4.0 International Public License. To view a copy of this license, visit
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.

import torch
import torch.nn.parallel
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from IPython import embed

class Downsample(nn.Module):
    def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
        super(Downsample, self).__init__()
        self.filt_size = filt_size
        self.pad_off = pad_off
        self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
        self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
        self.stride = stride
        self.off = int((self.stride-1)/2.)
        self.channels = channels

        if(self.filt_size==1):
            a = np.array([1.,])
        elif(self.filt_size==2):
            a = np.array([1., 1.])
        elif(self.filt_size==3):
            a = np.array([1., 2., 1.])
        elif(self.filt_size==4):
            a = np.array([1., 3., 3., 1.])
        elif(self.filt_size==5):
            a = np.array([1., 4., 6., 4., 1.])
        elif(self.filt_size==6):
            a = np.array([1., 5., 10., 10., 5., 1.])
        elif(self.filt_size==7):
            a = np.array([1., 6., 15., 20., 15., 6., 1.])

        filt = torch.Tensor(a[:,None]*a[None,:])
        filt = filt/torch.sum(filt)
        self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))

        self.pad = get_pad_layer(pad_type)(self.pad_sizes)

    def forward(self, inp):
        if(self.filt_size==1):
            if(self.pad_off==0):
                return inp[:,:,::self.stride,::self.stride]
            else:
                return self.pad(inp)[:,:,::self.stride,::self.stride]
        else:
            return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])

def get_pad_layer(pad_type):
    if(pad_type in ['refl','reflect']):
        PadLayer = nn.ReflectionPad2d
    elif(pad_type in ['repl','replicate']):
        PadLayer = nn.ReplicationPad2d
    elif(pad_type=='zero'):
        PadLayer = nn.ZeroPad2d
    else:
        print('Pad type [%s] not recognized'%pad_type)
    return PadLayer

class Downsample1D(nn.Module):
    def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
        super(Downsample1D, self).__init__()
        self.filt_size = filt_size
        self.pad_off = pad_off
        self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
        self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
        self.stride = stride
        self.off = int((self.stride - 1) / 2.)
        self.channels = channels

        # print('Filter size [%i]' % filt_size)
        if(self.filt_size == 1):
            a = np.array([1., ])
        elif(self.filt_size == 2):
            a = np.array([1., 1.])
        elif(self.filt_size == 3):
            a = np.array([1., 2., 1.])
        elif(self.filt_size == 4):
            a = np.array([1., 3., 3., 1.])
        elif(self.filt_size == 5):
            a = np.array([1., 4., 6., 4., 1.])
        elif(self.filt_size == 6):
            a = np.array([1., 5., 10., 10., 5., 1.])
        elif(self.filt_size == 7):
            a = np.array([1., 6., 15., 20., 15., 6., 1.])

        filt = torch.Tensor(a)
        filt = filt / torch.sum(filt)
        self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1)))

        self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)

    def forward(self, inp):
        if(self.filt_size == 1):
            if(self.pad_off == 0):
                return inp[:, :, ::self.stride]
            else:
                return self.pad(inp)[:, :, ::self.stride]
        else:
            return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])

def get_pad_layer_1d(pad_type):
    if(pad_type in ['refl', 'reflect']):
        PadLayer = nn.ReflectionPad1d
    elif(pad_type in ['repl', 'replicate']):
        PadLayer = nn.ReplicationPad1d
    elif(pad_type == 'zero'):
        PadLayer = nn.ZeroPad1d
    else:
        print('Pad type [%s] not recognized' % pad_type)
    return PadLayer

In [3]:
from torch.nn import ModuleList
import torch.nn.functional as F
import torch.nn as nn
import torch
from easydict import EasyDict

def swish(x):
    return x * torch.sigmoid(x)

class View(nn.Module):
    def __init__(self, size):
        super(View, self).__init__()
        self.size = size

    def forward(self, tensor):
        return tensor.view(self.size)


def kaiming_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        torch.nn.init.kaiming_normal(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)


def reparametrize(mu, logvar):
    std = logvar.div(2).exp()
    eps = torch.autograd.Variable(std.data.new(std.size()).normal_())
    return mu + std*eps



class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation

        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #

    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N)
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)

        out = self.gamma*out + x
        return out,attention


class CondResBlock(nn.Module):
    def __init__(self, downsample=True, rescale=True, filters=64, latent_dim=64, im_size=64, latent_grid=False):
        super(CondResBlock, self).__init__()

        self.filters = filters
        self.latent_dim = latent_dim
        self.im_size = im_size
        self.downsample = downsample
        self.latent_grid = latent_grid

        if filters <= 128:
            self.bn1 = nn.InstanceNorm2d(filters, affine=False)
        else:
            self.bn1 = nn.GroupNorm(32, filters, affine=False)

        self.conv1 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)

        if filters <= 128:
            self.bn2 = nn.InstanceNorm2d(filters, affine=False)
        else:
            self.bn2 = nn.GroupNorm(32, filters, affine=False)


        torch.nn.init.normal_(self.conv2.weight, mean=0.0, std=1e-5)

        # Upscale to an mask of image
        self.latent_fc1 = nn.Linear(latent_dim, 2*filters)
        self.latent_fc2 = nn.Linear(latent_dim, 2*filters)

        # Upscale to mask of image
        if downsample:
            if rescale:
                self.conv_downsample = nn.Conv2d(filters, 2 * filters, kernel_size=3, stride=1, padding=1)
            else:
                self.conv_downsample = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)

            self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)

    def forward(self, x, latent):
        x_orig = x

        latent_1 = self.latent_fc1(latent)
        latent_2 = self.latent_fc2(latent)

        gain = latent_1[:, :self.filters, None, None]
        bias = latent_1[:, self.filters:, None, None]

        gain2 = latent_2[:, :self.filters, None, None]
        bias2 = latent_2[:, self.filters:, None, None]

        x = self.conv1(x)
        x = gain * x + bias
        x = swish(x)


        x = self.conv2(x)
        x = gain2 * x + bias2
        x = swish(x)

        x_out = x_orig + x

        if self.downsample:
            x_out = swish(self.conv_downsample(x_out))
            x_out = self.avg_pool(x_out)

        return x_out


class CondResBlockNoLatent(nn.Module):
    def __init__(self, downsample=True, rescale=True, filters=64, upsample=False):
        super(CondResBlockNoLatent, self).__init__()

        self.filters = filters
        self.downsample = downsample

        if filters <= 128:
            self.bn1 = nn.GroupNorm(int(32  * filters / 128), filters, affine=True)
        else:
            self.bn1 = nn.GroupNorm(32, filters, affine=False)

        self.conv1 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)

        if filters <= 128:
            self.bn2 = nn.GroupNorm(int(32 * filters / 128), filters, affine=True)
        else:
            self.bn2 = nn.GroupNorm(32, filters, affine=True)

        self.upsample = upsample
        self.upsample_module = nn.Upsample(scale_factor=2)
        # Upscale to mask of image
        if downsample:
            if rescale:
                self.conv_downsample = nn.Conv2d(filters, 2 * filters, kernel_size=3, stride=1, padding=1)
            else:
                self.conv_downsample = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)

            self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)

        if upsample:
            self.conv_downsample = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x_orig = x


        x = self.conv1(x)
        x = swish(x)

        x = self.conv2(x)
        x = swish(x)

        x_out = x_orig + x

        if self.upsample:
            x_out = self.upsample_module(x_out)
            x_out = swish(self.conv_downsample(x_out))

        if self.downsample:
            x_out = swish(self.conv_downsample(x_out))
            x_out = self.avg_pool(x_out)

        return x_out


class BroadcastConvDecoder(nn.Module):
    def __init__(self, im_size, latent_dim):
        super().__init__()
        self.im_size = im_size + 8
        self.latent_dim = latent_dim
        self.init_grid()

        self.g = nn.Sequential(
                    nn.Conv2d(self.latent_dim+2, 32, 3, 1, 0),
                    nn.ReLU(True),
                    nn.Conv2d(32, 32, 3, 1, 0),
                    nn.ReLU(True),
                    nn.Conv2d(32, 32, 3, 1, 0),
                    nn.ReLU(True),
                    nn.Conv2d(32, 32, 3, 1, 0),
                    nn.ReLU(True),
                    nn.Conv2d(32, self.latent_dim, 1, 1, 0)
                    )

    def init_grid(self):
        x = torch.linspace(0, 1, self.im_size)
        y = torch.linspace(0, 1, self.im_size)
        self.x_grid, self.y_grid = torch.meshgrid(x, y)


    def broadcast(self, z):
        b = z.size(0)
        x_grid = self.x_grid.expand(b, 1, -1, -1).to(z.device)
        y_grid = self.y_grid.expand(b, 1, -1, -1).to(z.device)
        z = z.view((b, -1, 1, 1)).expand(-1, -1, self.im_size, self.im_size)
        z = torch.cat((z, x_grid, y_grid), dim=1)
        return z

    def forward(self, z):
        z = self.broadcast(z)
        x = self.g(z)
        return x


class LatentEBM128(nn.Module):
    def __init__(self, args, dataset):
        super(LatentEBM128, self).__init__()

        filter_dim = args.filter_dim
        self.filter_dim = filter_dim
        latent_dim_expand = args.latent_dim * args.components
        latent_dim = args.latent_dim

        self.components = args.components

        n_instance = len(dataset)
        self.pos_embed = args.pos_embed

        if self.pos_embed:
            self.conv1 = nn.Conv2d(3, filter_dim // 2, kernel_size=3, stride=1, padding=1, bias=True)
            self.conv1_embed = nn.Conv2d(2, filter_dim // 2, kernel_size=3, stride=1, padding=1, bias=True)
        else:
            self.conv1 = nn.Conv2d(3, filter_dim // 4, kernel_size=3, stride=1, padding=1, bias=True)
        self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)

        self.gain = nn.Linear(args.latent_dim, filter_dim // 4)
        self.bias = nn.Linear(args.latent_dim, filter_dim // 4)

        self.recurrent_model = args.recurrent_model

        if args.dataset == "tetris":
            self.im_size = 35
        else:
            self.im_size = 64

        self.layer_encode = CondResBlock(filters=filter_dim//4, latent_dim=latent_dim, rescale=True)
        self.layer1 = CondResBlock(filters=filter_dim//2, latent_dim=latent_dim, rescale=True)
        self.layer2 = CondResBlock(filters=filter_dim, latent_dim=latent_dim, rescale=False)
        self.layer3 = CondResBlock(filters=filter_dim, latent_dim=latent_dim, rescale=False)
        self.layer4 = CondResBlock(filters=filter_dim, latent_dim=latent_dim)
        self.mask_decode = BroadcastConvDecoder(64, latent_dim)

        self.latent_map = nn.Linear(latent_dim, filter_dim * 8)
        self.energy_map = nn.Linear(filter_dim * 2, 1)

        self.embed_conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1)
        self.embed_layer1 = CondResBlockNoLatent(filters=filter_dim, rescale=False, downsample=True)
        self.embed_layer2 = CondResBlockNoLatent(filters=filter_dim, rescale=False, downsample=True)
        self.embed_layer3 = CondResBlockNoLatent(filters=filter_dim, rescale=False, downsample=True)

        self.decode_layer1 = CondResBlockNoLatent(filters=filter_dim, rescale=False, upsample=True, downsample=False)
        self.decode_layer2 = CondResBlockNoLatent(filters=filter_dim, rescale=False, upsample=True, downsample=False)
        self.decode_layer3 = CondResBlockNoLatent(filters=filter_dim, rescale=False, upsample=True, downsample=False)

        self.latent_decode = nn.Conv2d(filter_dim, latent_dim_expand, kernel_size=3, stride=1, padding=1)

        self.downsample = Downsample(channels=args.latent_dim)
        self.dataset = args.dataset

        if self.recurrent_model:
            self.embed_layer4 = CondResBlockNoLatent(filters=filter_dim, rescale=False, downsample=True)
            self.lstm = nn.LSTM(filter_dim, filter_dim, 1)
            self.embed_fc2 = nn.Linear(filter_dim, latent_dim)

            self.at_fc1 = nn.Linear(filter_dim*2, filter_dim)
            self.at_fc2 = nn.Linear(filter_dim, 1)

            self.map_embed = nn.Linear(filter_dim*2, filter_dim)

            if args.dataset == "tetris":
                self.pos_embedding = nn.Parameter(torch.zeros(9, filter_dim))
            else:
                self.pos_embedding = nn.Parameter(torch.zeros(16, filter_dim))
        else:
            self.embed_fc1 = nn.Linear(filter_dim, filter_dim)
            self.embed_fc2 = nn.Linear(filter_dim, latent_dim_expand)

        self.init_grid()

    def gen_mask(self, latent):
        return self.mask_decode(latent)

    def init_grid(self):
        x = torch.linspace(0, 1, self.im_size)
        y = torch.linspace(0, 1, self.im_size)
        self.x_grid, self.y_grid = torch.meshgrid(x, y)

    def embed_latent(self, im):
        x = self.embed_conv1(im)
        x = F.relu(x)
        x = self.embed_layer1(x)
        x = self.embed_layer2(x)
        x = self.embed_layer3(x)

        if self.recurrent_model:

            #if self.dataset != "clevr":
            x = self.embed_layer4(x)

            s = x.size()
            x = x.view(s[0], s[1], -1)
            x = x.permute(0, 2, 1).contiguous()
            pos_embed = self.pos_embedding

            # x = x + pos_embed[None, :, :]
            h = torch.zeros(1, im.size(0), self.filter_dim).to(x.device), torch.zeros(1, im.size(0), self.filter_dim).to(x.device)
            outputs = []

            for i in range(self.components):
                (sx, cx) = h

                cx = cx.permute(1, 0, 2).contiguous()
                context = torch.cat([cx.expand(-1, x.size(1), -1), x], dim=-1)
                at_wt = self.at_fc2(F.relu(self.at_fc1(context)))
                at_wt = F.softmax(at_wt, dim=1)
                inp = (at_wt * context).sum(dim=1, keepdim=True)
                inp = self.map_embed(inp)
                inp = inp.permute(1, 0, 2).contiguous()

                output, h = self.lstm(inp, h)
                outputs.append(output)

            output = torch.cat(outputs, dim=0)
            output = output.permute(1, 0, 2).contiguous()
            output = self.embed_fc2(output)
            s = output.size()
            output = output.view(s[0], -1)
        else:
            x = x.mean(dim=2).mean(dim=2)

            x = x.view(x.size(0), -1)
            output = self.embed_fc1(x)
            x = F.relu(self.embed_fc1(x))
            output = self.embed_fc2(x)

        return output

    def forward(self, x, latent):

        if self.pos_embed:
            b = x.size(0)
            x_grid = self.x_grid.expand(b, 1, -1, -1).to(x.device)
            y_grid = self.y_grid.expand(b, 1, -1, -1).to(x.device)
            coord_grid = torch.cat([x_grid, y_grid], dim=1)

        inter = self.conv1(x)
        inter = swish(inter)

        if self.pos_embed:
            pos_inter = self.conv1_embed(coord_grid)
            pos_inter = swish(pos_inter)

            inter = torch.cat([inter, pos_inter], dim=1)


        x = self.layer_encode(inter, latent)
        x = self.layer1(x, latent)


        x = self.layer2(x, latent)
        x = self.layer3(x, latent)
        x = self.layer4(x, latent)

        x = x.mean(dim=2).mean(dim=2)
        x = x.view(x.size(0), -1)

        energy = self.energy_map(x)

        return energy


class LatentEBM(nn.Module):
    def __init__(self, args, dataset):
        super(LatentEBM, self).__init__()

        filter_dim = args.filter_dim
        self.filter_dim = filter_dim
        latent_dim_expand = args.latent_dim * args.components
        latent_dim = args.latent_dim

        self.components = args.components

        n_instance = len(dataset)
        self.pos_embed = args.pos_embed

        if self.pos_embed:
            self.conv1 = nn.Conv2d(3, filter_dim // 2, kernel_size=3, stride=1, padding=1, bias=True)
            self.conv1_embed = nn.Conv2d(2, filter_dim // 2, kernel_size=3, stride=1, padding=1, bias=True)
        else:
            self.conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1, bias=True)

        self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)

        self.gain = nn.Linear(args.latent_dim, filter_dim)
        self.bias = nn.Linear(args.latent_dim, filter_dim)

        self.recurrent_model = args.recurrent_model

        if args.dataset == "tetris":
            self.im_size = 35
        else:
            self.im_size = 64

        self.layer_encode = CondResBlock(filters=filter_dim, latent_dim=latent_dim, rescale=False)
        self.layer1 = CondResBlock(filters=filter_dim, latent_dim=latent_dim, rescale=False)
        self.layer2 = CondResBlock(filters=filter_dim, latent_dim=latent_dim)
        self.mask_decode = BroadcastConvDecoder(64, latent_dim)

        self.latent_map = nn.Linear(latent_dim, filter_dim * 8)
        self.energy_map = nn.Linear(filter_dim * 2, 1)

        self.embed_conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1)
        self.embed_layer1 = CondResBlockNoLatent(filters=filter_dim, rescale=False, downsample=True)
        self.embed_layer2 = CondResBlockNoLatent(filters=filter_dim, rescale=False, downsample=True)
        self.embed_layer3 = CondResBlockNoLatent(filters=filter_dim, rescale=False, downsample=True)

        self.decode_layer1 = CondResBlockNoLatent(filters=filter_dim, rescale=False, upsample=True, downsample=False)
        self.decode_layer2 = CondResBlockNoLatent(filters=filter_dim, rescale=False, upsample=True, downsample=False)
        self.decode_layer3 = CondResBlockNoLatent(filters=filter_dim, rescale=False, upsample=True, downsample=False)

        self.latent_decode = nn.Conv2d(filter_dim, latent_dim_expand, kernel_size=3, stride=1, padding=1)

        self.downsample = Downsample(channels=args.latent_dim)
        self.dataset = args.dataset

        if self.recurrent_model:
            self.embed_layer4 = CondResBlockNoLatent(filters=filter_dim, rescale=False, downsample=True)
            self.lstm = nn.LSTM(filter_dim, filter_dim, 1)
            self.embed_fc2 = nn.Linear(filter_dim, latent_dim)

            self.at_fc1 = nn.Linear(filter_dim*2, filter_dim)
            self.at_fc2 = nn.Linear(filter_dim, 1)

            self.map_embed = nn.Linear(filter_dim*2, filter_dim)

            if args.dataset == "tetris":
                self.pos_embedding = nn.Parameter(torch.zeros(9, filter_dim))
            else:
                self.pos_embedding = nn.Parameter(torch.zeros(16, filter_dim))
        else:
            self.embed_fc1 = nn.Linear(filter_dim, filter_dim)
            self.embed_fc2 = nn.Linear(filter_dim, latent_dim_expand)

        self.init_grid()

    def gen_mask(self, latent):
        return self.mask_decode(latent)

    def init_grid(self):
        x = torch.linspace(0, 1, self.im_size)
        y = torch.linspace(0, 1, self.im_size)
        self.x_grid, self.y_grid = torch.meshgrid(x, y)

    def embed_latent(self, im):
        x = self.embed_conv1(im)
        x = F.relu(x)
        x = self.embed_layer1(x)
        x = self.embed_layer2(x)
        x = self.embed_layer3(x)

        if self.recurrent_model:

            x = self.embed_layer4(x)

            s = x.size()
            x = x.view(s[0], s[1], -1)
            x = x.permute(0, 2, 1).contiguous()

            h = torch.zeros(1, im.size(0), self.filter_dim).to(x.device), torch.zeros(1, im.size(0), self.filter_dim).to(x.device)
            outputs = []

            for i in range(self.components):
                (sx, cx) = h

                cx = cx.permute(1, 0, 2).contiguous()
                context = torch.cat([cx.expand(-1, x.size(1), -1), x], dim=-1)
                at_wt = self.at_fc2(F.relu(self.at_fc1(context)))
                at_wt = F.softmax(at_wt, dim=1)
                inp = (at_wt * context).sum(dim=1, keepdim=True)
                inp = self.map_embed(inp)
                inp = inp.permute(1, 0, 2).contiguous()

                output, h = self.lstm(inp, h)
                outputs.append(output)

            output = torch.cat(outputs, dim=0)
            output = output.permute(1, 0, 2).contiguous()
            output = self.embed_fc2(output)
            s = output.size()
            output = output.view(s[0], -1)
        else:
            x = x.mean(dim=2).mean(dim=2)

            x = x.view(x.size(0), -1)
            output = self.embed_fc1(x)
            x = F.relu(self.embed_fc1(x))
            output = self.embed_fc2(x)

        return output

    def forward(self, x, latent):

        if self.pos_embed:
            b = x.size(0)
            x_grid = self.x_grid.expand(b, 1, -1, -1).to(x.device)
            y_grid = self.y_grid.expand(b, 1, -1, -1).to(x.device)
            coord_grid = torch.cat([x_grid, y_grid], dim=1)

        # x = x.contiguous()
        inter = self.conv1(x)
        inter = swish(inter)

        if self.pos_embed:
            pos_inter = self.conv1_embed(coord_grid)
            pos_inter = swish(pos_inter)

            inter = torch.cat([inter, pos_inter], dim=1)

        x = self.avg_pool(inter)

        x = self.layer_encode(x, latent)

        x = self.layer1(x, latent)
        x = self.layer2(x, latent)

        x = x.mean(dim=2).mean(dim=2)
        x = x.view(x.size(0), -1)

        energy = self.energy_map(x)

        return energy

class ToyEBM(nn.Module):
    def __init__(self, args, dataset):
        super(ToyEBM, self).__init__()

        filter_dim = args.filter_dim
        latent_dim_expand = args.latent_dim * args.components
        latent_dim = args.latent_dim
        im_size = 64

        n_instance = len(dataset)

        self.conv1 = nn.Conv2d(3, filter_dim, kernel_size=3, stride=1, padding=1, bias=True)
        self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
        self.layer_encode = CondResBlock(filters=filter_dim, latent_dim=latent_dim, rescale=False)
        self.layer1 = CondResBlock(filters=filter_dim, latent_dim=latent_dim, rescale=False)
        self.layer2 = CondResBlock(filters=filter_dim, latent_dim=latent_dim)


        self.fc1 = nn.Linear(filter_dim * 2, filter_dim * 2)

        self.latent_map = nn.Linear(latent_dim, filter_dim * 8)
        self.energy_map = nn.Linear(filter_dim * 2, 1)

        self.embed_conv1 = nn.Conv2d(3, filter_dim, kernel_size=5, stride=2, padding=3)
        self.embed_layer1 = CondResBlockNoLatent(filters=filter_dim, rescale=False)
        self.embed_layer2 = CondResBlockNoLatent(filters=filter_dim)
        self.embed_fc1 = nn.Linear(filter_dim * 2, filter_dim * 2)
        self.embed_fc2 = nn.Linear(filter_dim * 2, latent_dim_expand)

        self.steps = torch.nn.parameter.Parameter(torch.ones(args.num_steps), requires_grad=True)


    def embed_latent(self, im):
        x = self.embed_conv1(im)
        x = F.relu(x)
        x = self.embed_layer1(x)
        x = self.embed_layer2(x)
        x = x.mean(dim=2).mean(dim=2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.embed_fc1(x))
        output = self.embed_fc2(x)

        return output

    def forward(self, x, latent):
        x = swish(self.conv1(x))
        x = self.avg_pool(x)
        x = self.layer_encode(x, latent)
        x = self.layer1(x, latent)
        x = self.layer2(x, latent)
        x = x.mean(dim=2).mean(dim=2)
        x = x.view(x.size(0), -1)


        x = swish(self.fc1(x))
        energy = self.energy_map(x)

        return energy


class DisentangleModel(nn.Module):
    def __init__(self):
        super(DisentangleModel, self).__init__()

        self.fc1 = nn.Linear(256, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 6)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x


class PMM(nn.Module):
    def __init__(self, args):
        super(PMM, self).__init__()

        latent_dim = args.latent_dim * args.components
        self.latent_dim = latent_dim
        self.inner_dim = 1024

        self.fc1 = nn.Linear(self.latent_dim, self.inner_dim)
        self.fc2 = nn.Linear(self.inner_dim, self.inner_dim)
        self.fc3 = nn.Linear(self.inner_dim, self.inner_dim)
        self.fc4 = nn.Linear(self.inner_dim, 1)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = torch.sigmoid(self.fc4(x))

        return x


class LabelModel(nn.Module):
    def __init__(self, args):
        super(LabelModel, self).__init__()

        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 5)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)

        return x

class BetaVAE_H(nn.Module):
    """Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""

    def __init__(self, z_dim=10, nc=3):
        super(BetaVAE_H, self).__init__()
        self.z_dim = z_dim
        self.nc = nc
        self.encoder = nn.Sequential(
            nn.Conv2d(nc, 32, 4, 2, 1),          # B,  32, 32, 32
            nn.ReLU(True),
            nn.Conv2d(32, 32, 4, 2, 1),          # B,  32, 16, 16
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2, 1),          # B,  64,  8,  8
            nn.ReLU(True),
            nn.Conv2d(64, 64, 4, 2, 1),          # B,  64,  4,  4
            nn.ReLU(True),
            nn.Conv2d(64, 256, 4, 1),            # B, 256,  1,  1
            nn.ReLU(True),
            View((-1, 256*1*1)),                 # B, 256
            nn.Linear(256, z_dim*2),             # B, z_dim*2
        )
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),               # B, 256
            View((-1, 256, 1, 1)),               # B, 256,  1,  1
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 64, 4),      # B,  64,  4,  4
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 64, 4, 2, 1), # B,  64,  8,  8
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), # B,  32, 16, 16
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 32, 4, 2, 1), # B,  32, 32, 32
            nn.ReLU(True),
            nn.ConvTranspose2d(32, nc, 4, 2, 1),  # B, nc, 64, 64
        )

        self.weight_init()

    def weight_init(self):
        for block in self._modules:
            for m in self._modules[block]:
                kaiming_init(m)

    def forward(self, x, return_z=False):
        distributions = self._encode(x)
        mu = distributions[:, :self.z_dim]
        logvar = distributions[:, self.z_dim:]
        z = reparametrize(mu, logvar)
        x_recon = self._decode(z)

        if return_z:
            return x_recon, mu, logvar, z
        else:
            return x_recon, mu, logvar

    def _encode(self, x):
        return self.encoder(x)

    def _decode(self, z):
        return self.decoder(z)

    def reconstruction_loss(self, x, x_recon, distribution):
        batch_size = x.size(0)
        assert batch_size != 0

        if distribution == 'bernoulli':
            recon_loss = F.binary_cross_entropy_with_logits(
                x_recon, x, size_average=False).div(batch_size)
        elif distribution == 'gaussian':
            x_recon = F.sigmoid(x_recon)
            recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
        else:
            recon_loss = None

        return recon_loss

    def compute_cross_ent_normal(self, mu, logvar):
        return 0.5 * (mu**2 + torch.exp(logvar)) + np.log(np.sqrt(2 * np.pi))

    def compute_ent_normal(self, logvar):
        return 0.5 * (logvar + np.log(2 * np.pi * np.e))

'''
if __name__ == "__main__":
    args = EasyDict()
    args.filter_dim = 64
    args.latent_dim = 64
    args.im_size = 256

    model = LatentEBM(args).cuda()
    x = torch.zeros(1, 3, 256, 256).cuda()
    latent = torch.zeros(1, 64).cuda()
    model(x, latent)
  '''

'\nif __name__ == "__main__":\n    args = EasyDict()\n    args.filter_dim = 64\n    args.latent_dim = 64\n    args.im_size = 256\n\n    model = LatentEBM(args).cuda()\n    x = torch.zeros(1, 3, 256, 256).cuda()\n    latent = torch.zeros(1, 64).cuda()\n    model(x, latent)\n  '

In [1]:
!tar -xvzf /content/images_clevr.tar.gz

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
images_clevr/CLEVR_new_006654.png
images_clevr/CLEVR_new_007014.png
images_clevr/CLEVR_new_004101.png
images_clevr/CLEVR_new_008166.png
images_clevr/CLEVR_new_009726.png
images_clevr/CLEVR_new_005741.png
images_clevr/CLEVR_new_009081.png
images_clevr/CLEVR_new_007807.png
images_clevr/CLEVR_new_007475.png
images_clevr/CLEVR_new_006235.png
images_clevr/CLEVR_new_003919.png
images_clevr/CLEVR_new_006592.png
images_clevr/CLEVR_new_009347.png
images_clevr/CLEVR_new_005320.png
images_clevr/CLEVR_new_009892.png
images_clevr/CLEVR_new_005487.png
images_clevr/CLEVR_new_004560.png
images_clevr/CLEVR_new_004912.png
images_clevr/CLEVR_new_008975.png
images_clevr/CLEVR_new_008507.png
images_clevr/CLEVR_new_001599.png
images_clevr/CLEVR_new_005681.png
images_clevr/CLEVR_new_009141.png
images_clevr/CLEVR_new_005126.png
images_clevr/CLEVR_new_000678.png
images_clevr/CLEVR_new_001038.png
images_clevr/CLEVR_new_004766.png
images_clevr/CLEV

In [4]:
import os
import os.path as osp
import numpy as np
import json

import torchvision.transforms.functional as TF
import random

from PIL import Image
import torch.utils.data as data
import torch
import cv2
from torchvision import transforms
import glob

try:
    import multi_dsprites
    import tetrominoes
    import tensorflow.compat.v1 as tf
    tf.disable_v2_behavior()
except:
    pass

from glob import glob

from imageio import imread
from skimage.transform import resize as imresize
from torchvision.datasets import FashionMNIST


class GaussianBlur(object):

    def __init__(self, min=0.1, max=2.0, kernel_size=9):
        self.min = min
        self.max = max
        self.kernel_size = kernel_size

    def __call__(self, sample):
        sample = np.array(sample)

        # blur the image with a 50% chance
        prob = np.random.random_sample()

        if prob < 0.5:
            sigma = (self.max - self.min) * np.random.random_sample() + self.min
            sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)

        return sample


class IntPhysDataset(data.Dataset):
    """This dataset class can load a set of images specified by the path --dataroot /path/to/data.

    It can be used for generating CycleGAN results only for one side with the model option '-model test'.
    """

    def __init__(self, args):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        train_path = "/private/home/yilundu/dataset/intphys"
        # train_path = "/data/vision/billf/scratch/jerrymei/newIntPhys/render/output/train_v7"
        # random.seed(rank_idx)

        p = train_path

        dirs = os.listdir(p)
        files = []
        depth_files = []

        for d in dirs:
            base_path = osp.join(p, d, 'imgs')
            ims = os.listdir(base_path)
            ims = sorted(ims)
            ims = ims

            im_paths = [osp.join(base_path, im) for im in ims]
            files.append(im_paths)

        self.args = args
        self.A_paths = files
        self.D_paths = depth_files
        self.frames = 2
        self.im_size = args.im_size
        self.temporal = args.temporal


    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A and A_paths
            A(tensor) - - an image in one domain
            A_paths(str) - - the path of the image
        """

        args = self.args

        if args.single:
            index = 0

        index = index % len(self.A_paths)

        A_path = self.A_paths[index]
        ix = random.randint(0, len(A_path) - 20)

        ix_next = ix + random.randint(0, 19)

        im = imread(A_path[ix])[:, :, :3]
        im_next = imread(A_path[ix_next])[:, :, :3]

        im = imresize(im, (64, 64))[:, :, :3]
        im_next = imresize(im_next, (64, 64))[:, :, :3]

        im = torch.Tensor(im).permute(2, 0, 1)
        im_next = torch.Tensor(im_next).permute(2, 0, 1)

        if self.temporal:
            im = torch.stack([im, im_next], dim=0)

        return im, index


    def __len__(self):
        """Return the total number of images in the dataset."""
        return 1000000


class ToyDataset(data.Dataset):
    """Constructs a dataset with N circles, N is the number of components set by flags"""

    def __init__(self, opt):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.opt = opt
        self.components = 3

        # self.nsample = 10000
        self.samples = []



    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A and A_paths
            A(tensor) - - an image in one domain
            A_paths(str) - - the path of the image
        """

        FLAGS = self.opt
        A = np.zeros((3, 64, 64))
        dot_size = 5

        intersects = []

        for i in range(self.components):
            while True:
                x, y = random.randint(dot_size, 64 - dot_size), random.randint(dot_size, 64 - dot_size)

                valid = True
                for xi, yi in intersects:
                    if (abs(x - xi) < 2 * dot_size)  and (abs(y - yi) < 2 * dot_size):
                        valid = False
                        break

                if valid:
                    A[i, x-dot_size:x+dot_size, y-dot_size:y+dot_size] = 0.8
                    intersects.append((x, y))
                    break

        A = torch.Tensor(A)

        return A, index

    def __len__(self):
        """Return the total number of images in the dataset."""
        # Dataset is always randomly generated
        return 1000000


# class Blender(data.Dataset):
#     def __init__(self, stage=0):
#         self.path = "/private/home/yilundu/sandbox/data/blender/continual/test_rotation"
#
#     def __len__(self):
#         return 10000
#
#     def __getitem__(self, index):
#         im = imread(osp.join(self.path, "r_{}.png".format(index)))
#         im = imresize(im, (64, 64))[:, :, :3]
#         im = im / 255.
#
#         im = torch.Tensor(im).permute(2, 0, 1)
#
#         return im, index

# class Blender(data.Dataset):
#     def __init__(self, stage=0):
#         self.path = "/private/home/yilundu/sandbox/data/CLEVR_v1.0/images/train"
#         self.images = glob.glob(self.path + "/*.png")
#
#     def __len__(self):
#         return len(self.images)
#
#     def __getitem__(self, index):
#         im_path = self.images[index]
#         im = imread(im_path)
#         im = imresize(im, (64, 64))[:, :, :3]
#         im = im / 255.
#
#         im = torch.Tensor(im).permute(2, 0, 1)
#
#         return im, index

class Blender(data.Dataset):
    def __init__(self, stage=0):
        self.path = "/private/home/yilundu/dataset/shop_vrb/images/train"
        self.images = glob.glob(self.path + "/*.png")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        im_path = self.images[index]
        im = imread(im_path)
        im = imresize(im, (64, 64))[:, :, :3]
        im = im / 255.

        im = torch.Tensor(im).permute(2, 0, 1)

        return im, index

class FashionMnistDataset(data.Dataset):
    """
    A PyTorch Dataset class for FashionMNIST, similar to the other datasets
    in this file. It leverages the torchvision FashionMNIST dataset.
    """

    def __init__(self, args, root, train=True, transform=None, download=True, im_size=64):
        """
        Initializes the FashionMNIST dataset.

        Args:
            root (string): Root directory where the FashionMNIST data is stored.
            train (bool, optional): If True, creates dataset from training set, otherwise
                                     creates from test set. Defaults to True.
            transform (callable, optional): A function/transform that takes in an PIL image
                                            and returns a transformed version.
                                            E.g, ``transforms.RandomCrop``
            download (bool, optional): If true, downloads the dataset from the internet and
                                      puts it in root directory.  If dataset is already downloaded,
                                      it is not downloaded again.
            im_size (int, optional):  The size to resize the images to. Defaults to 64.
        """


        self.root = "/content/sample_data"
        self.fashion_mnist = FashionMNIST(root=root, train=train, download=download, transform=transform)  # Download and load FashionMNIST
        self.im_size = im_size
        self.transform = transform if transform is not None else transforms.ToTensor() # default to ToTensor if no transform provided
        #self.resize_transform = transforms.Resize((im_size, im_size)) # Resize images to desired size
        #self.to_pil = transforms.ToPILImage() # Convert tensor to PIL for resizing
        self.transform = transform
    def __len__(self):
        """Returns the total number of images in the dataset."""
        return len(self.fashion_mnist)

    def __getitem__(self, index):
        """
        Returns a data point and its metadata information.

        Args:
            index (int): Index of the data point.

        Returns:
            tuple: (image, label) where:
                image (torch.Tensor):  FashionMNIST image as a PyTorch tensor, resized to im_size x im_size.
                label (int):  Corresponding class label.
        """
        #image, label = self.fashion_mnist[index]
        # Convert to PIL image, resize, and then convert to tensor.  Handles grayscale correctly.
        #image = self.to_pil(image)
        #image = self.resize_transform(image) # Resize the image
        #print(type(image))
        #image = self.transform(image)
        #print(type(image))
        #image = self.transform(image) # Apply the transform
        (image, label) = self.fashion_mnist.__getitem__(index)
        image = self.transform(image)
        print(type(image))
        return image, label

class Cub(data.Dataset):
    def __init__(self, stage=0):
        self.path = "/private/home/yilundu/sandbox/data/CUB/images/*/*.jpg"
        self.images = glob.glob(self.path)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        im_path = self.images[index]
        im = imread(im_path)
        im = imresize(im, (64, 64))[:, :, :3]
        im = im / 255.

        im = torch.Tensor(im).permute(2, 0, 1)

        return im, index


class Nvidia(data.Dataset):
    def __init__(self, stage=0, filter_light=False):
        self.path = "/data/vision/billf/scratch/yilundu/dataset/disentanglement/Falcor3D_down128/images/{:06}.png"
        self.labels = np.load("/data/vision/billf/scratch/yilundu/dataset/disentanglement/Falcor3D_down128/train-rec.labels")
        label_mask = (self.labels[:, 0] > 0) & (self.labels[:, 0] < 1)
        idxs = np.arange(self.labels.shape[0])

        self.filter_light = filter_light

        # if self.filter_light:
        #     self.idxs = idxs[label_mask]
        # else:
        self.idxs = idxs

    def __len__(self):
        return self.idxs.shape[0]

    def __getitem__(self, index):
        index = self.idxs[index]
        im_path = self.path.format(index)
        # im_path = self.images[index]
        im = imread(im_path)
        im = imresize(im, (128, 128))[:, :, :3][:, :, ::-1].copy()
        im = im

        im = torch.Tensor(im).permute(2, 0, 1)

        return im, index


class NvidiaDisentangle(data.Dataset):
    def __init__(self, stage=0, filter_light=True):
        self.path = "/data/vision/billf/scratch/yilundu/dataset/disentanglement/Falcor3D_down128/images/{:06}.png"
        self.labels = np.load("/data/vision/billf/scratch/yilundu/dataset/disentanglement/Falcor3D_down128/train-rec.labels")
        idxs = np.arange(self.labels.shape[0])
        self.idxs = idxs

    def __len__(self):
        return self.idxs.shape[0]

    def __getitem__(self, index):
        index = self.idxs[index]
        im_path = self.path.format(index)
        # im_path = self.images[index]
        im = imread(im_path)
        im = imresize(im, (128, 128))[:, :, :3]
        im = im

        # label = int(self.labels[index, 0] * 5)
        label = self.labels[index, 1:]
        im = torch.Tensor(im).permute(2, 0, 1)

        return im, label

class Clevr(data.Dataset):
    def __init__(self, stage=0):
        #self.path = "/data/vision/billf/scratch/yilundu/dataset/clevr/images_clevr/*.png"
        self.path = "/content/images_clevr/*.png"
        self.images = sorted(glob(self.path))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        im_path = self.images[index]
        im = imread(im_path)
        im = imresize(im, (64, 64))[:, :, :3]

        im = torch.Tensor(im).permute(2, 0, 1)

        return im, index


class ClevrLighting(data.Dataset):
    def __init__(self, stage=0):
        self.path = "/data/vision/billf/scratch/yilundu/dataset/clevr_lighting/images_large_lighting/*.png"
        self.images = sorted(glob(self.path))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        im_path = self.images[index]
        im = imread(im_path)
        im = imresize(im, (64, 64))[:, :, :3]

        im = torch.Tensor(im).permute(2, 0, 1)

        return im, index


class Exercise(data.Dataset):
    def __init__(self, args):
        self.temporal = args.temporal
        self.path = "/private/home/yilundu/sandbox/data/release_data_set/images/*_im1.png"
        self.images = glob(self.path)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        try:
            im_path = self.images[index]
            im_path_next = im_path.replace("im1", "im2")
            im = imread(im_path)
            im_next = imread(im_path_next)
            im = imresize(im, (64, 64))[:, :, :3]
            im_next = imresize(im_next, (64, 64))[:, :, :3]

            im = torch.Tensor(im).permute(2, 0, 1)
            im_next = torch.Tensor(im_next).permute(2, 0, 1)

            if self.temporal:
                im = torch.stack([im, im_next], dim=0)

            return im, index
        except:
            return self.__getitem__((index + 1) % len(self.images))


class CelebaHQ(data.Dataset):
    def __init__(self, resolution=64):
        self.name = 'celebahq'
        self.channels = 3
        self.paths = glob("/data/vision/billf/scratch/yilundu/dataset/celebahq/data128x128/*.jpg")
        self.resolution = resolution

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        frame = imread(path)
        frame = imresize(frame, (self.resolution, self.resolution))[:, :, :3]

        im = torch.Tensor(frame).permute(2, 0, 1)

        return im, index


class Airplane(data.Dataset):
    def __init__(self, stage=0):
        # self.path = "/private/home/yilundu/sandbox/video_ebm/dataset/images/*.png"
        self.name = 'celebahq'
        self.channels = 3
        # self.path = "/datasets01/celebAHQ/081318/imgHQ{:05}.npy"
        self.paths = glob("/data/vision/billf/scratch/yilundu/nerf-pytorch/large_render/*.png")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        frame = imread(path)
        frame = imresize(frame, (64, 64))[:, :, :3]

        im = torch.Tensor(frame).permute(2, 0, 1)

        return im, index

class Anime(data.Dataset):
    def __init__(self, stage=0):
        # self.path = "/private/home/yilundu/sandbox/video_ebm/dataset/images/*.png"
        self.name = 'celebahq'
        self.channels = 3
        # self.path = "/datasets01/celebAHQ/081318/imgHQ{:05}.npy"
        anime_paths = sorted(glob("/data/vision/billf/scratch/yilundu/dataset/anime/cropped/*.jpg"))[:30000]
        self.paths = anime_paths

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        try:
            path = self.paths[index]
            frame = imread(path)
            frame = cv2.resize(frame, (64, 64), interpolation=cv2.INTER_AREA)

            im = torch.Tensor(frame).permute(2, 0, 1) / 255.

            return im, index
        except:
            ix = random.randint(0, len(self.paths) - 1)
            return self.__getitem__(ix)


class Faces(data.Dataset):
    def __init__(self, stage=0):
        # self.path = "/private/home/yilundu/sandbox/video_ebm/dataset/images/*.png"
        self.name = 'celebahq'
        self.channels = 3
        # self.path = "/datasets01/celebAHQ/081318/imgHQ{:05}.npy"
        paths = sorted(glob("/data/vision/billf/scratch/yilundu/dataset/celebahq/data128x128/*.jpg"))
        anime_paths = sorted(glob("/data/vision/billf/scratch/yilundu/dataset/anime/cropped/*.jpg"))[:30000]
        paths = list(paths) + anime_paths
        random.shuffle(paths)
        self.paths = paths

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        try:
            path = self.paths[index]
            frame = imread(path)
            frame = cv2.resize(frame, (64, 64), interpolation=cv2.INTER_AREA)

            im = torch.Tensor(frame).permute(2, 0, 1) / 255.

            return im, index
        except:
            ix = random.randint(0, len(self.paths) - 1)
            return self.__getitem__(ix)


class DSprites(data.Dataset):
    """Constructs a dataset with N circles, N is the number of components set by flags"""

    def __init__(self, opt):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.opt = opt
        self.components = opt.components
        self.data = np.load("dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz")['imgs']
        self.n = self.data.shape[0]

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A and A_paths
            A(tensor) - - an image in one domain
            A_paths(str) - - the path of the image
        """

        A = np.zeros((3, 64, 64))
        ix = random.randint(0, self.n-1)
        im = self.data[ix]
        for i in range(3):
            A[i] = im

        A = torch.Tensor(A)

        return A, index

    def __len__(self):
        """Return the total number of images in the dataset."""
        # Dataset is always randomly generated
        return 100000

class MultiDspritesLoader():

    def __init__(self, batchsize):
        tf_records_path = 'dataset/multi_dsprites_colored_on_colored.tfrecords'
        batch_size = batchsize

        dataset = multi_dsprites.dataset(tf_records_path, 'colored_on_colored')
        batched_dataset = dataset.batch(batch_size)  # optional batching
        iterator = batched_dataset.make_one_shot_iterator()
        self.data = iterator.get_next()
        self.sess = tf.InteractiveSession()

    def __iter__(self):
        return self

    def __next__(self):
        d = self.sess.run(self.data)
        img = d['image']
        img = img.transpose((0, 3, 1, 2))
        img = img / 255.
        img = torch.Tensor(img)

        return img, torch.ones(1)

    def __len__(self):
        return 1e6


class TetrominoesLoader():

    def __init__(self, batchsize):
        # tf_records_path = '/home/yilundu/my_repos/dataset/tetrominoes_train.tfrecords'
        tf_records_path = '/home/gridsan/yilundu/my_files/ebm_video/dataset/tetrominoes_train.tfrecords'
        batch_size = batchsize

        dataset = tetrominoes.dataset(tf_records_path)
        batched_dataset = dataset.batch(batch_size)  # optional batching
        iterator = batched_dataset.make_one_shot_iterator()
        self.data = iterator.get_next()
        config = tf.ConfigProto(
                device_count = {'GPU': 0}
            )
        self.sess = tf.InteractiveSession(config=config)

    def __iter__(self):
        return self

    def __next__(self):
        d = self.sess.run(self.data)
        img = d['image']
        img = img.transpose((0, 3, 1, 2))
        img = img / 255.
        img = torch.Tensor(img).contiguous()

        return img, torch.ones(1)

    def __len__(self):
        return 1e6

class TFImagenetLoader(data.Dataset):

    def __init__(self, split, batchsize, idx, num_workers, return_label=False):
        IMAGENET_NUM_TRAIN_IMAGES = 1281167
        IMAGENET_NUM_VAL_IMAGES = 50000
        self.return_label = return_label

        if split == "train":
            im_length = IMAGENET_NUM_TRAIN_IMAGES
        else:
            im_length = IMAGENET_NUM_VAL_IMAGES

        self.curr_sample = 0

        index_path = osp.join('/data/vision/billf/scratch/yilundu/imagenet', 'index.json')
        with open(index_path) as f:
            metadata = json.load(f)
            counts = metadata['record_counts']

        if split == 'train':
            files = list(sorted([x for x in counts.keys() if x.startswith('train')]))
        else:
            files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))

        files = [osp.join('/data/vision/billf/scratch/yilundu/imagenet', x) for x in files]
        preprocess_function = ImagenetPreprocessor(224, dtype=tf.float32, train=False).parse_and_preprocess

        ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
        ds = ds.apply(tf.data.TFRecordDataset)
        ds = ds.take(im_length)
        # ds = ds.prefetch(buffer_size=4)
        ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
        ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=batchsize, num_parallel_batches=4))
        ds = ds.prefetch(buffer_size=2)

        ds_iterator = ds.make_initializable_iterator()
        labels, images = ds_iterator.get_next()
        self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
        self.labels = labels

        config = tf.ConfigProto(device_count = {'GPU': 0})
        sess = tf.Session(config=config)
        sess.run(ds_iterator.initializer)

        # self.im_length = im_length // batchsize
        self.im_length = im_length

        self.sess = sess

    def __next__(self):
        self.curr_sample += 1

        sess = self.sess

        label, im = sess.run([self.labels, self.images])
        label = label.squeeze() - 1
        im = torch.from_numpy(im).permute((0, 3, 1, 2))
        label = torch.LongTensor(label)

        if self.return_label:
            return im, label
        else:
            return im[:, None, :]

    def __iter__(self):
        return self

    def __len__(self):
        return self.im_length


class TFTaskAdaptation(data.Dataset):

    def __init__(self, split, batchsize):
        data_params = {
            # "dataset": "data." + "clevr(task='count_all')",
            "dataset": "data." + "svhn()",
            "dataset_train_split_name": "trainval",
            "dataset_eval_split_name": "test",
            "shuffle_buffer_size": 10000,
            "prefetch": True,
            "train_examples": None,
            "batch_size": batchsize,
            "batch_size_eval": batchsize,
            "data_for_eval": split == "test",
            "data_dir": "/private/home/yilundu/tensorflow_datasets",
            "input_range": [0.0, 1.0]
        }
        ds = build_data_pipeline(data_params, split)
        ds = ds({'batch_size': batchsize})
        # ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
        # ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
        # ds = ds.prefetch(buffer_size=2)

        ds_iterator = tf.data.make_initializable_iterator(ds)
        outputs = ds_iterator.get_next()
        image, label = outputs['image'], outputs['label']
        self.images = image
        self.labels = label
        self.split = split

        config = tf.ConfigProto(device_count = {'GPU': 0})
        sess = tf.Session(config=config)
        sess.run(ds_iterator.initializer)

        self.im_length = 1000
        self.curr_sample = 0

        self.sess = sess

    def __next__(self):
        self.curr_sample += 1

        sess = self.sess
        label, im = sess.run([self.labels, self.images])

        if self.split == "train":
            im = im[:, 0].transpose((0, 1, 4, 2, 3))
        else:
            im = im.transpose((0, 3, 1, 2))

        if self.curr_sample == 1000:
            self.curr_sample = 0
            raise StopIteration

        return [torch.Tensor(im[:]), torch.Tensor(label).long()]

    def __iter__(self):
        return self

    def __len__(self):
        return self.im_length

class CubesColor(data.Dataset):
    """Constructs a dataset with N circles, N is the number of components set by flags"""

    def __init__(self, opt, return_label=False, train=True):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.data = np.load("/private/home/yilundu/dataset/cubes_varied_position_812.npz")
        self.ims = np.array(self.data['ims'])
        self.labels = np.array(self.data['labels'])
        self.return_label = return_label
        self.opt = opt

        n = self.ims.shape[0]
        split_idx = int(0.9 * n)

        if train:
            self.ims = self.ims[:split_idx]
            self.labels = self.labels[:split_idx]
        else:
            self.ims = self.ims[split_idx:]
            self.labels = self.labels[split_idx:]


    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A and A_paths
            A(tensor) - - an image in one domain
            A_paths(str) - - the path of the image
        """
        opt = self.opt

        im = np.array(self.ims[index])
        label = torch.FloatTensor(np.array(self.labels[index]))
        im = imresize(im, (opt.im_size, opt.im_size))
        s = im.shape
        im = im.transpose((2, 0, 1)) / 256 + np.random.uniform(0, 1, (s[2], s[0], s[1])) / 256
        im = torch.FloatTensor(im[None, :])

        if self.return_label:
            return im, label
        else:
            return im

    def __len__(self):
        """Return the total number of images in the dataset."""
        # Dataset is always randomly generated
        return self.ims.shape[0]


class CubesColorPair(data.Dataset):
    """Constructs a dataset with N circles, N is the number of components set by flags"""

    def __init__(self, opt, return_label=False, train=True):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.data = np.load("/private/home/yilundu/dataset/cubes_varied_multi_311.npz")
        self.ims = np.array(self.data['ims'])
        self.labels = np.array(self.data['labels'])
        self.return_label = return_label
        self.opt = opt

        n = self.ims.shape[0]
        split_idx = int(0.9 * n)

        if train:
            self.ims = self.ims[:split_idx]
            self.labels = self.labels[:split_idx]
        else:
            self.ims = self.ims[split_idx:]
            self.labels = self.labels[split_idx:]


    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A and A_paths
            A(tensor) - - an image in one domain
            A_paths(str) - - the path of the image
        """
        opt = self.opt

        im = self.ims[index]
        label = np.array(self.labels[index])
        im = imresize(im, (opt.im_size, opt.im_size))
        s = im.shape
        im = im.transpose((2, 0, 1)) / 256 + np.random.uniform(0, 1, (s[2], s[0], s[1])) / 256
        im = torch.Tensor(im[None, :])

        if self.return_label:
            return im, label
        else:
            return im

    def __len__(self):
        """Return the total number of images in the dataset."""
        # Dataset is always randomly generated
        return self.ims.shape[0]


class Kitti(data.Dataset):
    """Constructs a dataset with N circles, N is the number of components set by flags"""

    def __init__(self, opt, return_label=False, train=True):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        ims = sorted(glob("/data/vision/billf/scratch/yilundu/dataset/kitti/training/image_02/*/*.png"))
        virtual_ims = sorted(glob("/data/vision/billf/scratch/yilundu/dataset/kitti/virtual_kitti/*/*/frames/rgb/Camera_0/*.jpg"))

        ims = ims * 3 + virtual_ims
        self.ims = ims
        self.opt = opt

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A and A_paths
            A(tensor) - - an image in one domain
            A_paths(str) - - the path of the image
        """
        # 433 - 808
        im = self.ims[index]
        im = imread(im)
        im = im[:, 433:808, :]
        im = imresize(im, (64, 64))[:, :, :3]

        im = torch.Tensor(im).permute(2, 0, 1)

        return im, index

    def __len__(self):
        """Return the total number of images in the dataset."""
        # Dataset is always randomly generated
        return len(self.ims)


class VirtualKitti(data.Dataset):
    """Constructs a dataset with N circles, N is the number of components set by flags"""

    def __init__(self,return_label=False, train=True):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        ims = sorted(glob("/data/vision/billf/scratch/yilundu/dataset/kitti/training/image_02/*/*.png"))
        virtual_ims = sorted(glob("/data/vision/billf/scratch/yilundu/dataset/kitti/virtual_kitti/*/*/frames/rgb/Camera_0/*.jpg"))

        self.ims = virtual_ims

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A and A_paths
            A(tensor) - - an image in one domain
            A_paths(str) - - the path of the image
        """
        # 433 - 808
        im = self.ims[index]
        im = imread(im)
        im = im[:, 433:808, :]
        im = imresize(im, (64, 64))[:, :, :3]

        im = torch.Tensor(im).permute(2, 0, 1)

        return im, index

    def __len__(self):
        """Return the total number of images in the dataset."""
        # Dataset is always randomly generated
        return len(self.ims)


class KittiLabel(data.Dataset):
    """Constructs a dataset with N circles, N is the number of components set by flags"""

    def __init__(self,return_label=False, train=True):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        virtual_ims = sorted(glob("/data/vision/billf/scratch/yilundu/dataset/kitti/virtual_kitti/*/*/frames/rgb/Camera_0/*.jpg"))
        self.labels = ['fog', 'morning', 'overcast', 'rain', 'sunset']
        self.ims = virtual_ims

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A and A_paths
            A(tensor) - - an image in one domain
            A_paths(str) - - the path of the image
        """
        # 433 - 808
        try:
            im_path = self.ims[index]
            label_id = im_path.split("/")[-5]
            label_id = self.labels.index(label_id)

            im = imread(im_path)

            im = im[:, 433:808, :]
            im = imresize(im, (64, 64))[:, :, :3]

            im = torch.Tensor(im).permute(2, 0, 1)

            return im, label_id
        except:
            return self.__getitem__(random.randint(0, len(self.ims) - 1))

    def __len__(self):
        """Return the total number of images in the dataset."""
        # Dataset is always randomly generated
        return len(self.ims)


class RealKittiLabel(data.Dataset):
    """Constructs a dataset with N circles, N is the number of components set by flags"""

    def __init__(self,return_label=False, train=True):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        ims = sorted(glob("/data/vision/billf/scratch/yilundu/dataset/kitti/training/image_02/*/*.png"))

        self.ims = ims

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A and A_paths
            A(tensor) - - an image in one domain
            A_paths(str) - - the path of the image
        """
        # 433 - 808
        im_path = self.ims[index]
        im = imread(im_path)
        im = im[:, 433:808, :]
        im = imresize(im, (64, 64))[:, :, :3]

        im = torch.Tensor(im).permute(2, 0, 1)

        return im, 0

    def __len__(self):
        """Return the total number of images in the dataset."""
        # Dataset is always randomly generated
        return len(self.ims)


class RealKitti(data.Dataset):
    """Constructs a dataset with N circles, N is the number of components set by flags"""

    def __init__(self,return_label=False, train=True):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        ims = sorted(glob("/data/vision/billf/scratch/yilundu/dataset/kitti/training/image_02/*/*.png"))

        self.ims = ims

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A and A_paths
            A(tensor) - - an image in one domain
            A_paths(str) - - the path of the image
        """
        # 433 - 808
        im = self.ims[index]
        im = imread(im)
        im = im[:, 433:808, :]
        im = imresize(im, (64, 64))[:, :, :3]

        im = torch.Tensor(im).permute(2, 0, 1)

        return im, index

    def __len__(self):
        """Return the total number of images in the dataset."""
        # Dataset is always randomly generated
        return len(self.ims)

if __name__ == "__main__":
    loader = Kitti(None)
    # for data in loader:
    #     print("here")

In [5]:
import torch
#from models import LatentEBM, ToyEBM, BetaVAE_H, LatentEBM128
from tensorflow.python.platform import flags
import torch.nn.functional as F
import os
#from dataset import IntPhysDataset, ToyDataset, TFImagenetLoader, CubesColor, CubesColorPair, TFTaskAdaptation, DSprites, Blender, Cub, Nvidia, Clevr, Exercise, CelebaHQ, Kitti, Airplane, Faces, ClevrLighting
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from easydict import EasyDict
import os.path as osp
from torch.nn.utils import clip_grad_norm
import numpy as np
from imageio import imwrite
import cv2
import argparse
import pdb
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
import torch.backends.cudnn as cudnn
import random
from torchvision.utils import make_grid
#from dataset import MultiDspritesLoader, TetrominoesLoader
from imageio import get_writer


"""Parse input arguments"""
parser = argparse.ArgumentParser(description='Train EBM model')


parser.add_argument('--train', action='store_true', help='whether or not to train')
parser.add_argument('--optimize_test', action='store_true', help='whether or not to train')
parser.add_argument('--cuda', action='store_true', help='whether to use cuda or not')
parser.add_argument('--single', action='store_true', help='test overfitting of the dataset')


parser.add_argument('--dataset', default='blender', type=str, help='Dataset to use (intphys or others or imagenet or cubes)')
parser.add_argument('--logdir', default='cachedir', type=str, help='location where log of experiments will be stored')
parser.add_argument('--exp', default='default', type=str, help='name of experiments')

# training
parser.add_argument('--resume_iter', default=0, type=int, help='iteration to resume training')
parser.add_argument('--batch_size', default=64, type=int, help='size of batch of input to use')
parser.add_argument('--num_epoch', default=10000, type=int, help='number of epochs of training to run')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate for training')
parser.add_argument('--log_interval', default=10, type=int, help='log outputs every so many batches')
parser.add_argument('--save_interval', default=1000, type=int, help='save outputs every so many batches')

# data
parser.add_argument('--data_workers', default=4, type=int, help='Number of different data workers to load data in parallel')
parser.add_argument('--ensembles', default=1, type=int, help='use an ensemble of models')
parser.add_argument('--vae-beta', type=float, default=0.)

# EBM specific settings

# Model specific settings
parser.add_argument('--filter_dim', default=64, type=int, help='number of filters to use')
parser.add_argument('--components', default=2, type=int, help='number of components to explain an image with')
parser.add_argument('--component_weight', action='store_true', help='optimize for weights of the components also')
parser.add_argument('--tie_weight', action='store_true', help='tie the weights between seperate models')
parser.add_argument('--optimize_mask', action='store_true', help='also optimize a segmentation mask over image')
parser.add_argument('--recurrent_model', action='store_true', help='use a recurrent model to infer latents')
parser.add_argument('--pos_embed', action='store_true', help='add a positional embedding to model')
parser.add_argument('--spatial_feat', action='store_true', help='use spatial latents for object segmentation')


parser.add_argument('--num_steps', default=10, type=int, help='Steps of gradient descent for training')
parser.add_argument('--num_visuals', default=16, type=int, help='Number of visuals')
parser.add_argument('--num_additional', default=0, type=int, help='Number of additional components to add')

parser.add_argument('--step_lr', default=500.0, type=float, help='step size of latents')

parser.add_argument('--latent_dim', default=64, type=int, help='dimension of the latent')
parser.add_argument('--sample', action='store_true', help='generate negative samples through Langevin')
parser.add_argument('--decoder', action='store_true', help='decoder for model')

# Distributed training hyperparameters
parser.add_argument('--nodes', default=1, type=int, help='number of nodes for training')
parser.add_argument('--gpus', default=1, type=int, help='number of gpus per nodes')
parser.add_argument('--node_rank', default=0, type=int, help='rank of node')



def average_gradients(models):
    size = float(dist.get_world_size())

    for model in models:
        for name, param in model.named_parameters():
            if param.grad is None:
                continue

            dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
            param.grad.data /= size


def gen_image(latents, FLAGS, models, im_neg, im, num_steps, sample=False, create_graph=True, idx=None, weights=None):
    im_noise = torch.randn_like(im_neg).detach()
    im_negs_samples = []

    im_negs = []

    latents = torch.stack(latents, dim=0)

    if FLAGS.decoder:
        masks = []
        colors = []
        for i in range(len(latents)):
            if idx is not None and idx != i:
                pass
            else:
                color, mask = models[i % FLAGS.components].forward(None, latents[i])
                masks.append(mask)
                colors.append(color)
        masks = F.softmax(torch.stack(masks, dim=1), dim=1)
        colors = torch.stack(colors, dim=1)
        im_neg = torch.sum(masks * colors, dim=1)
        im_negs = [im_neg]
        im_grad = torch.zeros_like(im_neg)
    else:
        im_neg.requires_grad_(requires_grad=True)
        s = im.size()
        masks = torch.zeros(s[0], FLAGS.components, s[-2], s[-1]).to(im_neg.device)
        masks.requires_grad_(requires_grad=True)

        for i in range(num_steps):
            im_noise.normal_()

            energy = 0
            for j in range(len(latents)):
                if idx is not None and idx != j:
                    pass
                else:
                    ix = j % FLAGS.components
                    energy = models[j % FLAGS.components].forward(im_neg, latents[j]) + energy

            im_grad, = torch.autograd.grad([energy.sum()], [im_neg], create_graph=create_graph)

            im_neg = im_neg - FLAGS.step_lr * im_grad

            latents = latents

            im_neg = torch.clamp(im_neg, 0, 1)
            im_negs.append(im_neg)
            im_neg = im_neg.detach()
            im_neg.requires_grad_()

    return im_neg, im_negs, im_grad, masks


def ema_model(models, models_ema, mu=0.999):
    for (model, model_ema) in zip(models, models_ema):
        for param, param_ema in zip(model.parameters(), model_ema.parameters()):
            param_ema.data[:] = mu * param_ema.data + (1 - mu) * param.data


def sync_model(models):
    size = float(dist.get_world_size())

    for model in models:
        for param in model.parameters():
            dist.broadcast(param.data, 0)


def init_model(FLAGS, device, dataset):
    if FLAGS.tie_weight:
        if FLAGS.dataset == "toy":
            model = ToyEBM(FLAGS, dataset).to(device)
        else:
            if FLAGS.vae_beta:
                model = BetaVAE_H(z_dim=FLAGS.latent_dim, nc=3).to(device)
                FLAGS.ensembles = 1
                FLAGS.components = 1
            else:
                if FLAGS.dataset == "celebahq_128":
                    model = LatentEBM128(FLAGS, dataset).to(device)
                else:
                    model = LatentEBM(FLAGS, dataset).to(device)

        models = [model for i in range(FLAGS.ensembles)]
        optimizers = [Adam(model.parameters(), lr=FLAGS.lr)]
    else:
        models = [LatentEBM(FLAGS, dataset).to(device) for i in range(FLAGS.ensembles)]

        optimizers = [Adam(model.parameters(), lr=FLAGS.lr) for model in models]

    return models, optimizers


def test(train_dataloader, models, FLAGS, step=0):
    if FLAGS.cuda:
        dev = torch.device("cuda")
    else:
        dev = torch.device("cpu")

    replay_buffer = None

    [model.eval() for model in models]
    for im, idx in train_dataloader:

        im = im.to(dev)
        idx = idx.to(dev)
        im = im[:FLAGS.num_visuals]
        idx = idx[:FLAGS.num_visuals]
        batch_size = im.size(0)
        latent = models[0].embed_latent(im)

        latents = torch.chunk(latent, FLAGS.components, dim=1)

        im_init = torch.rand_like(im)
        assert len(latents) == FLAGS.components
        im_neg, _, im_grad, mask = gen_image(latents, FLAGS, models, im_init, im, FLAGS.num_steps, sample=FLAGS.sample,
                                       create_graph=False)
        im_neg = im_neg.detach()
        im_components = []

        if FLAGS.components > 1:
            for i, latent in enumerate(latents):
                im_init = torch.rand_like(im)
                latents_select = latents[i:i+1]
                im_component, _, _, _ = gen_image(latents_select, FLAGS, models, im_init, im, FLAGS.num_steps, sample=FLAGS.sample,
                                           create_graph=False)
                im_components.append(im_component)

            im_init = torch.rand_like(im)
            latents_perm = [torch.cat([latent[i:], latent[:i]], dim=0) for i, latent in enumerate(latents)]
            im_neg_perm, _, im_grad_perm, _ = gen_image(latents_perm, FLAGS, models, im_init, im, FLAGS.num_steps, sample=FLAGS.sample,
                                                     create_graph=False)
            im_neg_perm = im_neg_perm.detach()
            im_init = torch.rand_like(im)
            add_latents = list(latents)
            for i in range(FLAGS.num_additional):
                add_latents.append(torch.roll(latents[i], i + 1, 0))
            im_neg_additional, _, _, _ = gen_image(tuple(add_latents), FLAGS, models, im_init, im, FLAGS.num_steps, sample=FLAGS.sample,
                                                     create_graph=False)

        im.requires_grad = True
        im_grads = []

        for i, latent in enumerate(latents):
            if FLAGS.decoder:
                im_grad = torch.zeros_like(im)
            else:
                energy_pos = models[i].forward(im, latents[i])
                im_grad = torch.autograd.grad([energy_pos.sum()], [im])[0]
            im_grads.append(im_grad)

        im_grad = torch.stack(im_grads, dim=1)

        s = im.size()
        im_size = s[-1]

        im_grad = im_grad.view(batch_size, FLAGS.components, 3, im_size, im_size) # [4, 3, 3, 128, 128]
        im_grad_dense = im_grad.view(batch_size, FLAGS.components, 1, 3 * im_size * im_size, 1) # [4, 3, 1, 49152, 1]
        im_grad_min = im_grad_dense.min(dim=3, keepdim=True)[0]
        im_grad_max = im_grad_dense.max(dim=3, keepdim=True)[0] # [4, 3, 1, 1, 1]


        im_grad = (im_grad - im_grad_min) / (im_grad_max - im_grad_min + 1e-5) # [4, 3, 3, 128, 128]
        im_grad[:, :, :, :1, :] = 1
        im_grad[:, :, :, -1:, :] = 1
        im_grad[:, :, :, :, :1] = 1
        im_grad[:, :, :, :, -1:] = 1
        im_output = im_grad.permute(0, 3, 1, 4, 2).reshape(batch_size * im_size, FLAGS.components * im_size, 3)
        im_output = im_output.cpu().detach().numpy() * 100

        im_output = (im_output - im_output.min()) / (im_output.max() - im_output.min())

        im = im.cpu().detach().numpy().transpose((0, 2, 3, 1)).reshape(batch_size*im_size, im_size, 3)

        im_output = np.concatenate([im_output, im], axis=1)
        im_output = (im_output*255).astype(np.uint8)
        imwrite("result/%s/s%08d_grad.png" % (FLAGS.exp,step), im_output)

        im_neg = im_neg_tensor = im_neg.detach().cpu()
        im_components = [im_components[i].detach().cpu() for i in range(len(im_components))]
        im_neg = torch.cat([im_neg] + im_components)
        im_neg = np.clip(im_neg, 0.0, 1.0)
        im_neg = make_grid(im_neg, nrow=int(im_neg.shape[0] / (FLAGS.components + 1))).permute(1, 2, 0)
        im_neg = (im_neg.numpy()*255).astype(np.uint8)
        imwrite("result/%s/s%08d_gen.png" % (FLAGS.exp,step), im_neg)

        if FLAGS.components > 1:
            im_neg_perm = im_neg_perm.detach().cpu()
            im_components_perm = []
            for i,im_component in enumerate(im_components):
                im_components_perm.append(torch.cat([im_component[i:], im_component[:i]]))
            im_neg_perm = torch.cat([im_neg_perm] + im_components_perm)
            im_neg_perm = np.clip(im_neg_perm, 0.0, 1.0)
            im_neg_perm = make_grid(im_neg_perm, nrow=int(im_neg_perm.shape[0] / (FLAGS.components + 1))).permute(1, 2, 0)
            im_neg_perm = (im_neg_perm.numpy()*255).astype(np.uint8)
            imwrite("result/%s/s%08d_gen_perm.png" % (FLAGS.exp,step), im_neg_perm)

            im_neg_additional = im_neg_additional.detach().cpu()
            for i in range(FLAGS.num_additional):
                im_components.append(torch.roll(im_components[i], i + 1, 0))
            im_neg_additional = torch.cat([im_neg_additional] + im_components)
            im_neg_additional = np.clip(im_neg_additional, 0.0, 1.0)
            im_neg_additional = make_grid(im_neg_additional,
                                nrow=int(im_neg_additional.shape[0] / (FLAGS.components + FLAGS.num_additional + 1))).permute(1, 2, 0)
            im_neg_additional = (im_neg_additional.numpy()*255).astype(np.uint8)
            imwrite("result/%s/s%08d_gen_add.png" % (FLAGS.exp,step), im_neg_additional)

            print('test at step %d done!' % step)
        break

    [model.train() for model in models]


def train(train_dataloader, test_dataloader, logger, models, optimizers, FLAGS, logdir, rank_idx):
    it = FLAGS.resume_iter
    [optimizer.zero_grad() for optimizer in optimizers]

    dev = torch.device("cuda")

    # Use LPIPS loss for CelebA-HQ 128x128
    if FLAGS.dataset == "celebahq_128":
        import lpips
        loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()

    for epoch in range(FLAGS.num_epoch):
        for im, idx in train_dataloader:

            im = im.to(dev)
            idx = idx.to(dev)
            im_orig = im

            random_idx = random.randint(0, FLAGS.ensembles - 1)
            random_idx = 0

            latent = models[0].embed_latent(im)

            latents = torch.chunk(latent, FLAGS.components, dim=1)

            im_neg = torch.rand_like(im)
            im_neg_init = im_neg

            im_neg, im_negs, im_grad, _ = gen_image(latents, FLAGS, models, im_neg, im, FLAGS.num_steps, FLAGS.sample)

            im_negs = torch.stack(im_negs, dim=1)

            energy_pos = 0
            energy_neg = 0

            energy_poss = []
            energy_negs = []
            for i in range(FLAGS.components):
                energy_poss.append(models[i].forward(im, latents[i]))
                energy_negs.append(models[i].forward(im_neg.detach(), latents[i]))

            energy_pos = torch.stack(energy_poss, dim=1)
            energy_neg = torch.stack(energy_negs, dim=1)
            ml_loss = (energy_pos - energy_neg).mean()

            im_loss = torch.pow(im_negs[:, -1:] - im[:, None], 2).mean()

            if it < 10000 or FLAGS.dataset != "celebahq_128":
                loss = im_loss
            else:
                vgg_loss = loss_fn_vgg(im_negs[:, -1], im).mean()
                loss = vgg_loss  + 0.1 * im_loss

            loss.backward()
            if FLAGS.gpus > 1:
                average_gradients(models)

            [torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0) for model in models]
            [optimizer.step() for optimizer in optimizers]
            [optimizer.zero_grad() for optimizer in optimizers]

            if it % FLAGS.log_interval == 0 and rank_idx == 0:
                loss = loss.item()
                energy_pos_mean = energy_pos.mean().item()
                energy_neg_mean = energy_neg.mean().item()
                energy_pos_std = energy_pos.std().item()
                energy_neg_std = energy_neg.std().item()

                kvs = {}
                kvs['loss'] = loss
                kvs['ml_loss'] = ml_loss.item()
                kvs['im_loss'] = im_loss.item()

                if FLAGS.dataset == "celebahq_128" and ('vgg_loss' in kvs):
                    kvs['vgg_loss'] = vgg_loss.item()

                kvs['energy_pos_mean'] = energy_pos_mean
                kvs['energy_neg_mean'] = energy_neg_mean
                kvs['energy_pos_std'] = energy_pos_std
                kvs['energy_neg_std'] = energy_neg_std
                kvs['average_im_grad'] = torch.abs(im_grad).max()

                string = "Iteration {} ".format(it)

                for k, v in kvs.items():
                    string += "%s: %.6f  " % (k,v)
                    logger.add_scalar(k, v, it)

                print(string)

            if it % FLAGS.save_interval == 0 and rank_idx == 0:
                model_path = osp.join(logdir, "model_{}.pth".format(it))


                ckpt = {'FLAGS': FLAGS}

                for i in range(len(models)):
                    ckpt['model_state_dict_{}'.format(i)] = models[i].state_dict()

                for i in range(len(optimizers)):
                    ckpt['optimizer_state_dict_{}'.format(i)] = optimizers[i].state_dict()

                torch.save(ckpt, model_path)
                print("Saving model in directory....")
                print('run test')

                test(test_dataloader, models, FLAGS, step=it)

            it += 1



def main_single(rank, FLAGS):
    rank_idx = FLAGS.node_rank * FLAGS.gpus + rank
    world_size = FLAGS.nodes * FLAGS.gpus


    if not os.path.exists('result/%s' % FLAGS.exp):
        try:
            os.makedirs('result/%s' % FLAGS.exp)
        except:
            pass

    if FLAGS.dataset == 'cubes':
        dataset = CubesColor(FLAGS, train=True)
        test_dataset = CubesColor(FLAGS, train=False)
    elif FLAGS.dataset == 'cubes_pair':
        dataset = CubesColorPair(FLAGS, train=True)
        test_dataset = CubesColorPair(FLAGS, train=False)
    elif FLAGS.dataset == "nvidia":
        dataset = Nvidia(FLAGS)
        test_dataset = dataset
    elif FLAGS.dataset == "clevr":
        dataset = Clevr(FLAGS)
        test_dataset = dataset
    elif FLAGS.dataset == "clevr_lighting":
        dataset = ClevrLighting(FLAGS)
        test_dataset = dataset
    elif FLAGS.dataset == "exercise":
        dataset = Exercise(FLAGS)
        test_dataset = dataset
    elif FLAGS.dataset == "intphys":
        dataset = IntPhysDataset(FLAGS)
        test_dataset = dataset
    elif FLAGS.dataset == "celebahq":
        dataset = CelebaHQ(resolution=64)
        test_dataset = dataset
    elif FLAGS.dataset == "celebahq_128":
        dataset = CelebaHQ(resolution=128)
        test_dataset = dataset
    elif FLAGS.dataset == "kitti":
        dataset = Kitti(FLAGS)
        test_dataset = dataset
    elif FLAGS.dataset == "faces":
        dataset = Faces(FLAGS)
        test_dataset = dataset
    elif FLAGS.dataset == "fashionmnist":
        dataset = FashionMNIST(root = '/content/sample_data', download = True, transform =  transforms.Compose([transforms.Grayscale(num_output_channels=3), transforms.ToTensor()]))
        test_dataset = dataset
    else:
        dataset = ToyDataset(FLAGS)
        test_dataset = ToyDataset(FLAGS)

    shuffle=True
    sampler = None

    if world_size > 1:
        group = dist.init_process_group(backend='nccl', init_method='tcp://localhost:8113', world_size=world_size, rank=rank_idx, group_name="default")

    torch.cuda.set_device(rank)
    device = torch.device('cuda')

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    FLAGS_OLD = FLAGS

    if FLAGS.resume_iter != 0:
        model_path = osp.join(logdir, "model_{}.pth".format(FLAGS.resume_iter))

        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        FLAGS = checkpoint['FLAGS']

        FLAGS.resume_iter = FLAGS_OLD.resume_iter
        FLAGS.save_interval = FLAGS_OLD.save_interval
        FLAGS.nodes = FLAGS_OLD.nodes
        FLAGS.gpus = FLAGS_OLD.gpus
        FLAGS.node_rank = FLAGS_OLD.node_rank
        FLAGS.train = FLAGS_OLD.train
        FLAGS.batch_size = FLAGS_OLD.batch_size
        FLAGS.num_visuals = FLAGS_OLD.num_visuals
        FLAGS.num_additional = FLAGS_OLD.num_additional
        FLAGS.decoder = FLAGS_OLD.decoder
        FLAGS.optimize_test = FLAGS_OLD.optimize_test
        FLAGS.temporal = FLAGS_OLD.temporal
        FLAGS.sim = FLAGS_OLD.sim
        FLAGS.exp = FLAGS_OLD.exp
        FLAGS.step_lr = FLAGS_OLD.step_lr
        FLAGS.num_steps = FLAGS_OLD.num_steps
        FLAGS.vae_beta = FLAGS_OLD.vae_beta

        models, optimizers  = init_model(FLAGS, device, dataset)
        state_dict = models[0].state_dict()

        for i, (model, optimizer) in enumerate(zip(models, optimizers)):
            model.load_state_dict(checkpoint['model_state_dict_{}'.format(i)], strict=False)
            optimizer.load_state_dict(checkpoint['optimizer_state_dict_{}'.format(i)], strict=False)

    else:
        models, optimizers = init_model(FLAGS, device, dataset)

    if FLAGS.gpus > 1:
        sync_model(models)

    if FLAGS.dataset == "multidsprites":
        train_dataloader = MultiDspritesLoader(FLAGS.batch_size)
        test_dataloader = MultiDspritesLoader(FLAGS.batch_size)
    elif FLAGS.dataset == "tetris":
        train_dataloader = TetrominoesLoader(FLAGS.batch_size)
        test_dataloader = TetrominoesLoader(FLAGS.batch_size)
    else:
        train_dataloader = DataLoader(dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=shuffle, pin_memory=False)
        test_dataloader = DataLoader(test_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.num_visuals, shuffle=True, pin_memory=False, drop_last=True)

    logger = SummaryWriter(logdir)
    it = FLAGS.resume_iter

    if FLAGS.train:
        models = [model.train() for model in models]
    else:
        models = [model.eval() for model in models]

    if FLAGS.train:
        train(train_dataloader, test_dataloader, logger, models, optimizers, FLAGS, logdir, rank_idx)

    elif FLAGS.optimize_test:
        test_optimize(test_dataloader, models, FLAGS, step=FLAGS.resume_iter)
    else:
        test(test_dataloader, models, FLAGS, step=FLAGS.resume_iter)

In [None]:
import easydict
args = easydict.EasyDict({
    "dataset": "imagenet",
    "cuda": True,
})

In [None]:
FLAGS, _ = parser.parse_known_args()

FLAGS.ensembles = FLAGS.components
FLAGS.tie_weight = True
FLAGS.sample = True
FLAGS.cuda = True
FLAGS.dataset = 'clevr'
FLAGS.train = True
#FLAGS.num_epoch = 1

FLAGS.batch_size = 32
FLAGS.lr = 1e-4

logdir = osp.join(FLAGS.logdir, FLAGS.exp)

if not osp.exists(logdir):
  os.makedirs(logdir)

if FLAGS.gpus > 1:
  mp.spawn(main_single, nprocs=FLAGS.gpus, args=(FLAGS,))
else:
  main_single(0, FLAGS)

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


Iteration 0 loss: 0.091789  ml_loss: 0.000190  im_loss: 0.091789  energy_pos_mean: 0.064609  energy_neg_mean: 0.064419  energy_pos_std: 0.004775  energy_neg_std: 0.004778  average_im_grad: 0.000004  
Saving model in directory....
run test


  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


test at step 0 done!
Iteration 10 loss: 0.084738  ml_loss: -0.000477  im_loss: 0.084738  energy_pos_mean: 0.087351  energy_neg_mean: 0.087828  energy_pos_std: 0.002526  energy_neg_std: 0.002472  average_im_grad: 0.000036  
Iteration 20 loss: 0.074805  ml_loss: 0.002110  im_loss: 0.074805  energy_pos_mean: 0.097070  energy_neg_mean: 0.094960  energy_pos_std: 0.004931  energy_neg_std: 0.004816  average_im_grad: 0.000072  
Iteration 30 loss: 0.054956  ml_loss: -0.005163  im_loss: 0.054956  energy_pos_mean: 0.055435  energy_neg_mean: 0.060598  energy_pos_std: 0.009566  energy_neg_std: 0.009556  average_im_grad: 0.000067  
Iteration 40 loss: 0.033699  ml_loss: -0.002246  im_loss: 0.033699  energy_pos_mean: 0.025567  energy_neg_mean: 0.027814  energy_pos_std: 0.013282  energy_neg_std: 0.013251  average_im_grad: 0.000064  
Iteration 50 loss: 0.020740  ml_loss: 0.001356  im_loss: 0.020740  energy_pos_mean: -0.041794  energy_neg_mean: -0.043150  energy_pos_std: 0.003960  energy_neg_std: 0.00386

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


Iteration 360 loss: 0.003669  ml_loss: 0.015972  im_loss: 0.003669  energy_pos_mean: -21.478693  energy_neg_mean: -21.494665  energy_pos_std: 8.119082  energy_neg_std: 8.120598  average_im_grad: 0.001396  
Iteration 370 loss: 0.003490  ml_loss: 0.014830  im_loss: 0.003490  energy_pos_mean: -21.172680  energy_neg_mean: -21.187511  energy_pos_std: 7.466502  energy_neg_std: 7.469315  average_im_grad: 0.000319  
Iteration 380 loss: 0.003428  ml_loss: 0.014994  im_loss: 0.003428  energy_pos_mean: -21.511909  energy_neg_mean: -21.526905  energy_pos_std: 7.527701  energy_neg_std: 7.529947  average_im_grad: 0.000865  
Iteration 390 loss: 0.003450  ml_loss: 0.014118  im_loss: 0.003450  energy_pos_mean: -21.095295  energy_neg_mean: -21.109413  energy_pos_std: 7.264012  energy_neg_std: 7.268037  average_im_grad: 0.000561  
Iteration 400 loss: 0.003470  ml_loss: 0.015188  im_loss: 0.003470  energy_pos_mean: -20.623621  energy_neg_mean: -20.638809  energy_pos_std: 7.329878  energy_neg_std: 7.330553

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


Iteration 720 loss: 0.003645  ml_loss: 0.015574  im_loss: 0.003645  energy_pos_mean: -20.445253  energy_neg_mean: -20.460827  energy_pos_std: 6.380224  energy_neg_std: 6.383384  average_im_grad: 0.000258  
Iteration 730 loss: 0.003098  ml_loss: 0.013710  im_loss: 0.003098  energy_pos_mean: -20.341230  energy_neg_mean: -20.354942  energy_pos_std: 6.902572  energy_neg_std: 6.907522  average_im_grad: 0.000203  
Iteration 740 loss: 0.003240  ml_loss: 0.014280  im_loss: 0.003240  energy_pos_mean: -20.114033  energy_neg_mean: -20.128315  energy_pos_std: 6.841996  energy_neg_std: 6.844350  average_im_grad: 0.000391  
Iteration 750 loss: 0.003337  ml_loss: 0.015087  im_loss: 0.003337  energy_pos_mean: -20.700218  energy_neg_mean: -20.715305  energy_pos_std: 6.686415  energy_neg_std: 6.692278  average_im_grad: 0.000563  
Iteration 760 loss: 0.003152  ml_loss: 0.013735  im_loss: 0.003152  energy_pos_mean: -20.876520  energy_neg_mean: -20.890255  energy_pos_std: 6.900044  energy_neg_std: 6.903772

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


test at step 1000 done!
Iteration 1010 loss: 0.003350  ml_loss: 0.014379  im_loss: 0.003350  energy_pos_mean: -20.219444  energy_neg_mean: -20.233824  energy_pos_std: 6.490745  energy_neg_std: 6.493614  average_im_grad: 0.000350  
Iteration 1020 loss: 0.003373  ml_loss: 0.015416  im_loss: 0.003373  energy_pos_mean: -20.382298  energy_neg_mean: -20.397713  energy_pos_std: 6.272751  energy_neg_std: 6.277453  average_im_grad: 0.000846  
Iteration 1030 loss: 0.002991  ml_loss: 0.014323  im_loss: 0.002991  energy_pos_mean: -21.719078  energy_neg_mean: -21.733400  energy_pos_std: 7.472044  energy_neg_std: 7.477097  average_im_grad: 0.000539  
Iteration 1040 loss: 0.003328  ml_loss: 0.014698  im_loss: 0.003328  energy_pos_mean: -19.887405  energy_neg_mean: -19.902103  energy_pos_std: 6.223855  energy_neg_std: 6.228460  average_im_grad: 0.000482  
Iteration 1050 loss: 0.002946  ml_loss: 0.013352  im_loss: 0.002946  energy_pos_mean: -20.806480  energy_neg_mean: -20.819834  energy_pos_std: 6.752

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


Iteration 1080 loss: 0.003074  ml_loss: 0.012493  im_loss: 0.003074  energy_pos_mean: -19.618387  energy_neg_mean: -19.630878  energy_pos_std: 6.151655  energy_neg_std: 6.155577  average_im_grad: 0.000519  
Iteration 1090 loss: 0.003044  ml_loss: 0.012861  im_loss: 0.003044  energy_pos_mean: -19.159948  energy_neg_mean: -19.172808  energy_pos_std: 6.222581  energy_neg_std: 6.226635  average_im_grad: 0.000222  
Iteration 1100 loss: 0.003107  ml_loss: 0.013566  im_loss: 0.003107  energy_pos_mean: -19.259737  energy_neg_mean: -19.273302  energy_pos_std: 6.047978  energy_neg_std: 6.051506  average_im_grad: 0.000411  
Iteration 1110 loss: 0.003383  ml_loss: 0.014429  im_loss: 0.003383  energy_pos_mean: -19.328409  energy_neg_mean: -19.342838  energy_pos_std: 5.721486  energy_neg_std: 5.724516  average_im_grad: 0.000303  
Iteration 1120 loss: 0.002652  ml_loss: 0.012420  im_loss: 0.002652  energy_pos_mean: -20.990845  energy_neg_mean: -21.003265  energy_pos_std: 6.887709  energy_neg_std: 6.8

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


Iteration 1430 loss: 0.003204  ml_loss: 0.015688  im_loss: 0.003204  energy_pos_mean: -20.372433  energy_neg_mean: -20.388123  energy_pos_std: 6.083119  energy_neg_std: 6.088202  average_im_grad: 0.000601  
Iteration 1440 loss: 0.003040  ml_loss: 0.014504  im_loss: 0.003040  energy_pos_mean: -19.845812  energy_neg_mean: -19.860315  energy_pos_std: 5.432484  energy_neg_std: 5.437566  average_im_grad: 0.000699  
Iteration 1450 loss: 0.003079  ml_loss: 0.013351  im_loss: 0.003079  energy_pos_mean: -19.178234  energy_neg_mean: -19.191586  energy_pos_std: 4.937485  energy_neg_std: 4.941880  average_im_grad: 0.000137  
Iteration 1460 loss: 0.003411  ml_loss: 0.015719  im_loss: 0.003411  energy_pos_mean: -19.683353  energy_neg_mean: -19.699072  energy_pos_std: 5.354571  energy_neg_std: 5.359975  average_im_grad: 0.000621  
Iteration 1470 loss: 0.003172  ml_loss: 0.014180  im_loss: 0.003172  energy_pos_mean: -19.237595  energy_neg_mean: -19.251774  energy_pos_std: 5.067973  energy_neg_std: 5.0

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


Iteration 1790 loss: 0.003483  ml_loss: 0.015568  im_loss: 0.003483  energy_pos_mean: -18.962561  energy_neg_mean: -18.978130  energy_pos_std: 4.999571  energy_neg_std: 5.004310  average_im_grad: 0.000379  
Iteration 1800 loss: 0.003238  ml_loss: 0.014855  im_loss: 0.003238  energy_pos_mean: -19.124554  energy_neg_mean: -19.139408  energy_pos_std: 4.521467  energy_neg_std: 4.525697  average_im_grad: 0.000252  
Iteration 1810 loss: 0.003250  ml_loss: 0.013913  im_loss: 0.003250  energy_pos_mean: -17.817547  energy_neg_mean: -17.831459  energy_pos_std: 4.319366  energy_neg_std: 4.323823  average_im_grad: 0.000425  
Iteration 1820 loss: 0.003205  ml_loss: 0.014509  im_loss: 0.003205  energy_pos_mean: -18.879162  energy_neg_mean: -18.893671  energy_pos_std: 4.475304  energy_neg_std: 4.481134  average_im_grad: 0.000616  
Iteration 1830 loss: 0.003217  ml_loss: 0.015007  im_loss: 0.003217  energy_pos_mean: -19.905148  energy_neg_mean: -19.920155  energy_pos_std: 5.196982  energy_neg_std: 5.2

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


test at step 2000 done!
Iteration 2010 loss: 0.002877  ml_loss: 0.013702  im_loss: 0.002877  energy_pos_mean: -19.432549  energy_neg_mean: -19.446251  energy_pos_std: 4.798249  energy_neg_std: 4.803733  average_im_grad: 0.000465  
Iteration 2020 loss: 0.003381  ml_loss: 0.015953  im_loss: 0.003381  energy_pos_mean: -19.014929  energy_neg_mean: -19.030884  energy_pos_std: 4.584660  energy_neg_std: 4.590585  average_im_grad: 0.000699  
Iteration 2030 loss: 0.003017  ml_loss: 0.013541  im_loss: 0.003017  energy_pos_mean: -19.545479  energy_neg_mean: -19.559019  energy_pos_std: 5.240391  energy_neg_std: 5.245588  average_im_grad: 0.000197  
Iteration 2040 loss: 0.003196  ml_loss: 0.014444  im_loss: 0.003196  energy_pos_mean: -18.581249  energy_neg_mean: -18.595694  energy_pos_std: 4.648165  energy_neg_std: 4.654274  average_im_grad: 0.000883  
Iteration 2050 loss: 0.002992  ml_loss: 0.014627  im_loss: 0.002992  energy_pos_mean: -19.735500  energy_neg_mean: -19.750126  energy_pos_std: 4.807

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


Iteration 2150 loss: 0.003227  ml_loss: 0.014857  im_loss: 0.003227  energy_pos_mean: -19.148037  energy_neg_mean: -19.162891  energy_pos_std: 4.866153  energy_neg_std: 4.869101  average_im_grad: 0.000189  
Iteration 2160 loss: 0.003217  ml_loss: 0.015093  im_loss: 0.003217  energy_pos_mean: -19.241871  energy_neg_mean: -19.256962  energy_pos_std: 4.864297  energy_neg_std: 4.869764  average_im_grad: 0.000272  
Iteration 2170 loss: 0.003015  ml_loss: 0.014201  im_loss: 0.003015  energy_pos_mean: -19.200651  energy_neg_mean: -19.214851  energy_pos_std: 4.427473  energy_neg_std: 4.431752  average_im_grad: 0.000192  
Iteration 2180 loss: 0.002832  ml_loss: 0.013270  im_loss: 0.002832  energy_pos_mean: -19.097538  energy_neg_mean: -19.110809  energy_pos_std: 4.394319  energy_neg_std: 4.398452  average_im_grad: 0.000512  
Iteration 2190 loss: 0.003203  ml_loss: 0.013813  im_loss: 0.003203  energy_pos_mean: -18.330608  energy_neg_mean: -18.344421  energy_pos_std: 4.016957  energy_neg_std: 4.0

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


Iteration 2500 loss: 0.003321  ml_loss: 0.016597  im_loss: 0.003321  energy_pos_mean: -19.199135  energy_neg_mean: -19.215733  energy_pos_std: 4.644844  energy_neg_std: 4.652752  average_im_grad: 0.000561  
Iteration 2510 loss: 0.003448  ml_loss: 0.015326  im_loss: 0.003448  energy_pos_mean: -17.921412  energy_neg_mean: -17.936737  energy_pos_std: 4.229701  energy_neg_std: 4.234013  average_im_grad: 0.000242  
Iteration 2520 loss: 0.003093  ml_loss: 0.013810  im_loss: 0.003093  energy_pos_mean: -18.478168  energy_neg_mean: -18.491978  energy_pos_std: 4.091642  energy_neg_std: 4.094984  average_im_grad: 0.000135  
Iteration 2530 loss: 0.003125  ml_loss: 0.014974  im_loss: 0.003125  energy_pos_mean: -18.014149  energy_neg_mean: -18.029121  energy_pos_std: 4.213812  energy_neg_std: 4.219507  average_im_grad: 0.000287  
Iteration 2540 loss: 0.003073  ml_loss: 0.014570  im_loss: 0.003073  energy_pos_mean: -18.103243  energy_neg_mean: -18.117813  energy_pos_std: 4.301610  energy_neg_std: 4.3

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


Iteration 2860 loss: 0.002909  ml_loss: 0.013915  im_loss: 0.002909  energy_pos_mean: -18.009325  energy_neg_mean: -18.023241  energy_pos_std: 3.340082  energy_neg_std: 3.344334  average_im_grad: 0.000683  
Iteration 2870 loss: 0.002884  ml_loss: 0.014963  im_loss: 0.002884  energy_pos_mean: -19.350151  energy_neg_mean: -19.365112  energy_pos_std: 3.815667  energy_neg_std: 3.820024  average_im_grad: 0.000351  
Iteration 2880 loss: 0.002859  ml_loss: 0.012775  im_loss: 0.002859  energy_pos_mean: -17.862999  energy_neg_mean: -17.875772  energy_pos_std: 3.718583  energy_neg_std: 3.721895  average_im_grad: 0.000118  
Iteration 2890 loss: 0.003185  ml_loss: 0.015172  im_loss: 0.003185  energy_pos_mean: -18.055759  energy_neg_mean: -18.070932  energy_pos_std: 3.706158  energy_neg_std: 3.711926  average_im_grad: 0.000631  
Iteration 2900 loss: 0.002868  ml_loss: 0.013559  im_loss: 0.002868  energy_pos_mean: -18.063215  energy_neg_mean: -18.076775  energy_pos_std: 3.534119  energy_neg_std: 3.5

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


test at step 3000 done!
Iteration 3010 loss: 0.003013  ml_loss: 0.013033  im_loss: 0.003013  energy_pos_mean: -17.642246  energy_neg_mean: -17.655279  energy_pos_std: 3.511469  energy_neg_std: 3.515113  average_im_grad: 0.000390  
Iteration 3020 loss: 0.002824  ml_loss: 0.013783  im_loss: 0.002824  energy_pos_mean: -18.910439  energy_neg_mean: -18.924225  energy_pos_std: 4.175560  energy_neg_std: 4.180647  average_im_grad: 0.000144  
Iteration 3030 loss: 0.002932  ml_loss: 0.013460  im_loss: 0.002932  energy_pos_mean: -18.119659  energy_neg_mean: -18.133120  energy_pos_std: 3.694366  energy_neg_std: 3.699204  average_im_grad: 0.000518  
Iteration 3040 loss: 0.002874  ml_loss: 0.014087  im_loss: 0.002874  energy_pos_mean: -19.224403  energy_neg_mean: -19.238489  energy_pos_std: 4.206249  energy_neg_std: 4.211729  average_im_grad: 0.000284  
Iteration 3050 loss: 0.003020  ml_loss: 0.013949  im_loss: 0.003020  energy_pos_mean: -18.212715  energy_neg_mean: -18.226665  energy_pos_std: 3.695

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


Iteration 3220 loss: 0.003226  ml_loss: 0.015176  im_loss: 0.003226  energy_pos_mean: -17.733503  energy_neg_mean: -17.748680  energy_pos_std: 2.817638  energy_neg_std: 2.823695  average_im_grad: 0.000648  
Iteration 3230 loss: 0.003051  ml_loss: 0.014995  im_loss: 0.003051  energy_pos_mean: -18.675644  energy_neg_mean: -18.690639  energy_pos_std: 3.727898  energy_neg_std: 3.732816  average_im_grad: 0.000318  
Iteration 3240 loss: 0.002884  ml_loss: 0.013280  im_loss: 0.002884  energy_pos_mean: -18.028664  energy_neg_mean: -18.041945  energy_pos_std: 4.057128  energy_neg_std: 4.061463  average_im_grad: 0.000420  
Iteration 3250 loss: 0.002819  ml_loss: 0.012916  im_loss: 0.002819  energy_pos_mean: -18.044987  energy_neg_mean: -18.057903  energy_pos_std: 3.487938  energy_neg_std: 3.493346  average_im_grad: 0.000452  
Iteration 3260 loss: 0.003173  ml_loss: 0.014278  im_loss: 0.003173  energy_pos_mean: -17.362333  energy_neg_mean: -17.376610  energy_pos_std: 3.066273  energy_neg_std: 3.0

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


Iteration 3570 loss: 0.003105  ml_loss: 0.013272  im_loss: 0.003105  energy_pos_mean: -20.652950  energy_neg_mean: -20.666222  energy_pos_std: 1.195303  energy_neg_std: 1.197419  average_im_grad: 0.001751  
Iteration 3580 loss: 0.002880  ml_loss: 0.011336  im_loss: 0.002880  energy_pos_mean: -20.140715  energy_neg_mean: -20.152050  energy_pos_std: 1.216148  energy_neg_std: 1.217520  average_im_grad: 0.000202  
Iteration 3590 loss: 0.003077  ml_loss: 0.011969  im_loss: 0.003077  energy_pos_mean: -19.606348  energy_neg_mean: -19.618317  energy_pos_std: 0.947692  energy_neg_std: 0.947185  average_im_grad: 0.000127  
Iteration 3600 loss: 0.002792  ml_loss: 0.011045  im_loss: 0.002792  energy_pos_mean: -20.175194  energy_neg_mean: -20.186237  energy_pos_std: 1.546647  energy_neg_std: 1.547405  average_im_grad: 0.000545  
Iteration 3610 loss: 0.002840  ml_loss: 0.010367  im_loss: 0.002840  energy_pos_mean: -19.401302  energy_neg_mean: -19.411669  energy_pos_std: 1.089858  energy_neg_std: 1.0

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


Iteration 3930 loss: 0.002652  ml_loss: 0.010676  im_loss: 0.002652  energy_pos_mean: -18.883175  energy_neg_mean: -18.893850  energy_pos_std: 0.971440  energy_neg_std: 0.970531  average_im_grad: 0.000729  
Iteration 3940 loss: 0.003079  ml_loss: 0.011313  im_loss: 0.003079  energy_pos_mean: -18.310400  energy_neg_mean: -18.321712  energy_pos_std: 0.994031  energy_neg_std: 0.991735  average_im_grad: 0.001378  
Iteration 3950 loss: 0.002919  ml_loss: 0.010548  im_loss: 0.002919  energy_pos_mean: -18.428877  energy_neg_mean: -18.439426  energy_pos_std: 0.905655  energy_neg_std: 0.904086  average_im_grad: 0.000241  
Iteration 3960 loss: 0.002994  ml_loss: 0.010666  im_loss: 0.002994  energy_pos_mean: -18.381668  energy_neg_mean: -18.392334  energy_pos_std: 1.113743  energy_neg_std: 1.111329  average_im_grad: 0.000253  
Iteration 3970 loss: 0.002800  ml_loss: 0.010218  im_loss: 0.002800  energy_pos_mean: -18.137533  energy_neg_mean: -18.147751  energy_pos_std: 0.933254  energy_neg_std: 0.9

  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)
  im = imread(im_path)


test at step 4000 done!
Iteration 4010 loss: 0.002832  ml_loss: 0.010295  im_loss: 0.002832  energy_pos_mean: -18.821920  energy_neg_mean: -18.832218  energy_pos_std: 0.843642  energy_neg_std: 0.843159  average_im_grad: 0.000376  
Iteration 4020 loss: 0.003274  ml_loss: 0.011448  im_loss: 0.003274  energy_pos_mean: -18.313629  energy_neg_mean: -18.325077  energy_pos_std: 0.880485  energy_neg_std: 0.879145  average_im_grad: 0.000277  
Iteration 4030 loss: 0.002429  ml_loss: 0.009444  im_loss: 0.002429  energy_pos_mean: -18.369642  energy_neg_mean: -18.379086  energy_pos_std: 1.177535  energy_neg_std: 1.175031  average_im_grad: 0.000880  
Iteration 4040 loss: 0.003004  ml_loss: 0.010555  im_loss: 0.003004  energy_pos_mean: -17.944725  energy_neg_mean: -17.955280  energy_pos_std: 0.842957  energy_neg_std: 0.841818  average_im_grad: 0.000225  
Iteration 4050 loss: 0.002764  ml_loss: 0.009739  im_loss: 0.002764  energy_pos_mean: -17.650528  energy_neg_mean: -17.660267  energy_pos_std: 1.033