In [2]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from sklearn import svm
from sklearn.metrics import accuracy_score

In [3]:
# Define transformation to convert images to tensors and normalize
transform = transforms.Compose([
    transforms.ToTensor()
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Filter for airplane (0) and bird (2) classes
classes_to_keep = [0, 2]
train_indices = [i for i, label in enumerate(train_dataset.targets) if label in classes_to_keep]
test_indices = [i for i, label in enumerate(test_dataset.targets) if label in classes_to_keep]

train_subset = Subset(train_dataset, train_indices)
test_subset = Subset(test_dataset, test_indices)

In [5]:
# Create data loaders
train_loader = DataLoader(train_subset, batch_size=len(train_subset), shuffle=False)
test_loader = DataLoader(test_subset, batch_size=len(test_subset), shuffle=False)

# Extract data
x_train, y_train = next(iter(train_loader))
x_test, y_test = next(iter(test_loader))

# Flatten images into vectors
x_train_flat = x_train.view(x_train.size(0), -1).numpy()
x_test_flat = x_test.view(x_test.size(0), -1).numpy()

# Normalize pixel values
x_train_flat = x_train_flat / 255.0
x_test_flat = x_test_flat / 255.0

In [6]:
# Relabel classes to 0 and 1 for easier interpretation
y_train = y_train.numpy()
y_test = y_test.numpy()
y_train_binary = (y_train == 2).astype(int)  # 0: airplane, 1: bird
y_test_binary = (y_test == 2).astype(int)

# Initialize and train SVM
clf = svm.SVC(kernel='linear')
clf.fit(x_train_flat, y_train_binary)

# Predict and evaluate
y_pred = clf.predict(x_test_flat)
accuracy = accuracy_score(y_test_binary, y_pred)

print(f"SVM accuracy on airplane vs bird: {accuracy * 100:.2f}%")

SVM accuracy on airplane vs bird: 75.85%
