In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset

In [None]:

class PseudoLabelDataset(Dataset):
    def __init__(self, dataset, device):
        self.dataset = dataset
        self.device = device

    def __getitem__(self, index):
        img, _ = self.dataset[index]
        return img.to(self.device)

    def __len__(self):
        return len(self.dataset)

In [None]:
import torch.nn as nn

class PseudoLabelNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(PseudoLabelNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(hidden_size, num_classes)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

In [None]:
def train(model, labeled_loader, unlabeled_loader, test_loader, device, epochs, alpha_schedule):
    optimizer = optim.SGD(model.parameters(), lr=1.5, momentum=0.9)

    for epoch in range(epochs):
        model.train()

        labeled_loss = 0.0
        unlabeled_loss = 0.0

        # Train on labeled data
        for i, (x, y) in enumerate(labeled_loader):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(x)
            y_onehot = torch.zeros_like(output).scatter_(1, y.unsqueeze(1), 1)
            labeled_loss += nn.BCELoss()(output, y_onehot)
            loss.backward()
            optimizer.step()
        labeled_loss /= len(labeled_loader)

        # Train on unlabeled data with Pseudo-Labels
        alpha = alpha_schedule(epoch)
        for i, x in enumerate(unlabeled_loader):
            x = x.to(device)
            optimizer.zero_grad()
            output = model(x)
            pseudo_labels = (output.max(1)[1].float().unsqueeze(1) == torch.arange(10).unsqueeze(0).to(device)).float()
            unlabeled_loss += alpha * nn.BCELoss()(output, pseudo_labels)
        unlabeled_loss /= len(unlabeled_loader)

        loss = labeled_loss + unlabeled_loss
        loss.backward()
        optimizer.step()

        # Evaluate on test set
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_accuracy = 100 * correct / total

        print(f'Epoch [{epoch+1}/{epochs}], Labeled Loss: {labeled_loss:.4f}, Unlabeled Loss: {unlabeled_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')


In [None]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Load MNIST dataset
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_set = MNIST(root='./data', train=True, download=True, transform=transform)
    test_set = MNIST(root='./data', train=False, download=True, transform=transform)

    # Split dataset into labeled and unlabeled
    labeled_indices = torch.randperm(len(train_set))[:600]
    unlabeled_indices = torch.randperm(len(train_set))[600:]
    labeled_dataset = torch.utils.data.Subset(train_set, labeled_indices)
    unlabeled_dataset = torch.utils.data.Subset(train_set, unlabeled_indices)

    # Create dataloaders
    labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)
    unlabeled_loader = DataLoader(PseudoLabelDataset(unlabeled_dataset, device), batch_size=256, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=1000, shuffle=False)

    # Create model and train
    model = PseudoLabelNet(28*28, 5000, 10).to(device)

    T1 = 10
    T2 = 40

    def alpha_schedule(epoch):
        if epoch < T1:
            return 0
        elif epoch < T2:
            return ((epoch - T1) / (T2 - T1)) * 3
        else:
            return 3

    train(model, labeled_loader, unlabeled_loader,test_loader, device, 50, alpha_schedule)

    # Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy on test set: %d %%' % (100 * correct / total))

if __name__ == "__main__":
    main()

cuda
Epoch [1/50], Labeled Loss: 0.7296, Unlabeled Loss: 0.0000, Test Accuracy: 9.74%
Epoch [2/50], Labeled Loss: 2.5380, Unlabeled Loss: 0.0000, Test Accuracy: 10.49%
Epoch [3/50], Labeled Loss: 2.0290, Unlabeled Loss: 0.0000, Test Accuracy: 39.19%
Epoch [4/50], Labeled Loss: 0.3259, Unlabeled Loss: 0.0000, Test Accuracy: 61.99%
Epoch [5/50], Labeled Loss: 0.2867, Unlabeled Loss: 0.0000, Test Accuracy: 62.78%
Epoch [6/50], Labeled Loss: 0.2500, Unlabeled Loss: 0.0000, Test Accuracy: 63.66%
Epoch [7/50], Labeled Loss: 0.1553, Unlabeled Loss: 0.0000, Test Accuracy: 69.16%
Epoch [8/50], Labeled Loss: 0.1577, Unlabeled Loss: 0.0000, Test Accuracy: 73.73%
Epoch [9/50], Labeled Loss: 0.1221, Unlabeled Loss: 0.0000, Test Accuracy: 79.78%
Epoch [10/50], Labeled Loss: 0.1134, Unlabeled Loss: 0.0000, Test Accuracy: 79.74%
Epoch [11/50], Labeled Loss: 0.0986, Unlabeled Loss: 0.0000, Test Accuracy: 77.39%
Epoch [12/50], Labeled Loss: 0.0872, Unlabeled Loss: 0.0058, Test Accuracy: 78.61%
Epoch [13