In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from tqdm.notebook import tqdm
import numpy as np
from sklearn import svm
import torch.hub
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score
from pathlib import Path
import os
import gdown
import tarfile
import timm
import torch.nn as nn

# Download the dataset
url = 'https://drive.google.com/uc?id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp'
output = '101_ObjectCategories.tar.gz'
gdown.download(url, output, quiet=False)

# Extract the dataset
if tarfile.is_tarfile(output):
    with tarfile.open(output, "r:gz") as tar_ref:
        tar_ref.extractall()
    print("File extracted successfully.")
else:
    print("Downloaded file is not a tar file.")


transform_image = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(256),       
    transforms.CenterCrop(224),  
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  
])

dataset_path = './101_ObjectCategories'
original_dataset = datasets.ImageFolder(dataset_path)
filtered_data = [(img, label) for img, label in original_dataset.imgs if "BACKGROUND_Google" not in img]

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path, label = self.data[index]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label - int(label > original_dataset.class_to_idx['BACKGROUND_Google'])


filtered_dataset = CustomDataset(filtered_data, transform=transform_image)


total_size = len(filtered_dataset)
val_size = test_size = int(0.1 * total_size)  
train_size = total_size - val_size - test_size  

# Split
train_dataset, test_val_dataset = random_split(filtered_dataset, [train_size, val_size + test_size])
val_dataset, test_dataset = random_split(test_val_dataset, [val_size, test_size])

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")

# DinoV2
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14").to(device)

# Embeddings
def compute_embeddings(data_loader):
    dinov2_vits14.eval()
    all_embeddings = []
    labels = []
    with torch.no_grad():
        for images, targets in tqdm(data_loader, desc="Computing embeddings"):
            images = images.to(device)
            embeddings = dinov2_vits14(images)
            all_embeddings.append(embeddings.cpu().numpy())
            labels.extend(targets.numpy())
    all_embeddings = np.vstack(all_embeddings)
    return all_embeddings, labels
train_embeddings, train_labels = compute_embeddings(train_loader)
val_embeddings, val_labels = compute_embeddings(val_loader)
test_embeddings, test_labels = compute_embeddings(test_loader)

# Classifier
class EmbeddingClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(EmbeddingClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

input_size = 384  # Feature dimension of DINOv2 embeddings
hidden_size = 512 
num_classes = 101

classifier = EmbeddingClassifier(input_size, hidden_size, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)

train_embeddings_tensor = torch.tensor(train_embeddings).float().to(device)
train_labels_tensor = torch.tensor(train_labels).long().to(device)

val_embeddings_tensor = torch.tensor(val_embeddings).float().to(device)
val_labels_tensor = torch.tensor(val_labels).long().to(device)

model_dir = Path('teacher_model')
model_dir.mkdir(parents=True, exist_ok=True) 

epochs = 10
batch_size = 32

num_train_batches = len(train_embeddings_tensor) // batch_size
num_val_batches = len(val_embeddings_tensor) // batch_size

best_val_loss = float('inf') 

# Train
for epoch in range(epochs):
    classifier.train()
    train_loss = 0.0
    for i in tqdm(range(num_train_batches), desc=f'Epoch {epoch+1}/{epochs}, Training'):
        batch_start = i * batch_size
        batch_end = (i + 1) * batch_size
        
        images = train_embeddings_tensor[batch_start:batch_end]
        labels = train_labels_tensor[batch_start:batch_end]
        
        optimizer.zero_grad()
        outputs = classifier(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    avg_train_loss = train_loss / num_train_batches
    tqdm.write(f'Epoch [{epoch+1}/{epochs}], Training Loss: {avg_train_loss:.4f}')
    
    # Validation phase
    classifier.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for i in tqdm(range(num_val_batches), desc=f'Epoch {epoch+1}/{epochs}, Validating'):
            batch_start = i * batch_size
            batch_end = (i + 1) * batch_size
            
            images = val_embeddings_tensor[batch_start:batch_end]
            labels = val_labels_tensor[batch_start:batch_end]
            
            outputs = classifier(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_val_loss = val_loss / num_val_batches
    val_accuracy = correct / total * 100
    tqdm.write(f'Epoch [{epoch+1}/{epochs}], Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')
    
    # Save
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_path = model_dir / f'best_model_epoch_{epoch+1}.pth'
        torch.save(classifier.state_dict(), best_model_path)
        tqdm.write(f"Best model saved to {best_model_path} with Validation Loss: {best_val_loss:.4f}")

# Test
test_embeddings_tensor = torch.tensor(test_embeddings).float().to(device)
test_labels_tensor = torch.tensor(test_labels).long().to(device)

classifier.eval()
test_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    outputs = classifier(test_embeddings_tensor)  # Forward pass
    _, predicted = torch.max(outputs, 1)  
    total = test_labels_tensor.size(0) 
    correct = (predicted == test_labels_tensor).sum().item()  

# Accuracy
test_accuracy = 100 * correct / total
print(f'Test Accuracy: {test_accuracy:.2f}%')


print("-----------------Starting KD--------------------")


class CosineSimilarityLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.cosine_similarity = nn.CosineSimilarity()

    def forward(self, student_outputs, teacher_outputs):
        loss = 1 - self.cosine_similarity(student_outputs, teacher_outputs).mean()
        return loss


# KD DinoV2 -> EfficientNetV2
# class EfficientNetV2Embeddings(nn.Module):
#     def __init__(self, output_dim=384):
#         super(EfficientNetV2Embeddings, self).__init__()
#         self.base_model = timm.create_model('efficientnetv2_rw_s', pretrained=True, features_only=True)
#         feature_dim = self.base_model.feature_info[-1]['num_chs']
    
#         self.embedding_layer = nn.Linear(feature_dim, output_dim)

#     def forward(self, x):
#         features = self.base_model(x)[-1] 
#         pooled_features = features.mean([2, 3])
#         embeddings = self.embedding_layer(pooled_features)
#         return embeddings

class EfficientNetV2Embeddings(nn.Module):
    def __init__(self, output_dim=384, dropout_rate=0.5):
        super(EfficientNetV2Embeddings, self).__init__()
        self.base_model = timm.create_model('efficientnetv2_rw_s', pretrained=True, features_only=True)
        feature_dim = self.base_model.feature_info[-1]['num_chs']
        self.batch_norm = nn.BatchNorm1d(feature_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.embedding_layer = nn.Linear(feature_dim, output_dim)

    def forward(self, x):
        features = self.base_model(x)[-1]
        pooled_features = features.mean([2, 3])
        norm_features = self.batch_norm(pooled_features)
        dropped_out_features = self.dropout(norm_features)
        embeddings = self.embedding_layer(dropped_out_features)
        return embeddings
        
teacher_model = dinov2_vits14
student_model = EfficientNetV2Embeddings(output_dim=384)
student_model.to(device)

def validate(model, criterion, data_loader, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for images, _ in data_loader:
            images = images.to(device)
        
            teacher_embeddings = teacher_model(images)

            student_embeddings = model(images)
            loss = criterion(student_embeddings, teacher_embeddings)
            total_loss += loss.item() * images.size(0)
            total_samples += images.size(0)
    avg_loss = total_loss / total_samples
    return avg_loss

model_save_dir = Path('distilled_effv2')
model_save_dir.mkdir(parents=True, exist_ok=True)

best_val_loss = float('inf')
best_epoch = 0

criterion = CosineSimilarityLoss()
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4, weight_decay=1e-5)
epochs = 15

train_losses = []
val_losses = []


for epoch in range(epochs):
    student_model.train()
    total_train_loss = 0.0
    
    for images, _ in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{epochs}"):
        images = images.to(device)
        
        optimizer.zero_grad()
        
        with torch.no_grad():
            teacher_embeddings = teacher_model(images)
        
        student_embeddings = student_model(images)
        
        loss = criterion(student_embeddings, teacher_embeddings)
        total_train_loss += loss.item() * images.size(0)
        
        loss.backward()
        optimizer.step()
    train_losses.append(total_train_loss / len(train_loader))
    
    avg_train_loss = total_train_loss / len(train_loader.dataset)
    
    avg_val_loss = validate(student_model, criterion, val_loader, device)
    val_losses.append(avg_val_loss)
    
    print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}')
    
    # Check if the current model is better than the best model so far
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = epoch + 1
        best_model_path = model_save_dir / f'best_model_epoch_{best_epoch}.pth'
        
        torch.save(student_model.state_dict(), best_model_path)
        print(f"New best model saved at epoch {best_epoch} with Validation Loss: {best_val_loss:.4f}")

print(f"Best model was saved at epoch {best_epoch} with Validation Loss: {best_val_loss:.4f}")


# Plotting the training and validation losses
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss during Knowledge Distillation')
plt.legend()
plt.savefig(plot_dir / 'kd_training_validation_losses.jpg')
plt.show()


# Evaluate

best_model_path = model_save_dir / f'best_model_epoch_{best_epoch}.pth'
student_model = EfficientNetV2Embeddings(output_dim=384).to(device)
student_model.load_state_dict(torch.load(best_model_path))

print("Best model loaded.")

# def evaluate_model(model, teacher_model, criterion, data_loader, device):
#     model.eval() 
#     teacher_model.eval()
#     total_loss = 0.0
#     total_cosine_similarity = 0.0
#     total_samples = 0

#     with torch.no_grad():
#         for images, _ in tqdm(data_loader, desc="Evaluating"):
#             images = images.to(device)
#             student_embeddings = model(images)
#             teacher_embeddings = teacher_model(images)
#             loss = criterion(student_embeddings, teacher_embeddings)
#             total_loss += loss.item() * images.size(0)
#             cosine_similarity = nn.functional.cosine_similarity(student_embeddings, teacher_embeddings, dim=1).mean()
#             total_cosine_similarity += cosine_similarity.item() * images.size(0)
#             total_samples += images.size(0)
    
#     avg_loss = total_loss / total_samples
#     avg_cosine_similarity = total_cosine_similarity / total_samples

#     print(f"Test MSE Loss: {avg_loss:.4f}")
#     print(f"Average Cosine Similarity: {avg_cosine_similarity:.4f}")

# evaluate_model(student_model, teacher_model, criterion, test_loader, device)

Downloading...
From (original): https://drive.google.com/uc?id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
From (redirected): https://drive.google.com/uc?id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp&confirm=t&uuid=0e0bca9c-651b-44d5-aa2e-23d49ab9dd68
To: /storage/usersb/CoE_Shraddha/DinoDistilled/101_ObjectCategories.tar.gz
100%|███████████████████████████████████████| 132M/132M [00:01<00:00, 71.6MB/s]


File extracted successfully.
Training set size: 6943
Validation set size: 867
Test set size: 867


Using cache found in /home/users/CoE_Shraddha/.cache/torch/hub/facebookresearch_dinov2_main


Computing embeddings:   0%|          | 0/217 [00:00<?, ?it/s]

Computing embeddings:   0%|          | 0/28 [00:00<?, ?it/s]

Computing embeddings:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 1/10, Training:   0%|          | 0/216 [00:00<?, ?it/s]

Epoch [1/10], Training Loss: 0.5042


Epoch 1/10, Validating:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch [1/10], Validation Loss: 0.1233, Validation Accuracy: 95.95%
Best model saved to teacher_model/best_model_epoch_1.pth with Validation Loss: 0.1233


Epoch 2/10, Training:   0%|          | 0/216 [00:00<?, ?it/s]

Epoch [2/10], Training Loss: 0.0661


Epoch 2/10, Validating:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch [2/10], Validation Loss: 0.0905, Validation Accuracy: 96.30%
Best model saved to teacher_model/best_model_epoch_2.pth with Validation Loss: 0.0905


Epoch 3/10, Training:   0%|          | 0/216 [00:00<?, ?it/s]

Epoch [3/10], Training Loss: 0.0458


Epoch 3/10, Validating:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch [3/10], Validation Loss: 0.0972, Validation Accuracy: 96.30%


Epoch 4/10, Training:   0%|          | 0/216 [00:00<?, ?it/s]

Epoch [4/10], Training Loss: 0.0371


Epoch 4/10, Validating:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch [4/10], Validation Loss: 0.0918, Validation Accuracy: 96.64%


Epoch 5/10, Training:   0%|          | 0/216 [00:00<?, ?it/s]

Epoch [5/10], Training Loss: 0.0326


Epoch 5/10, Validating:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch [5/10], Validation Loss: 0.0917, Validation Accuracy: 96.76%


Epoch 6/10, Training:   0%|          | 0/216 [00:00<?, ?it/s]

Epoch [6/10], Training Loss: 0.0237


Epoch 6/10, Validating:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch [6/10], Validation Loss: 0.0928, Validation Accuracy: 96.76%


Epoch 7/10, Training:   0%|          | 0/216 [00:00<?, ?it/s]

Epoch [7/10], Training Loss: 0.0256


Epoch 7/10, Validating:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch [7/10], Validation Loss: 0.0970, Validation Accuracy: 97.11%


Epoch 8/10, Training:   0%|          | 0/216 [00:00<?, ?it/s]

Epoch [8/10], Training Loss: 0.0236


Epoch 8/10, Validating:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch [8/10], Validation Loss: 0.1028, Validation Accuracy: 96.88%


Epoch 9/10, Training:   0%|          | 0/216 [00:00<?, ?it/s]

Epoch [9/10], Training Loss: 0.0213


Epoch 9/10, Validating:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch [9/10], Validation Loss: 0.0971, Validation Accuracy: 97.22%


Epoch 10/10, Training:   0%|          | 0/216 [00:00<?, ?it/s]

Epoch [10/10], Training Loss: 0.0154


Epoch 10/10, Validating:   0%|          | 0/27 [00:00<?, ?it/s]

Epoch [10/10], Validation Loss: 0.1006, Validation Accuracy: 97.22%
Test Accuracy: 96.19%
-----------------Starting KD--------------------


Training Epoch 1/15:   0%|          | 0/217 [00:00<?, ?it/s]

Epoch 1, Training Loss: 0.7827, Validation Loss: 0.5526
New best model saved at epoch 1 with Validation Loss: 0.5526


Training Epoch 2/15:   0%|          | 0/217 [00:00<?, ?it/s]

Epoch 2, Training Loss: 0.5734, Validation Loss: 0.4300
New best model saved at epoch 2 with Validation Loss: 0.4300


Training Epoch 3/15:   0%|          | 0/217 [00:00<?, ?it/s]

Epoch 3, Training Loss: 0.4879, Validation Loss: 0.3661
New best model saved at epoch 3 with Validation Loss: 0.3661


Training Epoch 4/15:   0%|          | 0/217 [00:00<?, ?it/s]

Epoch 4, Training Loss: 0.4366, Validation Loss: 0.3241
New best model saved at epoch 4 with Validation Loss: 0.3241


Training Epoch 5/15:   0%|          | 0/217 [00:00<?, ?it/s]

Epoch 5, Training Loss: 0.4018, Validation Loss: 0.2992
New best model saved at epoch 5 with Validation Loss: 0.2992


Training Epoch 6/15:   0%|          | 0/217 [00:00<?, ?it/s]

Epoch 6, Training Loss: 0.3761, Validation Loss: 0.2794
New best model saved at epoch 6 with Validation Loss: 0.2794


Training Epoch 7/15:   0%|          | 0/217 [00:00<?, ?it/s]

In [None]:
plot_dir = Path('plots_new')
plot_dir.mkdir(parents=True, exist_ok=True)

# Save directory for models
model_save_dir = Path('distilled_effv2_new')
model_save_dir.mkdir(parents=True, exist_ok=True)

In [None]:
class EmbeddingClassifierV2(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(EmbeddingClassifierV2, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

def compute_embeddings(model, loader):
    model.eval()
    embeddings = []
    labels = []
    with torch.no_grad():
        for images, targets in tqdm(loader, desc="Computing embeddings"):
            images = images.to(device)
            emb = model(images)
            embeddings.append(emb.cpu().numpy())
            labels.extend(targets.cpu().numpy())
    embeddings = np.vstack(embeddings)
    labels = np.array(labels)
    return embeddings, labels

# Compute embeddings using the distilled student model
train_embeddings, train_labels = compute_embeddings(student_model, train_loader)
val_embeddings, val_labels = compute_embeddings(student_model, val_loader)
test_embeddings, test_labels = compute_embeddings(student_model, test_loader)


In [None]:
def train_and_evaluate_classifier(train_embeddings, train_labels, val_embeddings, val_labels, test_embeddings, test_labels, num_classes):
    train_dataset = torch.utils.data.TensorDataset(torch.tensor(train_embeddings), torch.tensor(train_labels))
    val_dataset = torch.utils.data.TensorDataset(torch.tensor(val_embeddings), torch.tensor(val_labels))
    test_dataset = torch.utils.data.TensorDataset(torch.tensor(test_embeddings), torch.tensor(test_labels))

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = EmbeddingClassifierV2(train_embeddings.shape[1], 512, num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    best_val_loss = float('inf')
    for epoch in range(10):  # Train for 10 epochs
        model.train()
        total_train_loss = 0
        for data in train_loader:
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        # Validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for data in val_loader:
                inputs, targets = data
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                total_val_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_embedding_classifier_v2.pth')

    # Testing the model
    model.load_state_dict(torch.load('best_embedding_classifier_v2.pth'))
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == targets).sum().item()
            total += targets.size(0)
    test_accuracy = correct / total * 100
    print(f"Test Accuracy: {test_accuracy:.2f}%")

# Call the function with proper arguments
train_and_evaluate_classifier(train_embeddings, train_labels, val_embeddings, val_labels, test_embeddings, test_labels, num_classes=101)


In [None]:
def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

# After training and validation in your training function
save_model(model, 'distilled_effnetV2_classifier.pth')

In [None]:
def train_and_evaluate_classifier(train_embeddings, train_labels, val_embeddings, val_labels, test_embeddings, test_labels, num_classes, device):
    train_dataset = torch.utils.data.TensorDataset(torch.tensor(train_embeddings), torch.tensor(train_labels))
    val_dataset = torch.utils.data.TensorDataset(torch.tensor(val_embeddings), torch.tensor(val_labels))
    test_dataset = torch.utils.data.TensorDataset(torch.tensor(test_embeddings), torch.tensor(test_labels))

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = EmbeddingClassifierV2(train_embeddings.shape[1], 512, num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    best_val_loss = float('inf')
    for epoch in range(10):  # Train for 10 epochs
        model.train()
        total_train_loss = 0
        for data in train_loader:
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        # Validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for data in val_loader:
                inputs, targets = data
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                total_val_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_embedding_classifier_v2.pth')

    # Testing the model
    model.load_state_dict(torch.load('best_embedding_classifier_v2.pth'))
    model.eval()
    predictions = []
    true_labels = []
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == targets).sum().item()
            total += targets.size(0)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(targets.cpu().numpy())

    test_accuracy = correct / total * 100
    precision = precision_score(true_labels, predictions, average='macro')
    recall = recall_score(true_labels, predictions, average='macro')
    f1 = f1_score(true_labels, predictions, average='macro')
    print(f"Test Accuracy: {test_accuracy:.2f}%")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

# Example call to the function, ensure to pass `device`
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_and_evaluate_classifier(train_embeddings, train_labels, val_embeddings, val_labels, test_embeddings, test_labels, num_classes=101, device=device)