<a href="https://colab.research.google.com/github/AndreH2002/Photo_cropping_and_AI_SuperResolution/blob/main/ESRGAN_MODEL_MOBILE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Initial setup

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models

In [2]:
#ConvBlock class
class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size=3, stride=1, padding="same", activation="leakyrelu"):
        super().__init__()
        self.conv = layers.Conv2D(filters, kernel_size, strides=stride, padding=padding)

        if activation == "leakyrelu":
            self.activation = layers.LeakyReLU(0.2)
        elif activation == "relu":
            self.activation = layers.ReLU()
        elif activation == "prelu":
            self.activation = layers.PReLU(shared_axes=[1, 2])
        elif activation is None:
            self.activation = None
        else:
            raise ValueError(f"Unknown activation: {activation}")

    def call(self, x):
        x = self.conv(x)
        if self.activation:
            x = self.activation(x)
        return x

In [None]:
#Residual Dense Block
class ResidualDenseBlock(tf.keras.layers.Layer):
  def __init__(self, filters=64, growth_channels = 32, num_layers=5):
    super().__init__()
    self.blocks = []
    for i in range(num_layers):
      out_channels = growth_channels if i < num_layers - 1 else filters
      self.blocks.append(ConvBlock(out_channels, activation ="leakyrelu"))

  def call(self, x):
    inputs = x
    concat_feat = x
    for block in self.blocks:
      out = block(concat_feat)
      concat_feat = tf.concat([concat_feat, out], axis=-1)
    return inputs + 0.2 * out

In [3]:
#Residual-In-Residual Dense Block
class RRDB(layers.Layer):
  def __init__(self, filters=64, growth_channels = 32):
    super().__init__()
    self.rdb1 = ResidualDenseBlock(filters, growth_channels)
    self.rdb2 = ResidualDenseBlock(filters, growth_channels)
    self.rdb3 = ResidualDenseBlock(filters, growth_channels)

  def call(self, x):
    inputs = x
    out = self.rdb1(x)
    out = self.rdb2(x)
    out = self.rdb3(x)
    return inputs + 0.2 * out

In [None]:
class UpsampleBlock(layers.Layer):
  def __init_(self, filters, scale=2):
    super().__init__()
    self.conv = ConvBlock(filters *
     (scale ** 2), kernel_size=3, activation =None)
    self.scale = scale

  def call(self, x):
    x = self.conv(x)
    #rearrange channels into spatial resolution
    x = tf.nn.depth_to_space(x, block_size = self.scale)
    x = tf.nn.leaky_relu(x, alpha=0.2)
    return x



In [4]:
#Generator
class Generator(layers.layer):
  def __init__(self, num_rrdb=23, filters=64, growth_channels=32, scale = 4):
    super().__init__()
    self.conv_first = ConvBlock(filters, kernel_size=3, activation=None)

    #RRDB trunk
    self.rrdb_blocks = [RRDB(filters, growth_channels) for _ in range(num_rrdb)]
    self.trunk_conv = ConvBlock(filters, kernel_size=3, activation=None)

    #Upsampling
    upsample_blocks = []
    for _ in range(scale // 2):
      upsample_blocks.append(UpsampleBlock(filters, scale=2))
    self.upsample_blocks = upsample_blocks

    self.hr_conv = ConvBlock(filters, kernel_size=3, activation="leakyrelu")
    self.conv_last = ConvBlock(3, kernel_size =3, activation = None) #output RGB

  def call(self, x):
    features = self.conv_first(x)

    trunk = features
    for rrdb in self.rrdb_blocks:
      trunk = rrdb(trunk)
    trunk = self.trunk_conv(trunk)

    features = features + trunk #global residual

    #Upsampling
    for up in self.upsample_blocks:
      features = up(features)

    features = self.hr_conv(features)
    out = self.conv_last(features)
    return out



#Training Stage 1/just generator

> Add blockquote



In [None]:
#1.


#Prepare paired data

#Sample an HR image.

#Randomly crop an HR patch of size hr_size (e.g., 512×512).

#Downscale it by the scale factor (e.g., ×4) to get LR patch lr_size = hr_size/scale (128×128).

#(Optional) Add real-world degradations (blur, noise, JPEG) to mimic old iPhone photos.

#Apply random flips/rotations consistently to LR/HR.

#Normalize both LR and HR to the same range (e.g., [-1,1] or [0,1]).


In [None]:
#2.

#Forward pass (Generator)

#Feed LR → Generator → get SR.

In [None]:
#3.

#Compute losses

#Content loss (L1): L1(SR, HR).

#Perceptual loss (optional but recommended):

#Pass SR and HR through a fixed VGG feature extractor.

#Compute L1/L2 in one or more feature layers (e.g., conv3_4, conv5_4).

#(Optional) Regularizers: TV loss, edge-aware loss, color consistency.

In [None]:
#4.

#Aggregate loss

#L_total = λc * L1 + λp * L_perc (+ λtv * L_tv …)

#Choose small λp at first; tune empirically.

In [None]:
#5.

#Backprop + update

#Compute gradients w.r.t. generator params.

#(Optional) Gradient clipping (e.g., global norm).

#Optimizer step (e.g., Adam).

#(Optional) EMA update of generator weights for more stable checkpoints.

In [None]:
#6.

#Log + schedule

#Track running loss; periodically compute PSNR/SSIM on a validation set.

#Apply LR schedule (e.g., cosine, step decay, or ReduceLROnPlateau).

In [None]:
#7.

#Checkpoint

#Save best generator by val metric (PSNR/SSIM) and latest checkpoint.

In [None]:
#8.

#Repeat for epochs

#Stop when content/perceptual losses plateau and visuals look clean (no GAN yet).

#Training stage 2/adding discriminator

In [None]:

'''

B) Adversarial Fine-Tuning (Full ESRGAN)

Goal: add realistic texture with a discriminator while preserving fidelity from phase A.

Initialize

Load the best pretrained Generator from Phase A.

Initialize Discriminator (patch-based).

Keep VGG feature extractor frozen.

Prepare paired data (same as A)

HR crop → LR via degradation pipeline; augment; normalize.

Update Discriminator (D)

Forward:

SR_detached = G(LR) (stop gradient).

D_real = D(HR), D_fake = D(SR_detached).

Loss (choose one):

LSGAN: (D_real - 1)^2 + (D_fake)^2
or

BCE: -log(σ(D_real)) - log(1 - σ(D_fake))
or

Relativistic average (ESRGAN paper), if you prefer.

(Optional) Regularization:

R1 gradient penalty on real images, or spectral norm in D.

Backprop + update D.

Update Generator (G)

Forward (fresh SR): SR = G(LR).

Loss components:

Content (L1): L1(SR, HR) (keeps identity/structure).

Perceptual: VGG feature loss between SR and HR (sharpness).

Adversarial (G-side):

For LSGAN: (D(SR) - 1)^2

For BCE: -log(σ(D(SR)))

(Relativistic variant if using it in D.)

(Optional) Feature matching: L1 between intermediate D features for SR vs HR (stabilizes textures).

(Optional) TV/edge regularizers to control noise.

Aggregate:

L_G = λc*L1 + λp*L_perc + λadv*L_adv (+ λfm*L_featmatch + …)

Use small λadv to avoid over-texturing faces.

Backprop + update G.

(Optional) EMA update of G.

Stabilization tactics (each iteration or periodically)

Alternate k steps of D per 1 step of G (e.g., k=1).

Use mixed precision for speed; keep loss scaling safe.

Monitor D/G losses; if D becomes too strong, reduce k or apply dropout/augmentations in D.

Validation + early stopping

On a val set, compute PSNR/SSIM and do a visual panel review (textures vs artifacts).

Track identity preservation on faces (manually or with a face embedding distance if available).

Early stop when textures improve without identity drift.

Checkpointing

Save both G and D regularly.

Also save EMA-smoothed G—often best for inference.

Export

Freeze the final/EMA Generator.

(Optional) Run a calibration pass for INT8 post-training quantization using a few hundred LR samples.

Convert to TFLite (FP16 or INT8) for mobile deployment.
'''