In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt


In [2]:
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_dim=28*28):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)


class Discriminator(nn.Module):
    def __init__(self, img_dim=28*28):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)


In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# Splitting dataset into labeled and unlabeled sets
labeled_size = 1000
unlabeled_size = len(dataset) - labeled_size
labeled_data, unlabeled_data = random_split(dataset, [labeled_size, unlabeled_size])

labeled_loader = DataLoader(labeled_data, batch_size=64, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_data, batch_size=64, shuffle=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|████████████████████████████| 9912422/9912422 [00:01<00:00, 5598890.78it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|█████████████████████████████████| 28881/28881 [00:00<00:00, 195179.62it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|████████████████████████████| 1648877/1648877 [00:00<00:00, 1657546.28it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████| 4542/4542 [00:00<00:00, 8058599.31it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

z_dim = 100
img_dim = 28*28
lr = 0.0002
num_epochs = 50

gen = Generator(z_dim).to(device)
disc = Discriminator(img_dim).to(device)

opt_gen = optim.Adam(gen.parameters(), lr=lr)
opt_disc = optim.Adam(disc.parameters(), lr=lr)

criterion = nn.BCELoss()

for epoch in range(num_epochs):
    for (real, _), _ in zip(labeled_loader, unlabeled_loader):
        real = real.view(-1, 28*28).to(device)
        batch_size = real.shape[0]

        # Train Discriminator
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward(retain_graph=True)
        opt_disc.step()

        # Train Generator
        output = disc(fake).view(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}")

# Save the trained models
torch.save(gen.state_dict(), "generator.pth")
torch.save(disc.state_dict(), "discriminator.pth")


Epoch [1/50] Loss D: 0.4546, loss G: 0.5527
Epoch [2/50] Loss D: 0.3674, loss G: 0.7756
Epoch [3/50] Loss D: 0.4877, loss G: 0.6377
Epoch [4/50] Loss D: 0.4712, loss G: 0.7457
Epoch [5/50] Loss D: 0.3801, loss G: 0.9248
Epoch [6/50] Loss D: 0.4446, loss G: 0.8223
Epoch [7/50] Loss D: 0.4110, loss G: 0.9564
Epoch [8/50] Loss D: 0.2840, loss G: 1.3797
Epoch [9/50] Loss D: 0.1410, loss G: 1.9801
Epoch [10/50] Loss D: 0.1162, loss G: 2.1833
Epoch [11/50] Loss D: 0.1425, loss G: 1.9609
Epoch [12/50] Loss D: 0.1751, loss G: 1.8019
Epoch [13/50] Loss D: 0.1314, loss G: 2.0619
Epoch [14/50] Loss D: 0.1118, loss G: 2.1460
Epoch [15/50] Loss D: 0.1652, loss G: 1.7613
Epoch [16/50] Loss D: 0.2731, loss G: 1.2973
Epoch [17/50] Loss D: 0.4354, loss G: 0.9109
Epoch [18/50] Loss D: 0.2954, loss G: 1.2474
Epoch [19/50] Loss D: 0.5912, loss G: 0.7116
Epoch [20/50] Loss D: 0.6709, loss G: 0.6034
Epoch [21/50] Loss D: 0.5172, loss G: 0.7640
Epoch [22/50] Loss D: 0.3585, loss G: 1.1538
Epoch [23/50] Loss 

In [5]:
class DiscriminatorWithClassifier(nn.Module):
    def __init__(self, img_dim=28*28, num_classes=10):
        super(DiscriminatorWithClassifier, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 256),
            nn.ReLU()
        )
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        features = self.disc(x)
        class_logits = self.classifier(features)
        return class_logits

classifier = DiscriminatorWithClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=lr)

# Train the classifier
for epoch in range(num_epochs):
    for real, labels in labeled_loader:
        real = real.view(-1, 28*28).to(device)
        labels = labels.to(device)
        outputs = classifier(real)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {loss:.4f}")

# Evaluate the classifier
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

classifier.eval()
correct = 0
total = 0
with torch.no_grad():
    for real, labels in test_loader:
        real = real.view(-1, 28*28).to(device)
        labels = labels.to(device)
        outputs = classifier(real)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy: {100 * correct / total:.2f}%")


Epoch [1/50] Loss: 1.9339
Epoch [2/50] Loss: 1.4965
Epoch [3/50] Loss: 1.1181
Epoch [4/50] Loss: 1.0773
Epoch [5/50] Loss: 0.8905
Epoch [6/50] Loss: 0.7065
Epoch [7/50] Loss: 0.7730
Epoch [8/50] Loss: 0.5604
Epoch [9/50] Loss: 0.7145
Epoch [10/50] Loss: 0.6730
Epoch [11/50] Loss: 0.3918
Epoch [12/50] Loss: 0.5106
Epoch [13/50] Loss: 0.4982
Epoch [14/50] Loss: 0.5505
Epoch [15/50] Loss: 0.4640
Epoch [16/50] Loss: 0.2667
Epoch [17/50] Loss: 0.5339
Epoch [18/50] Loss: 0.2827
Epoch [19/50] Loss: 0.3184
Epoch [20/50] Loss: 0.4288
Epoch [21/50] Loss: 0.2706
Epoch [22/50] Loss: 0.2987
Epoch [23/50] Loss: 0.2334
Epoch [24/50] Loss: 0.2488
Epoch [25/50] Loss: 0.1967
Epoch [26/50] Loss: 0.5042
Epoch [27/50] Loss: 0.3402
Epoch [28/50] Loss: 0.2527
Epoch [29/50] Loss: 0.3566
Epoch [30/50] Loss: 0.2852
Epoch [31/50] Loss: 0.1853
Epoch [32/50] Loss: 0.2440
Epoch [33/50] Loss: 0.2029
Epoch [34/50] Loss: 0.2275
Epoch [35/50] Loss: 0.3487
Epoch [36/50] Loss: 0.2356
Epoch [37/50] Loss: 0.1936
Epoch [38/