In [28]:
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
import pytorch_lightning as pl
from torch.nn.parameter import Parameter
from torchmetrics import PeakSignalNoiseRatio
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import datetime, random
import csv
import os
import argparse
import models
from utils import permute, to_numpy, init_weights, Self_Energy_log, get_dataloaders
from utils import gaussian_noise, sp_noise, delete_square, generate_Y0

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="CIFAR10")
parser.add_argument("--data_size", type=int, default=60000)
parser.add_argument("--train_ratio", type=float, default=1.0)
parser.add_argument("--subset", action="store_true", default=True)
parser.add_argument("--use_label", type=int, default=5)
parser.add_argument("--use_unpaired", action="store_true", default=False)

parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--in_channel", type=int, default=3)
parser.add_argument("--img_size", type=int, default=8)
parser.add_argument("--dim_feature", type=int, default=32)

parser.add_argument("--gaussian_noise", type=float, default=0.005)
parser.add_argument("--sp_noise", type=float, default=0.1)
parser.add_argument("--square_pixels", type=int, default=20)
parser.add_argument("--degradation", type=str, default='gaussian_noise')
parser.add_argument("--Y0_type", type=str, default='random')

parser.add_argument("--lr_energy_model", type=float, default=0.0001)
parser.add_argument("--lr_langevin_min", type=float, default=0.01)  
parser.add_argument("--lr_langevin_max", type=float, default=0.1)  
parser.add_argument("--number_step_langevin", type=int, default=50)  
parser.add_argument("--use_energy_sched", action="store_true", default=False)  
parser.add_argument("--regular_data", type=float, default=0.0)
parser.add_argument("--init_noise_decay", type=float, default=1.0)
# parser.add_argument("--use_gp", action="store_true", default=True)
# parser.add_argument("--use_energy_reg", action="store_true", default=False)
parser.add_argument("--reg_w", type=float, default=0.0005)
# parser.add_argument("--use_energy_L2_reg", action="store_true", default=True)
parser.add_argument("--L2_reg_w", type=float, default=0.1)
parser.add_argument('--num_workers', type=int, default=4, help='')

parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--save_plot", type=int, default=20)

parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")

args = parser.parse_args(args=[])

transforms_train = torchvision.transforms.Compose([
    torchvision.transforms.Resize(args.img_size),
    # torchvision.transforms.RandomCrop(32, padding=4),
    # torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.ToTensor(),
    # torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    # torchvision.transforms.Normalize((0.5), (0.5))
    ])
scale_range = [-1, 1]
sched_step_size       = 1 # 10
sched_gamma           = 0.93
dim_output            = 1
add_noise             = True
dir_data = '/hdd1/dataset'
list_lr_langevin      = np.linspace(args.lr_langevin_max, args.lr_langevin_min, num=args.epochs, endpoint=True)
pl.seed_everything(0)

train_loader, test_loader= get_dataloaders(dir_data, args.dataset, args.img_size, 
                                            args.batch_size, train_size = args.data_size,
                                            transform=transforms_train, use_subset = args.subset,
                                            parallel=False,
                                            num_workers=args.num_workers)




Global seed set to 0


Files already downloaded and verified
Files already downloaded and verified


In [72]:
a, _ = next(iter(train_loader))
print(torch.max(a), torch.min(a), torch.mean(a))

tensor(0.8039) tensor(0.0667) tensor(0.4184)


In [20]:
sigmas_np = np.linspace(0.05, 0.5, args.batch_size)
noise = torch.Tensor(sigmas_np).view((args.batch_size, 1, 1, 1))
noise

tensor([[[[0.0500]]],


        [[[0.2000]]],


        [[[0.3500]]],


        [[[0.5000]]]])

In [2]:
class Swish(nn.Module):

    def forward(self, x):
        return x * torch.sigmoid(x)


class CNNModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        # We increase the hidden dimension over layers. Here pre-calculated for simplicity.

        # Series of convolutions and Swish activation functions
        self.cnn_layers = nn.Sequential(
                nn.Conv2d(args.in_channel, args.dim_feature, 3, 2, padding=1), # [16x16] - Larger padding to get 32x32 image
                Swish(),
                nn.Conv2d(args.dim_feature, args.dim_feature * 2, 3,2,1), #  [8x8]
                Swish(),
                nn.Conv2d(args.dim_feature * 2, args.dim_feature * 4, kernel_size=3, stride=2, padding=1), # [4x4]
                Swish(),
                nn.Conv2d(args.dim_feature * 4,args.dim_feature*8, kernel_size=3, stride=2, padding=1), # [2x2]
                Swish(),
                nn.Flatten(),
                nn.Linear(args.dim_feature*8 * 4, args.dim_feature*8),
                Swish(),
                nn.Linear(args.dim_feature*8, 1)
        )
        self.initialize()

    def forward(self, x):
        x = self.cnn_layers(x).squeeze(dim=-1)
        return x
    
    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight.data)
                # nn.init.xavier_uniform_(m.weight.data)
                #nn.init.uniform_(m.weight)
                if m.bias is not None: 
                    nn.init.constant_(m.bias.data, 0)
            
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0)
            
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0)
			
            elif isinstance(m, nn.Linear):
                #nn.init.kaiming_uniform_(m.weight)
                nn.init.xavier_uniform_(m.weight.data)
                #nn.init.uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias.data, 0)

class Sampler:

    def __init__(self, model, img_shape, sample_size, max_len=8192):
        """
        Inputs:
            model - Neural network to use for modeling E_theta
            img_shape - Shape of the images to model
            sample_size - Batch size of the samples
            max_len - Maximum number of data points to keep in the buffer
        """
        super().__init__()
        self.model = model
        self.img_shape = img_shape
        self.sample_size = sample_size
        self.max_len = max_len
        self.examples = [(torch.rand((1,)+img_shape)*2-1) for _ in range(self.sample_size)]

    def sample_new_exmps(self, steps=60, step_size=10):
        """
        Function for getting a new batch of "fake" images.
        Inputs:
            steps - Number of iterations in the MCMC algorithm
            step_size - Learning rate nu in the algorithm above
        """
        # Choose 95% of the batch from the buffer, 5% generate from scratch 
        n_new = np.random.binomial(self.sample_size, 0.05)
        rand_imgs = torch.rand((n_new,) + self.img_shape) * 2 - 1
        old_imgs = torch.cat(random.choices(self.examples, k=self.sample_size-n_new), dim=0)
        inp_imgs = torch.cat([rand_imgs, old_imgs], dim=0).detach().to(device)

        # Perform MCMC sampling
        inp_imgs = Sampler.generate_samples(self.model, inp_imgs, steps=steps, step_size=step_size)

        # Add new images to the buffer and remove old ones if needed
        self.examples = list(inp_imgs.to(torch.device("cpu")).chunk(self.sample_size, dim=0)) + self.examples
        self.examples = self.examples[:self.max_len]
        return inp_imgs

    @staticmethod
    def generate_samples(model, inp_imgs, steps=60, step_size=10, return_img_per_step=False):
        """
        Function for sampling images for a given model. 
        Inputs:
            model - Neural network to use for modeling E_theta
            inp_imgs - Images to start from for sampling. If you want to generate new images, enter noise between -1 and 1.
            steps - Number of iterations in the MCMC algorithm.
            step_size - Learning rate nu in the algorithm above
            return_img_per_step - If True, we return the sample at every iteration of the MCMC
        """
        # Before MCMC: set model parameters to "required_grad=False"
        # because we are only interested in the gradients of the input. 
        is_training = model.training
        model.eval()
        for p in model.parameters():
            p.requires_grad = False
        inp_imgs.requires_grad = True
        
        # Enable gradient calculation if not already the case
        had_gradients_enabled = torch.is_grad_enabled()
        torch.set_grad_enabled(True)
        
        # We use a buffer tensor in which we generate noise each loop iteration.
        # More efficient than creating a new tensor every iteration.
        noise = torch.randn(inp_imgs.shape, device=inp_imgs.device)
        
        # List for storing generations at each step (for later analysis)
        imgs_per_step = []
        
        # Loop over K (steps)
        for _ in range(steps):
            # Part 1: Add noise to the input.
            noise.normal_(0, 0.005)
            inp_imgs.data.add_(noise.data)
            inp_imgs.data.clamp_(min=-1.0, max=1.0)
            
            # Part 2: calculate gradients for the current input.
            out_imgs = - model(inp_imgs)
            out_imgs.sum().backward()
            inp_imgs.grad.data.clamp_(-0.03, 0.03) # For stabilizing and preventing too high gradients

            # Apply gradients to our current samples
            inp_imgs.data.add_(-step_size * inp_imgs.grad.data)
            inp_imgs.grad.detach_()
            inp_imgs.grad.zero_()
            inp_imgs.data.clamp_(min=-1.0, max=1.0)
            
            if return_img_per_step:
                imgs_per_step.append(inp_imgs.clone().detach())
        
        # Reactivate gradients for parameters for training
        for p in model.parameters():
            p.requires_grad = True
        model.train(is_training)
        
        # Reset gradient calculation to setting before this function
        torch.set_grad_enabled(had_gradients_enabled)

        if return_img_per_step:
            return torch.stack(imgs_per_step, dim=0)
        else:
            return inp_imgs

In [None]:

class SampleBuffer:
    def __init__(self, max_samples=10000):
        self.max_samples = max_samples
        self.buffer = []

    def __len__(self):
        return len(self.buffer)

    def push(self, samples, class_ids=None):
        samples = samples.detach().to('cpu')
        class_ids = class_ids.detach().to('cpu')

        for sample, class_id in zip(samples, class_ids):
            self.buffer.append((sample.detach(), class_id))

            if len(self.buffer) > self.max_samples:
                self.buffer.pop(0)

    def get(self, n_samples, device=device):
        items = random.choices(self.buffer, k=n_samples)
        samples, class_ids = zip(*items)
        samples = torch.stack(samples, 0)
        class_ids = torch.tensor(class_ids)
        samples = samples.to(device)
        class_ids = class_ids.to(device)

        return samples, class_ids

def sample_buffer(buffer, batch_size=args.batch_size, p=0.95, device=device):
    if len(buffer) < 1:
        return (
            torch.rand(batch_size, 1, 32, 32, device=device),
            torch.randint(0, 10, (batch_size,), device=device),
        )

    n_replay = (np.random.rand(batch_size) < p).sum()

    replay_sample, replay_id = buffer.get(n_replay)
    random_sample = torch.rand(batch_size - n_replay, 1, 32, 32, device=device)
    random_id = torch.randint(0, 10, (batch_size - n_replay,), device=device)

    return (
        torch.cat([replay_sample, random_sample], 0),
        torch.cat([replay_id, random_id], 0),
    )
                   
def langevin(model, inp_imgs, epochs, lr_langevin, noise_decay=1.0, add_noise=True, return_img_per_step=False):
    """
    Function for getting a new batch of "fake" images.
    Inputs:
        epochs, steps - Number of iterations in the MCMC algorithm
        lr, step_size - Learning rate nu in the algorithm above
    """
    # Before MCMC: set model parameters to "required_grad=False"
    # because we are only interested in the gradients of the input.
    is_training = model.training
    model.eval()
    for p in model.parameters():
        p.requires_grad = False
    inp_imgs.requires_grad = True

    # Enable gradient calculation if not already the case
    had_gradients_enabled = torch.is_grad_enabled()
    torch.set_grad_enabled(True)

    # We use a buffer tensor in which we generate noise each loop iteration.
    # More efficient than creating a new tensor every iteration.
    if add_noise:
        noise_scale = np.sqrt(lr_langevin) * noise_decay
        noise = 0.5 * torch.rand_like(inp_imgs) * noise_scale
    else:
        noise = 0.0

# List for storing generations at each step (for later analysis)
    imgs_per_step = []
    
    # Loop over K (steps)
    for _ in range(epochs):
        # Part 1: Add noise to the input.
        noise.normal_(0, 0.005)
        inp_imgs.data.add_(noise.data)
        inp_imgs.data.clamp_(min=-1.0, max=1.0)
        
        # Part 2: calculate gradients for the current input.
        out_imgs = -model(inp_imgs)
        out_imgs.sum().backward()
        inp_imgs.grad.data.clamp_(-0.03, 0.03) # For stabilizing and preventing too high gradients

        # Apply gradients to our current samples
        inp_imgs.data.add_(-lr_langevin * inp_imgs.grad.data)
        inp_imgs.grad.detach_()
        inp_imgs.grad.zero_()
        inp_imgs.data.clamp_(min=-1.0, max=1.0)
        
        if return_img_per_step:
            imgs_per_step.append(inp_imgs.clone().detach())
    
    # Reactivate gradients for parameters for training
    for p in model.parameters():
        p.requires_grad = True
    model.train(is_training)

    # Reset gradient calculation to setting before this function
    torch.set_grad_enabled(had_gradients_enabled)

    if return_img_per_step:
        return torch.stack(imgs_per_step, dim=0)
    else:
        return inp_imgs
    
def update_langevin(model, inp, step, lr):
    model.eval()
    for p in model.parameters():
        p.requires_grad = False
    
    inp.requires_grad = True
    noise = torch.randn(inp.shape, device=inp.device)
    
    for _ in range(step):
        noise.normal_(0, 0.005)
        inp.data.add_(noise.data)
        inp.data.clamp_(min=-1.0, max=1.0)
        
        out = -model(inp)
        out.sum().backward()
        
        inp.grad.data.add_(-0.5*lr*inp.grad.data)
        inp.grad.detach_()
        inp.grad.zero_()
        inp.data.clamp_(min=-1.0, max=1.0)
    
    for p in model.parameters():
        p.requires_grad = True
    model.train()
    
    return inp    
    
    
def sample_langevin(x, model, stepsize, n_steps, noise_scale=None, intermediate_samples=False):
    """Draw samples using Langevin dynamics
    x: torch.Tensor, initial points
    model: An energy-based model
    stepsize: float
    n_steps: integer
    noise_scale: Optional. float. If None, set to np.sqrt(stepsize * 2)
    """
    if noise_scale is None:
        noise_scale = np.sqrt(stepsize * 2)

    l_samples = []
    l_dynamics = []
    x.requires_grad = True
    for _ in range(n_steps):
        l_samples.append(x.detach().to('cpu'))
        noise = torch.randn_like(x) * noise_scale
        out = model(x)
        grad = torch.autograd.grad(out.sum(), x, only_inputs=True)[0]
        dynamics = stepsize * grad + noise
        x = x + dynamics
        l_samples.append(x.detach().to('cpu'))
        l_dynamics.append(dynamics.detach().to('cpu'))

    if intermediate_samples:
        return l_samples, l_dynamics
    else:
        return l_samples[-1]
    
def clip_grad(parameters, optimizer):
    with torch.no_grad():
        for group in optimizer.param_groups:
            for p in group['params']:
                state = optimizer.state[p]

                if 'step' not in state or state['step'] < 1:
                    continue

                step = state['step']
                exp_avg_sq = state['exp_avg_sq']
                _, beta2 = group['betas']

                bound = 3 * torch.sqrt(exp_avg_sq / (1 - beta2 ** step)) + 0.1
                p.grad.data.copy_(torch.max(torch.min(p.grad.data, bound), -bound))

In [42]:
en = CNNModel(args)
image, _ = next(iter(dl_train))
z = torch.randn(image.shape) # randn_like(image) 와 동일
z.requires_grad = True
# print(torch.max(z), torch.min(z), torch.mean(z)) # tensor(4.4663) tensor(-4.3208) tensor(-0.0002)
out = en(z)
print(out.sum())
out.sum().backward()
print(z.grad.data.sum(), z.grad.sum(()))



tensor(4.9248, grad_fn=<SumBackward0>)
tensor(0.0221) tensor(0.0221)


In [3]:
import torchsummary
en = CNNModel(args).to(device)

print(torchsummary.summary(en, (1,32,32)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 16, 16]             320
             Swish-2           [-1, 32, 16, 16]               0
            Conv2d-3             [-1, 64, 8, 8]          18,496
             Swish-4             [-1, 64, 8, 8]               0
            Conv2d-5            [-1, 128, 4, 4]          73,856
             Swish-6            [-1, 128, 4, 4]               0
            Conv2d-7            [-1, 256, 2, 2]         295,168
             Swish-8            [-1, 256, 2, 2]               0
           Flatten-9                 [-1, 1024]               0
           Linear-10                  [-1, 256]         262,400
            Swish-11                  [-1, 256]               0
           Linear-12                    [-1, 1]             257
Total params: 650,497
Trainable params: 650,497
Non-trainable params: 0
-------------------------------

In [4]:
energy = CNNModel(args).to(device)

optimizer = optim.Adam(energy.parameters(), lr=args.lr_energy_model, betas=(args.b1, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=sched_step_size, gamma=sched_gamma) # Exponential decay over epochs


In [5]:
# noise = torch.randn(args.batch_size, 1, 32, 32) * args.gaussian_noise
# noise = noise.to(device)
# buffer = SampleBuffer()
sampler = Sampler(energy, (1,32,32), args.batch_size)

val_loss_energy_model_mean = np.zeros(args.epochs)
val_psnr_mean              = np.zeros(args.epochs)
val_psnr_langevin_mean     = np.zeros(args.epochs)


def to_numpy(tensor: torch.Tensor) -> np.ndarray:
  return tensor.detach().cpu().numpy()

for i in range(args.epochs):
  val_loss_energy_model = list()
  val_loss_cd = list()
  val_loss_reg = list()
  val_psnr              = list()
  val_psnr_langevin     = list()
  val_fake = list()
  val_real = list()

  for j, (image, _) in enumerate(iter(dl_train)):

    noise = torch.randn_like(image) * args.gaussian_noise
    image.add_(noise).clamp_(min=-1.0, max=1.0)
    real_imgs = image.to(device)
    
    fake_imgs = sampler.sample_new_exmps(steps=60, step_size=10)
    # Predict energy score for all images
    
    optimizer.zero_grad()
    inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
    real_out, fake_out = energy(inp_imgs).chunk(2, dim=0)
    
    # Calculate losses
    reg_loss = args.L2_reg_w * (real_out ** 2 + fake_out ** 2).mean()
    cdiv_loss = fake_out.mean() - real_out.mean()
    loss = reg_loss + cdiv_loss
    loss.backward()
    scheduler.step()
    
    # val_ = to_numpy(PSNR(fake_out, image))
    val_lan = to_numpy(PSNR(fake_imgs, real_imgs))
    
    # Logging
    val_loss_energy_model.append(loss.item())
    val_loss_cd.append(cdiv_loss.item())
    val_loss_reg.append(reg_loss.item())
    # val_psnr.append(val_)
    val_psnr_langevin.append(val_lan)
    val_real.append(real_out.detach().to(torch.device('cpu')).mean())
    val_fake.append(fake_out.detach().to(torch.device('cpu')).mean())
    
  log = '[%4d/%4d] loss=%5.3f, loss_contrastive_divergence=%5.3f, loss_regularization=%5.3f, psnr=%5.3f, lan_psnr=%5.3f, metrics_avg_real=%5.3f, metrics_avg_fake=%5.3f' % (i, args.epochs, 
                       np.mean(val_loss_energy_model), np.mean(val_loss_cd), np.mean(val_loss_reg), 
                       np.mean(val_psnr), np.mean(val_psnr_langevin), 
                       np.mean(val_real), np.mean(val_fake))
  print(log, flush=True)



[   0/ 100] loss=28.658, loss_contrastive_divergence=12.633, loss_regularization=16.025, psnr=  nan, lan_psnr=3.470, metrics_avg_real=0.017, metrics_avg_fake=12.650


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


[   1/ 100] loss=29.076, loss_contrastive_divergence=12.751, loss_regularization=16.324, psnr=  nan, lan_psnr=3.462, metrics_avg_real=0.018, metrics_avg_fake=12.769
[   2/ 100] loss=29.263, loss_contrastive_divergence=12.804, loss_regularization=16.458, psnr=  nan, lan_psnr=3.459, metrics_avg_real=0.017, metrics_avg_fake=12.822
[   3/ 100] loss=29.418, loss_contrastive_divergence=12.848, loss_regularization=16.570, psnr=  nan, lan_psnr=3.459, metrics_avg_real=0.018, metrics_avg_fake=12.865
[   4/ 100] loss=29.544, loss_contrastive_divergence=12.883, loss_regularization=16.661, psnr=  nan, lan_psnr=3.459, metrics_avg_real=0.017, metrics_avg_fake=12.900
[   5/ 100] loss=29.525, loss_contrastive_divergence=12.878, loss_regularization=16.648, psnr=  nan, lan_psnr=3.458, metrics_avg_real=0.017, metrics_avg_fake=12.895
[   6/ 100] loss=29.588, loss_contrastive_divergence=12.895, loss_regularization=16.693, psnr=  nan, lan_psnr=3.457, metrics_avg_real=0.018, metrics_avg_fake=12.913
[   7/ 100

In [None]:
    neg_img, _ = sample_buffer(buffer, image.shape[0]) # random variables
    neg_img.requires_grad = True
    # fake_imgs = langevin(energy, latent, 10, list_lr_langevin[i])
    for p in energy.parameters():
        p.requires_grad = False
    energy.eval()
    
    for _ in range(args.number_step_langevin):
        if noise.shape[0] != neg_img.shape[0]:
            noise = torch.randn(image.shape[0], 1, 32, 32, device=device)
                
        noise.normal_(0, 0.005)
        neg_img.data.add_(noise.data)
        
        neg_out = -energy(neg_img)
        neg_out.sum().backward()
        neg_img.grad.data.clamp_(-0.03, 0.03)

        neg_img.data.add_(-list_lr_langevin[i] * neg_img.grad.data)

        neg_img.grad.detach_()
        neg_img.grad.zero_()

        neg_img.data.clamp_(0, 1)
        
    neg_img = neg_img.detach()
    for p in energy.parameters():
        p.requires_grad = True
        
    energy.train()
    # fake_imgs = update_langevin(energy, latent, args.number_step_langevin, list_lr_langevin[i])
    optimizer.zero_grad()
    
    pos_out = energy(image)
    neg_out = energy(neg_img)

    loss = args.energy_L2_reg_weight * (pos_out ** 2 + neg_out ** 2)
    loss = loss + ( -pos_out + neg_out)
    loss = loss.mean()
    loss.backward()
    
    clip_grad(energy.parameters(), optimizer)
    
    value_psnr = to_numpy(PSNR(image, neg_img))
    scheduler.step()
    
    val_loss_energy_model.append(loss.item())
    val_psnr.append(value_psnr)
    
  loss = np.mean(val_loss_energy_model)
  val_psnr_mean = np.mean(val_psnr)
    

  log = '[%4d/%4d] loss=%5.3f, psnr=%5.3f' % (i, args.epochs, loss, val_psnr_mean)
  print(log, flush=True)
  
  if i % args.save_plot == 19:
    nRow    = 4 
    nCol    = 4
    fSize   = 3

    fig, ax = plt.subplots(nRow, nCol, figsize=(fSize * nCol, fSize * nRow))

    for r in range(2): 
        for c in range(nCol):
            ax[r+0][c].set_title('data')
            if args.in_channel == 1: 
                p = ax[r+0][c].imshow(image[r*nCol+c].cpu().numpy().squeeze(), cmap='gray')
            else:
                p = ax[r+0][c].imshow(image[r*nCol+c].cpu().numpy().permute(1,2,0))
            plt.colorbar(p, ax=ax[r+0][c])
    
    for r in range(2):    
        for c in range(nCol):
            ax[r+2][c].set_title('fake')
            if args.in_channel == 1: 
                p = ax[r+2][c].imshow(neg_img[r*nCol+c].detach().cpu().numpy().squeeze(), cmap='gray')
            else:
                p = ax[r+2][c].imshow(neg_img[r*nCol+c].detach().cpu().numpy().permute(1,2,0))
            plt.colorbar(p, ax=ax[r+2][c])

    plt.tight_layout()
    plt.show()
    plt.close(fig)

In [75]:
for i in range(args.epochs):
  val_loss_energy_model = list()
  val_psnr              = list()
  val_psnr_langevin     = list()

  for j, (image, _) in enumerate(iter(dl_train)):

    image = image.to(device)
    noise = torch.randn_like(image) * args.gaussian_noise
    noisy_image = image.add_(noise).clamp(min=-1.0, max=1.0).to(device)
    
    latent = torch.randn_like(image.clone()).to(device)
    
    fake_imgs = langevin(energy, latent, 10, list_lr_langevin[i])
    # fake_imgs = update_langevin(energy, latent, args.number_step_langevin, list_lr_langevin[i])
    fake_imgs = fake_imgs.to(device)
    input_imgs = torch.cat([noisy_image, fake_imgs], dim=0)
    
    real, fake = energy(input_imgs).chunk(2, dim=0)
    
    energy.zero_grad()
    reg_loss = args.energy_L2_reg_weight * (real**2 + fake**2).sum()
    cd_loss = fake.sum() - real.sum()
    loss = cd_loss + reg_loss
    
    loss.backward()
    
    value_psnr = to_numpy(PSNR(image, fake_imgs))
    scheduler.step()
    
    val_loss_energy_model.append(loss.item())
    val_psnr.append(value_psnr)
    
  loss = np.mean(val_loss_energy_model)
  val_psnr_mean = np.mean(val_psnr)
    

  log = '[%4d/%4d] loss=%5.3f, psnr=%5.3f' % (i, args.epochs, loss, val_psnr_mean)
  print(log, flush=True)
  
  if i % args.save_plot == 19:
    nRow    = 4 
    nCol    = 4
    fSize   = 3

    fig, ax = plt.subplots(nRow, nCol, figsize=(fSize * nCol, fSize * nRow))

    for r in range(2): 
        for c in range(nCol):
            ax[r+0][c].set_title('data')
            if args.in_channel == 1: 
                p = ax[r+0][c].imshow(image[r*nCol+c].cpu().numpy().squeeze(), cmap='gray')
            else:
                p = ax[r+0][c].imshow(image[r*nCol+c].cpu().numpy().permute(1,2,0))
            plt.colorbar(p, ax=ax[r+0][c])
    
    for r in range(2):    
        for c in range(nCol):
            ax[r+2][c].set_title('fake')
            if args.in_channel == 1: 
                p = ax[r+2][c].imshow(fake_imgs[r*nCol+c].detach().cpu().numpy().squeeze(), cmap='gray')
            else:
                p = ax[r+2][c].imshow(fake_imgs[r*nCol+c].detach().cpu().numpy().permute(1,2,0))
            plt.colorbar(p, ax=ax[r+2][c])

    plt.tight_layout()
    plt.show()
    plt.close(fig)

In [76]:
n_new = np.random.binomial(args.batch_size, 0.05)
rand_imgs = torch.rand((n_new,) + (1,32,32)) * 2 - 1
old_imgs = torch.cat(random.choices(examples, k=args.batch_size-n_new), dim=0)
inp_imgs = torch.cat([rand_imgs, old_imgs], dim=0).detach().to(device)

In [77]:
inp_imgs.shape

torch.Size([100, 1, 32, 32])