In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory


# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
! pip install albumentations==1.3.0

In [None]:
path = Path("./Dataset")  
model_dir = Path("./Models") 
model_dir.mkdir(parents=True, exist_ok=True)


In [None]:
# Custom transformation class for integrating Albumentations with Fastai
class AlbumentationsTransform(RandTransform):
    split_idx = 0  # Indicates whether it's training (0) or validation (1) phase
    order = 2  # Set the order in which transformations are applied

    def __init__(self, train_aug, valid_aug):
        self.train_aug = train_aug  # Store training augmentations
        self.valid_aug = valid_aug  # Store validation augmentations

    def before_call(self, b, split_idx):
        self.idx = split_idx  # Assign split index to apply the correct augmentations

    def encodes(self, img: PILImage):
        img_np = np.array(img)  # Convert image to numpy array for Albumentations processing
        # Apply training augmentations if idx is 0 (training phase), otherwise apply validation augmentations
        aug_img = self.train_aug(image=img_np)['image'] if self.idx == 0 else self.valid_aug(image=img_np)['image']
        return PILImage.create(aug_img)  # Convert augmented numpy array back to PILImage

# Define training augmentations with several transformations for robustness
def get_train_aug(sz):
    return A.Compose([
        A.RandomResizedCrop(height=sz, width=sz, scale=(0.08, 1.0), p=1),  # Random crop to focus on varied portions of images
        A.Transpose(p=0.5),  # Random transpose (flipping) to introduce variety
        A.HorizontalFlip(p=0.5),  # Horizontal flip for rotational invariance
        A.VerticalFlip(p=0.5),  # Vertical flip for additional variety
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),  # Shift, scale, rotate for more transformations
        A.HueSaturationValue(hue_shift_limit=0.2, p=0.5),  # Random color adjustments for robustness
        A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),  # Random brightness and contrast changes
        A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.5),  # Introduce occlusions to help generalize
    ])

# Define validation augmentations (only resizing images for evaluation consistency)
def get_valid_aug(sz):
    return A.Compose([
        A.Resize(height=sz, width=sz, interpolation=cv2.INTER_LINEAR, p=1.0)  # Resize all validation images to the desired size
    ])

# Create data loaders with augmentation and normalization for training and validation
def get_dls(path, sz=224, bs=64):
    item_tfms = [Resize(sz), AlbumentationsTransform(get_train_aug(sz), get_valid_aug(sz))]  # Apply resizing and augmentation
    batch_tfms = [Normalize.from_stats(*imagenet_stats)]  # Normalize using ImageNet statistics for pre-trained models

    # Load images from folder, apply item transformations, and split data into training/validation sets
    dls = ImageDataLoaders.from_folder(
        path, valid_pct=0.2, seed=42,  # Use 20% for validation
        item_tfms=item_tfms, batch_tfms=batch_tfms, bs=bs  # Apply item and batch transformations
    )
    return dls

print("Preparing data...")
dls = get_dls(path, sz=224, bs=64)  # Create data loaders with image size 224 and batch size 64


In [None]:
# Function to print the classification report for the trained model
def print_classification_report(learn):
    preds, targs = learn.get_preds(dl=learn.dls.valid)  # Get predictions and targets from the validation set
    pred_classes = preds.argmax(dim=1)  # Get the predicted class (highest probability)
    class_names = learn.dls.vocab  # Get class names from the data loader
    print(classification_report(targs, pred_classes, target_names=class_names))  # Print the classification report

# Function to train the model with specified parameters
def train_model(model_name, arch, dls, epochs=10):
    print(f"\nTraining {model_name}...")

    # Create a learner object with the given data loader, architecture, and metrics
    learn = vision_learner(dls, arch, metrics=[accuracy], wd=1e-2,
                          cbs=[EarlyStoppingCallback(monitor='valid_loss', patience=2),
                               SaveModelCallback(monitor='valid_loss')])  # Callbacks for early stopping and saving model

    learn.model_dir = model_dir.absolute()  # Set the directory for saving the model

    # Phase 1: Fine-tuning the model
    print("="*50)
    print("Phase 1: Fine-tuning...")
    print("="*50)
    learn.fine_tune(epochs)  # Fine-tune the model on the data

    # Phase 2: Training with unfrozen layers (train all layers)
    print("\n" + "="*50)
    print("Phase 2: Training with unfrozen layers...")
    print("="*50)
    learn.unfreeze()  # Unfreeze all layers of the model for further training

    # Find the best learning rate using learning rate finder
    try:
        lr_results = learn.lr_find(suggest_funcs=(steep, valley))  # Find steepest and valley points of the loss curve
        lr_min = lr_results[0] if isinstance(lr_results, tuple) else 1e-3  # Get the suggested minimum learning rate
    except Exception as e:
        print(f"Warning: Failed to find learning rate. Using default 1e-3. Error: {e}")
        lr_min = 1e-3  # Default learning rate in case of failure

    # Training the model with the best learning rate found
    learn.fit_one_cycle(epochs, lr_min)  # Train for the given number of epochs with learning rate schedule

    # Evaluation phase: Generate confusion matrix and top losses for inspection
    print("\nGenerating evaluation metrics...")
    interp = ClassificationInterpretation.from_learner(learn)  # Get model interpretation for evaluation
    interp.plot_confusion_matrix()  # Plot confusion matrix to visualize model performance
    interp.plot_top_losses(10, nrows=2)  # Plot the top 10 worst predictions

    # Print the classification report for further insights
    print(f"\nClassification Report for {model_name}:")
    print_classification_report(learn)

    # Get the final accuracy and convert it to percentage
    acc = learn.validate()[1] * 100
    print(f"\nFinal Accuracy for {model_name}: {acc:.2f}%")

    # Save the model only if accuracy is greater than 90%
    if acc > 90:
        model_path = model_dir / f'{model_name}.pkl'  # Save model as a .pkl file
        learn.export(model_path)  # Export the trained model
        print(f"Model saved as {model_path}")  # Print confirmation of saving the model
        return learn  # Return the trained model
    else:
        print(f"Model {model_name} not saved (Accuracy < 90%)")  # Do not save model if accuracy is less than 90%
        return None  # Return None if the model is not saved


In [None]:
models_to_train = [
    ("resnet50", resnet50, "hair-resnet50"),
    ("vgg16_bn", vgg16_bn, "hair-vgg16"),
    ("convnext_tiny", convnext_tiny, "hair-convnext-tiny")
]

trained_models = {}

for model_desc, arch, save_name in models_to_train:
    print(f"\n{'='*50}")
    print(f"Starting training for {model_desc}")
    print(f"{'='*50}")

    model = train_model(save_name, arch, dls)
    if model:
        trained_models[model_desc] = model
        print(f"\nCompleted training {model_desc}")
    else:
        print(f"\nSkipping {model_desc} for ensemble")

print("\nTraining completed for all models!")


In [None]:
def ensemble_predict(models, dls):
    preds_list = []

    for model_name, learn in models.items():
        preds, _ = learn.get_preds(dl=dls.valid)
        preds_list.append(preds)

    avg_preds = torch.mean(torch.stack(preds_list), dim=0)
    final_pred_classes = avg_preds.argmax(dim=1)

    return final_pred_classes

if trained_models:
    print("\nGenerating Ensemble Predictions...")
    ensemble_preds = ensemble_predict(trained_models, dls)

    # Compute Classification Report
    _, targs = list(trained_models.values())[0].get_preds(dl=dls.valid)
    class_names = dls.vocab
    print("\nEnsemble Model Classification Report:")
    print(classification_report(targs, ensemble_preds, target_names=class_names))
else:
    print("\nNo models were trained with accuracy > 90%, skipping ensemble.")


In [None]:
import cv2
import torch
import fastai.vision.all as fv
from PIL import Image
import torchvision.transforms as transforms

# Load the FastAI model
model_path = "Models/hair-vgg16.pkl"
learn = fv.load_learner(model_path)

# Define transformation (FastAI handles its own preprocessing)
class_labels = ["Straight", "Wavy", "Curly", "Dreadlocks", "Kinky"]

# Start webcam
cap = cv2.VideoCapture(0)

while True:
    ret, frame = cap.read()
    if not ret:
        print("Failed to capture image")
        break

    # Convert OpenCV frame to PIL Image
    img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img_pil = Image.fromarray(img)

    # Run inference with FastAI
    pred, pred_idx, probs = learn.predict(img_pil)
    predicted_label = str(pred)

    # Display prediction on webcam feed
    cv2.putText(frame, f"Prediction: {predicted_label}", (50, 50),
                cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)

    cv2.imshow("Webcam - Hair Type Classification", frame)

    # Check if the window is closed
    if cv2.getWindowProperty("Webcam - Hair Type Classification", cv2.WND_PROP_VISIBLE) < 1:
        print("Window closed, stopping camera...")
        break

    # Press 'q' to exit
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()
