In [None]:
# #In lambdalabs jupyter lab instance, run these:
# pip install transformers
# pip install tf-keras
# pip install --upgrade "numpy<2"
# pip install datasets
# pip install --upgrade datasets pillow
# pip install --upgrade "accelerate>=0.26.0"
# #then check dependency warnings
# pip check
# #if any issues run
# pip install debugpy
# pip install --upgrade argcomplete
# sudo apt-get install python3-cairo

In [1]:
import os
import numpy as np
from functools import partial
from io import BytesIO
from transformers import (
    AutoImageProcessor, 
    AutoModelForImageClassification, 
    EarlyStoppingCallback,
    get_cosine_schedule_with_warmup,
    TrainingArguments, 
    Trainer
)
from datasets import load_dataset, Image as DatasetsImage
import torch
import torchvision.transforms as T
from PIL import Image, ImageOps, ExifTags, UnidentifiedImageError

  from .autonotebook import tqdm as notebook_tqdm
2025-03-20 16:02:57.694652: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-20 16:02:57.710695: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742486577.730768    4942 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742486577.737095    4942 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1742486577.752697    4942 computation_placer.cc:177] computation placer already r

In [2]:
# --------------------------
# 1. Load the Pre-trained Model and Processor
# --------------------------
processor = AutoImageProcessor.from_pretrained("trpakov/vit-face-expression", use_fast=True)
model = AutoModelForImageClassification.from_pretrained("trpakov/vit-face-expression")

In [3]:
# --------------------------
# 2. Load Your Dataset
# --------------------------

# Load Your Dataset (Without automatic decoding)
dataset = load_dataset(
    "imagefolder",
    data_dir="/home/ubuntu/MLexpressionsStorage/img_datasets/combo_ferckja_dataset",
    split="train"
)

In [4]:
def validate_images(example):
    try:
        if isinstance(example["image"], Image.Image):
            example["image"].load()
        else:
            with open(example["image"], "rb") as f:
                img = Image.open(f).convert("RGB")
                img.load()
                example["image"] = img
        return True
    except (UnidentifiedImageError, AttributeError, OSError) as e:
        print(f"Skipping corrupted/unreadable image: {example.get('image', 'Unknown')} - {e}")
        return False

# Apply image validation
dataset = dataset.filter(validate_images)

In [6]:
# # If issues arise Define allowed image file extensions
# ALLOWED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff'}

# def is_image_file(file_name):
#     """
#     Check if a given file has one of the allowed image extensions.
#     """
#     ext = os.path.splitext(file_name)[1].lower()  # Extract extension and normalize to lowercase
#     return ext in ALLOWED_EXTENSIONS

# def validate_image(file_path):
#     """
#     Validate a local image file.
    
#     1. Opens the file in binary mode and reads its content.
#     2. Wraps the content with BytesIO to create a file-like object.
#     3. Loads the image using Pillow and handles EXIF orientation data.
#     4. Returns True if the image is valid; otherwise, logs an error and returns False.
#     """
#     try:
#         # Open and read the image file as bytes
#         with open(file_path, 'rb') as f:
#             img_bytes = f.read()
        
#         # Wrap the bytes into a BytesIO object for Pillow compatibility
#         img = Image.open(BytesIO(img_bytes))
#         img.load()  # Ensure the image data is fully loaded

#         try:
#             # Attempt to handle EXIF data for orientation
#             exif = img.getexif()
#             if exif is not None:
#                 orientation = exif.get(ExifTags.Base.Orientation)
#                 if orientation is not None:
#                     img = ImageOps.exif_transpose(img)
#         except Exception as ex:
#             # Log any errors encountered during EXIF processing but continue
#             print(f'EXIF handling error for {file_path}: {ex}')
        
#         # If processing reaches here, the image is considered valid
#         return True

#     except Exception as e:
#         # Log errors (e.g., corrupted files or unsupported formats) and mark the image as invalid
#         print(f'Image validation failed for {file_path}: {e}')
#         return False

# def search_and_validate_images(root_dir):
#     """
#     Recursively search for image files within the given directory,
#     validate them, and return a list of valid image file paths.
#     """
#     valid_images = []  # List to store paths of valid images
#     # os.walk traverses the directory tree starting at root_dir
#     for root, dirs, files in os.walk(root_dir):
#         for file in files:
#             # Process the file only if it has an allowed image extension
#             if is_image_file(file):
#                 file_path = os.path.join(root, file)
#                 if validate_image(file_path):
#                     valid_images.append(file_path)
#     return valid_images


In [5]:
# # Example usage:
# root_dir = '/home/ubuntu/MLexpressionsStorage/img_datasets/combo_ferckja_dataset'  # Root directory of your dataset
# valid_images = search_and_validate_images(root_dir)
# print(f"Total valid images: {len(valid_images)}")

# # 2. Function to robustly filter corrupted/unreadable images
# def validate_images(example):
#     try:
#         # Attempt loading image data
#         if isinstance(example["image"], Image.Image):
#             example["image"].load()
#         else:
#             example["image"] = Image.open(io.BytesIO(example["image"]["bytes"])).convert("RGB")
#             example["image"].load()
#         return True
#     except (UnidentifiedImageError, AttributeError, OSError, Exception) as e:
#         print(f"Removing corrupted image due to error: {e}")
#         return False

# # Apply validation
# dataset = dataset.filter(validate_images)

# print(f"Total examples after removing corrupted images: {len(dataset)}")

In [5]:
# --------------------------
# 3. Define mapping: dataset label -> pre-trained model label
# --------------------------
# Updated mapping using lowercase keys
label_mapping = {
    'anger': 'Angry',
    'contempt': 'Disgust',  # Merge "contempt" with "disgust"
    'disgust': 'Disgust',
    'fear': 'Fear',
    'happiness': 'Happy',
    'sadness': 'Sad',
    'surprise': 'Surprise',
    'neutral': 'Neutral'
}

# Numerical mapping for the pre-trained model's labels.
num_mapping = {
    'Angry': 0,
    'Disgust': 1,
    'Fear': 2,
    'Happy': 3,
    'Sad': 4,
    'Surprise': 5,
    'Neutral': 6
}

def reconcile_labels(example):
    # If the label is already an integer, convert it to a string using the dataset features.
    if isinstance(example["label"], int):
        # Use dataset.features["label"].int2str to get the string label.
        original_label = dataset.features["label"].int2str(example["label"]).strip().lower()
    else:
        original_label = example["label"].strip().lower()
    
    # Map the lowercased label to the pre-trained model's expected label.
    pretrain_label = label_mapping.get(original_label)
    
    if pretrain_label is None:
        # If not recognized, mark it for filtering.
        example["label"] = -1
    else:
        # Convert the mapped label to its corresponding integer.
        example["label"] = num_mapping[pretrain_label]
    return example

# Apply the reconciliation function to the dataset.
dataset = dataset.map(reconcile_labels)

# Filter out any examples that were marked as unrecognized.
dataset = dataset.filter(lambda x: x["label"] != -1)

In [6]:
# --------------------------
# 4. Define Data Augmentation and Preprocessing Transformation
# --------------------------
# Use torchvision transforms for lightweight CPU-based augmentation.
data_augment = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomRotation(20),                      # increased from 10 degrees
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # stronger jitter
    T.RandomResizedCrop(224, scale=(0.8, 1.0)),  # random zoom/crop
    T.RandomAffine(15),                        # affine transformations
])

def transform_function(example, processor):
    # Ensure the image is loaded as a PIL image.
    if not isinstance(example["image"], Image.Image):
        example["image"] = Image.open(example["image"])
    
    # Convert image to RGB mode if it isn't already.
    if example["image"].mode != "RGB":
        example["image"] = example["image"].convert("RGB")
    
    # Apply data augmentation.
    augmented_image = data_augment(example["image"])
    
    # Process the augmented image using the pre-trained processor.
    inputs = processor(augmented_image, return_tensors="pt")
    inputs = {k: v.squeeze(0) for k, v in inputs.items()}
    
    # Add the label (ensure the label is in the proper format, e.g. integer).
    inputs["labels"] = example["label"]
    return inputs

# Map the transformation to every example in the dataset.
dataset = dataset.map(partial(transform_function, processor=processor))

In [7]:
# --------------------------
# 5. Split Dataset into Training and Validation Sets
# --------------------------
split_dataset = dataset.train_test_split(test_size=0.2)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

In [11]:
# --------------------------
# 6. Define Training Arguments for Robust Fine-Tuning
# --------------------------
training_args = TrainingArguments(
    output_dir="./finetuned_vit_model",    # Directory to save checkpoints and the final model
    evaluation_strategy="epoch",           # Evaluate at the end of each epoch
    save_strategy="epoch",                 # Save checkpoint at each epoch
    learning_rate=1e-4,                    # A conservative learning rate for fine-tuning
    per_device_train_batch_size=8,         # Adjust based on your CPU memory limits
    per_device_eval_batch_size=8,
    num_train_epochs=5,                    # Fine-tune for a few epochs (adjust as needed)
    load_best_model_at_end=True,           # Automatically load the best model when training finishes
    metric_for_best_model="accuracy",      # Monitor accuracy for best model selection
    logging_dir="./logs",                  # Directory for TensorBoard logs
    label_smoothing_factor=0.1,            # Label smoothing applied to prevent overconfidence in predictions
    lr_scheduler_type="cosine",            # Implements cosine decay learning rate schedule
    warmup_ratio=0.1                        # Warms up the learning rate over the first 10% of training
)



In [12]:
# --------------------------
# 7. Define a Compute Metrics Function for Evaluation
# --------------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = (predictions == labels).mean()
    return {"accuracy": accuracy}

In [14]:
# --------------------------
# 8. Initialize and Run the Trainer for Fine-Tuning
# --------------------------
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]  # Stops training if no improvement in 2 consecutive epochs
)

In [15]:
# 9. Fine-tune the model
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,1.076769,0.696508
2,1.302800,1.06627,0.707024
3,0.878200,1.108606,0.712283
4,0.615200,1.166855,0.71808
5,0.491600,1.170024,0.720237


TrainOutput(global_step=2320, training_loss=0.772868064354206, metrics={'train_runtime': 14140.4948, 'train_samples_per_second': 10.489, 'train_steps_per_second': 0.164, 'total_flos': 1.1494126967676273e+19, 'train_loss': 0.772868064354206, 'epoch': 5.0})

In [16]:
# 10. Save final model
# After your training loops:
torch.save(model.state_dict(), '/home/ubuntu/MLexpressionsStorage/final_model_2.pth')