**WM811K - Autoencoder → Classifier**
(using C++-processed data: wafer + dist_to_center + bad_mask)

What this script does:
1) Builds datasets from the C++ outputs:
   - processed_images/<label>/<file>.png
   - processed_aux/dist/<label>/<file>.png
   - processed_masks/bad/<label>/<file>.png
   - converted_images/metadata.csv
2) Pretrains an autoencoder on ALL images (labeled + unlabeled)
3) Reuses the encoder for a classifier (9 classes), trained on labeled data
4) Evaluates with macro-F1; saves SavedModels;

In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

In [None]:
print("TensorFlow:", tf.__version__)
DEVICE = 'GPU' if tf.config.list_physical_devices('GPU') else 'CPU'
print("Device:", DEVICE)

#Config

In [None]:
IMG_SIZE   = 64   # must match the C++ preprocessor
BATCH_AE   = 512  # AE pretrain batch size
BATCH_CLF  = 512  # classifier batch size
EPOCHS_AE  = 3
EPOCHS_CLF = 3
LIMIT      = None   # 20000 for quick run, or None for all

ROOT_PROC = Path('processed_images')
ROOT_DIST = Path('processed_aux/dist')
ROOT_BAD  = Path('processed_masks/bad')
META_CSV  = Path('extracted_images/metadata.csv')

CLASSES  = ['Center','Donut','Edge-Loc','Edge-Ring','Loc','Random','Scratch','Near-Full','None']
CLS2ID   = {c:i for i,c in enumerate(CLASSES)}
N_CLASSES = len(CLASSES)

# Label normalization

In [None]:
def canon_label(s):
    if s is None:
        return 'Unlabeled'
    s = str(s).strip().strip('_').replace('_','-').lower()
    table = {
        'center':'Center','donut':'Donut','edge-loc':'Edge-Loc','edge-ring':'Edge-Ring',
        'loc':'Loc','random':'Random','scratch':'Scratch','near-full':'Near-Full',
        'none':'None','unlabeled':'Unlabeled'
    }
    return table.get(s, 'Unlabeled')

def is_labeled(lbl):
    return lbl in CLS2ID


# Build dataframe of triplets (wafer, dist, bad)

In [None]:
assert META_CSV.exists(), "metadata.csv not found at converted_images/metadata.csv"
meta = pd.read_csv(META_CSV)

for col in ['file','label','split']:
    if col not in meta.columns:
        raise ValueError(f"metadata.csv missing required column: {col}")

meta['label'] = meta['label'].apply(canon_label)

def triplet_paths(rel):
    p = Path(rel)
    return (ROOT_PROC/p, ROOT_DIST/p, ROOT_BAD/p)

rows = []
for _, r in meta.iterrows():
    w, d, b = triplet_paths(r['file'])
    rows.append([str(w), str(d), str(b), r['label'], r['split']])
df = pd.DataFrame(rows, columns=['wafer','dist','bad','label','split'])

mask = df['wafer'].apply(lambda p: Path(p).exists()) \
     & df['dist'].apply(lambda p: Path(p).exists()) \
     & df['bad'].apply(lambda p: Path(p).exists())
df = df[mask].reset_index(drop=True)

df_l = df[df['label'].apply(is_labeled)].reset_index(drop=True)
df_u = df[~df['label'].apply(is_labeled)].reset_index(drop=True)

train_df = df_l[df_l['split'].astype(str).str.lower().str.startswith('train')].reset_index(drop=True)
test_df  = df_l[df_l['split'].astype(str).str.lower().str.startswith('test')].reset_index(drop=True)

if len(train_df) == 0 or len(test_df) == 0:
    if len(df_l) == 0:
        print("WARNING: No labeled samples found. Classifier will be skipped.")
        train_df = pd.DataFrame(columns=df_l.columns)
        test_df  = pd.DataFrame(columns=df_l.columns)
    else:
        print("INFO: No usable train/test in metadata. Creating a stratified 80/20 split.")
        train_df, test_df = train_test_split(
            df_l, test_size=0.2, random_state=42, stratify=df_l['label']
        )
        train_df = train_df.reset_index(drop=True)
        test_df  = test_df.reset_index(drop=True)

if LIMIT is not None:
    train_df = train_df.iloc[:LIMIT].reset_index(drop=True)
    test_df  = test_df.iloc[:min(LIMIT, len(test_df))].reset_index(drop=True)

print(f"Triplets: {len(df)} | Train labeled: {len(train_df)} | Test labeled: {len(test_df)} | Unlabeled: {len(df_u)}")

# Image readers and loaders

In [None]:
def read_gray64(path):
    img = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(path)
    if img.shape[0] != IMG_SIZE or img.shape[1] != IMG_SIZE:
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
    return img

def load_triplet_row(row):
    """Return (H,W,3) float32 in [0,1], channels = [wafer, dist, bad]"""
    w = read_gray64(row['wafer'])
    d = read_gray64(row['dist'])
    b = read_gray64(row['bad'])

    w = (w // 127).astype(np.float32) / 2.0
    d = (d.astype(np.float32) / 255.0)
    b = (b.astype(np.float32) / 255.0)

    x = np.stack([w, d, b], axis=-1)
    return x

def gen_x(frame):
    for _, r in frame.iterrows():
        yield load_triplet_row(r)

def gen_xy(frame):
    for _, r in frame.iterrows():
        yield load_triplet_row(r), CLS2ID[r['label']]

# tf.data pipelines

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

def make_ds_all(frame, batch):
    """For Autoencoder: yields (x, x) so Keras has targets."""
    ds_x = tf.data.Dataset.from_generator(
        lambda: gen_x(frame),
        output_signature=tf.TensorSpec(shape=(IMG_SIZE, IMG_SIZE, 3), dtype=tf.float32)
    )
    ds = ds_x.map(lambda x: (x, x), num_parallel_calls=AUTOTUNE)
    ds = ds.shuffle(4096).batch(batch).prefetch(AUTOTUNE)
    return ds

def make_ds_xy(frame, batch, shuffle=True):
    ds = tf.data.Dataset.from_generator(
        lambda: gen_xy(frame),
        output_signature=(
            tf.TensorSpec(shape=(IMG_SIZE, IMG_SIZE, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int32),
        )
    )
    if shuffle:
        ds = ds.shuffle(4096)
    ds = ds.batch(batch).prefetch(AUTOTUNE)
    return ds

In [None]:
ae_df = pd.concat([df_l, df_u], axis=0).reset_index(drop=True)
ae_ds = make_ds_all(ae_df, BATCH_AE)
train_ds = make_ds_xy(train_df, BATCH_CLF, shuffle=True) if len(train_df) else None
test_ds  = make_ds_xy(test_df,  BATCH_CLF, shuffle=False) if len(test_df) else None

# Autoencoder

In [None]:
from tensorflow.keras import layers, models, optimizers

inp = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3), name='input')

# Encoder
x = layers.Conv2D(32, 3, padding='same', activation='relu')(inp)
x = layers.MaxPooling2D()(x)  # 32x32
x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
x = layers.MaxPooling2D()(x)  # 16x16
z = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
z = layers.MaxPooling2D()(z)  # 8x8 latent

# Decoder
y = layers.Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu')(z)  # 16x16
y = layers.Conv2DTranspose(32, 3, strides=2, padding='same', activation='relu')(y)  # 32x32
recon = layers.Conv2DTranspose(3, 3, strides=2, padding='same', activation='sigmoid')(y)  # 64x64

ae = models.Model(inp, recon, name='autoencoder')
ae.compile(optimizer=optimizers.Adam(1e-3), loss='mse')
ae.summary()

In [None]:
history = ae.fit(ae_ds, epochs=EPOCHS_AE, verbose=1)
ae.save('ae_keras_savedmodel')
print('Saved ae_keras_savedmodel')

# Classifier: reuse encoder + GAP + Dense softmax

In [None]:
encoder = models.Model(inp, z, name='encoder')
clf_in = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3), name='input')
feat = encoder(clf_in)
gap  = tf.keras.layers.GlobalAveragePooling2D()(feat)
logits = tf.keras.layers.Dense(N_CLASSES, activation=None, name='logits')(gap)
prob   = tf.keras.layers.Softmax(name='softmax')(logits)
clf = tf.keras.Model(clf_in, prob, name='classifier')
clf.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy'])
clf.summary()
hist = clf.fit(train_ds, epochs=EPOCHS_CLF, validation_data=test_ds, verbose=1)
y_true, y_pred = [], []
for x, y in test_ds:
    p = clf.predict(x, verbose=0)
    y_true += y.numpy().tolist()
    y_pred += p.argmax(axis=1).tolist()
macro = f1_score(y_true, y_pred, average='macro')
print('Macro-F1:', macro)
print(classification_report(y_true, y_pred, target_names=CLASSES))
print(confusion_matrix(y_true, y_pred))
clf.save('clf_keras_savedmodel')
print('Saved clf_keras_savedmodel')
