In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from skimage.segmentation import slic
import matplotlib.pyplot as plt

# Step 1: Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize images
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Step 2: Generate superpixels for each image
def generate_superpixels(image):
    # Use SLIC algorithm to generate superpixels
    superpixels = slic(image.numpy(), n_segments=100, compactness=10)
    return superpixels

# Step 3: Extract features from superpixels
def extract_superpixel_features(image, superpixels):
    features = []
    for i in range(superpixels.max() + 1):
        mask = (superpixels == i)
        superpixel_region = image[mask]
        # Compute mean color for each superpixel region
        feature = torch.mean(superpixel_region, dim=0)
        features.append(feature)
    return torch.stack(features)

# Step 4: Define CNN model
class SuperpixelCNN(nn.Module):
    def __init__(self):
        super(SuperpixelCNN, self).__init__()
        self.fc1 = nn.Linear(3 * 100, 128)  # 100 superpixels, each with 3 features
        self.fc2 = nn.Linear(128, 10)  # 10 output classes

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Step 5: Training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SuperpixelCNN().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

def train(model, train_loader, optimizer, criterion):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to("cpu"), target.to(device)
        optimizer.zero_grad()
        superpixels = [generate_superpixels(img) for img in data]
        features = [extract_superpixel_features(img, sp) for img, sp in zip(data, superpixels)]
        features = torch.stack(features).to(device)
        output = model(features)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

# Step 6: Evaluation
def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to("cpu"), target.to("cpu")
            superpixels = [generate_superpixels(img) for img in data]
            features = [extract_superpixel_features(img, sp) for img, sp in zip(data, superpixels)]
            features = torch.stack(features).to(device)
            output = model(features)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), accuracy))

# Training loop
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

for epoch in range(1, 11):
    train(model, train_loader, optimizer, criterion)
    test(model, test_loader, criterion)


Files already downloaded and verified
Files already downloaded and verified


KeyboardInterrupt: 

: 