In [None]:
!curl -L "https://app.roboflow.com/ds/RolsO4e05e?key=OgTkpJ7f0p" > roboflow.zip; unzip -q roboflow.zip; rm roboflow.zip
! rm README.*

In [None]:
import os, math, random
import shutil
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans, DBSCAN
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
from tqdm import tqdm

In [None]:
SOURCE_DIR = './train/' 
DEST_DIR = 'split_dataset'
CLASSES = ['ok', 'not_ok']
INPUT_SHAPE = (224, 224, 3)
IMG_SIZE = (INPUT_SHAPE[0], INPUT_SHAPE[1])
NUM_CLUSTERS = 20
SPLIT_RATIOS = {
    'train': 0.75,
    'valid': 0.25,
}

In [None]:
base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=INPUT_SHAPE, pooling='avg')
model = Model(inputs=base_model.input, outputs=base_model.output)

In [None]:
def load_and_preprocess_image(img_path):
    img = image.load_img(img_path, target_size=IMG_SIZE)
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    return preprocess_input(img_array)

def extract_embeddings(image_paths):
    embeddings = []
    for img_path in tqdm(image_paths, desc="Extracting embeddings"):
        img_array = load_and_preprocess_image(img_path)
        embedding = model.predict(img_array, verbose=0)
        embeddings.append(embedding[0])
    return np.array(embeddings)

def split_by_cluster(clustered_images, output_dir, cluster_groups, problem_cluster_groups):
    for split in SPLIT_RATIOS:
        for cls in CLASSES:
            os.makedirs(os.path.join(output_dir, split, cls), exist_ok=True)

    discarded = {
        'ok': [],
        'not_ok': []
    }
    to_train = {
        'ok': [],
        'not_ok': []
    }
    to_valid = {
        'ok': [],
        'not_ok': []
    }

    for group in cluster_groups:
        cluster = []
        idx = ""
        for cidx in group:
            cluster.extend(clustered_images[cidx])
            idx += f"{cidx} "
        
        ok_cluster = []
        not_ok_cluster = []
        for img_path in cluster:
            if 'not_ok' in img_path:
                not_ok_cluster.append(img_path)
            else:
                ok_cluster.append(img_path)
        random.shuffle(ok_cluster)
        random.shuffle(not_ok_cluster)
        ok_split_index = int(len(ok_cluster)*SPLIT_RATIOS['train'])
        not_ok_split_index = int(len(not_ok_cluster)*SPLIT_RATIOS['train'])
        to_train['ok'].extend(ok_cluster[:ok_split_index])
        to_train['not_ok'].extend(not_ok_cluster[:not_ok_split_index])
        to_valid['ok'].extend(ok_cluster[ok_split_index:])
        to_valid['not_ok'].extend(not_ok_cluster[not_ok_split_index:])
        print(f"Spliting regular cluster {idx}: found {len(ok_cluster)} ok, {len(not_ok_cluster)} not_ok. Keeping {len(ok_cluster) + len(not_ok_cluster)}, discarding {0}")

    for group in problem_cluster_groups:
        cluster = []
        idx = ""
        for cidx in group:
            cluster.extend(clustered_images[cidx])
            idx += f"{cidx} "
        
        ok_cluster = []
        not_ok_cluster = []
        for img_path in cluster:
            if 'not_ok' in img_path:
                not_ok_cluster.append(img_path)
            else:
                ok_cluster.append(img_path)
        random.shuffle(ok_cluster)
        random.shuffle(not_ok_cluster)
        
        balanced_min = min(len(ok_cluster), len(not_ok_cluster))
        split_index = int(balanced_min * SPLIT_RATIOS['train'])

        to_train['ok'].extend(ok_cluster[:split_index])
        to_train['not_ok'].extend(not_ok_cluster[:split_index])
        to_valid['ok'].extend(ok_cluster[split_index:balanced_min])
        to_valid['not_ok'].extend(not_ok_cluster[split_index:balanced_min])
    
        discarded['ok'].extend(ok_cluster[balanced_min:])
        discarded['not_ok'].extend(not_ok_cluster[balanced_min:])
        print(f"Spliting problem cluster {idx}: found {len(ok_cluster)} ok, {len(not_ok_cluster)} not_ok. Keeping {2*balanced_min}, discarding {len(ok_cluster) - balanced_min + len(not_ok_cluster) - balanced_min}")
            
            
    split_count = {
        'train': [0,0],
        'valid': [0,0]
    }
    for split in ['train', 'valid']:
        split_dict = to_train if split=='train' else to_valid
        for cls, imgs in split_dict.items():
            for img_path in imgs:
                dest_path = os.path.join(output_dir, split, cls, os.path.basename(img_path))
                shutil.copy2(img_path, dest_path)
                split_count[split][0 if cls=='not_ok' else 1] += 1

    print(f"{'Split':<10} {'NOT_OK':>10} {'OK':>10} {'Ratio':>10}")
    print('-' * (44))
    for split, counts in split_count.items():
        print(f"{split:<10} {counts[0]:>10} {counts[1]:>10} {counts[0]/counts[1]:>10.2f}")
    print()
    print(f"Total discarded: {len(discarded['ok']) + len(discarded['not_ok'])}")

In [None]:
# Load all image paths
all_image_paths = []
for cls in CLASSES:
    class_dir = os.path.join(SOURCE_DIR, cls)
    image_paths = [os.path.join(class_dir, fname) for fname in os.listdir(class_dir)
                   if fname.lower().endswith(('jpg', 'jpeg', 'png'))]
    all_image_paths.extend(image_paths)

# Generate embeddings
embeddings = extract_embeddings(all_image_paths)

In [None]:
# Cluster with KMeans
kmeans = KMeans(n_clusters=NUM_CLUSTERS, random_state=42)
cluster_labels = kmeans.fit_predict(embeddings)
distinct_labels = set(cluster_labels)

# Group images by cluster
clustered_images = {label: [] for label in distinct_labels}
for idx, label in enumerate(cluster_labels):
    clustered_images[label].append(all_image_paths[idx])

print(f"Divided into {len(distinct_labels)} clusters!")

In [None]:
SAMPLES = 180
SAMPLES_PER_ROW = 20
for cluster_id, paths in clustered_images.items():
    not_ok_count = sum(1 if 'not_ok' in path else 0 for path in paths)
    ok_count = len(paths) - not_ok_count
    ratio = max(ok_count, not_ok_count) / min(ok_count, not_ok_count)
    print(f"\nCluster {cluster_id} ({len(paths)} images) ({not_ok_count} damaged) ({ok_count} ok) - {ratio}")
    sample_paths = (random.sample(paths,SAMPLES)) if SAMPLES <= len(paths) else paths
    num_images = len(sample_paths)
    cols = min(SAMPLES_PER_ROW, num_images)  # Maximum 5 images per row
    rows = math.ceil(num_images / cols)

    fig, axes = plt.subplots(rows, cols, figsize=( cols, rows))
    axes = axes.flatten() if num_images > 1 else [axes]
    for ax, img_path in zip(axes, sample_paths):
        img = image.load_img(img_path, target_size=IMG_SIZE)
        ax.imshow(img)
        ax.axis('off')
    plt.show()

In [None]:
split_by_cluster(
    clustered_images, 
    DEST_DIR, 
    cluster_groups=[[2], [3], [6], [7], [18], [19]], 
    problem_cluster_groups=[[0], [1], [4,9], [5,8], [10], [11], [12, 16], [17]]
    )