
# Image Segmentation with U‑Net

In [None]:
# Environment & GPU check
import os, sys, math, glob, json, shutil, zipfile, random
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Reproducibility
np.random.seed(42)
random.seed(42)

try:
    import tensorflow as tf
    print("TensorFlow:", tf.__version__)
    gpus = getattr(tf.config, "list_physical_devices", lambda *_: [])("GPU")
    print("GPUs:", gpus)
except Exception as e:
    print("TensorFlow not found (you can install TF 2.x).", e)


## 1) Dataset setup (TGS Salt)

You have two options:

**A. Kaggle API** (recommended)
1. Upload your `kaggle.json` (API token) to this runtime, or place it under `~/.kaggle/kaggle.json`.
2. Run the cell below to download and unzip the competition data.

**B. Local/Pre‑downloaded**
- If you already have the dataset, set `DATA_DIR` below to the root containing `train/` and `test/` folders
  with `images/` and `masks/` inside `train/` (typical layout).

The expected structure after unzip (typical community mirrors):  
```
data/
  train/
    images/  (101x101 PNGs)
    masks/   (101x101 PNGs)
  test/
    images/
  depths.csv         (optional metadata)
```


In [3]:
# Choose where to place data
Path = "content/sample_data"
BASE = Path.cwd() / "data_tgs"
BASE.mkdir(parents=True, exist_ok=True)

USE_KAGGLE = False  # flip to True to use Kaggle API in this notebook

if USE_KAGGLE:
    # --- Kaggle download (requires kaggle.json) ---
    # In Colab: from google.colab import files; files.upload()  # upload kaggle.json first
    kaggle_dir = Path.home() / ".kaggle"
    kaggle_dir.mkdir(exist_ok=True)
    if Path("kaggle.json").exists():
        shutil.move("kaggle.json", kaggle_dir / "kaggle.json")
        os.chmod(kaggle_dir / "kaggle.json", 0o600)
    # Install kaggle CLI if needed
    try:
        import kaggle  # noqa: F401
    except:
        !pip -q install kaggle
    # This competition is archived; mirrors exist.
    # If the official download is unavailable, place your zips under BASE and skip this.
    print("Attempting kaggle download (may require acceptance):")
    !kaggle competitions download -c tgs-salt-identification-challenge -p "{BASE}"
    # Unzip files we care about (some mirrors name files differently)
    for z in BASE.glob("*.zip"):
        with zipfile.ZipFile(z, "r") as f:
            f.extractall(BASE)
else:
    print("Set USE_KAGGLE=True to download via Kaggle API, or copy your dataset under:", BASE)

AttributeError: 'str' object has no attribute 'cwd'


## 2) Load, preprocess, and visualize

- The original images are **101×101** grayscale.  
- We will **pad to 128×128** for convenience (power‑of‑two makes UNet easier).  
- Scale pixel values to `[0,1]`.  
- Split into **train/validation**.  
- Provide both a **tf.data** pipeline and a simple **NumPy** loader to keep it flexible.


In [None]:
from PIL import Image

def load_pairs(root):
    """Return lists of (image_path, mask_path)."""
    root = Path(root)
    img_dir = root / "train" / "images"
    msk_dir = root / "train" / "masks"
    if not img_dir.exists() or not msk_dir.exists():
        raise FileNotFoundError(f"Expected {img_dir} and {msk_dir}")
    img_paths = sorted(img_dir.glob("*.png"))
    pairs = []
    for p in img_paths:
        m = msk_dir / p.name
        if m.exists():
            pairs.append((p, m))
    return pairs

def pad_to_square(img, size=128):
    """Pad a 2D array (H,W) with zeros to (size,size) centered."""
    h, w = img.shape
    out = np.zeros((size, size), dtype=img.dtype)
    y0 = (size - h)//2
    x0 = (size - w)//2
    out[y0:y0+h, x0:x0+w] = img
    return out

def load_arrays(pairs, size=128):
    X, Y = [], []
    for ip, mp in pairs:
        img = np.array(Image.open(ip))          # shape (101,101)
        msk = np.array(Image.open(mp))          # shape (101,101)
        img = pad_to_square(img, size)
        msk = pad_to_square(msk, size)
        X.append(img[None, ...])    # add channel dim -> (1,H,W)
        Y.append((msk>0).astype(np.float32)[None, ...])
    X = np.stack(X, axis=0).astype(np.float32) / 255.0  # (N,1,H,W)
    Y = np.stack(Y, axis=0).astype(np.float32)          # (N,1,H,W)
    return X, Y

DATA_DIR = BASE   # set to your data root if different
pairs = []
try:
    pairs = load_pairs(DATA_DIR)
    print("Found pairs:", len(pairs))
except Exception as e:
    print("Data not found yet:", e)

# quick peek
if pairs:
    Xnp, Ynp = load_arrays(pairs[:8], size=128)
    print("Batch shapes:", Xnp.shape, Ynp.shape)
    # visualize a few
    fig, axes = plt.subplots(2, 4, figsize=(10,5))
    for i, ax in enumerate(axes.ravel()):
        ax.imshow(Xnp[i,0])
        ax.set_title(f"mask sum={int(Ynp[i,0].sum())}")
        ax.axis("off")
    plt.show()

In [None]:
# Split
def train_val_split(pairs, val_ratio=0.2, seed=42):
    r = random.Random(seed)
    pairs = pairs.copy()
    r.shuffle(pairs)
    n = len(pairs)
    nv = int(n*val_ratio)
    return pairs[nv:], pairs[:nv]

train_pairs, val_pairs = train_val_split(pairs, 0.2, 42) if pairs else ([], [])
print("Train/Val sizes:", len(train_pairs), len(val_pairs))

# tf.data helpers (optional; works when TF is available)
def make_tf_dataset(pairs, size=128, batch=16, shuffle=True):
    import tensorflow as tf
    ipaths = [str(p[0]) for p in pairs]
    mpaths = [str(p[1]) for p in pairs]
    ds = tf.data.Dataset.from_tensor_slices((ipaths, mpaths))
    if shuffle:
        ds = ds.shuffle(len(ipaths), seed=42, reshuffle_each_iteration=True)
    def _load(ip, mp):
        img = tf.io.read_file(ip)
        img = tf.io.decode_png(img, channels=1)
        img = tf.image.convert_image_dtype(img, tf.float32)  # [0,1]
        img = tf.squeeze(img, axis=-1)                       # (H,W)
        msk = tf.io.read_file(mp)
        msk = tf.io.decode_png(msk, channels=1)
        msk = tf.cast(msk>0, tf.float32)
        msk = tf.squeeze(msk, axis=-1)
        img = tf.numpy_function(lambda a: pad_to_square(a, size), [img], tf.float32)
        msk = tf.numpy_function(lambda a: pad_to_square(a, size), [msk], tf.float32)
        img = tf.reshape(img, [size, size])
        msk = tf.reshape(msk, [size, size])
        img = tf.expand_dims(img, -1)  # (H,W,1) for tf.keras
        msk = tf.expand_dims(msk, -1)
        return img, msk
    ds = ds.map(_load, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch).prefetch(tf.data.AUTOTUNE)
    return ds

tf_train = make_tf_dataset(train_pairs, size=128, batch=16) if train_pairs else None
tf_val   = make_tf_dataset(val_pairs,   size=128, batch=16, shuffle=False) if val_pairs else None


## 3) Loss & metrics

We use **binary cross‑entropy** and add **Dice** and **IoU** as monitoring metrics.


In [None]:
import tensorflow as tf

def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    inter  = tf.reduce_sum(y_true * y_pred, axis=[1,2,3])
    denom  = tf.reduce_sum(y_true + y_pred, axis=[1,2,3])
    dice   = (2.0*inter + smooth) / (denom + smooth)
    return tf.reduce_mean(dice)

def iou_coef(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    inter  = tf.reduce_sum(y_true * y_pred, axis=[1,2,3])
    union  = tf.reduce_sum(y_true + y_pred - y_true*y_pred, axis=[1,2,3])
    iou    = (inter + smooth) / (union + smooth)
    return tf.reduce_mean(iou)


## 4) Model


In [None]:
USE_ZHIXUHAO = True  # set False to force local model

def get_unet(input_shape=(128,128,1)):
    """Local compact U‑Net in case the public repo isn't available."""
    from tensorflow.keras import layers, models
    inputs = layers.Input(shape=input_shape)

    def conv_block(x, f):
        x = layers.Conv2D(f, 3, padding="same", activation="relu")(x)
        x = layers.Conv2D(f, 3, padding="same", activation="relu")(x)
        return x

    c1 = conv_block(inputs, 32); p1 = layers.MaxPooling2D(2)(c1)
    c2 = conv_block(p1,     64); p2 = layers.MaxPooling2D(2)(c2)
    c3 = conv_block(p2,    128); p3 = layers.MaxPooling2D(2)(c3)
    c4 = conv_block(p3,    256); p4 = layers.MaxPooling2D(2)(c4)

    bn = conv_block(p4, 512)

    u4 = layers.UpSampling2D()(bn); u4 = layers.Concatenate()([u4, c4]); c5 = conv_block(u4, 256)
    u3 = layers.UpSampling2D()(c5); u3 = layers.Concatenate()([u3, c3]); c6 = conv_block(u3, 128)
    u2 = layers.UpSampling2D()(c6); u2 = layers.Concatenate()([u2, c2]); c7 = conv_block(u2, 64)
    u1 = layers.UpSampling2D()(c7); u1 = layers.Concatenate()([u1, c1]); c8 = conv_block(u1, 32)

    outputs = layers.Conv2D(1, 1, activation="sigmoid")(c8)
    return models.Model(inputs, outputs, name="UNet_local")

model = None
if USE_ZHIXUHAO:
    try:
        # Try to clone/import the repo
        import importlib.util, sys, subprocess, types
        if not Path("unet_repo").exists():
            !git clone --depth 1 https://github.com/zhixuhao/unet.git unet_repo
        sys.path.insert(0, str(Path("unet_repo")))
        from model import unet
        model = unet(input_size=(128,128,1))
        print("Using zhixuhao/unet model")
    except Exception as e:
        print("Falling back to local U-Net:", e)
        model = get_unet()
else:
    model = get_unet()

model.summary()


## 5) Train

We compile with `binary_crossentropy` and monitor Dice/IoU.  
Early stopping and model checkpointing help prevent overfitting.


In [None]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

model.compile(optimizer=Adam(1e-3), loss="binary_crossentropy",
              metrics=["accuracy", dice_coef, iou_coef])

ckpt_path = "unet_tgs_best.h5"
callbacks = [
    EarlyStopping(monitor="val_iou_coef", patience=10, mode="max", restore_best_weights=True),
    ModelCheckpoint(ckpt_path, monitor="val_iou_coef", mode="max", save_best_only=True, verbose=1),
    ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, verbose=1)
]

EPOCHS = 40
if train_pairs:
    history = model.fit(tf_train,
                        validation_data=tf_val,
                        epochs=EPOCHS,
                        callbacks=callbacks)
else:
    print("No data found yet. Once data is available, re-run this cell to train.")


## 6) Learning curves


In [None]:
if 'history' in locals():
    h = history.history
    fig, ax = plt.subplots(1,2, figsize=(12,4))
    ax[0].plot(h["loss"]); ax[0].plot(h.get("val_loss", [])); ax[0].set_title("Loss"); ax[0].set_xlabel("epoch")
    ax[1].plot(h.get("iou_coef", [])); ax[1].plot(h.get("val_iou_coef", [])); ax[1].set_title("IoU"); ax[1].set_xlabel("epoch")
    ax[0].legend(["train","val"]); ax[1].legend(["train","val"])
    plt.show()
else:
    print("Train first to see curves.")


## 7) Inference & qualitative results


In [None]:
def predict_and_plot(n=6, th=0.5):
    if not val_pairs:
        print("No validation data to show.")
        return
    samp = random.sample(val_pairs, min(n, len(val_pairs)))
    Xb, Yb = load_arrays(samp, size=128)
    preds = model.predict(np.transpose(Xb, (0,2,3,1)))  # (N,128,128,1)
    preds = preds[...,0]
    fig, axes = plt.subplots(3, len(samp), figsize=(3*len(samp), 8))
    for i,(ip,mp) in enumerate(samp):
        axes[0,i].imshow(Xb[i,0]); axes[0,i].set_title(Path(ip).name); axes[0,i].axis("off")
        axes[1,i].imshow(Yb[i,0]); axes[1,i].set_title("GT mask"); axes[1,i].axis("off")
        axes[2,i].imshow((preds[i]>th).astype(np.float32)); axes[2,i].set_title("Pred>th"); axes[2,i].axis("off")
    plt.show()

# Run if trained
if pairs:
    try:
        predict_and_plot(6, th=0.5)
    except Exception as e:
        print("Run training first. Error:", e)
else:
    print("Add data first to demo predictions.")

## Appendix - Utilities

- Lightweight output‑size helper for conv/pool strides (handy when changing input sizes)
- Simple RLE encoder/decoder stubs if you want to submit to Kaggle (optional; fill as needed)

In [None]:
def conv2d_out(H, W, kH=3, kW=3, stride=1, pad=0):
    Hout = (H + 2*pad - kH)//stride + 1
    Wout = (W + 2*pad - kW)//stride + 1
    return Hout, Wout

print("Example conv2d_out(128,128, k=3, s=1, p=0) ->", conv2d_out(128,128))