In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split
import numpy as np
import timm
import pandas as pd
from sklearn import svm
from sklearn.metrics import accuracy_score

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
data_dir = "temp"
dataset = ImageFolder(data_dir, transform=data_transforms)
train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.to(device)

num_epochs = 80
learning_rate = 0.0001
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

total_steps = len(train_dataloader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_dataloader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
test_predictions = []
train_features = []
train_labels = []
model.eval()
with torch.no_grad():
    for images, labels in train_dataloader:
        images = images.to(device)
        labels = labels.to(device)
        features = model.forward_features(images)
        train_features.extend(features.cpu().numpy())
        train_labels.extend(labels.cpu().numpy())

test_features = []
test_labels = []
with torch.no_grad():
    for images, labels in test_dataloader:
        images = images.to(device)
        labels = labels.to(device)
        features = model.forward_features(images)
        test_features.extend(features.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        test_predictions.extend(predicted.cpu().numpy())
        test_predictions = np.array(test_predictions)
        acc = accuracy_score(np.array(test_labels), test_predictions)
        print(acc)


In [None]:
train_features_df = pd.read_csv("vit_train_features.csv")
train_labels_df = pd.read_csv("vit_train_labels.csv")
test_features_df = pd.read_csv("vit_test_features.csv")
test_labels_df = pd.read_csv("vit_test_labels.csv")
train_features = train_features_df.values
train_labels = train_labels_df['label'].values
test_features = test_features_df.values
test_labels = test_labels_df['label'].values
clf = svm.SVC()
clf.fit(train_features, train_labels)
predictions = clf.predict(test_features)
accuracy = accuracy_score(test_labels, predictions)
print("Accuracy:", accuracy)
