In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from timm import create_model
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset and transforms
data_dir = "/kaggle/input/breakhist-binary-classificationfew-shot/BreakHist/Final 200X"  # Change to your dataset path
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = ImageFolder(root=f"{data_dir}/Train", transform=transform)
test_dataset = ImageFolder(root=f"{data_dir}/Test", transform=transform)

# Load ViT_small pre-trained on ImageNet-1k
vit_model = create_model('vit_small_patch16_224', pretrained=True)
vit_model.head = nn.Linear(vit_model.num_features, len(train_dataset.classes))  # Add classification head
vit_model.to(device)

# Hyperparameters
num_epochs = 50
learning_rate = 0.001
batch_size = 32
shots_per_class = 5

## Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate)

# Prepare support and query sets (only once)
def create_support_and_query(dataset, shots_per_class):
    class_indices = {label: [] for label in range(len(dataset.classes))}
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)

    support_indices, query_indices = [], []
    for label, indices in class_indices.items():
        np.random.shuffle(indices)
        support_indices.extend(indices[:shots_per_class])
        query_indices.extend(indices[shots_per_class:])

    return Subset(dataset, support_indices), Subset(dataset, query_indices)

support_set, query_set = create_support_and_query(train_dataset, shots_per_class)

support_loader = DataLoader(support_set, batch_size=1, shuffle=False)
query_loader = DataLoader(query_set, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)  # Define test_loader

# Feature extraction
def extract_features(data_loader, model):
    features, labels = [], []
    model.eval()
    with torch.no_grad():
        for images, targets in tqdm(data_loader, desc="Extracting features"):
            images = images.to(device)
            output = model(images).cpu().numpy()
            features.append(output)
            labels.append(targets.numpy())
    return np.vstack(features), np.hstack(labels)

# Training loop with support set only
def train_model_on_support(model, support_loader, criterion, optimizer, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        correct_train = 0
        total_train = 0

        for images, labels in tqdm(support_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_accuracy = 100 * correct_train / total_train
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%")

# Train the model on the support set
train_model_on_support(vit_model, support_loader, criterion, optimizer, num_epochs)

# Evaluate on merged query and test dataset
def evaluate_on_merged(model, query_loader, test_loader):
    # Extract features from query set
    query_features, query_labels = extract_features(query_loader, model)
    
    # Extract features from test set
    test_features, test_labels = extract_features(test_loader, model)

    # Combine query and test features and labels
    all_features = np.vstack((query_features, test_features))
    all_labels = np.hstack((query_labels, test_labels))

    # Compute class prototypes from support set
    support_features, support_labels = extract_features(support_loader, model)
    class_prototypes = []
    for label in np.unique(support_labels):
        class_features = support_features[support_labels == label]
        class_prototypes.append(class_features.mean(axis=0))
    class_prototypes = np.array(class_prototypes)

    # Classification using cosine similarity
    correct_merged = 0
    for feature, label in zip(all_features, all_labels):
        similarities = cosine_similarity(feature.reshape(1, -1), class_prototypes)
        probs = torch.softmax(torch.tensor(similarities), dim=1).numpy().flatten()
        predicted_class = np.argmax(probs)
        if predicted_class == label:
            correct_merged += 1

    merged_accuracy = correct_merged / len(all_labels) * 100
    print(f"Accuracy on merged query and test dataset: {merged_accuracy:.4f}%")

# Evaluate final model on the merged dataset
evaluate_on_merged(vit_model, query_loader, test_loader)

model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

Epoch 1/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.51it/s]


Epoch [1/100], Loss: 28.7897, Training Accuracy: 75.98%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 97.50it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.53it/s]


Epoch [1/100], Query Accuracy: 59.39%
Best model saved with Query Accuracy: 59.39%


Epoch 2/100 - Training: 100%|██████████| 28/28 [00:10<00:00,  2.62it/s]


Epoch [2/100], Loss: 14.5792, Training Accuracy: 79.55%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 129.63it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.05it/s]


Epoch [2/100], Query Accuracy: 60.74%
Best model saved with Query Accuracy: 60.74%


Epoch 3/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.47it/s]


Epoch [3/100], Loss: 13.9348, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 115.23it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.88it/s]


Epoch [3/100], Query Accuracy: 62.88%
Best model saved with Query Accuracy: 62.88%


Epoch 4/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [4/100], Loss: 19.9202, Training Accuracy: 75.98%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 122.66it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.53it/s]


Epoch [4/100], Query Accuracy: 62.99%
Best model saved with Query Accuracy: 62.99%


Epoch 5/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.20it/s]


Epoch [5/100], Loss: 14.2611, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 120.59it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 127.46it/s]


Epoch [5/100], Query Accuracy: 53.21%


Epoch 6/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


Epoch [6/100], Loss: 14.0296, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 113.43it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.29it/s]


Epoch [6/100], Query Accuracy: 62.77%


Epoch 7/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.36it/s]


Epoch [7/100], Loss: 13.9851, Training Accuracy: 80.22%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 117.86it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.97it/s]


Epoch [7/100], Query Accuracy: 61.75%


Epoch 8/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.36it/s]


Epoch [8/100], Loss: 13.4757, Training Accuracy: 80.67%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 124.01it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.91it/s]


Epoch [8/100], Query Accuracy: 66.37%
Best model saved with Query Accuracy: 66.37%


Epoch 9/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [9/100], Loss: 14.0930, Training Accuracy: 79.55%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 120.66it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.28it/s]


Epoch [9/100], Query Accuracy: 67.94%
Best model saved with Query Accuracy: 67.94%


Epoch 10/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [10/100], Loss: 13.7165, Training Accuracy: 80.45%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 121.92it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.03it/s]


Epoch [10/100], Query Accuracy: 41.84%


Epoch 11/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.36it/s]


Epoch [11/100], Loss: 13.7440, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 102.62it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 138.35it/s]


Epoch [11/100], Query Accuracy: 39.48%


Epoch 12/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [12/100], Loss: 14.0170, Training Accuracy: 80.56%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.57it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.79it/s]


Epoch [12/100], Query Accuracy: 41.39%


Epoch 13/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [13/100], Loss: 13.9802, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 123.53it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.51it/s]


Epoch [13/100], Query Accuracy: 38.81%


Epoch 14/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [14/100], Loss: 13.1074, Training Accuracy: 80.56%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.29it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.62it/s]


Epoch [14/100], Query Accuracy: 69.85%
Best model saved with Query Accuracy: 69.85%


Epoch 15/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [15/100], Loss: 12.8782, Training Accuracy: 80.78%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 111.41it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 129.67it/s]


Epoch [15/100], Query Accuracy: 67.15%


Epoch 16/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.31it/s]


Epoch [16/100], Loss: 13.5770, Training Accuracy: 79.66%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.57it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.84it/s]


Epoch [16/100], Query Accuracy: 65.02%


Epoch 17/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


Epoch [17/100], Loss: 13.4813, Training Accuracy: 80.67%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 108.15it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.24it/s]


Epoch [17/100], Query Accuracy: 70.42%
Best model saved with Query Accuracy: 70.42%


Epoch 18/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [18/100], Loss: 12.9681, Training Accuracy: 80.22%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 116.75it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.06it/s]


Epoch [18/100], Query Accuracy: 69.29%


Epoch 19/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.30it/s]


Epoch [19/100], Loss: 13.7746, Training Accuracy: 79.55%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 121.07it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 131.75it/s]


Epoch [19/100], Query Accuracy: 48.59%


Epoch 20/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [20/100], Loss: 14.3749, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.51it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 130.14it/s]


Epoch [20/100], Query Accuracy: 65.24%


Epoch 21/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [21/100], Loss: 13.8756, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.59it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.75it/s]


Epoch [21/100], Query Accuracy: 55.91%


Epoch 22/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [22/100], Loss: 13.5662, Training Accuracy: 80.45%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 117.28it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.07it/s]


Epoch [22/100], Query Accuracy: 69.29%


Epoch 23/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [23/100], Loss: 13.5662, Training Accuracy: 81.01%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.87it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.94it/s]


Epoch [23/100], Query Accuracy: 68.95%


Epoch 24/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [24/100], Loss: 13.5891, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.61it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.63it/s]


Epoch [24/100], Query Accuracy: 59.17%


Epoch 25/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [25/100], Loss: 13.6819, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 114.43it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 128.99it/s]


Epoch [25/100], Query Accuracy: 65.02%


Epoch 26/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [26/100], Loss: 13.3817, Training Accuracy: 80.45%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 107.06it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.21it/s]


Epoch [26/100], Query Accuracy: 65.58%


Epoch 27/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [27/100], Loss: 13.5210, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 117.96it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 138.22it/s]


Epoch [27/100], Query Accuracy: 66.25%


Epoch 28/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [28/100], Loss: 14.0839, Training Accuracy: 80.11%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 123.65it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.10it/s]


Epoch [28/100], Query Accuracy: 65.24%


Epoch 29/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [29/100], Loss: 14.2227, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.16it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.58it/s]


Epoch [29/100], Query Accuracy: 62.65%


Epoch 30/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [30/100], Loss: 13.6288, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 112.03it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 130.88it/s]


Epoch [30/100], Query Accuracy: 69.40%


Epoch 31/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [31/100], Loss: 13.6612, Training Accuracy: 80.45%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 121.72it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.05it/s]


Epoch [31/100], Query Accuracy: 65.02%


Epoch 32/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.30it/s]


Epoch [32/100], Loss: 13.2089, Training Accuracy: 80.56%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 100.69it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 134.09it/s]


Epoch [32/100], Query Accuracy: 60.52%


Epoch 33/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [33/100], Loss: 13.6976, Training Accuracy: 79.89%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.21it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.83it/s]


Epoch [33/100], Query Accuracy: 59.39%


Epoch 34/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [34/100], Loss: 13.5311, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 112.85it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.57it/s]


Epoch [34/100], Query Accuracy: 67.27%


Epoch 35/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


Epoch [35/100], Loss: 13.3053, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 116.34it/s]
Extracting features: 100%|██████████| 889/889 [00:07<00:00, 126.73it/s]


Epoch [35/100], Query Accuracy: 60.40%


Epoch 36/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.33it/s]


Epoch [36/100], Loss: 13.0680, Training Accuracy: 80.56%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.28it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.97it/s]


Epoch [36/100], Query Accuracy: 69.07%


Epoch 37/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [37/100], Loss: 12.8167, Training Accuracy: 81.12%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 99.69it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.17it/s]


Epoch [37/100], Query Accuracy: 64.00%


Epoch 38/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [38/100], Loss: 12.8038, Training Accuracy: 80.89%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 114.61it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.79it/s]


Epoch [38/100], Query Accuracy: 75.25%
Best model saved with Query Accuracy: 75.25%


Epoch 39/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


Epoch [39/100], Loss: 12.8141, Training Accuracy: 80.34%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.86it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.08it/s]


Epoch [39/100], Query Accuracy: 57.71%


Epoch 40/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [40/100], Loss: 12.5311, Training Accuracy: 80.89%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 120.66it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 130.45it/s]


Epoch [40/100], Query Accuracy: 63.67%


Epoch 41/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [41/100], Loss: 12.4637, Training Accuracy: 80.22%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 123.60it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 131.35it/s]


Epoch [41/100], Query Accuracy: 70.42%


Epoch 42/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


Epoch [42/100], Loss: 12.4906, Training Accuracy: 80.89%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.44it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 138.22it/s]


Epoch [42/100], Query Accuracy: 70.98%


Epoch 43/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [43/100], Loss: 12.9414, Training Accuracy: 80.67%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 121.85it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.27it/s]


Epoch [43/100], Query Accuracy: 70.98%


Epoch 44/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [44/100], Loss: 12.0686, Training Accuracy: 81.45%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 113.60it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 132.52it/s]


Epoch [44/100], Query Accuracy: 64.23%


Epoch 45/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [45/100], Loss: 13.2020, Training Accuracy: 80.45%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 105.10it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 129.06it/s]


Epoch [45/100], Query Accuracy: 51.63%


Epoch 46/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [46/100], Loss: 13.1586, Training Accuracy: 80.45%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 104.16it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 138.54it/s]


Epoch [46/100], Query Accuracy: 67.04%


Epoch 47/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [47/100], Loss: 12.6732, Training Accuracy: 81.01%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 122.76it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.69it/s]


Epoch [47/100], Query Accuracy: 56.92%


Epoch 48/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


Epoch [48/100], Loss: 12.5705, Training Accuracy: 81.56%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 116.07it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 134.05it/s]


Epoch [48/100], Query Accuracy: 52.42%


Epoch 49/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [49/100], Loss: 12.3107, Training Accuracy: 80.89%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 116.14it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.46it/s]


Epoch [49/100], Query Accuracy: 55.79%


Epoch 50/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [50/100], Loss: 12.6020, Training Accuracy: 80.22%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 117.05it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.24it/s]


Epoch [50/100], Query Accuracy: 64.12%


Epoch 51/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [51/100], Loss: 11.7230, Training Accuracy: 81.56%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 117.67it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 134.06it/s]


Epoch [51/100], Query Accuracy: 44.77%


Epoch 52/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.33it/s]


Epoch [52/100], Loss: 12.0399, Training Accuracy: 80.89%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 112.52it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 130.62it/s]


Epoch [52/100], Query Accuracy: 71.32%


Epoch 53/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.32it/s]


Epoch [53/100], Loss: 12.8163, Training Accuracy: 80.11%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 120.35it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.47it/s]


Epoch [53/100], Query Accuracy: 70.08%


Epoch 54/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.30it/s]


Epoch [54/100], Loss: 12.1459, Training Accuracy: 81.23%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 120.27it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 130.04it/s]


Epoch [54/100], Query Accuracy: 74.92%


Epoch 55/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.31it/s]


Epoch [55/100], Loss: 11.5559, Training Accuracy: 82.91%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 113.85it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 130.46it/s]


Epoch [55/100], Query Accuracy: 67.83%


Epoch 56/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


Epoch [56/100], Loss: 11.3655, Training Accuracy: 82.01%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.61it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.70it/s]


Epoch [56/100], Query Accuracy: 77.73%
Best model saved with Query Accuracy: 77.73%


Epoch 57/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [57/100], Loss: 11.2409, Training Accuracy: 84.13%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 121.96it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 131.37it/s]


Epoch [57/100], Query Accuracy: 60.07%


Epoch 58/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [58/100], Loss: 12.6161, Training Accuracy: 80.89%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 116.64it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 132.56it/s]


Epoch [58/100], Query Accuracy: 67.15%


Epoch 59/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [59/100], Loss: 11.4836, Training Accuracy: 82.35%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.65it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.35it/s]


Epoch [59/100], Query Accuracy: 75.59%


Epoch 60/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


Epoch [60/100], Loss: 10.8761, Training Accuracy: 83.24%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 117.34it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.72it/s]


Epoch [60/100], Query Accuracy: 81.66%
Best model saved with Query Accuracy: 81.66%


Epoch 61/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.31it/s]


Epoch [61/100], Loss: 10.3326, Training Accuracy: 84.25%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.91it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.09it/s]


Epoch [61/100], Query Accuracy: 65.02%


Epoch 62/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [62/100], Loss: 10.7653, Training Accuracy: 83.91%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.15it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.79it/s]


Epoch [62/100], Query Accuracy: 78.85%


Epoch 63/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [63/100], Loss: 10.5775, Training Accuracy: 84.92%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.90it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 130.75it/s]


Epoch [63/100], Query Accuracy: 76.94%


Epoch 64/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


Epoch [64/100], Loss: 11.2057, Training Accuracy: 83.58%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 115.86it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.35it/s]


Epoch [64/100], Query Accuracy: 55.01%


Epoch 65/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [65/100], Loss: 11.8550, Training Accuracy: 81.90%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.44it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.13it/s]


Epoch [65/100], Query Accuracy: 62.32%


Epoch 66/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [66/100], Loss: 11.3817, Training Accuracy: 84.25%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.25it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.93it/s]


Epoch [66/100], Query Accuracy: 78.63%


Epoch 67/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [67/100], Loss: 9.4722, Training Accuracy: 86.48%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 123.36it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.33it/s]


Epoch [67/100], Query Accuracy: 78.29%


Epoch 68/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [68/100], Loss: 11.1496, Training Accuracy: 83.13%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.33it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 133.09it/s]


Epoch [68/100], Query Accuracy: 21.71%


Epoch 69/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [69/100], Loss: 11.0357, Training Accuracy: 83.35%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.43it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.99it/s]


Epoch [69/100], Query Accuracy: 23.28%


Epoch 70/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.32it/s]


Epoch [70/100], Loss: 12.2807, Training Accuracy: 80.78%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.66it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 134.98it/s]


Epoch [70/100], Query Accuracy: 31.16%


Epoch 71/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [71/100], Loss: 11.7659, Training Accuracy: 81.79%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 100.12it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.79it/s]


Epoch [71/100], Query Accuracy: 18.79%


Epoch 72/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [72/100], Loss: 11.1796, Training Accuracy: 82.46%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 120.68it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.12it/s]


Epoch [72/100], Query Accuracy: 70.53%


Epoch 73/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [73/100], Loss: 11.3319, Training Accuracy: 83.02%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 117.01it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 127.52it/s]


Epoch [73/100], Query Accuracy: 71.32%


Epoch 74/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [74/100], Loss: 10.5973, Training Accuracy: 84.58%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 97.66it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.96it/s]


Epoch [74/100], Query Accuracy: 22.50%


Epoch 75/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


Epoch [75/100], Loss: 10.0852, Training Accuracy: 86.37%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 122.66it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 134.91it/s]


Epoch [75/100], Query Accuracy: 62.54%


Epoch 76/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [76/100], Loss: 9.6184, Training Accuracy: 86.59%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 115.50it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 134.61it/s]


Epoch [76/100], Query Accuracy: 37.80%


Epoch 77/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [77/100], Loss: 9.7647, Training Accuracy: 86.59%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 116.17it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 134.09it/s]


Epoch [77/100], Query Accuracy: 41.28%


Epoch 78/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [78/100], Loss: 9.4968, Training Accuracy: 86.26%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.62it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 131.73it/s]


Epoch [78/100], Query Accuracy: 53.99%


Epoch 79/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [79/100], Loss: 11.8867, Training Accuracy: 81.23%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 106.16it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 134.43it/s]


Epoch [79/100], Query Accuracy: 77.84%


Epoch 80/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.33it/s]


Epoch [80/100], Loss: 11.0105, Training Accuracy: 84.13%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.99it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.12it/s]


Epoch [80/100], Query Accuracy: 66.14%


Epoch 81/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [81/100], Loss: 10.2778, Training Accuracy: 85.25%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.62it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.46it/s]


Epoch [81/100], Query Accuracy: 71.99%


Epoch 82/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [82/100], Loss: 11.1834, Training Accuracy: 84.02%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 116.54it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 134.51it/s]


Epoch [82/100], Query Accuracy: 82.00%
Best model saved with Query Accuracy: 82.00%


Epoch 83/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [83/100], Loss: 9.8494, Training Accuracy: 86.03%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 117.84it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 131.02it/s]


Epoch [83/100], Query Accuracy: 80.31%


Epoch 84/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [84/100], Loss: 9.8452, Training Accuracy: 86.37%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 109.01it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.24it/s]


Epoch [84/100], Query Accuracy: 45.11%


Epoch 85/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [85/100], Loss: 10.3367, Training Accuracy: 85.36%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 103.46it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.42it/s]


Epoch [85/100], Query Accuracy: 33.52%


Epoch 86/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [86/100], Loss: 10.2268, Training Accuracy: 84.69%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 115.89it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 134.58it/s]


Epoch [86/100], Query Accuracy: 64.00%


Epoch 87/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [87/100], Loss: 10.6108, Training Accuracy: 83.24%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 122.17it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.69it/s]


Epoch [87/100], Query Accuracy: 81.21%


Epoch 88/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [88/100], Loss: 9.2904, Training Accuracy: 86.70%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 117.33it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 130.08it/s]


Epoch [88/100], Query Accuracy: 83.58%
Best model saved with Query Accuracy: 83.58%


Epoch 89/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.33it/s]


Epoch [89/100], Loss: 8.6446, Training Accuracy: 87.26%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 118.76it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.80it/s]


Epoch [89/100], Query Accuracy: 47.81%


Epoch 90/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [90/100], Loss: 8.6266, Training Accuracy: 87.49%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 120.45it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.07it/s]


Epoch [90/100], Query Accuracy: 86.84%
Best model saved with Query Accuracy: 86.84%


Epoch 91/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [91/100], Loss: 8.6361, Training Accuracy: 88.04%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 115.66it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.63it/s]


Epoch [91/100], Query Accuracy: 46.12%


Epoch 92/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


Epoch [92/100], Loss: 8.9178, Training Accuracy: 87.49%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 122.14it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.92it/s]


Epoch [92/100], Query Accuracy: 91.90%
Best model saved with Query Accuracy: 91.90%


Epoch 93/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [93/100], Loss: 8.0880, Training Accuracy: 88.94%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 117.73it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 134.84it/s]


Epoch [93/100], Query Accuracy: 85.60%


Epoch 94/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [94/100], Loss: 7.2091, Training Accuracy: 89.94%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.23it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.93it/s]


Epoch [94/100], Query Accuracy: 87.85%


Epoch 95/100 - Training: 100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


Epoch [95/100], Loss: 6.5348, Training Accuracy: 90.39%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 120.70it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 134.69it/s]


Epoch [95/100], Query Accuracy: 88.53%


Epoch 96/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [96/100], Loss: 7.3039, Training Accuracy: 89.39%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 114.33it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 135.75it/s]


Epoch [96/100], Query Accuracy: 92.01%
Best model saved with Query Accuracy: 92.01%


Epoch 97/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.35it/s]


Epoch [97/100], Loss: 8.0800, Training Accuracy: 88.38%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 122.11it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 137.66it/s]


Epoch [97/100], Query Accuracy: 86.28%


Epoch 98/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.33it/s]


Epoch [98/100], Loss: 8.5041, Training Accuracy: 88.83%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 116.18it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 130.28it/s]


Epoch [98/100], Query Accuracy: 90.33%


Epoch 99/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.33it/s]


Epoch [99/100], Loss: 7.3033, Training Accuracy: 89.61%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 116.58it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 136.33it/s]


Epoch [99/100], Query Accuracy: 83.91%


Epoch 100/100 - Training: 100%|██████████| 28/28 [00:11<00:00,  2.34it/s]


Epoch [100/100], Loss: 8.1044, Training Accuracy: 89.27%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 119.70it/s]
Extracting features: 100%|██████████| 889/889 [00:06<00:00, 138.20it/s]
  vit_model.load_state_dict(torch.load(best_model_path))


Epoch [100/100], Query Accuracy: 91.45%


Extracting features: 100%|██████████| 6/6 [00:00<00:00, 114.43it/s]
Extracting features: 100%|██████████| 376/376 [00:04<00:00, 84.98it/s]


Accuracy on test dataset: 78.7234%
