In [None]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
from torchvision.models import vit_l_16, ViT_L_16_Weights
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from sklearn.metrics import normalized_mutual_info_score
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment
from finch import FINCH
from tqdm import tqdm

# ======================== DEVICE ========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ======================== VIT FEATURE EXTRACTOR ========================
class ViTFeatureExtractor(nn.Module):
    def __init__(self):
        super(ViTFeatureExtractor, self).__init__()
        self.backbone = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1)
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone.heads = nn.Identity()

    def forward(self, x):
        return self.backbone(x)

model = ViTFeatureExtractor().to(device)

# ======================== EKSTRAK FITUR & SIMPAN ========================
def extract_features_to_file(model, data_loader, save_path):
    model.eval()
    final_list = []
    with torch.no_grad():
        for images, labels in tqdm(data_loader):
            images = images.to(device)
            features = model(images).cpu().numpy()
            labels = labels.numpy()
            for feat, label in zip(features, labels):
                final_list.append([feat, label])
    np.save(save_path, np.array(final_list, dtype=object))

# ======================== CLUSTERING METRIK ========================
def clustering_accuracy(y_true, y_pred):
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    row_ind, col_ind = linear_sum_assignment(w.max() - w)
    return w[row_ind, col_ind].sum() / y_pred.size

# ======================== FITUR EKSTRAKSI DARI CLASSIFIER ========================
def extract_features(loader, model):
    model.eval()
    all_feats, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            feat = model.getHidden(x).cpu().numpy()
            all_feats.append(feat)
            all_labels.append(y.numpy())
    return np.concatenate(all_feats), np.concatenate(all_labels)

# ======================== DATASET FITUR ========================
class FeatureDataset(Dataset):
    def __init__(self, data):
        self.features = [item[0] for item in data]
        self.labels = [item[1] for item in data]

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return torch.tensor(self.features[idx], dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.long)

# ======================== SIMPLE CLASSIFIER ========================
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=4096, num_classes=12):
        super(SimpleClassifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        return self.model(x)

    def getHidden(self, x):
        x = self.model[0](x)
        return x

# ======================== LOOP TRIAL ========================
results_final = []

for trial in range(5):
    # Transforms
    labeled_transform = transforms.Compose([
        transforms.RandAugment(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    unlabeled_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    # Dataset
    labeled_folder = fr"C:\Users\HP\novelty\split_datasets\trial_{trial}\42labeled_10unlabeled\labeled"
    unlabeled_folder = fr"C:\Users\HP\novelty\split_datasets\trial_{trial}\42labeled_10unlabeled\unlabeled"

    labeled_dataset = ImageFolder(root=labeled_folder, transform=labeled_transform)
    unlabeled_dataset = ImageFolder(root=unlabeled_folder, transform=unlabeled_transform)

    batch_size = 16
    labeled_loader_raw = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=True)
    unlabeled_loader_raw = DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=False)

    # Ekstraksi fitur
    labeled_feat_path = fr"C:\Users\HP\novelty\split_datasets\trial_{trial}\42labeled_10unlabeled\labeled_features_vit_l_16_42labeled_10unlabeled.npy"
    unlabeled_feat_path = fr"C:\Users\HP\novelty\split_datasets\trial_{trial}\42labeled_10unlabeled\unlabeled_features_vit_l_16_42labeled_10unlabeled.npy"

    # extract_features_to_file(model, labeled_loader_raw, labeled_feat_path)
    # extract_features_to_file(model, unlabeled_loader_raw, unlabeled_feat_path)

    labeled_data = np.load(labeled_feat_path, allow_pickle=True)
    unlabeled_data = np.load(unlabeled_feat_path, allow_pickle=True)

    labeled_loader = DataLoader(FeatureDataset(labeled_data), batch_size=batch_size, shuffle=True)
    unlabeled_loader = DataLoader(FeatureDataset(unlabeled_data), batch_size=batch_size, shuffle=False)

    # Classifier setup
    num_classes = len(set([l[1] for l in labeled_data]))
    classifier = SimpleClassifier(input_dim=1024, hidden_dim=4096, num_classes=num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

    results = []
    finch_partitions = []
    all_image_info = []

    for epoch in range(50):
        classifier.train()
        running_loss = 0.0

        for x, y in labeled_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = classifier(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        train_loss = running_loss / len(labeled_loader)
        print(f"Epoch {epoch+1} | Train Loss: {train_loss:.4f}")

        feats, true_labels = extract_features(unlabeled_loader, classifier)

        # FINCH clustering
        c, num_clust, _ = FINCH(feats, use_ann_above_samples=1000, verbose=True)
        partition_info = {"Epoch": epoch + 1}
        for i, clust in enumerate(num_clust):
            partition_info[f"Partition {i}"] = clust
        finch_partitions.append(partition_info)

        finch_clusters = c[:, 2]
        finch_nmi = normalized_mutual_info_score(true_labels, finch_clusters)
        finch_acc = clustering_accuracy(true_labels, finch_clusters)

        # KMeans clustering
        num_clusters = len(np.unique(true_labels))
        kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(feats)
        kmeans_clusters = kmeans.labels_
        kmeans_nmi = normalized_mutual_info_score(true_labels, kmeans_clusters)
        kmeans_acc = clustering_accuracy(true_labels, kmeans_clusters)

        print(f"FINCH ACC: {finch_acc:.4f} | NMI: {finch_nmi:.4f} | Clusters: {len(np.unique(finch_clusters))}")
        print(f"KMeans ACC: {kmeans_acc:.4f} | NMI: {kmeans_nmi:.4f} | Clusters: {len(np.unique(kmeans_clusters))}\n")

        results.append({
            "Epoch": epoch + 1,
            "Loss": train_loss,
            "FINCH_ACC": finch_acc,
            "FINCH_NMI": finch_nmi,
            "FINCH_Clusters": len(np.unique(finch_clusters)),
            "KMeans_ACC": kmeans_acc,
            "KMeans_NMI": kmeans_nmi,
            "KMeans_Clusters": len(np.unique(kmeans_clusters))
        })

    unlabeled_features, unlabeled_true_labels = extract_features(unlabeled_loader, classifier)
    
    c, num_clust, req_c = FINCH(unlabeled_features, use_ann_above_samples=1000, verbose=True)
    finch_clusters = c[:,2]

    finch_nmi = normalized_mutual_info_score(unlabeled_true_labels, finch_clusters)
    finch_acc = clustering_accuracy(unlabeled_true_labels, finch_clusters)

    print(f"\nFINCH Clustering Performance:")
    print(f"Clustering Accuracy (ACC): {finch_acc:.4f}")
    print(f"Normalized Mutual Information (NMI): {finch_nmi:.4f}")
    print(f"Number of clusters found: {len(np.unique(finch_clusters))}")
    
    num_clusters = len(np.unique(unlabeled_true_labels))
    kmeans = KMeans(n_clusters=num_clusters, random_state=42)
    kmeans_clusters = kmeans.fit_predict(unlabeled_features)

    # ----- Evaluasi -----
    kmeans_nmi = normalized_mutual_info_score(unlabeled_true_labels, kmeans_clusters)
    kmeans_acc = clustering_accuracy(unlabeled_true_labels, kmeans_clusters)

    print(f"\nK-Means Clustering Performance: {num_clusters} Clusters")
    print(f"Clustering Accuracy (ACC): {kmeans_acc:.4f}")
    print(f"Normalized Mutual Information (NMI): {kmeans_nmi:.4f}")
    print(f"Number of clusters found: {len(np.unique(kmeans_clusters))}")
    
    # Simpan hasil trial ini
    results_final.append({
        'trial': trial,
        'FINCH_ACC': finch_acc,
        'FINCH_NMI': finch_nmi,
        'FINCH_Num_Clusters': len(np.unique(finch_clusters)),
        'KMeans_ACC': kmeans_acc,
        'KMeans_NMI': kmeans_nmi,
        'KMeans_Num_Clusters': len(np.unique(kmeans_clusters)),
    })
    
    # Simpan hasil per gambar
    for i in range(len(unlabeled_dataset)):
        path, true_label = unlabeled_dataset.samples[i]
        image_info = {
            'trial': trial,
            'image_path': path,
            'true_label': true_label,
            'finch_cluster': int(finch_clusters[i]),
            'kmeans_cluster': int(kmeans_clusters[i])
        }
        all_image_info.append(image_info)

    # Simpan file CSV per trial
    df_trial_detail = pd.DataFrame(all_image_info[-len(unlabeled_dataset):])  # ambil data dari trial ini saja
    df_trial_detail.to_csv(f'vit_l_16_finetuned_42labeled_10unlabeled_trial_{trial}_image_clustering.csv', index=False)

    # Simpan hasil
    result_csv = fr"clustering_metrics_per_epoch_vit_l_16_finetuned_42labeled_10unlabeled_trial_{trial}.csv"
    partition_csv = fr"finch_partitions_per_epoch_vit_l_16_finetuned_42labeled_10unlabeled_trial_{trial}.csv"

    pd.DataFrame(results).to_csv(result_csv, index=False)
    pd.DataFrame(finch_partitions).to_csv(partition_csv, index=False)

    print(f"\n📄 Hasil clustering disimpan di: {result_csv}")
    print(f"📄 Partisi FINCH disimpan di: {partition_csv}")

df_results = pd.DataFrame(results_final)
mean_values = df_results.select_dtypes(include=np.number).mean()

mean_row = pd.DataFrame({
    'trial': ['Average'],
    'FINCH_ACC': [mean_values['FINCH_ACC']],
    'FINCH_NMI': [mean_values['FINCH_NMI']],
    'FINCH_Num_Clusters': [mean_values['FINCH_Num_Clusters']],
    'KMeans_ACC': [mean_values['KMeans_ACC']],
    'KMeans_NMI': [mean_values['KMeans_NMI']],
    'KMeans_Num_Clusters': [mean_values['KMeans_Num_Clusters']]
})

df_results = pd.concat([df_results, mean_row], ignore_index=True)

df_results.to_csv('clustering_results_vit_l_16_finetuned_42labeled_10unlabeled_FINCH_KMEANS.csv', index=False)
print("Results saved to clustering_results_vit_l_16_finetuned_42labeled_10unlabeled_FINCH_KMEANS.csv")


Epoch 1 | Train Loss: 0.5242
Partition 0: 202 clusters
Partition 1: 45 clusters
Partition 2: 14 clusters
Partition 3: 6 clusters
Partition 4: 2 clusters
FINCH ACC: 0.8040 | NMI: 0.8790 | Clusters: 14
KMeans ACC: 0.9420 | NMI: 0.9079 | Clusters: 10

Epoch 2 | Train Loss: 0.1482
Partition 0: 208 clusters
Partition 1: 37 clusters
Partition 2: 12 clusters
Partition 3: 5 clusters
FINCH ACC: 0.8450 | NMI: 0.8830 | Clusters: 12
KMeans ACC: 0.8940 | NMI: 0.8735 | Clusters: 10

Epoch 3 | Train Loss: 0.0752
Partition 0: 213 clusters
Partition 1: 51 clusters
Partition 2: 14 clusters
Partition 3: 3 clusters
FINCH ACC: 0.7280 | NMI: 0.8573 | Clusters: 14
KMeans ACC: 0.7630 | NMI: 0.8458 | Clusters: 10

Epoch 4 | Train Loss: 0.1216
Partition 0: 212 clusters
Partition 1: 43 clusters
Partition 2: 14 clusters
Partition 3: 7 clusters
Partition 4: 6 clusters
FINCH ACC: 0.7970 | NMI: 0.8777 | Clusters: 14
KMeans ACC: 0.7470 | NMI: 0.8023 | Clusters: 10

Epoch 5 | Train Loss: 0.1162
Partition 0: 212 cluste