In [None]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models # ResNet, VGG
from torch.utils.data import DataLoader, TensorDataset
import cv2
from sklearn.neighbors import NearestNeighbors
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit # ensure balanced representation of all labels
from sklearn.metrics import average_precision_score

In [None]:
device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
def preprocess_image(image, target_size):
    image = cv2.resize(image, target_size)
    image = image.astype(np.float32) / 255.0
    mean = np.array([0.485, 0.456, 0.406]) # Channel-wise normalization (ImageNet statistics)
    std = np.array([0.229, 0.224, 0.225])
    image = (image - mean) / std

    image = np.transpose(image, (2, 0, 1)) # HWC -> CHW
    return image

In [None]:
dataset_dir = 'Datasets/DFC-15_MLC'
target_size = (256, 256)

images_path = os.path.join(dataset_dir, "images")
labels_csv_path = os.path.join(dataset_dir, "multilabel.csv")

df = pd.read_csv(labels_csv_path)
image_filenames = df["filename"].values
multi_labels = df.iloc[:, 1:].values

images = []
labels = []

for idx, filename in enumerate(image_filenames):
    img_path = os.path.join(images_path, f"{str(filename)}.png")
    image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)

    if image is None:
        continue

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = preprocess_image(image, target_size)

    images.append(image)
    labels.append(multi_labels[idx])

all_images = np.array(images, dtype=np.float32)
all_labels = np.array(labels, dtype=np.float32)

In [None]:
labeled_ratio = 0.25
test_ratio = 0.2

In [None]:
num_samples = all_images.shape[0]
num_classes = all_labels.shape[1]

msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=42)
train_idx, test_idx = next(msss.split(all_images, all_labels))
train_images, test_images = all_images[train_idx], all_images[test_idx]
train_labels, test_labels = all_labels[train_idx], all_labels[test_idx]

msss_labeled = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1-labeled_ratio, random_state=42)
labeled_idx, unlabeled_idx = next(msss_labeled.split(train_images, train_labels))

labeled_images = train_images[labeled_idx]
labeled_images_labels = train_labels[labeled_idx]
unlabeled_images = train_images[unlabeled_idx]

In [None]:
print(f"Training: {len(train_images)}, Testing: {len(test_images)}")
print(f"Labeled: {len(labeled_images)}, Unlabeled: {len(unlabeled_images)}")

In [None]:
def extract_features(images, feature_extractor, batch_size=128):

    feature_extractor.eval() # Disables dropout and batch normalization updates
    feature_extractor = feature_extractor.to(device)

    image_tensors = torch.FloatTensor(images) # Images to PyTorch tensors
    dataset = TensorDataset(image_tensors)
    dataloader = DataLoader(dataset, batch_size, shuffle=False) # Preserves original order

    features = []
    with torch.no_grad():  # Disable gradient computation
        for batch in dataloader:
            batch_images = batch[0].to(device)
            batch_features = feature_extractor(batch_images)

            if batch_features.dim() == 4:  # 4D tensor (batch_size, feature_dim, 1, 1) -> (batch_size, feature_dim)
                batch_features = batch_features.view(batch_features.size(0), -1)
            elif batch_features.dim() == 2:  # 2D tensor (batch_size, feature_dim)
                pass

            features.append(batch_features.cpu().numpy())
    return np.vstack(features)

In [None]:
model_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
model_resnet50.fc = nn.Linear(model_resnet50.fc.in_features, num_classes)

# model_efficientnet_b2 = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.IMAGENET1K_V1)
# model_efficientnet_b2.classifier[1] = nn.Linear(model_efficientnet_b2.classifier[1].in_features, num_classes)

In [None]:
# model = model_efficientnet_b2.to(device)
model=model_resnet50.to(device)

train_dataset = TensorDataset(
    torch.tensor(labeled_images).float(),
    torch.tensor(labeled_images_labels).float()
)

print(f"Image tensor shape: {train_dataset.tensors[0].shape}") # Shape: [N, 3, H, W]

In [None]:
losses = []

optimizer = optim.Adam(model.parameters(), lr=1e-4) # Update model weights
criterion = nn.BCEWithLogitsLoss() # Sigmoid layer + binary cross-entropy loss: Measures the difference between predicted probabilities and true binary labels
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

epochs=20

model.train()
for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad() # Clears old gradients
        output_labels = model(inputs)
        loss = criterion(output_labels, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    losses.append(epoch_loss)
    print(f"Epoch {epoch + 1}/{epochs} loss: {epoch_loss:.4f}")

feature_extractor = nn.Sequential(*list(model.children())[:-1])
feature_extractor=feature_extractor.to(device)

plt.figure(figsize=(10, 5))
plt.plot(range(1, epochs+1), losses, marker='o', linestyle='-', color='b')
plt.title('Training Loss Curve (Cross Entropy Loss)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.xticks(range(1, epochs+1))
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
labeled_images_features = extract_features(labeled_images, feature_extractor)
unlabeled_images_features = extract_features(unlabeled_images, feature_extractor)
test_images_features = extract_features(test_images, feature_extractor)

print(f"Labeled features shape: {labeled_images_features.shape}")
print(f"Unlabeled features shape: {unlabeled_images_features.shape}")
print(f"Test features shape: {test_images_features.shape}")

In [None]:
def calculate_variance(node_features, node_labels, w):
    # Label Variance (Gini impurity)
    label_var = 0.0
    labeled_mask = ~np.isnan(node_labels).any(axis=1)
    if np.sum(labeled_mask) > 0:
        labeled_y = node_labels[labeled_mask]
        n_labels = labeled_y.shape[1]
        gini_scores = []
        for label_idx in range(n_labels):
            p = np.mean(labeled_y[:, label_idx])
            gini = 2*p*(1-p)
            gini_scores.append(gini)
        label_var = np.mean(gini_scores)
    else:
        label_var = 0  # No labeled data in this node

    # Feature Variance
    feature_var = np.var(node_features, axis=0, ddof=0).mean()
    return w * label_var + (1 - w) * feature_var

In [None]:
def build_pct(X, y, depth=0, max_depth=5, min_samples_split=2, w=1, parent_prototype=None):
    node_prototype = None
    labeled_mask = ~np.isnan(y).any(axis=1)
    if np.sum(labeled_mask) > 0:
        node_prototype = np.nanmean(y[labeled_mask], axis=0)
    elif parent_prototype is not None:
        node_prototype = parent_prototype
    else:
        node_prototype = np.zeros(y.shape[1])
    
    if depth >= max_depth or len(y) < min_samples_split:
        return {'type': 'leaf', 'prediction': node_prototype}

    best_split = None
    best_score = float('inf')
    best_split_idx = None

    n_features = X.shape[1]
    feature_subset_size = max(1, int(np.sqrt(n_features)))
    feature_indices = np.random.choice(n_features, feature_subset_size, replace=False)
    
    for feature_idx in feature_indices:
        feature_values = X[:, feature_idx]
        
        if len(np.unique(feature_values)) > 10:
            percentiles = np.percentile(feature_values, [25, 50, 75])
            split_candidates = percentiles
        else:
            split_candidates = np.unique(feature_values)
        
        for split_val in split_candidates:
            left_mask = feature_values <= split_val
            right_mask = ~left_mask
            
            if np.sum(left_mask) < min_samples_split or np.sum(right_mask) < min_samples_split:
                continue
                
            left_var = calculate_variance(X[left_mask], y[left_mask], w) * np.sum(left_mask)
            right_var = calculate_variance(X[right_mask], y[right_mask], w) * np.sum(right_mask)
            total_var = (left_var + right_var) / len(y)
            
            if total_var < best_score:
                best_score = total_var
                best_split = split_val
                best_split_idx = feature_idx
                
    if best_split is None:
        return {'type': 'leaf', 'prediction': node_prototype}
        
    left_mask = X[:, best_split_idx] <= best_split
    right_mask = ~left_mask
    
    node = {
        'type': 'node',
        'feature_idx': best_split_idx,
        'threshold': best_split,
        'prototype': node_prototype,  # Store the prototype at each node
        'left': build_pct(X[left_mask], y[left_mask], depth+1, max_depth, min_samples_split, w, node_prototype),
        'right': build_pct(X[right_mask], y[right_mask], depth+1, max_depth, min_samples_split, w, node_prototype)
    }
    return node

In [None]:
def build_pct_forest(X, y, n_estimators, max_depth, w):
    min_samples_split = 2
    bootstrap = True # bootstrap sampling (random sampling with replacement)
    forest = []
    n_samples = X.shape[0]

    for _ in tqdm(range(n_estimators), desc="Building PCT Forest"):
        if bootstrap:
            indices = np.random.choice(n_samples, n_samples, replace=True)
            X_bootstrap = X[indices]
            y_bootstrap = y[indices]
        else:
            X_bootstrap = X
            y_bootstrap = y

        tree = build_pct(X_bootstrap, y_bootstrap, max_depth=max_depth, min_samples_split=min_samples_split, w=w)
        forest.append(tree)
        
    return forest

In [None]:
def predict_pct(tree, X, num_classes):
    if not tree:
        return np.zeros((X.shape[0], num_classes))

    if tree['type'] == 'leaf':
        predictions = np.tile(tree['prediction'], (X.shape[0], 1))
        nan_mask = np.isnan(predictions).any(axis=1)
        if np.any(nan_mask) and 'prototype' in tree:
            predictions[nan_mask] = tree['prototype']
        return predictions

    predictions = np.zeros((X.shape[0], num_classes))
    left_mask = X[:, tree['feature_idx']] <= tree['threshold']
    right_mask = ~left_mask

    if np.any(left_mask):
        predictions[left_mask] = predict_pct(tree['left'], X[left_mask], num_classes)
    if np.any(right_mask):
        predictions[right_mask] = predict_pct(tree['right'], X[right_mask], num_classes)

    nan_mask = np.isnan(predictions).any(axis=1)
    if np.any(nan_mask) and 'prototype' in tree:
        predictions[nan_mask] = np.tile(tree['prototype'], (np.sum(nan_mask), 1))

    return predictions

In [None]:
def post_process_forest(forest, X_labeled, y_labeled):    
    if len(X_labeled) == 0:
        return forest
    
    # Build a nearest neighbors model on the labeled data
    nn_model = NearestNeighbors(n_neighbors=1)
    nn_model.fit(X_labeled)
    
    def fix_nan_nodes(node):
        if node is None or not isinstance(node, dict):
            return
            
        # If this is a leaf node with NaN predictions
        if node['type'] == 'leaf' and (np.isnan(node['prediction']).any() or node['prediction'].sum() == 0):
            # Create a dummy example to find nearest labeled sample
            dummy_X = np.zeros((1, X_labeled.shape[1]))
            if 'sample_indices' in node and len(node['sample_indices']) > 0:
                # If we stored sample indices, use the mean feature vector
                dummy_X[0] = np.mean(X_labeled[node['sample_indices']], axis=0)
                
            distances, indices = nn_model.kneighbors(dummy_X)
            node['prediction'] = y_labeled[indices[0][0]]
            
        if 'left' in node:
            fix_nan_nodes(node['left'])
        if 'right' in node:
            fix_nan_nodes(node['right'])
    
    for tree in forest:
        fix_nan_nodes(tree)
        
    return forest

In [None]:
def predict_pct_forest(forest, X, num_classes):
    predictions = []
    for tree in forest:
        tree_preds = predict_pct(tree, X, num_classes)
        
        if np.isnan(tree_preds).any():
            tree_preds = np.nan_to_num(tree_preds, nan=0.0)
            
        predictions.append(tree_preds)

    avg_predictions = np.mean(np.array(predictions), axis=0)
    return avg_predictions

In [None]:
X_combined = np.vstack((labeled_images_features, unlabeled_images_features))
y_combined = np.vstack((labeled_images_labels, np.full((unlabeled_images_features.shape[0], num_classes), np.nan)))
X_labeled = labeled_images_features
y_labeled = labeled_images_labels

num_classes = y_combined.shape[1]

In [None]:
best_score = -np.inf
best_w = 1.0

w_values = []
mean_scores = []

for w in tqdm(np.arange(0, 1.1, 0.1), desc="Optimizing w"):
    fold_scores = []
    
    kf = MultilabelStratifiedShuffleSplit(n_splits=3, test_size=0.3)
    splits = kf.split(X_labeled, y_labeled)

    for train_idx, val_idx in splits:
        X_train, X_val = X_labeled[train_idx], X_labeled[val_idx]
        y_train, y_val = y_labeled[train_idx], y_labeled[val_idx]

        # Add unlabeled data with dummy labels
        X_train_full = np.vstack((X_train, unlabeled_images_features))
        y_train_full = np.vstack((y_train, np.full((unlabeled_images_features.shape[0], num_classes), np.nan)))

        tree = build_pct(X_train_full, y_train_full, max_depth=5, w=w)

        # Validate on labeled validation set
        y_pred = predict_pct(tree, X_val, num_classes)
        score = average_precision_score(y_val, y_pred, average='micro')
        fold_scores.append(score)

    mean_score = np.mean(fold_scores)
    w_values.append(w)
    mean_scores.append(mean_score)

    if mean_score > best_score:
        best_score = mean_score
        best_w = w

print(f"Optimal w: {best_w:.1f}")

plt.figure(figsize=(8, 5))
plt.plot(w_values, mean_scores, marker='o', linestyle='-', color='b', label='Mean AUPRC Score')
plt.axvline(best_w, color='r', linestyle='--', label=f'Best w = {best_w:.1f}')
plt.xlabel('w')
plt.ylabel('Mean Score (AUPRC)')
plt.title('Optimization of w')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
best_w=0.4

In [None]:
ssl_forest = build_pct_forest(
    X_combined, 
    y_combined, 
    n_estimators=1, 
    max_depth=10, 
    w=best_w
)
ssl_forest = post_process_forest(ssl_forest, labeled_images_features, labeled_images_labels)

In [None]:
y_pred_ssl = predict_pct_forest(ssl_forest, test_images_features, num_classes)
micro_auprc = average_precision_score(test_labels, y_pred_ssl, average='micro')
macro_auprc = average_precision_score(test_labels, y_pred_ssl, average='macro')

print("\nFinal Results:")
print(f"Micro-AUPRC: {micro_auprc:.4f}, Macro-AUPRC: {macro_auprc:.4f}")

In [None]:
N_values = [1, 5, 10, 25]
auprc_values = {
    "SSL-RForest": ([84.69, 93.72, 95.58, 97.03]),
    "SL-RForest": ([78.23, 91.07, 94.98, 96.5]),
    "SSL-PCT": ([52.35, 66.69, 73.73, 82.27]),
    "SL-PCT": ([46.80, 60.66, 68.24, 80.56])
}

styles = {
    "SSL-RForest": ("black", "d", "-."),
    "SL-RForest": ("blue", "*", "-."),
    "SSL-PCT": ("green", "o", "-."),
    "SL-PCT": ("red", "s", "-.")
}

plt.figure(figsize=(7, 3.5))
for key, means in auprc_values.items():
    color, marker, linestyle = styles[key]
    plt.errorbar(N_values, means, fmt=marker, linestyle=linestyle, 
                 color=color, capsize=3, label=key)

plt.title("DFC-15 (MLC), EfficientNetB2")
plt.xlabel("N (%)", fontsize=12)
plt.ylabel("AUPRC", fontsize=12)
plt.xticks(N_values)
plt.yticks([50,60,70,80,90,100])
plt.legend()
plt.grid(True, linestyle="--", alpha=0.5)

plt.gca().yaxis.set_label_coords(-0.06, 0.5) 
plt.show()

In [None]:
N_values = [1, 5, 10, 25]
auprc_values = {
    "SSL-RForest": ([80.56,93.49,96.17,97.69]),
    "SL-RForest": ([76.32,91.64,94.19,96.13]),
    "SSL-PCT": ([61.37, 74.61, 79.02, 83.80]),
    "SL-PCT": ([55.73, 70.46, 71.96, 78.21])
}

styles = {
    "SSL-RForest": ("black", "d", "-."),
    "SL-RForest": ("blue", "*", "-."),
    "SSL-PCT": ("green", "o", "-."),
    "SL-PCT": ("red", "s", "-.")
}

plt.figure(figsize=(7, 5))
for key, means in auprc_values.items():
    color, marker, linestyle = styles[key]
    plt.errorbar(N_values, means, fmt=marker, linestyle=linestyle, 
                 color=color, capsize=3, label=key)

plt.title("DFC-15 (MLC), ResNet50")
plt.xlabel("N (%)", fontsize=12)
plt.ylabel("AUPRC", fontsize=12)
plt.xticks(N_values)
plt.yticks([50,60,70,80,90,100])
plt.legend()
plt.grid(True, linestyle="--", alpha=0.5)

plt.gca().yaxis.set_label_coords(-0.10, 0.5) 
plt.show()


In [None]:
N_values = [1, 5, 10, 25]
auprc_values = {
    "SSL-RForest": ([31.81, 80.11, 91.98, 92.32]),
    "SL-RForest": ([28.89, 66.54, 89.23, 91.12]),
    "SSL-PCT": ([11.38, 42.78, 51.66, 58.22]),
    "SL-PCT": ([5.63, 20.37, 31.56, 55.43])
}

styles = {
    "SSL-RForest": ("black", "d", "-."),
    "SL-RForest": ("blue", "*", "-."),
    "SSL-PCT": ("green", "o", "-."),
    "SL-PCT": ("red", "s", "-.")
}

plt.figure(figsize=(7, 5))
for key, means in auprc_values.items():
    color, marker, linestyle = styles[key]
    plt.errorbar(N_values, means, fmt=marker, linestyle=linestyle, 
                 color=color, capsize=3, label=key)

plt.title("OPTIMAL-31 (MCC), ResNet50")
plt.xlabel("N (%)", fontsize=12)
plt.ylabel("AUPRC", fontsize=12)
plt.xticks(N_values)
plt.yticks([0,20,40,60,80,100])
plt.legend()
plt.grid(True, linestyle="--", alpha=0.5)

plt.gca().yaxis.set_label_coords(-0.10, 0.5) 
plt.show()


In [None]:
N_values = [1, 5, 10, 25]
auprc_values = {
    "SSL-RForest": ([34.49, 84.26, 92.38, 93.57]),
    "SL-RForest": ([31.47, 74.19, 90.11, 91.96]),
    "SSL-PCT": ([13.78, 47.87, 63.16, 64.37]),
    "SL-PCT": ([5.77, 17.89, 38.50, 57.44])
}

styles = {
    "SSL-RForest": ("black", "d", "-."),
    "SL-RForest": ("blue", "*", "-."),
    "SSL-PCT": ("green", "o", "-."),
    "SL-PCT": ("red", "s", "-.")
}

plt.figure(figsize=(7, 3.5))
for key, means in auprc_values.items():
    color, marker, linestyle = styles[key]
    plt.errorbar(N_values, means, fmt=marker, linestyle=linestyle, 
                 color=color, capsize=3, label=key)

plt.title("OPTIMAL-31 (MCC), EfficientNetB2")
plt.xlabel("N (%)", fontsize=12)
plt.ylabel("AUPRC", fontsize=12)
plt.xticks(N_values)
plt.yticks([0,20,40,60,80,100])
plt.legend()
plt.grid(True, linestyle="--", alpha=0.5)

plt.gca().yaxis.set_label_coords(-0.06, 0.5) 
plt.show()

In [None]:
w_values = np.arange(0, 1.1, 0.1)
mean_scores = [0.35, 0.48, 0.69, 0.71, 0.82, 0.81, 0.77, 0.79, 0.69, 0.63, 0.59]

best_idx = np.argmax(mean_scores)
best_w = w_values[best_idx]
best_score = mean_scores[best_idx]

plt.figure(figsize=(8, 5))
plt.plot(w_values, mean_scores, marker='o', label='Mean Score')
plt.axvline(best_w, color='r', linestyle='--', label=f'Best w = {best_w:.1f}')
plt.xlabel('w')
plt.ylabel('Mean Score')
plt.title('w vs Mean Score')
plt.grid(True)
plt.legend()
plt.show()