In [29]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from transformers import ViTModel, ViTFeatureExtractor

# Define the MLP Classifier with configurable dropout
class MLPClassifier(nn.Module):
    def __init__(self, input_size, num_classes, hidden_units, dropout_rate):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_units)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_units, int(hidden_units / 2))
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate / 2)
        self.fc3 = nn.Linear(int(hidden_units / 2), num_classes)
        self.output = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
        x = self.fc3(x)
        return self.output(x)

# Function to extract features using ViT
def extract_features(data_loader, model, device):
    model.eval()
    features = []
    labels = []
    with torch.no_grad():
        for imgs, lbls in data_loader:
            imgs = imgs.to(device)
            lbls = lbls.to(device)
            outputs = model(imgs).last_hidden_state[:, 0, :]
            features.append(outputs)
            labels.append(lbls)
    return torch.cat(features), torch.cat(labels)

# Load a pretrained Vision Transformer model
model_name = "google/vit-base-patch16-224"
model = ViTModel.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Image transformations
transform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset and create splits
dataset = ImageFolder(root='images/train', transform=transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Hyperparameters to tune
batch_sizes = [16, 32, 64]
learning_rates = [0.001, 0.0005, 0.0001]
dropout_rates = [0.5, 0.3]
hidden_units = [512, 1024]

best_accuracy = 0
best_params = {}

# Testing different hyperparameters
for batch_size in batch_sizes:
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    train_features, train_labels = extract_features(train_loader, model, device)
    test_features, test_labels = extract_features(test_loader, model, device)

    for lr in learning_rates:
        for dropout in dropout_rates:
            for units in hidden_units:
                mlp = MLPClassifier(train_features.shape[1], len(dataset.classes), units, dropout).to(device)
                optimizer = optim.Adam(mlp.parameters(), lr=lr)
                criterion = nn.NLLLoss()
                scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

                # Training loop for current set of hyperparameters
                for epoch in range(10):
                    mlp.train()
                    total_loss = 0
                    for i in range(0, train_features.size(0), batch_size):
                        batch_features = train_features[i:i+batch_size].to(device)
                        batch_labels = train_labels[i:i+batch_size].to(device)

                        optimizer.zero_grad()
                        outputs = mlp(batch_features)
                        loss = criterion(outputs, batch_labels)
                        loss.backward()
                        optimizer.step()
                        total_loss += loss.item()

                    scheduler.step()

                    # Evaluate on test set
                    mlp.eval()
                    correct = 0
                    total = 0
                    with torch.no_grad():
                        for i in range(0, test_features.size(0), batch_size):
                            batch_features = test_features[i:i+batch_size].to(device)
                            batch_labels = test_labels[i:i+batch_size].to(device)
                            outputs = mlp(batch_features)
                            _, predicted = torch.max(outputs.data, 1)
                            total += batch_labels.size(0)
                            correct += (predicted == batch_labels).sum().item()

                    accuracy = 100 * correct / total
                    if accuracy > best_accuracy:
                        best_accuracy = accuracy
                        best_params = {
                            'batch_size': batch_size,
                            'learning_rate': lr,
                            'dropout_rate': dropout,
                            'hidden_units': units
                        }
                        torch.save(mlp.state_dict(), 'mlp_classifier_best.pth')

print(f"Best Hyperparameters: {best_params} with Accuracy: {best_accuracy}%")


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Best Hyperparameters: {'batch_size': 32, 'learning_rate': 0.0001, 'dropout_rate': 0.3, 'hidden_units': 1024} with Accuracy: 61.769297484822204%
