In [98]:
# === IMPORTS ===
import os
import math
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import class_weight
from tensorflow.keras import layers, models, losses, optimizers, applications, mixed_precision

In [158]:
np.random.seed(42)
tf.random.set_seed(42)

# Configuration
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
NUM_EPOCHS = 7  # Increase epochs when in production
LEARNING_RATE = 0.001
TASK = "style"  # Options: "style", "artist", "genre"

In [101]:
# Paths
BASE_DIR = "C:/Users/Ace/Gsoc_HumanAI/wikiart_csv"
WIKIART_DIR = "C:/Users/Ace/Gsoc_HumanAI/wikiart" # artwork images here
MODELS_DIR = "C:/Users/Ace/Gsoc_HumanAI"
TRAIN_DATA_PATH = f"{BASE_DIR}/{TASK}_train.csv"
VAL_DATA_PATH = f"{BASE_DIR}/{TASK}_val.csv"
CLASSES_PATH = f"{BASE_DIR}/{TASK}_class.txt"

In [104]:
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

def load_data(data_path, subset_size=1.0, random_state=42):
    """Load data from CSV file and apply stratified sampling."""
    df = pd.read_csv(data_path)
    df.columns = ['image_path', 'label']

    df['image_path'] = df['image_path'].apply(lambda x: os.path.join(WIKIART_DIR, x))

    df = df[df['image_path'].apply(os.path.exists)]

    if subset_size < 1.0:
        df = df.groupby('label', group_keys=False).apply(
            lambda x: x.sample(frac=subset_size, random_state=random_state)
        )

    # Print sample paths for verification
    sample_paths = df['image_path'].sample(min(5, len(df))).tolist()
    for path in sample_paths:
        print(f"Checking if path exists: {path}")
        print(f"Exists: {os.path.exists(path)}")

    return df

def load_classes(classes_path):
    """Load class names from text file."""
    with open(classes_path, 'r') as f:
        classes = [line.strip() for line in f.readlines()]
    return classes

In [106]:
def preprocess_data(train_df, val_df, classes):
    """Preprocess data for training."""
    is_numeric_labels = isinstance(train_df['label'].iloc[0], (int, np.integer))

    if is_numeric_labels:
        train_df['label_encoded'] = train_df['label']
        val_df['label_encoded'] = val_df['label']

        label_map = {i: class_name for i, class_name in enumerate(classes)}

        train_df['label_name'] = train_df['label'].map(label_map)
        val_df['label_name'] = val_df['label'].map(label_map)

        le = LabelEncoder()
        le.fit(classes)
    else:
        le = LabelEncoder()
        le.fit(classes)

        unknown_train_labels = set(train_df['label']) - set(classes)
        unknown_val_labels = set(val_df['label']) - set(classes)

        if unknown_train_labels:
            print(f"Warning: Found {len(unknown_train_labels)} unknown labels in training data")
            print(f"Sample unknown labels: {list(unknown_train_labels)[:5]}")

            train_df = train_df[train_df['label'].isin(classes)]

        if unknown_val_labels:
            print(f"Warning: Found {len(unknown_val_labels)} unknown labels in validation data")
            print(f"Sample unknown labels: {list(unknown_val_labels)[:5]}")

            val_df = val_df[val_df['label'].isin(classes)]

        train_df['label_encoded'] = le.transform(train_df['label'])
        val_df['label_encoded'] = le.transform(val_df['label'])

    class_weights = compute_class_weight(
        'balanced',
        classes=np.unique(train_df['label_encoded']),
        y=train_df['label_encoded']
    )
    class_weights_dict = {i: weight for i, weight in enumerate(class_weights)}

    return train_df, val_df, le, class_weights_dict

In [108]:
def create_data_generators(df, batch_size, task):
    """Create a tf.data.Dataset compatible with EfficientNetB3."""
    def load_and_preprocess_image(path):
        image = tf.io.read_file(path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH])
        image = tf.keras.applications.efficientnet.preprocess_input(image)
        return image

    def augment_image(image, label):
        if task == "style":
            image = tf.cast(image, tf.uint8)
            image = tf.image.random_flip_left_right(image)
            
            image = tf.image.random_brightness(image, max_delta=0.1)  
            image = tf.image.random_contrast(image, lower=0.9, upper=1.1) 
            image = tf.image.random_saturation(image, lower=0.9, upper=1.1)  
            
            crop_factor = tf.random.uniform([], 0.9, 1.0)  
            crop_size = tf.cast(
                tf.cast([IMG_HEIGHT, IMG_WIDTH], tf.float32) * crop_factor,
                tf.int32
            )
            crop_size = tf.minimum(crop_size, [IMG_HEIGHT, IMG_WIDTH])
            image = tf.image.random_crop(image, [crop_size[0], crop_size[1], 3])
            image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH])

        elif task == "artist":
            image = tf.image.random_flip_left_right(image)
            image = tf.image.random_brightness(image, max_delta=0.2)
            image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
            image = tf.image.random_saturation(image, lower=0.8, upper=1.2)

        elif task == "genre":
            image = tf.image.random_flip_left_right(image)
            
            image = tf.image.random_brightness(image, max_delta=0.1)  
            image = tf.image.random_contrast(image, lower=0.9, upper=1.1) 
            image = tf.image.random_saturation(image, lower=0.9, upper=1.1) 
            
            if tf.random.uniform([], 0, 1) > 0.5:
                crop_factor = tf.random.uniform([], 0.95, 1.0)  
                crop_size = tf.cast(
                    tf.cast([IMG_HEIGHT, IMG_WIDTH], tf.float32) * crop_factor,
                    tf.int32
                )
                image = tf.image.random_crop(image, [crop_size[0], crop_size[1], 3])
                image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH])
        
        return image, label

    paths = df['image_path'].values
    labels = df['label_encoded'].values

    dataset = tf.data.Dataset.from_tensor_slices((paths, labels))

    dataset = dataset.map(
        lambda path, label: (load_and_preprocess_image(path), label),
        num_parallel_calls=tf.data.AUTOTUNE
    )

    dataset = dataset.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)

    dataset = dataset.shuffle(buffer_size=10000)

    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return dataset

In [132]:
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint

def focal_loss(gamma=2., alpha=0.25):
    def focal_loss_fixed(y_true, y_pred):
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
        y_true = tf.one_hot(tf.cast(y_true, tf.int32), tf.shape(y_pred)[-1])
        alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
        loss = -alpha_t * (y_true * tf.math.pow(1. - y_pred, gamma) * tf.math.log(y_pred))
        return tf.reduce_sum(loss, axis=-1)
    return focal_loss_fixed

class CosineAnnealingWithRestarts(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_lr, T_max, eta_min=0):
        self.initial_lr = initial_lr
        self.T_max = T_max
        self.eta_min = eta_min
        self.t = 0

    def __call__(self, step):
        cos_inner = tf.math.pi * (self.t % self.T_max) / self.T_max
        lr = self.eta_min + (self.initial_lr - self.eta_min) * (1 + tf.math.cos(cos_inner)) / 2
        self.t += 1
        return lr

def build_conv_recurrent_model(num_classes, task, img_height=224, img_width=224):
    """Build improved models for art classification with fixes for artist task."""
    if task == "style" or task == "genre":
        base_model = tf.keras.applications.EfficientNetB2(
            weights='imagenet',
            include_top=False,
            input_shape=(img_height, img_width, 3)
        )
    else:
        base_model = tf.keras.applications.EfficientNetB2(
            weights='imagenet',
            include_top=False,
            input_shape=(img_height, img_width, 3)
        )
    
    if task == "style":
        for layer in base_model.layers[:-30]:
            layer.trainable = False
    elif task == "genre":
        for layer in base_model.layers[:-60]:
            layer.trainable = False
    else:
        for layer in base_model.layers[:-40]:
            layer.trainable = False
    
    inputs = Input(shape=(img_height, img_width, 3))
    x = tf.keras.applications.efficientnet.preprocess_input(inputs)
    x = base_model(x)


    x = GlobalAveragePooling2D()(x)

    if task == "artist":
        x = Dense(1536, activation='relu')(x)
        x = BatchNormalization()(x)
        x = Dropout(0.4)(x)

        shortcut =Dense(768)(x)
        
        x = Dense(768, activation='relu')(x)
        x = BatchNormalization()(x)
        x = Dropout(0.3)(x)
        
        se = Dense(128, activation='relu')(x)
        se = Dense(768, activation='sigmoid')(se)
        x = x * se
        
        x = x + shortcut
    
    elif task == "style":
        x = Dense(1024, activation='relu')(x)
        x = BatchNormalization()(x)
        x = Dropout(0.3)(x)
        
        x = Dense(512, activation='relu')(x)
        x = BatchNormalization()(x)
        x = Dropout(0.3)(x)
        
        x = Dense(256, activation='relu')(x)
        x = BatchNormalization()(x)
        x = Dropout(0.2)(x)
    
    elif task == "genre":
        x = Dense(1024, activation='relu')(x)
        x = BatchNormalization()(x)
        x = Dropout(0.3)(x)
        
        x = Dense(512, activation='relu')(x)
        x = BatchNormalization()(x)
        x = Dropout(0.25)(x)
        
        x = Dense(256, activation='relu')(x)
        x = BatchNormalization()(x)
        x = Dropout(0.2)(x)
    
    outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs=inputs, outputs=outputs)

    if task == "style":
        optimizer = Adam(learning_rate=CosineAnnealingWithRestarts(2e-3, 1000), weight_decay=1e-5)
    elif task == "genre":
        initial_learning_rate = 5e-4
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate,
            decay_steps=2000,
            decay_rate=0.95,
            staircase=True
        )
        optimizer = Adam(learning_rate=lr_schedule, weight_decay=1e-5)
    else:
        optimizer = Adam(learning_rate=CosineAnnealingWithRestarts(5e-4, 2000), weight_decay=1e-5)
    
    model.compile(
        optimizer=optimizer,
        loss=focal_loss(gamma=2.0, alpha=0.25),
        metrics=['accuracy']
    )

    return model


In [122]:
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score
from tensorflow.keras.utils import plot_model

def evaluate_model(model, val_dataset, le,TASK):
    """Evaluate the model and handle both integer and one-hot encoded labels."""
    # Create results folder
    os.makedirs('results', exist_ok=True)
    
    plot_model(
        model, 
        to_file=f'results/model_architecture_{TASK}.png', 
        show_shapes=True, 
        show_layer_names=True,
        expand_nested=True
    )
    print(f"Model architecture saved as 'results/model_architecture_{TASK}.png'")

    predictions = model.predict(val_dataset, steps=len(val_dataset), verbose=1)
    predicted_classes = np.argmax(predictions, axis=1)

    true_classes = []
    for _, labels in val_dataset:
        labels = labels.numpy()
        if labels.ndim == 2:
            true_classes.extend(np.argmax(labels, axis=1))
        else:
            true_classes.extend(labels)

    true_classes = np.array(true_classes)

    accuracy = accuracy_score(true_classes, predicted_classes)
    f1 = f1_score(true_classes, predicted_classes, average='weighted')

    print(f"\nAccuracy: {accuracy:.4f}")
    print(f"F1 Score: {f1:.4f}\n")

    conf_matrix = confusion_matrix(true_classes, predicted_classes)
    class_report = classification_report(true_classes, predicted_classes, target_names=le.classes_)


    with open(f'results/classification_report_{TASK}.txt', 'w') as f:
        f.write(class_report)

    plt.figure(figsize=(14, 12))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', 
                xticklabels=le.classes_, yticklabels=le.classes_,
                cbar_kws={'shrink': 0.8}, linewidths=0.5, linecolor='gray',
                annot_kws={"size": 7},
                vmin=0, vmax=conf_matrix.max()) 

    plt.title(f'Confusion Matrix for {TASK}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    

    plt.xticks(ticks=np.arange(len(le.classes_)) + 0.5, labels=le.classes_, rotation=45, ha='right', fontsize=8)
    plt.yticks(ticks=np.arange(len(le.classes_)) + 0.5, labels=le.classes_, rotation=0, va='center', fontsize=8)

    plt.subplots_adjust(left=0.3, bottom=0.2) 
    
    plt.savefig(f'results/confusion_matrix_{TASK}.png')
    plt.show()


    f1_scores = f1_score(true_classes, predicted_classes, average=None)
    plt.figure(figsize=(16, 6))
    sns.barplot(x=le.classes_, y=f1_scores, palette='viridis')
    
    plt.title(f'Per-Class F1 Scores for {TASK}')
    plt.ylabel('F1 Score')

    plt.xticks(ticks=np.arange(len(le.classes_)), labels=le.classes_, rotation=45, ha='right')
    
    plt.tight_layout()
    plt.savefig(f'results/f1_scores_{TASK}.png')
    plt.show()

    with open(f'results/class_accuracy_{TASK}.txt', 'w') as f:
        f.write("Class-wise Evaluation:\n")
        for i, class_name in enumerate(le.classes_):
            class_acc = conf_matrix[i, i] / conf_matrix[i].sum() if conf_matrix[i].sum() > 0 else 0
            f.write(f"{class_name} - Accuracy: {class_acc:.4f}\n")
            
    prediction_confidence = np.max(predictions, axis=1)
    low_confidence_indices = np.where(prediction_confidence < 0.5)[0]
    misclassified_indices = np.where(predicted_classes != true_classes)[0]
    outlier_indices = np.union1d(low_confidence_indices, misclassified_indices)

    print(f"\nFound {len(outlier_indices)} potential outliers")

    return outlier_indices, predictions, true_classes, predicted_classes


In [124]:
def visualize_training_history(history,TASK):
    """Visualize training history"""
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history['accuracy'], label='Train Accuracy')
    plt.plot(history['val_accuracy'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(loc='upper left')

    plt.subplot(1, 2, 2)
    plt.plot(history['loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(loc='upper left')

    plt.tight_layout()
    plt.savefig(f'training_history_{TASK}.png')
    plt.show()



def visualize_outliers(outlier_indices, val_dataset, predictions, true_classes, predicted_classes, le, task, num_examples=5):
    """Visualize outlier examples"""
    
    results_dir = f'results/{task}'
    os.makedirs(results_dir, exist_ok=True)

    images = []
    for img_batch, _ in val_dataset:
        images.extend(img_batch.numpy()) 

    if len(outlier_indices) > num_examples:
        sample_indices = np.random.choice(outlier_indices, num_examples, replace=False)
    else:
        sample_indices = outlier_indices
    
    plt.figure(figsize=(12, 12))
    for i, idx in enumerate(sample_indices):
        if idx < len(images):  
            image = images[idx]

            if image.max() <= 1.0:  
                image = (image * 255).astype('uint8')
            else:
                image = np.clip(image, 0, 255).astype('uint8')

            true_label = le.classes_[true_classes[idx]]
            pred_label = le.classes_[predicted_classes[idx]]
            confidence = predictions[idx][predicted_classes[idx]]

            plt.subplot(3, 2, i + 1)
            plt.imshow(image)
            plt.title(f"True: {true_label}\nPred: {pred_label}\nConf: {confidence:.2f}", fontsize=10)
            plt.axis('off')

    plt.tight_layout()
    outlier_path = os.path.join(results_dir, f'outliers_{task}.png')
    plt.savefig(outlier_path)
    plt.show()
    print(f"Outliers saved to: {outlier_path}")


In [126]:
# Data exploration function
def explore_dataset():
    """Explore the dataset structure"""
    for task in ["artist", "genre", "style"]:
        train_path = f"{BASE_DIR}/{task}_train.csv"
        val_path = f"{BASE_DIR}/{task}_val.csv"
        class_path = f"{BASE_DIR}/{task}_class.txt"

        print(f"\n{'='*40}")
        print(f"Exploring {task.upper()} dataset")
        print(f"{'='*40}")

        print(f"Train file exists: {os.path.exists(train_path)}")
        print(f"Val file exists: {os.path.exists(val_path)}")
        print(f"Class file exists: {os.path.exists(class_path)}")

        try:
            train_df = pd.read_csv(train_path)  
            val_df = pd.read_csv(val_path)      

            print(f"\nTrain data shape: {train_df.shape}")
            print(f"Validation data shape: {val_df.shape}")

            print(f"\nTrain columns: {train_df.columns.tolist()}")

            print("\nSample train data (first 3 rows):")
            print(train_df.head(3))

            if len(train_df.columns) >= 2:
                label_col = train_df.iloc[:, 1] 
                unique_labels = label_col.unique()
                print(f"\nNumber of unique labels in training data: {len(unique_labels)}")
                print(f"Sample labels: {unique_labels[:5].tolist()}")

            if len(train_df.columns) >= 1:
                img_col = train_df.iloc[:, 0] 
                sample_paths = img_col.sample(min(3, len(img_col))).tolist()
                print("\nSample image paths:")
                for path in sample_paths:
                    print(f"  {path}")
                    full_path = os.path.join(WIKIART_DIR, path)
                    print(f"  Exists in wikiart folder: {os.path.exists(full_path)}")

        except Exception as e:
            print(f"Error exploring {task} dataset: {str(e)}")

print("Exploring dataset structure...")
explore_dataset()

print("\nStarting model training...")


Exploring dataset structure...

Exploring ARTIST dataset
Train file exists: True
Val file exists: True
Class file exists: True

Train data shape: (13345, 2)
Validation data shape: (5705, 2)

Train columns: ['Realism/vincent-van-gogh_pine-trees-in-the-fen-1884.jpg', '22']

Sample train data (first 3 rows):
  Realism/vincent-van-gogh_pine-trees-in-the-fen-1884.jpg  22
0  Baroque/rembrandt_the-angel-appearing-to-the-s...       20
1  Post_Impressionism/paul-cezanne_portrait-of-th...       16
2  Impressionism/pierre-auguste-renoir_young-girl...       17

Number of unique labels in training data: 23
Sample labels: [20, 16, 17, 9, 1]

Sample image paths:
  Romanticism/gustave-dore_paradise-lost-4.jpg
  Exists in wikiart folder: True
  Post_Impressionism/vincent-van-gogh_the-raising-of-lazarus-1890.jpg
  Exists in wikiart folder: True
  Impressionism/eugene-boudin_the-port-of-deauville-1.jpg
  Exists in wikiart folder: True

Exploring GENRE dataset
Train file exists: True
Val file exists: True

In [160]:
def train_model(task, subset_size=1.0):
    """Train the improved model with enhanced training strategy."""
    train_df = load_data(f"{BASE_DIR}/{task}_train.csv", subset_size=subset_size)
    val_df = load_data(f"{BASE_DIR}/{task}_val.csv", subset_size=subset_size)
    classes = load_classes(f"{BASE_DIR}/{task}_class.txt")

    train_df, val_df, label_encoder, class_weights_dict = preprocess_data(train_df, val_df, classes)

    train_dataset = create_data_generators(train_df, BATCH_SIZE, task)
    val_dataset = create_data_generators(val_df, BATCH_SIZE, task)

    model = build_conv_recurrent_model(len(classes), task)
    print(model.summary())

    # Callbacks for better training
    checkpoint = ModelCheckpoint(
        f'{MODELS_DIR}/models/best_model_{task}.keras',
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    )

    early_stopping = EarlyStopping(
        monitor='val_accuracy',
        patience=12,  
        restore_best_weights=True,
        verbose=1
    )

    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,  
        patience=4,  
        min_lr=1e-7,  
        verbose=1
    )

    model.compile(
        optimizer=Adam(learning_rate=LEARNING_RATE, weight_decay=1e-6),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    print("Phase 1: Initial training with mostly frozen base model...")
    history1 = model.fit(
        train_dataset,
        epochs=NUM_EPOCHS,
        validation_data=val_dataset,
        callbacks=[checkpoint, early_stopping, reduce_lr],
        class_weight=class_weights_dict
    )

    print("Phase 2: Fine-tuning with more layers unfrozen...")
    base_model = model.layers[1]


    for layer in base_model.layers:
        layer.trainable = True

    model.compile(
        optimizer=Adam(learning_rate=LEARNING_RATE/10, weight_decay=1e-6), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    history2 = model.fit(
        train_dataset,
        epochs=10,  
        initial_epoch=history1.epoch[-1] + 1,
        validation_data=val_dataset,
        callbacks=[checkpoint, early_stopping, reduce_lr],
        class_weight=class_weights_dict
    )

    # Evaluate the model
    print("Evaluating model...")
    outlier_indices, predictions, true_classes, predicted_classes = evaluate_model(
        model, val_dataset, label_encoder,task
    )

    print("Visualizing results...")
    combined_history = {}
    for k in history1.history.keys():
        combined_history[k] = history1.history[k] + history2.history[k]

    visualize_training_history(combined_history,task)
    visualize_outliers(
        outlier_indices, val_dataset, predictions,
        true_classes, predicted_classes, label_encoder,task
    )

    return model, combined_history


In [None]:
def train_all_models_improved():
    """Train improved models for all tasks with stratified sampling."""
    results = {}
    subset_sizes = {
        "artist": 1.0,  # Use 100% of the dataset for Artist
        "genre": 0.5,   # Use 50% of the dataset for Genre
        "style": 0.3    # Use 30% of the dataset for Style
    }

    for task in ["artist"]:
        print(f"\n\n{'='*60}")
        print(f"Training improved model for {task.upper()}")
        print(f"{'='*60}\n")

        model, history = train_model(task, subset_size=subset_sizes[task])
        results[task] = (model, history)

    return results


if __name__ == "__main__":
    # Training all models
    all_results = train_all_models_improved()

In [164]:
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image

In [166]:
def load_trained_model(task):
    """Load a trained model for a specific task"""
    model_path = f'{MODELS_DIR}/models/best_model_{task}.keras'
    if not os.path.exists(model_path):
        print(f"Error: Model for {task} not found at {model_path}")
        return None

    try:
        model = load_model(model_path)
        print(f"Successfully loaded {task} model")
        return model
    except Exception as e:
        print(f"Error loading {task} model: {str(e)}")
        return None

def load_class_labels(task):
    """Load class labels for a specific task"""
    classes_path = f'{BASE_DIR}/{task}_class.txt'
    try:
        with open(classes_path, 'r') as f:
            classes = [line.strip() for line in f.readlines()]
        print(f"Loaded {len(classes)} {task} classes")
        return classes
    except Exception as e:
        print(f"Error loading {task} classes: {str(e)}")
        return None


In [168]:
def preprocess_image(img_path):
    """Preprocess an image for model prediction"""
    try:
        if not os.path.exists(img_path):
            print(f"Error: Image not found at {img_path}")
            return None

        # Load and preprocess the image
        img = image.load_img(img_path, target_size=(224, 224))
        img_array = image.img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0)
        img_array = img_array / 255.0  

        return img_array, img
    except Exception as e:
        print(f"Error preprocessing image: {str(e)}")
        return None, None


In [170]:
def predict_artwork(img_path, tasks=None):
    """
    Predict artist, style, and genre for a given artwork image

    Parameters:
    img_path (str): Path to the artwork image
    tasks (list): List of tasks to perform, default ["artist", "genre", "style"]

    Returns:
    dict: Dictionary with predictions for each task
    """
    if tasks is None:
        tasks = ["artist", "genre", "style"]

    img_array, original_img = preprocess_image(img_path)
    if img_array is None:
        return None

    results = {}
    
    models = {task: load_trained_model(task) for task in tasks}
    class_labels = {task: load_class_labels(task) for task in tasks}

    for task in tasks:
        print(f"\nPredicting {task}...")

        model = models.get(task)
        classes = class_labels.get(task)

        if model is None or classes is None:
            results[task] = {"error": f"Could not load model or classes for {task}"}
            continue

        try:
            img_tensor = tf.convert_to_tensor(img_array)
            img_tensor = tf.ensure_shape(img_tensor, (1, 224, 224, 3))  # Example shape (adjust to your model)

            @tf.function(reduce_retracing=True)
            def predict_step(input_tensor):
                return model(input_tensor)

            predictions = predict_step(img_tensor)

            top_indices = tf.argsort(predictions[0], direction="DESCENDING")[:3]
            top_predictions = [(classes[i.numpy()], float(predictions[0][i].numpy())) for i in top_indices]

            results[task] = {
                "top_predictions": top_predictions,
                "prediction": classes[int(tf.argmax(predictions[0]))],
                "confidence": float(tf.reduce_max(predictions[0]))
            }

            print(f"Top {task} predictions:")
            for class_name, prob in top_predictions:
                print(f"  {class_name}: {prob:.4f}")

        except Exception as e:
            print(f"Error making prediction for {task}: {str(e)}")
            results[task] = {"error": str(e)}

    visualize_prediction_results(img_path, original_img, results)

    return results


In [172]:
def visualize_prediction_results(img_path, original_img, results):
    """Visualize the prediction results"""
    plt.figure(figsize=(12, 8))

    plt.subplot(1, 2, 1)
    plt.imshow(original_img)
    plt.title(f"Artwork: {os.path.basename(img_path)}")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.axis('off')

    result_text = "Prediction Results:\n\n"

    for task in results:
        result_text += f"{task.capitalize()}:\n"

        if "error" in results[task]:
            result_text += f"  Error: {results[task]['error']}\n"
        else:
            for i, (class_name, prob) in enumerate(results[task]["top_predictions"]):
                result_text += f"  {i+1}. {class_name}: {prob:.2%}\n"

        result_text += "\n"

    plt.text(0.1, 0.5, result_text, fontsize=12, verticalalignment='center')

    plt.tight_layout()
    plt.savefig('artwork_prediction.png')
    plt.show()


In [174]:
def batch_predict_artworks(folder_path, tasks=None):
    """
    Predict artist, style, and genre for all artwork images in a folder

    Parameters:
    folder_path (str): Path to the folder containing artwork images
    tasks (list): List of tasks to perform, default ["artist", "genre", "style"]
    """
    if tasks is None:
        tasks = ["artist", "genre", "style"]

    if not os.path.exists(folder_path):
        print(f"Error: Folder not found at {folder_path}")
        return

    image_extensions = ['.jpg', '.jpeg', '.png']
    image_files = [f for f in os.listdir(folder_path)
                  if os.path.isfile(os.path.join(folder_path, f)) and
                  any(f.lower().endswith(ext) for ext in image_extensions)]

    if not image_files:
        print(f"No image files found in {folder_path}")
        return

    print(f"Found {len(image_files)} image files. Starting batch prediction...")

    models = {}
    class_labels = {}

    for task in tasks:
        models[task] = load_trained_model(task)
        class_labels[task] = load_class_labels(task)

        if models[task] is None or class_labels[task] is None:
            print(f"Warning: Could not load model or classes for {task}")

    all_results = {}
    for img_file in image_files:
        img_path = os.path.join(folder_path, img_file)
        print(f"\nProcessing {img_file}...")

        img_array, _ = preprocess_image(img_path)
        if img_array is None:
            all_results[img_file] = {"error": "Failed to preprocess image"}
            continue

        img_results = {}
        for task in tasks:
            if models[task] is None or class_labels[task] is None:
                img_results[task] = {"error": f"Model or classes not available for {task}"}
                continue

            try:
                predictions = models[task].predict(img_array)

                top_indices = np.argsort(predictions[0])[-3:][::-1]
                top_predictions = [(class_labels[task][i], float(predictions[0][i])) for i in top_indices]

                img_results[task] = {
                    "top_predictions": top_predictions,
                    "prediction": class_labels[task][np.argmax(predictions)],
                    "confidence": float(np.max(predictions))
                }

            except Exception as e:
                print(f"Error making {task} prediction for {img_file}: {str(e)}")
                img_results[task] = {"error": str(e)}

        all_results[img_file] = img_results

    export_results_to_csv(all_results, folder_path)

    print(f"\nCompleted batch prediction for {len(image_files)} images")
    return all_results


In [176]:
def export_results_to_csv(all_results, folder_path):
    """Export batch prediction results to CSV"""
    import pandas as pd

    rows = []
    for img_file, img_results in all_results.items():
        row = {'image': img_file}

        for task in img_results:
            if "error" in img_results[task]:
                row[f'{task}_prediction'] = "ERROR"
                row[f'{task}_confidence'] = 0.0
            else:
                row[f'{task}_prediction'] = img_results[task]["prediction"]
                row[f'{task}_confidence'] = img_results[task]["confidence"]

                # Add top 3 predictions
                for i, (class_name, prob) in enumerate(img_results[task]["top_predictions"]):
                    row[f'{task}_top{i+1}'] = class_name
                    row[f'{task}_top{i+1}_confidence'] = prob

        rows.append(row)

    df = pd.DataFrame(rows)
    csv_path = os.path.join(folder_path, 'artwork_predictions.csv')
    df.to_csv(csv_path, index=False)
    print(f"Results exported to {csv_path}")

def analyze_single_image(img_path):
    """
    Analyze a single artwork image for artist, style, and genre

    Parameters:
    img_path (str): Path to the artwork image
    """
    print(f"Analyzing artwork: {img_path}")
    results = predict_artwork(img_path)

    if results:
        print("\nSummary of predictions:")
        for task, task_results in results.items():
            if "error" in task_results:
                print(f"  {task.capitalize()}: Error - {task_results['error']}")
            else:
                print(f"  {task.capitalize()}: {task_results['prediction']} ({task_results['confidence']:.2%})")

    return results

In [None]:
if __name__ == "__main__":
    # Single image prediction
    print("\n===== Single Image Prediction =====")
    analyze_single_image(test_image_path)

    # Batch prediction example - uncomment to use
    # test_folder_path = "./test_images"  # Change this to your test folder path
    # print("\n===== Batch Prediction =====")
    # batch_results = batch_predict_artworks(test_folder_path)
