In [None]:
# 1️⃣ Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
import h5py
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Masking, LSTM, Dropout, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger

In [None]:
# 2️⃣ Paths (adjust to your folder structure)
DRIVE_BASE = '/content/drive/MyDrive/asl_model_train/'
TRAIN_H5 = os.path.join(DRIVE_BASE, 'train_data.h5')
VAL_H5   = os.path.join(DRIVE_BASE, 'val_data.h5')
CHECKPOINT_DIR = os.path.join(DRIVE_BASE, 'checkpoints')
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

BEST_MODEL_PATH   = os.path.join(CHECKPOINT_DIR, 'model_best.h5')
LATEST_MODEL_PATH = os.path.join(CHECKPOINT_DIR, 'model_latest.h5')
EPOCH_FILE        = os.path.join(CHECKPOINT_DIR, 'epoch.txt')
LOG_CSV           = os.path.join(CHECKPOINT_DIR, 'training_log.csv')

# 3️⃣ Hyper‑parameters
SEQ_LEN     = 300      # match how you preprocessed
FEATURE_DIM = 126
NUM_CLASSES = 1000
BATCH_SIZE  = 25
EPOCHS      = 50

In [None]:
# 4️⃣ (Optional) TPU strategy
try:
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    strategy = tf.distribute.TPUStrategy(resolver)
    print("⚡ TPU enabled")
except ValueError:
    strategy = tf.distribute.get_strategy()
    print("⚙️ CPU/GPU strategy")

In [None]:
# 5️⃣ Data generator
def data_generator(h5_path, batch_size):
    with h5py.File(h5_path, 'r') as f:
        X, y = f['X'], f['y']
        size = len(X)
        while True:
            idxs = np.arange(size)
            np.random.shuffle(idxs)
            for start in range(0, size, batch_size):
                batch = idxs[start:start+batch_size]
                yield X[batch], y[batch]

In [None]:
# 6️⃣ Model builder (inside strategy scope)
with strategy.scope():
    def build_model(seq_len, feature_dim, num_classes):
        m = Sequential([
            Masking(mask_value=0., input_shape=(seq_len, feature_dim)),
            LSTM(64, return_sequences=False),
            Dropout(0.3),
            Dense(128, activation='relu'),
            Dense(num_classes, activation='softmax')
        ])
        m.compile(optimizer=Adam(1e-3),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
        return m

    # 7️⃣ Resume logic
    if os.path.exists(LATEST_MODEL_PATH):
        print("🔄 Resuming from last checkpoint…")
        model = tf.keras.models.load_model(LATEST_MODEL_PATH)
        # read last epoch (stored as completed epochs)
        with open(EPOCH_FILE, 'r') as f:
            initial_epoch = int(f.read().strip())
    else:
        print("🚀 Starting new training run")
        model = build_model(SEQ_LEN, FEATURE_DIM, NUM_CLASSES)
        initial_epoch = 0

In [None]:
# 8️⃣ Compute steps
with h5py.File(TRAIN_H5, 'r') as f: train_size = len(f['X'])
with h5py.File(VAL_H5,   'r') as f: val_size   = len(f['X'])
steps_per_epoch = train_size // BATCH_SIZE
val_steps       = val_size   // BATCH_SIZE

In [None]:
# 9️⃣ Callbacks
checkpoint_latest = ModelCheckpoint(
    LATEST_MODEL_PATH, save_best_only=False, verbose=1
)
checkpoint_best = ModelCheckpoint(
    BEST_MODEL_PATH, save_best_only=True,
    monitor='val_accuracy', mode='max', verbose=1
)
early_stop = EarlyStopping(
    monitor='val_loss', patience=5, restore_best_weights=True
)
csv_logger = CSVLogger(LOG_CSV, append=True)

In [None]:
class EpochTracker(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        # save next starting epoch
        with open(EPOCH_FILE, 'w') as f:
            f.write(str(epoch + 1))

In [None]:
callbacks = [
    checkpoint_latest,
    checkpoint_best,
    early_stop,
    csv_logger,
    EpochTracker()
]

In [None]:
# 🔟 Launch Training
history = model.fit(
    data_generator(TRAIN_H5, BATCH_SIZE),
    validation_data=data_generator(VAL_H5, BATCH_SIZE),
    steps_per_epoch=steps_per_epoch,
    validation_steps=val_steps,
    epochs=EPOCHS,
    initial_epoch=initial_epoch,
    callbacks=callbacks,
    verbose=1
)