
# Transfer Learning with ResNet & VGG Encoders (Keras)

Objectives:

- **Problem 1 (Code review)**: Show what changes vs. previous U-Net, and how transfer learning is used.
- **Problem 2 (Rewrite)**: Provide both **ResNet** and **VGG** encoder variants of U-Net.
- **Problem 3 (Train & compare)**: Train/validate both and report metrics + qualitative results.

**Versions targeted:** TensorFlow 1.14.x + Keras 2.3.x (works with TF 2.x via `tf.keras` fallback too).

In [None]:
# Environment checks (works with Keras 2.3.x or tf.keras)
import os, sys, random, math, gc, time, json, glob, pathlib, shutil
import numpy as np
import PIL.Image as Image
import tensorflow as tf

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
try:
    import keras
    from keras import backend as K
    from keras.layers import (Input, Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D,
                              BatchNormalization, Activation, Dropout, Concatenate)
    from keras.models import Model
    from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, CSVLogger
    from keras.optimizers import Adam
    from keras.preprocessing.image import ImageDataGenerator
    from keras.applications import ResNet50, VGG16
    USING_TF_KERAS = False
except Exception as e:
    # Fallback to tf.keras
    from tensorflow import keras
    from tensorflow.keras import backend as K
    from tensorflow.keras.layers import (Input, Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D,
                              BatchNormalization, Activation, Dropout, Concatenate)
    from tensorflow.keras.models import Model
    from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, CSVLogger
    from tensorflow.keras.optimizers import Adam
    from tensorflow.keras.preprocessing.image import ImageDataGenerator
    from tensorflow.keras.applications import ResNet50, VGG16
    USING_TF_KERAS = True

print("Keras backend:", "tf.keras" if USING_TF_KERAS else "keras", " — TF version:", tf.__version__)

# GPU memory growth (optional)
try:
    gpus = tf.config.experimental.list_physical_devices('GPU')
    for g in gpus:
        tf.config.experimental.set_memory_growth(g, True)
    print("GPUs:", gpus)
except Exception as _:
    pass


## Configuration
- `DATA_DIR` should contain `train/images` and `train/masks` folders with 101×101 PNGs (TGS Salt).
- Images are padded to 128×128 and repeated to 3 channels for ImageNet backbones.


In [1]:
# Paths & hyperparameters
DATA_DIR = os.environ.get("TGS_DIR", "/content/tgs")  # change if needed
IMG_SIZE = 128
BATCH_SIZE = 16
EPOCHS = 15
VAL_SPLIT = 0.2
FREEZE_ENCODER = True  # freeze backbone at start

# For quick smoke runs if dataset is missing:
FALLBACK_SYNTHETIC = True

NameError: name 'os' is not defined


## Data loading utilities
- Loads image/mask PNGs.
- Pads to 128×128 (constant=0).
- Converts grayscale to 3-channels by repeat for encoders.
- Simple on-the-fly augmentation (flips) applied **consistently** to image/mask.


In [None]:
from sklearn.model_selection import train_test_split

def _pad_to_square(img_np, target=128, constant=0):
    h, w = img_np.shape[:2]
    pad_h = target - h
    pad_w = target - w
    assert pad_h >= 0 and pad_w >= 0
    if img_np.ndim == 2:
        img_np = np.pad(img_np, ((0,pad_h),(0,pad_w)), mode='constant', constant_values=constant)
    else:
        img_np = np.pad(img_np, ((0,pad_h),(0,pad_w),(0,0)), mode='constant', constant_values=constant)
    return img_np

def load_tgs_pair(img_path, mask_path):
    img = Image.open(img_path).convert("L")
    msk = Image.open(mask_path).convert("L")
    img = np.array(img, dtype=np.uint8)
    msk = np.array(msk, dtype=np.uint8)
    # pad to 128×128
    img = _pad_to_square(img, IMG_SIZE, 0)
    msk = _pad_to_square(msk, IMG_SIZE, 0)
    # scale to [0,1]
    img = img.astype(np.float32) / 255.0
    msk = (msk > 127).astype(np.float32)  # binary
    # 3-channel for encoders
    img3 = np.repeat(img[...,None], 3, axis=-1)
    return img3, msk[...,None]

def load_tgs_dataset(data_dir):
    img_dir = os.path.join(data_dir, "train", "images")
    mask_dir = os.path.join(data_dir, "train", "masks")
    img_paths = sorted(glob.glob(os.path.join(img_dir, "*.png")))
    mask_paths = [os.path.join(mask_dir, os.path.basename(p)) for p in img_paths]
    X, Y = [], []
    for ip, mp in zip(img_paths, mask_paths):
        if not os.path.exists(mp):
            continue
        x, y = load_tgs_pair(ip, mp)
        X.append(x); Y.append(y)
    X = np.stack(X, 0); Y = np.stack(Y, 0)
    return X, Y

def get_data_or_synthetic():
    if os.path.exists(os.path.join(DATA_DIR, "train", "images")):
        X, Y = load_tgs_dataset(DATA_DIR)
        print("Loaded TGS:", X.shape, Y.shape)
    else:
        if not FALLBACK_SYNTHETIC:
            raise FileNotFoundError("TGS dataset not found. Set DATA_DIR correctly.")
        # Tiny synthetic set (blobs)
        n = 64
        X = np.zeros((n, IMG_SIZE, IMG_SIZE, 3), dtype=np.float32)
        Y = np.zeros((n, IMG_SIZE, IMG_SIZE, 1), dtype=np.float32)
        for i in range(n):
            cx, cy = np.random.randint(20, IMG_SIZE-20, size=2)
            rr = np.random.randint(8, 18)
            yy, xx = np.ogrid[:IMG_SIZE, :IMG_SIZE]
            mask = (xx-cx)**2 + (yy-cy)**2 <= rr**2
            Y[i, mask, 0] = 1.0
            X[i,...] = np.repeat((Y[i,...]*0.7 + np.random.rand(IMG_SIZE,IMG_SIZE,1)*0.3), 3, axis=-1)
        print("Using synthetic data:", X.shape, Y.shape)
    X_tr, X_va, Y_tr, Y_va = train_test_split(X, Y, test_size=VAL_SPLIT, random_state=SEED)
    return X_tr, Y_tr, X_va, Y_va

X_tr, Y_tr, X_va, Y_va = get_data_or_synthetic()


## Losses & metrics
- **Dice coefficient** & **IoU** for monitoring.
- Combined **BCE + Dice loss** often works well for salt segmentation.


In [None]:
def dice_coef(y_true, y_pred, smooth=1.0):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def iou_coef(y_true, y_pred, smooth=1.0):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_true_f) + K.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

def bce_dice_loss(y_true, y_pred):
    return keras.losses.binary_crossentropy(y_true, y_pred) + (1.0 - dice_coef(y_true, y_pred))


## Model builders
We compose a U-Net-style decoder on top of either **ResNet50** or **VGG16** backbones (`include_top=False`, `weights='imagenet'`).  
Encoder can be frozen for initial epochs, then optionally unfrozen.


In [None]:
def conv_block(x, filters, k=3, act='relu', bn=True, dp=0.0):
    x = Conv2D(filters, k, padding='same')(x)
    if bn: x = BatchNormalization()(x)
    x = Activation(act)(x)
    if dp>0: x = Dropout(dp)(x)
    x = Conv2D(filters, k, padding='same')(x)
    if bn: x = BatchNormalization()(x)
    x = Activation(act)(x)
    return x

def up_block(x, skip, filters):
    x = UpSampling2D((2,2))(x)
    x = Concatenate()([x, skip])
    x = conv_block(x, filters)
    return x

def build_unet_resnet50(input_shape=(128,128,3), freeze=True):
    inputs = Input(shape=input_shape)
    # ResNet50 backbone
    backbone = ResNet50(include_top=False, weights='imagenet', input_tensor=inputs)
    # Feature maps for skips
    skips = {
        "c1": backbone.get_layer("conv1_relu").output,          # 64, /2
        "c2": backbone.get_layer("conv2_block3_out").output,    # 256, /4
        "c3": backbone.get_layer("conv3_block4_out").output,    # 512, /8
        "c4": backbone.get_layer("conv4_block6_out").output,    # 1024, /16
    }
    x = backbone.get_layer("conv5_block3_out").output           # 2048, /32
    # Decoder
    x = up_block(x, skips["c4"], 512)
    x = up_block(x, skips["c3"], 256)
    x = up_block(x, skips["c2"], 128)
    x = up_block(x, skips["c1"], 64)
    x = UpSampling2D((2,2))(x)  # back to /1
    outputs = Conv2D(1, 1, activation="sigmoid")(x)
    model = Model(inputs, outputs, name="UNet_ResNet50")
    if freeze:
        for l in backbone.layers:
            l.trainable = False
    return model

def build_unet_vgg16(input_shape=(128,128,3), freeze=True):
    inputs = Input(shape=input_shape)
    backbone = VGG16(include_top=False, weights='imagenet', input_tensor=inputs)
    # VGG blocks
    b1 = backbone.get_layer("block1_conv2").output   # /2, 64
    b2 = backbone.get_layer("block2_conv2").output   # /4, 128
    b3 = backbone.get_layer("block3_conv3").output   # /8, 256
    b4 = backbone.get_layer("block4_conv3").output   # /16, 512
    b5 = backbone.get_layer("block5_conv3").output   # /32, 512
    # Decoder
    x = up_block(b5, b4, 512)
    x = up_block(x, b3, 256)
    x = up_block(x, b2, 128)
    x = up_block(x, b1, 64)
    x = UpSampling2D((2,2))(x)  # back to /1 from /2
    outputs = Conv2D(1, 1, activation="sigmoid")(x)
    model = Model(inputs, outputs, name="UNet_VGG16")
    if freeze:
        for l in backbone.layers:
            l.trainable = False
    return model


## Training helpers
We compile with **Adam**, `bce_dice_loss`, and track **Dice** & **IoU**.  
Two runs can be performed back-to-back: ResNet50-encoder and VGG16-encoder.


In [None]:
def compile_model(m, lr=1e-3):
    m.compile(optimizer=Adam(lr),
              loss=bce_dice_loss,
              metrics=[dice_coef, iou_coef])
    return m

def train_model(model, X_tr, Y_tr, X_va, Y_va, tag, epochs=EPOCHS):
    ckpt_path = f"/mnt/data/tgs_{tag}_best.h5"
    callbacks = [
        ModelCheckpoint(ckpt_path, monitor="val_dice_coef", mode="max",
                        save_best_only=True, save_weights_only=True, verbose=1),
        ReduceLROnPlateau(monitor="val_dice_coef", factor=0.5, patience=3,
                          mode="max", verbose=1, min_lr=1e-6),
        EarlyStopping(monitor="val_dice_coef", patience=6, mode="max",
                      restore_best_weights=True, verbose=1),
        CSVLogger(f"/mnt/data/tgs_{tag}_log.csv", append=False)
    ]
    hist = model.fit(X_tr, Y_tr, validation_data=(X_va, Y_va),
                     epochs=epochs, batch_size=BATCH_SIZE, callbacks=callbacks, verbose=1)
    # Load best
    if os.path.exists(ckpt_path):
        model.load_weights(ckpt_path)
    val_metrics = model.evaluate(X_va, Y_va, verbose=0)
    print(f"[{tag}] Val metrics: " + ", ".join(f"{n}={v:.4f}" for n,v in zip(model.metrics_names, val_metrics)))
    return model, hist


## Run: ResNet50 encoder
Freeze backbone for stability, then you can unfreeze and fine-tune for a few epochs.


In [None]:
resnet_model = compile_model(build_unet_resnet50((IMG_SIZE,IMG_SIZE,3), freeze=FREEZE_ENCODER), lr=1e-3)
resnet_model, resnet_hist = train_model(resnet_model, X_tr, Y_tr, X_va, Y_va, tag="resnet50", epochs=EPOCHS)

# Optional fine-tuning (unfreeze):
for l in resnet_model.layers:
    l.trainable = True
resnet_model = compile_model(resnet_model, lr=1e-4)
resnet_model, resnet_hist2 = train_model(resnet_model, X_tr, Y_tr, X_va, Y_va, tag="resnet50_ft", epochs=5)


## Run: VGG16 encoder
Same procedure for VGG16.


In [None]:
vgg_model = compile_model(build_unet_vgg16((IMG_SIZE,IMG_SIZE,3), freeze=FREEZE_ENCODER), lr=1e-3)
vgg_model, vgg_hist = train_model(vgg_model, X_tr, Y_tr, X_va, Y_va, tag="vgg16", epochs=EPOCHS)

# Optional fine-tuning
for l in vgg_model.layers:
    l.trainable = True
vgg_model = compile_model(vgg_model, lr=1e-4)
vgg_model, vgg_hist2 = train_model(vgg_model, X_tr, Y_tr, X_va, Y_va, tag="vgg16_ft", epochs=5)


## Learning curves


In [None]:
import matplotlib.pyplot as plt
def plot_history(hist, label):
    h = hist.history
    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.plot(h['loss'], label='train')
    plt.plot(h['val_loss'], label='val')
    plt.title(f'Loss — {label}')
    plt.legend()
    plt.subplot(1,2,2)
    key = 'dice_coef' if 'dice_coef' in h else list(h.keys())[1]
    plt.plot(h[key], label='train')
    plt.plot(h['val_'+key], label='val')
    plt.title(f'Dice — {label}')
    plt.legend()
    plt.show()

plot_history(resnet_hist, "ResNet50")
plot_history(vgg_hist, "VGG16")


## Qualitative results


In [None]:
def visualize_preds(model, X, Y, n=6):
    idx = np.random.choice(len(X), n, replace=False)
    preds = model.predict(X[idx], batch_size=n)
    plt.figure(figsize=(12, 2*n))
    for i, k in enumerate(idx):
        plt.subplot(n,3,3*i+1); plt.imshow(X[k].squeeze(), cmap='gray'); plt.title("Image"); plt.axis('off')
        plt.subplot(n,3,3*i+2); plt.imshow(Y[k].squeeze(), cmap='gray'); plt.title("Mask"); plt.axis('off')
        plt.subplot(n,3,3*i+3); plt.imshow((preds[i].squeeze()>0.5).astype(np.float32), cmap='gray'); plt.title("Pred"); plt.axis('off')
    plt.tight_layout(); plt.show()

print("ResNet50 samples:"); visualize_preds(resnet_model, X_va, Y_va, n=4)
print("VGG16 samples:"); visualize_preds(vgg_model, X_va, Y_va, n=4)