In [1]:
!pip install einops

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

def lower_bound_fwd(x: Tensor, bound: Tensor) -> Tensor:
    return torch.max(x, bound)


def lower_bound_bwd(x: Tensor, bound: Tensor, grad_output: Tensor):
    pass_through_if = (x >= bound) | (grad_output < 0)
    return pass_through_if * grad_output, None
#Compute PSNR
import math
def compute_psnr(img1, img2):
    img1 = img1.astype(np.float64) 
    img2 = img2.astype(np.float64) 
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return "Same Image"
    return 10 * math.log10(1. / mse)

class LowerBoundFunction(torch.autograd.Function):
    """Autograd function for the `LowerBound` operator."""

    @staticmethod
    def forward(ctx, x, bound):
        ctx.save_for_backward(x, bound)
        return lower_bound_fwd(x, bound)

    @staticmethod
    def backward(ctx, grad_output):
        x, bound = ctx.saved_tensors
        return lower_bound_bwd(x, bound, grad_output)
class LowerBound(nn.Module):
    """Lower bound operator, computes `torch.max(x, bound)` with a custom
    gradient.

    The derivative is replaced by the identity function when `x` is moved
    towards the `bound`, otherwise the gradient is kept to zero.
    """

    bound: Tensor

    def __init__(self, bound: float):
        super().__init__()
        self.register_buffer("bound", torch.Tensor([float(bound)]))

    @torch.jit.unused
    def lower_bound(self, x):
        return LowerBoundFunction.apply(x, self.bound)

    def forward(self, x):
        if torch.jit.is_scripting():
            return torch.max(x, self.bound)
        return self.lower_bound(x)


class NonNegativeParametrizer(nn.Module):
    """
    Non negative reparametrization.

    Used for stability during training.
    """

    pedestal: Tensor

    def __init__(self, minimum: float = 0, reparam_offset: float = 2 ** -18):
        super().__init__()

        self.minimum = float(minimum)
        self.reparam_offset = float(reparam_offset)

        pedestal = self.reparam_offset ** 2
        self.register_buffer("pedestal", torch.Tensor([pedestal]))
        bound = (self.minimum + self.reparam_offset ** 2) ** 0.5
        self.lower_bound = LowerBound(bound)

    def init(self, x: Tensor) -> Tensor:
        return torch.sqrt(torch.max(x + self.pedestal, self.pedestal))

    def forward(self, x: Tensor) -> Tensor:
        out = self.lower_bound(x)
        out = out ** 2 - self.pedestal
        return out
class GDN(nn.Module):
    r"""Generalized Divisive Normalization layer.

    Introduced in `"Density Modeling of Images Using a Generalized Normalization
    Transformation" <https://arxiv.org/abs/1511.06281>`_,
    by Balle Johannes, Valero Laparra, and Eero P. Simoncelli, (2016).

    .. math::

       y[i] = \frac{x[i]}{\sqrt{\beta[i] + \sum_j(\gamma[j, i] * x[j]^2)}}

    """

    def __init__(
        self,
        in_channels: int,
        inverse: bool = False,
        beta_min: float = 1e-6,
        gamma_init: float = 0.1,
    ):
        super().__init__()

        beta_min = float(beta_min)
        gamma_init = float(gamma_init)
        self.inverse = bool(inverse)

        self.beta_reparam = NonNegativeParametrizer(minimum=beta_min)
        beta = torch.ones(in_channels)
        beta = self.beta_reparam.init(beta)
        self.beta = nn.Parameter(beta)

        self.gamma_reparam = NonNegativeParametrizer()
        gamma = gamma_init * torch.eye(in_channels)
        gamma = self.gamma_reparam.init(gamma)
        self.gamma = nn.Parameter(gamma)

    def forward(self, x: Tensor) -> Tensor:
        _, C, _, _ = x.size()

        beta = self.beta_reparam(self.beta)
        gamma = self.gamma_reparam(self.gamma)
        gamma = gamma.reshape(C, C, 1, 1)
        norm = F.conv2d(x ** 2, gamma, beta)
        
        if self.inverse:
            norm = torch.sqrt(norm)
        else:
            norm = torch.rsqrt(norm)

        out = x * norm

        return out

In [3]:
import torch
import torch.nn as nn


from einops import rearrange
from einops.layers.torch import Rearrange


def conv_3x3_bn(inp, oup, image_size, downsample=False):
    stride = 1 if downsample == False else 2
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.GELU()
    )


class PreNorm(nn.Module):
    def __init__(self, dim, fn, norm):
        super().__init__()
        self.norm = norm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class SE(nn.Module):
    def __init__(self, inp, oup, expansion=0.25):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(oup, int(inp * expansion), bias=False),
            nn.GELU(),
            nn.Linear(int(inp * expansion), oup, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class MBConv(nn.Module):
    def __init__(self, inp, oup, image_size, downsample=False, expansion=4):
        super().__init__()
        self.downsample = downsample
        stride = 1 if self.downsample == False else 2
        hidden_dim = int(inp * expansion)

        if self.downsample:
            self.pool = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        
        self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)
    def forward(self, x):
        if self.downsample:
            return self.proj(self.pool(x)) + self.conv(x)
        else:
            return x + self.conv(x)


class Attention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out


class Transformer(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
        super().__init__()
        hidden_dim = int(inp * 4)

        self.ih, self.iw = image_size
        self.downsample = downsample

        if self.downsample:
            self.pool1 = nn.MaxPool2d(3, 2, 1)
            self.pool2 = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        self.attn = Attention(inp, oup, image_size, heads, dim_head, dropout)
        self.ff = FeedForward(oup, hidden_dim, dropout)

        self.attn = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(inp, self.attn, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

        self.ff = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ff, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

    def forward(self, x):
        if self.downsample:
            x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
        else:
            x = x + self.attn(x)
        x = x + self.ff(x)
        return x


class CoAtNet(nn.Module):
    def __init__(self, image_size, in_channels, num_blocks, channels, num_classes=1000, block_types=['C', 'C', 'T', 'T']):
        super().__init__()
        ih, iw = image_size
        block = {'C': MBConv, 'T': Transformer}

        self.s0 = self._make_layer(
            conv_3x3_bn, in_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))
        self.s1 = self._make_layer(
            block[block_types[0]], channels[0], channels[1], num_blocks[1], (ih // 4, iw // 4))
        self.s2 = self._make_layer(
            block[block_types[1]], channels[1], channels[2], num_blocks[2], (ih // 8, iw // 8))
        self.s3 = self._make_layer(
            block[block_types[2]], channels[2], channels[3], num_blocks[3], (ih // 16, iw // 16))
        self.s4 = self._make_layer(
            block[block_types[3]], channels[3], channels[4], num_blocks[4], (ih // 32, iw // 32))

        self.pool = nn.AvgPool2d(ih // 32, 1)
        self.fc = nn.Linear(channels[-1], num_classes, bias=False)

    def forward(self, x):
        x = self.s0(x)
        x = self.s1(x)
        x = self.s2(x)
        x = self.s3(x)
        x = self.s4(x)

        x = self.pool(x).view(-1, x.shape[1])
        x = self.fc(x)
        return x

    def _make_layer(self, block, inp, oup, depth, image_size):
        layers = nn.ModuleList([])
        for i in range(depth):
            if i == 0:
                layers.append(block(inp, oup, image_size, downsample=True))
            else:
                layers.append(block(oup, oup, image_size))
        return nn.Sequential(*layers)


def coatnet_0():
    num_blocks = [2, 2, 3, 5, 2]            # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoAtNet((224, 224), 3, num_blocks, channels, num_classes=1000)


def coatnet_1():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoAtNet((224, 224), 3, num_blocks, channels, num_classes=1000)


def coatnet_2():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [128, 128, 256, 512, 1026]   # D
    return CoAtNet((224, 224), 3, num_blocks, channels, num_classes=1000)


def coatnet_3():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoAtNet((224, 224), 3, num_blocks, channels, num_classes=1000)


def coatnet_4():
    num_blocks = [2, 2, 12, 28, 2]          # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoAtNet((224, 224), 3, num_blocks, channels, num_classes=1000)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [4]:
#----------------Same as CoAtNet--------------------------
import torch
import torch.nn as nn


from einops import rearrange
from einops.layers.torch import Rearrange

class PreNorm_IGDN(nn.Module):
    def __init__(self, dim, fn, norm):
        super().__init__()
        self.norm = norm(dim,inverse = True)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class SE(nn.Module):
    def __init__(self, inp, oup, expansion=0.25):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(oup, int(inp * expansion), bias=False),
            nn.GELU(),
            nn.Linear(int(inp * expansion), oup, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y
class Attention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

# -----------------Symmetric CoAtNet Modules--------------
class Bicubic_upsampler(nn.Module):
    def __init__(self,scale_factor,mode):
      super(Bicubic_upsampler,self).__init__()
      self.upsampler = nn.functional.interpolate
      self.scale_factor = scale_factor
      self.mode = mode
    def forward(self,x):
      x = self.upsampler(x,scale_factor = self.scale_factor, mode = self.mode)
      return x


class InverseTransformer(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, upsample=False, dropout=0.):
        super().__init__()
        

        self.ih, self.iw = image_size
        self.upsample = upsample
        

        if self.upsample: #We can change the upsampling method at some point.
          #Maybe using bicubic interpolation might be better
          
          self.upsampler = nn.PixelShuffle(2) # 2 is the upsample factor can be a hyperparameter
          inp = int(inp/4) # after upsampling 


          # Not needed since pixelshuffle reduces channels while upsampling, Needed when using bicubic interpolation
          # self.upsampler = Bicubic_upsampler(scale_factor = 2, mode = "bicubic")
          self.proj = nn.ConvTranspose2d(inp, oup, 1, 1, 0, bias=False) 

        hidden_dim = int(inp * 4)
        self.attn = Attention(inp, oup, image_size, heads, dim_head, dropout)
        self.ff = FeedForward(oup, hidden_dim, dropout)

        self.attn = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(inp, self.attn, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

        self.ff = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ff, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

    def forward(self, x):
        if self.upsample:
            x = self.proj(self.upsampler(x)) + self.attn(self.upsampler(x))
        else:
            x = x + self.attn(x)
        x = x + self.ff(x)
        return x

class InverseMBConv(nn.Module):
    def __init__(self, inp, oup, image_size, upsample=False, expansion=4):
        super().__init__()
        self.upsample = upsample
        stride = 1 if self.upsample == False else 2
        output_padding = 0 if self.upsample == False else 1

        if self.upsample:
            self.upsampler = nn.PixelShuffle(2)
            # self.upsampler = Bicubic_upsampler(scale_factor = 2, mode = "bicubic")
            self.proj = nn.Conv2d(int(inp/4), oup, 1, 1, 0, bias=False)

        hidden_dim = int(inp * expansion)
        
          
        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.ConvTranspose2d(inp, hidden_dim, 3, stride,
                          1, groups=inp, bias=False,output_padding = output_padding),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.ConvTranspose2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.ConvTranspose2d(inp, hidden_dim, 1, stride, 0, bias=False,output_padding = output_padding),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.ConvTranspose2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.ConvTranspose2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        
        self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)

    def forward(self, x):
        if self.upsample:
            return self.proj(self.upsampler(x)) + self.conv(x)
        else:
            return x + self.conv(x)

def Inverse_conv_3x3_bn(inp, oup, image_size, upsample=False):
    stride = 1 if upsample == False else 2
    output_padding = 0 if upsample == False else 1
    return nn.Sequential(
        nn.ConvTranspose2d(inp, oup, 3, stride, 1, bias=False,output_padding = output_padding),
        nn.BatchNorm2d(oup),
        nn.Tanh()
    )

In [5]:

class MLP(nn.Module):
  def __init__(self,inp,oup,dropout = 0.):
    super().__init__()
    hidden = int(inp/2)
    self.fc1 = nn.Linear(inp,hidden)
    self.act = nn.GELU()
    self.fc2 = nn.Linear(hidden,oup)
    self.drop = nn.Dropout(dropout)
    self.norm_layer1 = nn.BatchNorm2d(inp)
    self.norm_layer2 = nn.BatchNorm2d(hidden)
  
  def forward(self,x):
#     x = self.norm_layer1(x)
    x = x.permute(0,2,3,1)
    x = self.fc1(x)
    x = self.act(x)
    x = self.drop(x)
#     x = x.permute(0,3,1,2)
#     x = self.norm_layer2(x)
#     x = x.permute(0,2,3,1)
    x = self.fc2(x)
#     x = self.act(x)
    x = self.drop(x)
    x = x.permute(0,3,1,2)
    return x

class Inverse_MLP(nn.Module):
  def __init__(self,inp,oup,dropout = 0.):
    super().__init__()
    hidden = int(inp*2)
    self.fc1 = nn.Linear(inp,hidden)
    self.act = nn.GELU()
    self.fc2 = nn.Linear(hidden,oup)
    self.drop = nn.Dropout(dropout)
    self.norm_layer1 = nn.BatchNorm2d(inp)
    self.norm_layer2 = nn.BatchNorm2d(hidden)  
  def forward(self,x):
#     x = self.norm_layer1(x)
    x = x.permute(0,2,3,1)
    x = self.fc1(x)
    x = self.act(x)
    x = self.drop(x)
#     x = x.permute(0,3,1,2)
#     x = self.norm_layer2(x)
#     x = x.permute(0,2,3,1)
    x = self.fc2(x)
#     x = self.act(x)
    x = self.drop(x)
    x = x.permute(0,3,1,2)
    return x

class MergedAutoEncoder(nn.Module):
  def __init__(self):
    super(MergedAutoEncoder,self).__init__()
    ######### encoder layers #########
    #self.a0 =  conv_3x3_bn(3,12,(16,16),downsample = True)
    self.a0 = self._make_layer_analysis(conv_3x3_bn,3,64,1,(16,16))
    # self.gdn12 = GDN(12)

    #self.a1 = MBConv(12,48,(8,8),downsample = True)
    self.a1 = self._make_layer_analysis(MBConv,64,96,2,(8,8))
    # self.gdn48 = GDN(48)

    #self.a2 = MBConv(48,192,(4,4),downsample = True)
    self.a2 = self._make_layer_analysis(MBConv,96,384,6,(4,4))
    # self.gdn192 = GDN(192)

    #self.a3 = Transformer(192,768,(2,2),downsample=True)
    self.a3 = self._make_layer_analysis(Transformer,384,384,14,(2,2))
    # self.gdn768 = GDN(768)

    #self.a4 = Transformer(768,3072,(1,1),downsample = True)
    self.a4 = self._make_layer_analysis(Transformer,384,192,2,(1,1))
    # self.gdn3072 = GDN(3072)

#     self.compress = MLP(3072,192)

    ######### decoder layers #########

#     self. decompress = Inverse_MLP(192,3072)

    #self.s4 = InverseTransformer(3072,768,(2,2),upsample = True)
    self.s4 = self._make_layer_synthesis(InverseTransformer,192,384,2,(2,2))
    # self.igdn768 = GDN(768,inverse=True)

    #self.s3 = InverseTransformer(768,192,(4,4),upsample = True)
    self.s3 = self._make_layer_synthesis(InverseTransformer,384,384,14,(4,4))
    # self.igdn192 = GDN(192,inverse=True)

    # self.s2 = InverseMBConv(192,48,(8,8),upsample= True)
    self.s2 = self._make_layer_synthesis(InverseMBConv,384,48,6,(8,8))

    # self.igdn48= GDN(48,inverse=True)

    # self.s1 = InverseMBConv(48,12,(16,16),upsample = True)
    self.s1 = self._make_layer_synthesis(InverseMBConv,48,12,2,(16,16))
    # self.igdn12= GDN(12,inverse=True)

    # self.s0 = Inverse_conv_3x3_bn(12,3,(32,32),upsample = True)
    self.s0 = self._make_layer_synthesis(Inverse_conv_3x3_bn,12,3,1,(32,32))
    # self.igdn3= GDN(3,inverse=True)

  def encode(self,x):
    x = self.a0(x)
    x = self.a1(x)
    x = self.a2(x)
    x = self.a3(x)
    x = self.a4(x)
#     x = self.compress(x)
    return x

  def decode(self,x):
#     x = self.decompress(x)
    x = self.s4(x)
    x = self.s3(x)
    x = self.s2(x)
    x = self.s1(x)
    x = self.s0(x)
    return x

  def forward(self,x):
    enc = self.encode(x)
    x_hat = self.decode(enc)
    return x_hat

  def _make_layer_analysis(self,block,inp,oup,depth,image_size):
    layers = nn.ModuleList([])
    for i in range(depth):
      if i == 0:
        layers.append(block(inp,oup,image_size,downsample = True))
      else:
        layers.append(block(oup,oup,image_size))
    return nn.Sequential(*layers)
  
  def _make_layer_synthesis(self,block,inp,oup,depth,image_size):
    layers = nn.ModuleList([])
    for i in range(depth):
      if i == 0:
        layers.append(block(inp,oup,image_size,upsample = True))
      else:
        layers.append(block(oup,oup,image_size))
    return nn.Sequential(*layers)

Main

In [6]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler
from tqdm import tqdm
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(p=0.5),
     transforms.RandomCrop(32,padding=4),
     transforms.RandomInvert(p=0.5),
#      transforms.CenterCrop(1),
     transforms.ToTensor(),
     transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    ])# normalize the image between [0 1]

test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    ])

batch_size = 128
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


model = MergedAutoEncoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr = 1e-3)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer,total_iters=9)

def compute_psnr(img1, img2):
    img1 = img1.astype(np.float64) 
    img2 = img2.astype(np.float64) 
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return "Same Image"
    return 10 * math.log10(1. / mse)

######## Training #########
def train(dataloader,model,loss_fn,optimizer):
  size = len(dataloader.dataset)
  model.train()
  for batch, (X,y) in enumerate(dataloader):
    X,y = X.to(device), y.to(device) # I guess we dont actually need the labels
    pred = model(X)
    loss = criterion(pred,X) # The difference between X and  the prediction by model
    with torch.autograd.set_detect_anomaly(False):
      optimizer.zero_grad()
      loss.backward()
#       torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm = 1.)
      optimizer.step()
      if batch % 100 == 0:
        loss,current = loss.item(), batch * len(X)
        print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader,model,loss_fn):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  model.eval()
  test_loss, correct = 0,0
  psnr = 0
  with torch.no_grad():
    for X,y in dataloader:
      X,y = X.to(device), y.to(device)
      pred = model(X)
      psnr += compute_psnr(pred.cpu().numpy(),X.cpu().numpy())
      test_loss += criterion(pred,X).item()
    print(f"PSNR: {psnr/num_batches}")
    print(f"Test Loss: {test_loss/num_batches}")

epochs = 100
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(trainloader, model, criterion, optimizer)
    test(testloader, model, criterion)
    if (t+1) % 10 == 0:
      scheduler.step()
      print(f"The last LR is {scheduler.get_last_lr()[0]}")
print("Done!")

In [None]:
test = 0 , None
print(test)

In [None]:
#Save model
torch.save(model.state_dict(), "model-data_augmented_tanh-100e-coat2.pth")
print("Saved PyTorch Model State to model-data_augmented_tanh-100e-coat2.pth")

In [7]:
import matplotlib.pyplot as plt

In [8]:
for test_images, test_labels in testloader:  
    sample_image = test_images[0]    
    sample_label = test_labels[0]

In [9]:
plt.imshow(sample_image.reshape(3,32,32).permute(1,2,0))

In [10]:
with torch.no_grad(): 
  prediction = model(sample_image.unsqueeze(0).to(device))
  plt.imshow(prediction.cpu().reshape(3,32,32).permute(1,2,0))