In [1]:
import os
import logging
from sklearn.model_selection import train_test_split
# Suppress warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
logging.getLogger('tensorflow').setLevel(logging.ERROR)

# set backend
os.environ['KERAS_BACKEND'] = 'torch'
import keras
print("Backend after setting:", keras.config.backend())

import tensorflow as tf
import torch

# Check GPU visibility
# print("TensorFlow GPUs:", tf.config.list_physical_devices('GPU'))
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

Backend after setting: torch


In [2]:
class Config:
    def __init__(self, device_id=0):
        self.device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu")

# Set device to GPU 1
config = Config(device_id=1)
device = config.device
torch.cuda.set_device(config.device.index)
print(torch.cuda.current_device()) 
print("PyTorch Device:", device)


1
PyTorch Device: cuda:1


In [3]:
from keras.models import Sequential
from keras import regularizers
from keras.layers import (Input, Conv2D, BatchNormalization, ReLU, MaxPooling2D, 
                          Flatten, Dense, Dropout, Lambda)
from keras.initializers import HeNormal
import keras.ops as K

def get_model(hidden_units, output_units, input_shape, rate, l2_coeff=1e-5):
    """
    Creates a face verification model that outputs normalized embeddings.
    """

    model = Sequential([Input(shape=input_shape)])

    # --- Convolutional blocks / Feature extraction backbone ---

    # note we use he kaiming initialization for the weights
    model.add(Conv2D(32, (3, 3), padding='same', kernel_initializer=HeNormal(),
                     kernel_regularizer=regularizers.l2(l2_coeff)))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(MaxPooling2D((2, 2)))

    # 2nd block
    model.add(Conv2D(64, (3, 3), padding='same', kernel_initializer=HeNormal(),
                     kernel_regularizer=regularizers.l2(l2_coeff)))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(MaxPooling2D((2, 2)))

    # 3rd block
    model.add(Conv2D(128, (3, 3), padding='same', kernel_initializer=HeNormal(),
                     kernel_regularizer=regularizers.l2(l2_coeff)))
    model.add(BatchNormalization())
    model.add(ReLU())
    model.add(MaxPooling2D((2, 2)))

    model.add(Flatten())

    # --- Fully connected layers ---
    for units in hidden_units:
        model.add(Dense(units, kernel_initializer=HeNormal(),
                        kernel_regularizer=regularizers.l2(l2_coeff)))
        model.add(BatchNormalization())
        model.add(ReLU())
        model.add(Dropout(rate))

    # --- Output layer for classification ---
    # outputs 8000-probability vector
    model.add(Dense(output_units, kernel_initializer=HeNormal(), activation='softmax'))

    return model

model = get_model(
    hidden_units=[1024, 128],
    output_units=8000, # 8000 identities
    input_shape=(112, 112, 3),
    rate=0.5
)
model.summary()

In [4]:
import os
import random
import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader, Sampler

class PytorchFaceDataset(Dataset):
    """
    A single unified class that:
      - Loads up to `max_images_per_identity` images from each identity folder.
      - Optionally uses `sample_indexes` for train/val slicing.
    """
    def __init__(
        self, dataset_path, identities,
        max_images_per_identity=10,
        sample_indexes=None,
        classes_per_batch=None,
        samples_per_class=None
    ):
        self.dataset_path = dataset_path
        self.identities = identities
        self.max_images_per_identity = max_images_per_identity
        self.sample_indexes = sample_indexes
        self.classes_per_batch = classes_per_batch
        self.samples_per_class = samples_per_class

        self.image_paths = []
        self.labels = []  # We'll group by label = idx in `identities`
        
        # Gather all images & labels
        for idx, identity in enumerate(identities):
            identity_folder = os.path.join(dataset_path, identity)
            if not os.path.isdir(identity_folder):
                continue  # skip if folder doesn't exist

            image_files = sorted(os.listdir(identity_folder))
            # Take up to max_images_per_identity images
            selected_images = image_files[:max_images_per_identity]
            for img_name in selected_images:
                self.image_paths.append(os.path.join(identity_folder, img_name))
                self.labels.append(idx)

        # sample_indexes is train_indices or val_indices
        if self.sample_indexes is not None:
            self.image_paths = [self.image_paths[i] for i in self.sample_indexes]
            self.labels = [self.labels[i] for i in self.sample_indexes]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        label = self.labels[index]

        # Load image, convert to RGB, resize to (112, 112)
        img = Image.open(img_path).convert("RGB")
        img = img.resize((112, 112), resample=Image.BILINEAR)
        # Normalize to [0, 1]
        img = np.array(img, dtype=np.float32) / 255.0  # shape: (112, 112, 3)

        img_tensor = torch.from_numpy(img)  # currently (112, 112, 3)

        return img_tensor, label

    def get_dataloader(self, batch_size=128, shuffle=True, num_workers=4):
        # If we have custom-batch parameters set, build the custom sampler
        if self.classes_per_batch is not None and self.samples_per_class is not None:
            return DataLoader(
                self,
                batch_sampler=self._BatchSampler(
                    self.labels,
                    self.classes_per_batch,
                    self.samples_per_class
                ),
                num_workers=num_workers
            )
        else:
            return DataLoader(
                self,
                batch_size=batch_size,
                shuffle=shuffle,
                num_workers=num_workers
            )

    class _BatchSampler(Sampler):
        """
        Ensures each batch contains `classes_per_batch` classes, each has
        `samples_per_class` samples or fewer.
        """
        def __init__(self, labels, classes_per_batch, samples_per_class):
            self.labels = labels
            self.classes_per_batch = classes_per_batch
            self.samples_per_class = samples_per_class

            # Group indices by class
            self.class_to_indices = {}
            for idx, label in enumerate(labels):
                self.class_to_indices.setdefault(label, []).append(idx)

            # Keep list of all classes for shuffling each epoch
            self.all_classes = list(self.class_to_indices.keys())

        def __iter__(self):
            '''Defines how the batch indices are generated.'''
            random.shuffle(self.all_classes) # Shuffle list of all labels each epoch

            # We'll chunk the shuffled class list in groups of 'classes_per_batch'

            # loop over shuffled classes in chunks of 'classes_per_batch'
            # e.g. 8000 classes, 10 class per epoch, so 800 iterations
            for start in range(0, len(self.all_classes), self.classes_per_batch):
                chunk_classes = self.all_classes[start:start + self.classes_per_batch]
                
                batch_indices = []
                for cls in chunk_classes: # collect sample indices for each class
                    idx_list = self.class_to_indices[cls] # map label to indices

                    # if a class has enough samples, sample 'samples_per_class' indices
                    # otherwise, take all indices
                    if len(idx_list) >= self.samples_per_class:
                        chosen = random.sample(idx_list, self.samples_per_class)
                    else:
                        chosen = idx_list  # class is smaller than desired
                    batch_indices.extend(chosen)

                yield batch_indices # move to dataloader

        def __len__(self):
            ''' computes how many batches are needed to cover all classes '''
            # e.g. 8000 class, 10 class per batch, so (8010-1) // 10 = 800
            return (len(self.all_classes) + self.classes_per_batch - 1) // self.classes_per_batch


In [5]:
# # Load dataset identities
import numpy as np
dataset_path = "data/casia-webface"
identities = [d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))]
print("Number of identities in dataset:", len(identities))
indices = np.arange(80000)
train_indices, val_indices = train_test_split(indices, test_size=0.2, random_state=42)
train_torch_dataset = PytorchFaceDataset(dataset_path, identities, 
                                         max_images_per_identity=10, 
                                         sample_indexes=train_indices,
                                         classes_per_batch=10,   # each batch will pick 10 classes
                                        samples_per_class=10) # each class has 10 samples
train_loader = train_torch_dataset.get_dataloader(num_workers=4) # batch size = 10*10=100
# val_torch_dataset = PytorchFaceDataset(dataset_path, identities,
#                                        max_images_per_identity=10, 
#                                        sample_indexes=val_indices)
val_torch_dataset = PytorchFaceDataset(dataset_path, identities,
                                       max_images_per_identity=10, 
                                       sample_indexes=val_indices,
                                       classes_per_batch=10,   # each batch will pick 10 classes
                                       samples_per_class=10) # each class has 10 samples
# val_loader = val_torch_dataset.get_dataloader(batch_size=128, shuffle=False)
val_loader = val_torch_dataset.get_dataloader(num_workers=4) # batch size = 10*10=100

# Check shape
for images, labels in train_loader:
    print(f"Images shape: {images.shape}, Labels shape: {labels.shape}")
    print("Unique labels in this batch:", len(set(labels.tolist())))
    break


Number of identities in dataset: 8000


  self.pid = os.fork()
  self.pid = os.fork()


Images shape: torch.Size([83, 112, 112, 3]), Labels shape: torch.Size([83])
Unique labels in this batch: 10


In [6]:
for batch, labels in train_loader:
    print(batch.shape, 'batch shape')
    print(labels.shape, 'labels shape')
    break

torch.Size([77, 112, 112, 3]) batch shape
torch.Size([77]) labels shape


In [10]:
import torch
import torch.nn.functional as F
import keras
loss_metric = keras.metrics.Mean()
accuracy_metric = keras.metrics.SparseCategoricalAccuracy()
val_accuracy_metric = keras.metrics.Mean() 

def pt_train_step(model, loss_fn, optimizer, train_batch):
    model.zero_grad()

    images, labels = train_batch  # Move to device if necessary

    outputs = model(images)  # Forward pass
    loss = loss_fn(labels, outputs)  # Compute loss

    loss.backward()  # Backpropagation
    grads = [param.grad for param in model.parameters()]

    return loss.item(), grads

def pt_valid_step(model, val_batch):
    images, labels = val_batch
    images, labels = images.to(device), labels.to(device)

    outputs = model(images)  # Forward pass
    predicted_classes = torch.argmax(outputs, dim=1)
    accuracy = (predicted_classes == labels).float().mean()

    return accuracy


In [11]:
def train_model_custom(mlp_model, loss_fn, opt, training_dataset, validation_dataset, train_step_fn, valid_step_fn, epochs):

    # check if training is using GPU
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = config.device
    print(f"Device: {device}")
    mlp_model.to(device)
    print("Training on GPU:", next(mlp_model.parameters()).is_cuda)

    epoch_losses = []
    val_epoch_acc = []

    for epoch in range(epochs):
        print(f"Epoch {epoch}/{epochs}")
        # Reset metrics for new epoch
        loss_metric.reset_state()
        val_accuracy_metric.reset_state()
        
        # Training loop
        for images, labels in training_dataset:

            loss, grads = train_step_fn(mlp_model, loss_fn, opt, train_batch=(images, labels))
            opt.apply_gradients(zip(grads, mlp_model.trainable_variables))
            
            loss_metric.update_state(loss)

        # Compute training loss and accuracy
        avg_epoch_loss = float(loss_metric.result().cpu().numpy())

        # Validation loop
        with torch.no_grad():  # Disable gradients for validation
            for images, labels in validation_dataset:

                acc = valid_step_fn(mlp_model, val_batch=(images, labels))
                val_accuracy_metric.update_state(acc)

        # Compute validation loss and accuracy
        avg_val_acc = float(val_accuracy_metric.result().cpu().numpy())

        # Store epoch results
        epoch_losses.append(avg_epoch_loss)
        val_epoch_acc.append(avg_val_acc)

        # Print progress
        print(f"Epoch {epoch}: loss - {avg_epoch_loss:.4f}, ")

        # for param in model.parameters():
        #     print('grad:', param.grad)

        # save checkpoint
        # if epoch % 10 == 0:
        print(f"val_acc - {avg_val_acc:.4f}")

        checkpoint_dir = "checkpoint"
        os.makedirs(checkpoint_dir, exist_ok=True)



    return epoch_losses


In [13]:
# optimizer = keras.optimizers.SGD(learning_rate=0.05)
optimizer = keras.optimizers.Adam(learning_rate=0.01)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=False)  # If output has softmax


epoch_losses = train_model_custom(model, loss_fn=loss_fn, opt=optimizer,
                                            training_dataset=train_loader, 
                                            validation_dataset=val_loader, 
                                            train_step_fn=pt_train_step, 
                                            valid_step_fn=pt_valid_step,
                                            epochs=10)


Device: cuda:1
Training on GPU: True
Epoch 0/10


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 0: loss - 8.9893, 
val_acc - 0.0001
Epoch 1/10
Epoch 1: loss - 8.9873, 
val_acc - 0.0000
Epoch 2/10
Epoch 2: loss - 8.9878, 
val_acc - 0.0000
Epoch 3/10
Epoch 3: loss - 8.9881, 
val_acc - 0.0000
Epoch 4/10
Epoch 4: loss - 8.9873, 
val_acc - 0.0000
Epoch 5/10
Epoch 5: loss - 8.9873, 
val_acc - 0.0000
Epoch 6/10
Epoch 6: loss - 8.9873, 
val_acc - 0.0000
Epoch 7/10
Epoch 7: loss - 8.9873, 
val_acc - 0.0000
Epoch 8/10
Epoch 8: loss - 8.9873, 
val_acc - 0.0000
Epoch 9/10
Epoch 9: loss - 8.9873, 
val_acc - 0.0000
