<a href="https://colab.research.google.com/github/RossM/machine-learning-colabs/blob/main/pix2pix_denoising.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

## GPU status

In [None]:
import subprocess
simple_nvidia_smi_display = False#@param {type:"boolean"}
if simple_nvidia_smi_display:
    #!nvidia-smi
    nvidiasmi_output = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(nvidiasmi_output)
else:
    #!nvidia-smi -i 0 -e 0
    nvidiasmi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(nvidiasmi_output)
    nvidiasmi_ecc_note = subprocess.run(['nvidia-smi', '-i', '0', '-e', '0'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(nvidiasmi_ecc_note)

## Set up Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Download the data

In [None]:
DATA_DIR = "/content/danbooru2021/images"

!mkdir -p $DATA_DIR
!rsync --progress --recursive --size-only --verbose rsync://176.9.41.242:873/danbooru2021/512px/000*/ $DATA_DIR/512px/


## Install dependencies

In [None]:
import os

if 'COLAB_TPU_ADDR' in os.environ:
  !pip install cloud-tpu-client==0.10 torch==1.11.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl
!pip install dominate wandb torchvision einops denoising-diffusion-pytorch

## Define our model

In [None]:
import os, functools
import torch
import torch.nn as nn
import einops
from torch.nn import init
from torch.optim import lr_scheduler
import denoising_diffusion_pytorch.denoising_diffusion_pytorch as dd

# Based on code from pytorch-pix2pix https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

def default(val, d):
    if val is not None:
        return val
    return d() if isfunction(d) else d


def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.

    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, devices=["cpu"], dtype=torch.float):
    """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
    Parameters:
        net (network)      -- the network to be initialized
        init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        gain (float)       -- scaling factor for normal, xavier and orthogonal.

    Return an initialized network.
    """
    net.to(dtype)
    if devices[0] != "cpu":
        net.to(devices[0])
        net = torch.nn.DataParallel(net, devices)  # multi-GPUs
    init_weights(net, init_type, init_gain=init_gain)
    return net

def get_scheduler(optimizer, opt):
    """Return a learning rate scheduler

    Parameters:
        optimizer          -- the optimizer of the network
        opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions．　
                              opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine

    For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
    and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
    For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
    See https://pytorch.org/docs/stable/optim.html for more details.
    """
    def lambda_rule(epoch):
        lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
        return lr_l
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    return scheduler

##############################################################################
# NN-Blocks
##############################################################################

class SequentialExt(nn.Sequential):
    def forward(self, *args):
        modules = list(self)
        input = modules[0](*args)
        for module in modules[1:]:
            input = module(input)
        return input

def nameof(obj):
    if hasattr(obj, '__name__'):
        return obj.__name__
    return str(obj)

class Residual(nn.Sequential):
    def __init__(self, *args, reduce=None):
        super().__init__(*args)
        self.reduce = reduce

    def forward(self, input):
        x = input
        for module in self:
            x = module(x)
        if self.reduce != None:
            return self.reduce(x, input)
        else:
            return x + input

    def extra_repr(self):
        if self.reduce != None:
            return str(f"reduce={nameof(self.reduce)}")

def Bypass(*args, dim=1):
    def cat(*tensors):
        return torch.cat(tensors, dim=dim)
    return Residual(*args, reduce=cat)

class Sum(nn.ModuleList):
    def __init__(self, *args):
        super().__init__(args)
      
    def forward(self, input):
        modules = list(self)
        x = modules[0](input)
        for module in modules[1:]:
          x += module(input)
        return x

def Downsample(dim, dim_out = None):
    return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)

def Upsample(dim, dim_out = None):
    return nn.ConvTranspose2d(dim, default(dim_out, dim), 4, 2, 1)

def Block(dim, dim_out, *, kernel_size=3, groups=8):
    return nn.Sequential(
        nn.Conv2d(dim, dim_out, kernel_size=kernel_size, padding=kernel_size//2),
        nn.GroupNorm(groups, dim_out),
        nn.SiLU(),
    )

def ResnetBlock(dim, dim_out, *, kernel_size=3, groups=8):
    if dim != dim_out:
        return Sum(
            nn.Sequential(
                Block(dim, dim_out, kernel_size=kernel_size, groups=groups),
                Block(dim_out, dim_out, kernel_size=kernel_size, groups=groups),
            ),
            nn.Conv2d(dim, dim_out, 1),
        )
    else:
        return Residual(
            Block(dim, dim_out, kernel_size=kernel_size, groups=groups),
            Block(dim_out, dim_out, kernel_size=kernel_size, groups=groups),
        )

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, attention_dim=32):
        super().__init__()
        self.scale = attention_dim ** -0.5
        self.heads = heads
        self.to_kv = nn.Conv2d(dim, attention_dim * 2, 1, bias=False)
        self.to_q = nn.Conv2d(dim, attention_dim * heads, 1, bias=False)
        self.to_out = nn.Conv2d(attention_dim * heads, dim, 1)

    def forward(self, x):
        k, v = self.to_kv(x).chunk(2, dim=1)
        q = einops.rearrange(self.to_q(x), 'b (h d) x y -> b h d x y', h=self.heads)

        k = k.softmax(dim = 1)
        q = q.softmax(dim = 2) * self.scale

        context = torch.einsum('b d x y, b e x y -> b d e', k, v)
        out = torch.einsum('b d e, b h d x y -> b h e x y', context, q)
        out = einops.rearrange(out, 'b h e x y -> b (h e) x y')

        return self.to_out(out)

def Attention(dim, attention_dim=32):
    return Residual(
        nn.InstanceNorm2d(dim, affine=True),
        LinearAttention(dim, attention_dim=attention_dim),
        nn.InstanceNorm2d(dim, affine=True),
    )

class UNetBlock(nn.Sequential):
    @staticmethod
    def cat(x, y):
        return torch.cat((x, y), dim=1)

    def __init__(self, ch, inner_ch, attention_dim, *inner_blocks):
        super().__init__(
            ResnetBlock(ch, ch),
            ResnetBlock(ch, ch),
            Residual(
                Downsample(ch, inner_ch),
                Attention(inner_ch, attention_dim),
                *inner_blocks,
                Upsample(inner_ch, ch),
                reduce=self.cat
            ),
            #nn.Conv2d(ch*2, ch, 1),
            ResnetBlock(ch*2, ch),
            ResnetBlock(ch, ch),
            Attention(ch, attention_dim),
        )


#class SoftClamp(nn.Module):
#    def __init__(self, min=None, max=None):
#        super().__init__(self)
#        self.min = min
#        self.max = max
#        self.register_full_backward_hook(self._backward_hook)
#
#    def forward(self, input):
#        return input.clamp(-1, 1)
#
#    def _backward_hook(self, module, grad_input, grad_output):
#        return grad_output
#
#    def extra_repr(self):
#        return str(f"min: {self.min}, max {self.max}")

# This acts like clamp(-1, 1) but passes through gradients unchanged
#class SoftClampFn(torch.autograd.Function):
#    staticmethod
#    def forward(ctx, input):
#        return input.clamp(-1, 1)
#
#    staticmethod
#    def backward(ctx, grad_output):
#        return grad_output
#
#softclamp = SoftClampFn.apply

##############################################################################
# Classes
##############################################################################
class GANLoss(nn.Module):
    """Define different GAN objectives.

    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, device=None, dtype=torch.float):
        """ Initialize the GANLoss class.

        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image

        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label, device=device).to(dtype))
        self.register_buffer('fake_label', torch.tensor(target_fake_label, device=device).to(dtype))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.

        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """

        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        """Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) - - typically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            loss = self.loss(prediction, target_tensor)
        elif self.gan_mode == 'wgangp':
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
        return loss

class Pix2PixModel:
    """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.

    The model training requires '--dataset_mode aligned' dataset.
    By default, it uses a '--netG unet256' U-Net generator,
    a '--netD basic' discriminator (PatchGAN),
    and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).

    pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
    """
    def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        # Base model
        self.opt = opt
        self.isTrain = opt.isTrain
        self.devices = opt.devices
        self.dtype = opt.dtype
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)  # save all the checkpoints to save_dir
        torch.backends.cudnn.benchmark = True
        self.optimizers = []
        self.image_paths = []
        self.metric = 0  # used for learning rate policy 'plateau'

        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        self.visual_names = ['real_A', 'fake_B', 'real_B']
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load G
            self.model_names = ['G']
        # define networks (both generator and discriminator)
        self.netG = nn.Sequential(
            nn.Conv2d(opt.input_nc, opt.ngf, 7, padding=3),
            ResnetBlock(opt.ngf, opt.ngf),
            ResnetBlock(opt.ngf, opt.ngf),
            Attention(opt.ndf, opt.attention_dim),
            UNetBlock(opt.ngf, opt.ngf * 2, opt.attention_dim,
                UNetBlock(opt.ngf * 2, opt.ngf * 4, opt.attention_dim,
                    UNetBlock(opt.ngf * 4, opt.ngf * 8, opt.attention_dim,
                        ResnetBlock(opt.ngf * 8, opt.ngf * 8),
                        ResnetBlock(opt.ngf * 8, opt.ngf * 8),
                        Attention(opt.ngf * 8, opt.attention_dim),
                        ResnetBlock(opt.ngf * 8, opt.ngf * 8),
                        ResnetBlock(opt.ngf * 8, opt.ngf * 8),
                        Attention(opt.ngf * 8, opt.attention_dim),
                    )
                )
            ),
            ResnetBlock(opt.ngf, opt.ngf),
            ResnetBlock(opt.ngf, opt.ngf),
            nn.Conv2d(opt.ngf, opt.output_nc, 1)
        )
        self.netG = init_net(self.netG, opt.init_type, opt.init_gain, opt.devices, opt.dtype)

        if self.isTrain:  # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
            self.netD = nn.Sequential(
                nn.Conv2d(opt.input_nc+opt.output_nc, opt.ndf, 7, padding=3),
                ResnetBlock(opt.ndf, opt.ndf, kernel_size=7),
                ResnetBlock(opt.ndf, opt.ndf, kernel_size=7),
                Attention(opt.ndf, opt.attention_dim),
                ResnetBlock(opt.ndf, opt.ndf, kernel_size=7),
                ResnetBlock(opt.ndf, opt.ndf, kernel_size=7),
                Downsample(opt.ndf, opt.ndf*2),
                Attention(opt.ndf*2, opt.attention_dim),
                ResnetBlock(opt.ndf*2, opt.ndf*2),
                ResnetBlock(opt.ndf*2, opt.ndf*2),
                Downsample(opt.ndf*2, opt.ndf*4),
                Attention(opt.ndf*4, opt.attention_dim),
                ResnetBlock(opt.ndf*4, opt.ndf*4),
                ResnetBlock(opt.ndf*4, opt.ndf*4),
                Downsample(opt.ndf*4, opt.ndf*8),
                Attention(opt.ndf*8, opt.attention_dim),
                ResnetBlock(opt.ndf*8, opt.ndf*8),
                ResnetBlock(opt.ndf*8, opt.ndf*8),
                Attention(opt.ndf*8, opt.attention_dim),
                nn.Conv2d(opt.ndf*8, opt.ndf*8, 1),
                nn.SiLU(),
                nn.Conv2d(opt.ndf*8, 1, 1),
            )
            self.netD = init_net(self.netD, opt.init_type, opt.init_gain, opt.devices, opt.dtype)

        if self.isTrain:
            # define loss functions
            self.criterionGAN = GANLoss(opt.gan_mode, device=self.devices[0], dtype=self.dtype)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionMSE = torch.nn.MSELoss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999), amsgrad=opt.amsgrad)
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999), amsgrad=opt.amsgrad)
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG(self.real_A)  # G(A)

    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake_B
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminator
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()
        self.pred_fake = pred_fake
        self.pred_real = pred_real

    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
        self.loss_G_MSE = self.criterionMSE(self.fake_B, self.real_B) * self.opt.lambda_MSE
        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN + self.loss_G_L1 + self.loss_G_MSE
        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()                   # compute fake images: G(A)
        # update D
        self.set_requires_grad(self.netD, True)  # enable backprop for D
        self.optimizer_D.zero_grad()     # set D's gradients to zero
        self.backward_D()                # calculate gradients for D
        self.optimizer_D.step()          # update D's weights
        # update G
        self.set_requires_grad(self.netD, False)  # D requires no gradients when optimizing G
        self.optimizer_G.zero_grad()        # set G's gradients to zero
        self.backward_G()                   # calculate graidents for G
        self.optimizer_G.step()             # udpate G's weights

    def setup(self, opt):
        """Load and print networks; create schedulers

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        if self.isTrain:
            self.schedulers = [get_scheduler(optimizer, opt) for optimizer in self.optimizers]
        self.print_networks(opt.verbose)

    def eval(self):
        """Make models eval mode during test time"""
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                net.eval()

    def test(self):
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
        with torch.no_grad():
            self.forward()
            self.compute_visuals()

    def update_learning_rate(self):
        """Update learning rates for all the networks; called at the end of every epoch"""
        old_lr = self.optimizers[0].param_groups[0]['lr']
        for scheduler in self.schedulers:
            scheduler.step()

        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate %.7f -> %.7f' % (old_lr, lr))

    def save_networks(self, epoch):
        """Save all the networks to the disk.

        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, 'net' + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                    
                torch.save(net.state_dict(), save_path)

    def load_networks(self, epoch):
        """Load all the networks from the disk.

        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '%s_net_%s.pth' % (epoch, name)
                load_path = os.path.join(self.save_dir, load_filename)
                net = getattr(self, 'net' + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                print('loading the model from %s' % load_path)
                # if you are using PyTorch newer than 0.4 (e.g., built from
                # GitHub source), you can remove str() on self.device
                state_dict = torch.load(load_path, map_location=self.devices[0])
                if hasattr(state_dict, '_metadata'):
                    del state_dict._metadata

                net.load_state_dict(state_dict)

    def print_networks(self, verbose):
        """Print the total number of parameters in the network and (if verbose) network architecture

        Parameters:
            verbose (bool) -- if verbose: print the network architecture
        """
        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                    if param.dtype != self.dtype:
                      print(f"Warning! Parameter has wrong dtype {param.dtype}, {param.device}, {param.numel()}. Expected {self.dtype}")
                    if str(param.device) != str(self.opt.devices[0]):
                      print(f"Warning! Parameter has wrong device {param.dtype}, {param.device}, {param.numel()}. Expected {self.opt.devices[0]}")
                for buffer in net.buffers():
                    if buffer.dtype != self.dtype:
                      print(f"Warning! Buffer has wrong dtype {buffer.dtype}, {buffer.device}, {buffer.numel()}. Expected {self.dtype}")
                    if str(buffer.device) != str(self.opt.devices[0]):
                      print(f"Warning! Buffer has wrong device {buffer.dtype}, {buffer.device}, {buffer.numel()}. Expected {self.opt.devices[0]}")
                if verbose:
                    print(net)
                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
        print('-----------------------------------------------')

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def to(self, *args):
      self.netG.to(*args)
      self.netD.to(*args)


# Model configuration

In [None]:
class TrainOptions():
    pass

opt = TrainOptions()
opt.checkpoints_dir = '/content/drive/MyDrive/AI/pix2pix/checkpoints'
opt.name = "denoising9-medium"  #@param {type: "string"}

#@markdown ## Model parameters
#@markdown **Changing these will invalidate checkpoint files!**
opt.input_nc = opt.output_nc = 3
opt.ngf = 64              #@param {type: "integer"}
opt.ndf = 64              #@param {type: "integer"}
opt.attention_dim = 32    #@param {type: "integer"}

#@markdown ## Training parameters
opt.init_type = "normal"
opt.init_gain = 0.02
opt.gan_mode = "vanilla"    #@param ["vanilla", "lsgan"]
opt.beta1 = 0.9           #@param {type: "number"}
opt.lr = 0.0002           #@param {type: "number"}
opt.amsgrad = True        #@param {type: "boolean"}
opt.epoch_count = 1
opt.n_epochs = 100
opt.n_epochs_decay = 100
opt.batch_size = 4        #@param {type: "integer"}
opt.noise_schedule_beta = 0.003   #@param {type: "number"}

#@markdown ## Objective parameters
opt.lambda_L1 = 5       #@param {type: "number"}
opt.lambda_MSE = 50    #@param {type: "number"}

#@markdown ## Debugging
opt.verbose = False       #@param {type: 'boolean'}
opt.display_interval = 10 #@param {type: "integer"}
opt.debug_xla = False      #@param {type: 'boolean'}

#@markdown ## Backend selection
opt.backend = 'CUDA' #@param ["CPU", "CUDA", "TPU"]
opt.dtype = 'float32' #@param ["float32", "bfloat16"]

if opt.backend == 'CUDA':
  torch.cuda.set_device(0)
  opt.devices = [torch.device('cuda:0')]
elif opt.backend == 'TPU':
  import torch_xla
  import torch_xla.core.xla_model as xm
  opt.devices = [xm.xla_device()]
else:
  opt.devices = [torch.device('cpu')]

if opt.dtype == "bfloat16":
  opt.dtype = torch.bfloat16
else:
  opt.dtype = torch.float



# Train

In [None]:
import math, random
import torch
import torchvision as tv
import einops
from torchvision.transforms.functional import to_pil_image
from torchvision.datasets import ImageFolder
from IPython import display

dataset = tv.datasets.ImageFolder(root = DATA_DIR, 
  transform = tv.transforms.Compose([
    #tv.transforms.Resize(256),                                                            
    tv.transforms.ToTensor(),
  ]))
if opt.backend == 'TPU':
  sampler = torch.utils.data.distributed.DistributedSampler(
    dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)
  data_loader = torch.utils.data.DataLoader(dataset,
    sampler=sampler,
    batch_size=opt.batch_size,
    num_workers=4,
    drop_last=True)
else:
  data_loader = torch.utils.data.DataLoader(dataset,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=4)


opt.isTrain = True
opt.continue_train = False

RUN_DIR = f"{opt.checkpoints_dir}/{opt.name}"
!mkdir -p $RUN_DIR

model = Pix2PixModel(opt)
model.setup(opt)

start_epoch = 1
last_epoch = 100
# Resume
try:
  cp = open(f"{opt.checkpoints_dir}/{opt.name}/latest_epoch", "r")
  start_epoch = int(cp.read())
  model.load_networks(str(start_epoch))
  start_epoch += 1
  cp.close()
except:
  print(f"Failed to load checkpoint {start_epoch}")
  start_epoch = 1
  pass

arch = open(f"{opt.checkpoints_dir}/{opt.name}/architecture", "w")
arch.write(f"netG: {model.netG}\nnetD: {model.netD}\n")
arch.close()

#display.clear_output(wait=True)
#dataset[1000][0]

crop = tv.transforms.RandomCrop(256)

for epoch in range(start_epoch, last_epoch+1):
  model.update_learning_rate()
  for i, data in enumerate(data_loader):
    output_data = crop(data[0].to(opt.dtype).to(opt.devices[0])) * 2 - 1
    time = torch.rand(size = (output_data.size()[0], ), device=opt.devices[0], dtype=opt.dtype) * 1000
    beta = einops.rearrange((1 - opt.noise_schedule_beta) ** time, 'd -> d 1 1 1')
    input_data = (output_data * torch.sqrt(1 - beta) + torch.randn_like(output_data) * torch.sqrt(beta))

    model.real_A = input_data.to(opt.dtype)
    model.real_B = output_data.to(opt.dtype)

    model.optimize_parameters()

    if opt.backend == 'TPU' and opt.debug_xla:
      print(torch_xla._XLAC._get_xla_tensors_text([model.loss_G_GAN, model.loss_G_L1, model.loss_G_MSE, model.loss_D_fake, model.loss_D_real]))

    if opt.display_interval > 0 and i % opt.display_interval == 0:
      display.clear_output(wait=True)
      model.print_networks(verbose=opt.verbose)
      print(f"epoch {epoch}/{last_epoch} batch {i}/{len(data_loader)}")
      if opt.backend != 'TPU':
        sample_index = 0
        pred_fake = nn.functional.adaptive_avg_pool3d(model.pred_fake[sample_index], 1)
        pred_real = nn.functional.adaptive_avg_pool3d(model.pred_real[sample_index], 1)
        sample = to_pil_image(torch.cat((model.real_A[sample_index],
          model.real_B[sample_index],
          model.fake_B[sample_index]), dim=2).clamp(-1, 1) * 0.5 + 0.5)
        print(f"Sample {sample_index}: time {time[sample_index]}, beta {beta[sample_index][0][0][0]}, pred_fake {pred_fake[0][0][0]}, pred_real {pred_real[0][0][0]}")
        display.display_png(sample)
      print(f"Batch losses: G_GAN {model.loss_G_GAN}, G_L1 {model.loss_G_L1}, G_MSE {model.loss_G_MSE}, D_fake {model.loss_D_fake}, D_real {model.loss_D_real}")
  
  #del model.real_A
  #del model.real_B
  #del model.fake_B
  model.save_networks('latest')
  model.save_networks(epoch)
  cp = open(f"{opt.checkpoints_dir}/{opt.name}/latest_epoch", "w")
  cp.write(str(epoch))
  
  cp.close()


# Denoise

In [None]:
import math, io, requests
import torch
import einops
from PIL import Image
from torchvision.transforms.functional import to_tensor, to_pil_image, resize
from IPython import display

image_url = "https://i.imgur.com/3NHuxu1.png" #@param {type: "string"}
image_size = [512,512] #@param {type: "raw"}
noise_beta =  0#@param {type: "number"}
iterations =  1#@param {type: "integer", min:1}

#@markdown ## NN visualization
visualize_layer = "" #@param {type: "string"}
visualize_channels = [0, 64] #@param {type: "raw"}

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

opt.isTrain = False

model = Pix2PixModel(opt)
model.setup(opt)
model.load_networks('latest')

if image_url != "":
  pil_image = Image.open(fetch(image_url)).convert('RGB')

  if image_size != None:
    pil_image = resize(pil_image, image_size)

  print("Input image")
  display.display_png(pil_image)

  image = to_tensor(pil_image).to(model.devices[0])
  image = einops.rearrange(image, '... -> 1 ...')
  image = image * 2 - 1
else:
  if image_size == None:
    image_size = [256, 256]
  image = torch.zeros(1, 3, image_size[0], image_size[1]).to(model.devices[0])

image = image * math.sqrt(1 - noise_beta) + torch.randn_like(image) * math.sqrt(noise_beta)

model.set_requires_grad(model.netG, False)

cum_beta = noise_beta
for step in range(0, iterations):
  step_beta = 1 / (iterations - step)

  iter = model.netG(image)
  image = image * (1 - step_beta) + iter * (step_beta)
  cum_beta *= (1 - step_beta)

  if step + 1 < iterations:
    print(f"Iteration {step+1}/{iterations}")
    display.display_png(to_pil_image(iter[0].clamp(-1, 1) * 0.5 + 0.5))

print("Final result")
display.display_png(to_pil_image(image[0].clamp(-1, 1) * 0.5 + 0.5))
