In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset

class PseudoLabelNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(PseudoLabelNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

def train_supervised(model, labeled_loader, test_loader, device, epochs):
    optimizer = optim.SGD(model.parameters(), lr=1.5, momentum=0.9)

    for epoch in range(epochs):
        model.train()

        labeled_loss = 0.0

        # Train on labeled data
        for i, (x, y) in enumerate(labeled_loader):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(x)
            y_onehot = torch.zeros_like(output).scatter_(1, y.unsqueeze(1), 1)
            labeled_loss += nn.BCELoss()(output, y_onehot)
        labeled_loss /= len(labeled_loader)

        loss = labeled_loss
        loss.backward()
        optimizer.step()

        # Evaluate on test set
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_accuracy = 100 * correct / total

        print(f'Epoch [{epoch+1}/{epochs}], Labeled Loss: {labeled_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load MNIST dataset
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_set = MNIST(root='./data', train=True, download=True, transform=transform)
    test_set = MNIST(root='./data', train=False, download=True, transform=transform)

    # Split dataset into labeled and unlabeled
    labeled_indices = torch.randperm(len(train_set))[:600]
    labeled_dataset = torch.utils.data.Subset(train_set, labeled_indices)

    # Create dataloaders
    labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=1000, shuffle=False)

    # Create model and train
    model = PseudoLabelNet(28*28, 5000, 10).to(device)

    train_supervised(model, labeled_loader, test_loader, device, 50)

    # Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy on test set: %d %%' % (100 * correct / total))

if __name__ == "__main__":
    main()

Epoch [1/50], Labeled Loss: 0.6995, Test Accuracy: 9.74%
Epoch [2/50], Labeled Loss: 2.4257, Test Accuracy: 9.82%
Epoch [3/50], Labeled Loss: 1.6666, Test Accuracy: 52.72%
Epoch [4/50], Labeled Loss: 0.2781, Test Accuracy: 65.31%
Epoch [5/50], Labeled Loss: 0.2863, Test Accuracy: 67.30%
Epoch [6/50], Labeled Loss: 0.2363, Test Accuracy: 68.31%
Epoch [7/50], Labeled Loss: 0.1447, Test Accuracy: 73.44%
Epoch [8/50], Labeled Loss: 0.1582, Test Accuracy: 78.94%
Epoch [9/50], Labeled Loss: 0.1214, Test Accuracy: 83.46%
Epoch [10/50], Labeled Loss: 0.1190, Test Accuracy: 82.50%
Epoch [11/50], Labeled Loss: 0.0828, Test Accuracy: 81.45%
Epoch [12/50], Labeled Loss: 0.0870, Test Accuracy: 83.90%
Epoch [13/50], Labeled Loss: 0.0672, Test Accuracy: 85.57%
Epoch [14/50], Labeled Loss: 0.0589, Test Accuracy: 86.55%
Epoch [15/50], Labeled Loss: 0.0443, Test Accuracy: 87.56%
Epoch [16/50], Labeled Loss: 0.0364, Test Accuracy: 87.67%
Epoch [17/50], Labeled Loss: 0.0362, Test Accuracy: 88.06%
Epoch [1