In [None]:
!pip install datasets
!pip install accelerate -U

In [None]:
import timm
from PIL import Image
from datasets import load_dataset
import numpy as np
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score, roc_curve, auc, roc_auc_score
import seaborn as sns

# Load the dataset
dataset = load_dataset('GaryHuang/NTU_geolocation')
full_dataset = dataset['train']

# Verify the dataset size
print(f"Full dataset size: {len(full_dataset)}")

# Convert string labels to numeric labels
label_encoder = LabelEncoder()
numerical_labels = label_encoder.fit_transform(full_dataset['label'])  # Replace 'label' with the actual label column name if different

# Replace the original labels with numeric labels
full_dataset = full_dataset.remove_columns('label')
full_dataset = full_dataset.add_column('label', numerical_labels)

# Define the sample size (e.g., 10% of the full dataset)
sample_size = int(len(full_dataset))  # Adjust this percentage based on your needs

# Shuffle the dataset and sample a subset
np.random.seed(42)  # For reproducibility
sample_indices = np.random.permutation(len(full_dataset))[:sample_size]
sampled_dataset = full_dataset.select(sample_indices)

# Verify the sampled dataset size
print(f"Sampled dataset size: {len(sampled_dataset)}")

# Define the sizes for train and test split: 80%, 20%
train_size = int(0.85 * len(sampled_dataset))
test_size = len(sampled_dataset) - train_size  # Ensure all samples are used

# Generate shuffled indices for the sampled dataset
indices = np.random.permutation(len(sampled_dataset))

# Verify indices are within bounds
print(f"Max index in sampled dataset: {indices.max()}")
print(f"Sampled dataset length: {len(sampled_dataset)}")

# Split indices for each dataset
train_indices = indices[:train_size]
test_indices = indices[train_size:]

# Create subsets using the indices
train_subset = sampled_dataset.select(train_indices)
test_subset = sampled_dataset.select(test_indices)

# Print the sizes to confirm
print(f"Train subset size: {len(train_subset)}")
print(f"Test subset size: {len(test_subset)}")

# Load the pre-trained model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = timm.create_model('vit_large_patch16_224.augreg_in21k_ft_in1k', pretrained=True, num_classes=0)
model = model.eval().to(device)

# Get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
train_transform = timm.data.create_transform(**data_config, is_training=True)
transform = timm.data.create_transform(**data_config, is_training=False)

class CustomDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        if idx >= len(self.dataset):
            raise IndexError(f"Index {idx} is out of bounds for dataset with size {len(self.dataset)}")
        item = self.dataset[idx]
        image = Image.fromarray(np.array(item['image']))
        label = item['label']
        
        if self.transform:
            image = self.transform(image)
        
        return image, label
    
    def get_embeddings(self, model, device):
        embeddings = []
        labels = []
        with torch.no_grad():
            for idx in range(len(self.dataset)):
                if idx >= len(self.dataset):
                    print(f"Index {idx} is out of bounds for dataset with size {len(self.dataset)}")
                    continue
                item = self.dataset[idx]
                image = Image.fromarray(np.array(item['image']))
                label = item['label']
                image = self.transform(image).unsqueeze(0).to(device)
                embedding = model.forward_features(image)
                embedding = model.forward_head(embedding, pre_logits=True)
                embeddings.append(embedding.cpu().numpy().flatten())
                labels.append(label)
        return np.array(embeddings), np.array(labels)

# Create datasets
train_dataset = CustomDataset(train_subset, transform=train_transform)
test_dataset = CustomDataset(test_subset, transform=transform)

# Debug print statement to verify dataset lengths
print(f"Train dataset length: {len(train_dataset)}")
print(f"Test dataset length: {len(test_dataset)}")

# Extract embeddings
print("Extracting train embeddings...")
train_embeddings, train_labels = train_dataset.get_embeddings(model, device)
print("Extracting test embeddings...")
test_embeddings, test_labels = test_dataset.get_embeddings(model, device)

# Debug print statement to verify embeddings
print(f"Train embeddings shape: {train_embeddings.shape}")
print(f"Test embeddings shape: {test_embeddings.shape}")

# Verify embeddings and labels length
print(f"Train embeddings length: {len(train_embeddings)}, Train labels length: {len(train_labels)}")
print(f"Test embeddings length: {len(test_embeddings)}, Test labels length: {len(test_labels)}")

# KNN from scratch
class KNNClassifier:
    def __init__(self, k=3):
        self.k = k
    
    def fit(self, X_train, y_train):
        self.X_train = X_train
        self.y_train = y_train
    
    def predict(self, X_test):
        predictions = []
        for x in X_test:
            distances = np.linalg.norm(self.X_train - x, axis=1)
            k_indices = np.argsort(distances)[:self.k]
            k_nearest_labels = self.y_train[k_indices]
            most_common = np.bincount(k_nearest_labels).argmax()
            predictions.append(most_common)
        return np.array(predictions)
    
    def predict_proba(self, X_test):
        probs = []
        for x in X_test:
            distances = np.linalg.norm(self.X_train - x, axis=1)
            k_indices = np.argsort(distances)[:self.k]
            k_nearest_labels = self.y_train[k_indices]
            proba = np.bincount(k_nearest_labels, minlength=len(np.unique(self.y_train))) / self.k
            probs.append(proba)
        return np.array(probs)

# Using the custom KNN
knn = KNNClassifier(k=3)
knn.fit(train_embeddings, train_labels)

# Evaluate on test set
test_preds = knn.predict(test_embeddings)
test_probs = knn.predict_proba(test_embeddings)
test_accuracy = accuracy_score(test_labels, test_preds)
print(f"Test Accuracy: {test_accuracy:.4f}")
print("Test Classification Report:")
print(classification_report(test_labels, test_preds))

# Save the results
# Confusion matrix
cm = confusion_matrix(test_labels, test_preds)
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.savefig('confusion_matrix.png')
plt.close()

# Calculate evaluation metrics
f1 = f1_score(test_labels, test_preds, average='weighted')
roc_auc_macro = roc_auc_score(test_labels, test_probs, multi_class='ovr', average='macro')
roc_auc_weighted = roc_auc_score(test_labels, test_probs, multi_class='ovr', average='weighted')

# Plot and save ROC curve
n_classes = len(label_encoder.classes_)
fpr = {}
tpr = {}
roc_auc = {}

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(test_labels == i, test_probs[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

plt.figure()
for i in range(n_classes):
    plt.plot(fpr[i], tpr[i], label=f'Class {i} (AUC = {roc_auc[i]:.2f})')

plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.savefig('roc_curve.png')
plt.close()

# Write report to text file
with open('evaluation_report.txt', 'w') as f:
    f.write(f"Evaluation Accuracy: {test_accuracy:.4f}\n")
    f.write(f"F1 Score: {f1:.4f}\n")
    f.write(f"Macro AUC: {roc_auc_macro:.4f}\n")
    f.write(f"Weighted AUC: {roc_auc_weighted:.4f}\n")
    f.write("\nClass-wise AUC Scores:\n")
    for i in range(n_classes):
        f.write(f"Class {i} AUC: {roc_auc[i]:.4f}\n")

print("Evaluation report saved to evaluation_report.txt")
print("Confusion matrix image saved to confusion_matrix.png")
print("ROC curve image saved to roc_curve.png")