<a href="https://colab.research.google.com/github/NikosKats/ColabFiles/blob/Uncertainty-Sampling.ipynb/Uncertainty_Sampling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler

# Load the CIFAR-10 dataset
transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

cifar_dataset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform)

# Define the neural network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Define the uncertainty sampling strategy
def uncertainty_sampling(model, dataset, num_samples):
    """
    Selects the most uncertain samples from the dataset according to the model's predictions.
    """
    uncertain_samples = []
    with torch.no_grad():
        for data in dataset:
            inputs, labels = data
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            prob = torch.softmax(outputs, dim=1)
            entropy = -torch.sum(prob * torch.log(prob), dim=1)
            uncertain_samples.append((inputs, labels, entropy))
    uncertain_samples.sort(key=lambda x: x[2], reverse=True)
    return uncertain_samples[:num_samples]

# Train the model
num_epochs = 10
num_samples_to_label = 500
for epoch in range(num_epochs):
    running_loss = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()

        # Obtain the most uncertain samples from the training set
        uncertain_samples = uncertainty_sampling(net, trainloader, num_samples_to_label)

        # Manually label the uncertain samples
        inputs, labels = [], []
        for sample in uncertain_samples:
            inputs.append(sample[0])
            labels.append(sample[1])

        inputs, labels = torch.stack(inputs), torch.stack(labels)

        # Train the model on the labeled samples
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Print the average training loss for the epoch
    print(f'Epoch {epoch + 1} loss: {running_loss / (i + 1)}')

print('Finished Training')

# Test the model on the test set
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        inputs, labels = data
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the test images: {100 * correct / total}%')


This code uses the uncertainty sampling strategy to select the most uncertain samples from the training set, and manually labels those samples to improve the model's performance. It then trains the model on the labeled samples, and tests the model's accuracy on the test set.

Please note that this is a simple example and the results might not be optimal. You might want to adjust the parameters and try different techniques, for instance using combination of uncertainty, margin, and entropy, or using a deep Q-network strategy.