In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset
from torchvision.utils import save_image
import torchvision

In [2]:

# Define the Generator
class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.output_dim = output_dim

        self.generator = nn.Sequential(
            nn.Linear(self.latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, self.output_dim),
            nn.Tanh()
        )

    def forward(self, z):
        return self.generator(z)


# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.input_dim = input_dim

        self.discriminator = nn.Sequential(
            nn.Linear(self.input_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.discriminator(x)


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:

# Hyperparameters
batch_size = 64
latent_dim = 100
num_epochs_gan = 50
num_epochs_model = 10

# Prepare the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

# Function to train the GAN
def train_gan():
    generator = Generator(latent_dim, 28 * 28)
    discriminator = Discriminator(28 * 28)

    criterion = nn.BCELoss()
    generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
    discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)

    for epoch in range(num_epochs_gan):
        for batch_idx, (real_images, _) in enumerate(dataloader):
            real_images = real_images.view(-1, 28 * 28)

            # Train the discriminator with real images
            discriminator_optimizer.zero_grad()
            real_labels = torch.ones(real_images.size(0), 1)
            real_pred = discriminator(real_images)
            real_loss = criterion(real_pred, real_labels)
            real_loss.backward()

            # Train the discriminator with fake images (generated by the generator)
            noise = torch.randn(real_images.size(0), latent_dim)
            fake_images = generator(noise)
            fake_labels = torch.zeros(real_images.size(0), 1)
            fake_pred = discriminator(fake_images.detach())
            fake_loss = criterion(fake_pred, fake_labels)
            fake_loss.backward()

            discriminator_optimizer.step()

            # Train the generator to fool the discriminator
            generator_optimizer.zero_grad()
            noise = torch.randn(real_images.size(0), latent_dim)
            fake_images = generator(noise)
            fake_pred = discriminator(fake_images)
            generator_loss = criterion(fake_pred, real_labels)
            generator_loss.backward()

            generator_optimizer.step()
        if (epoch + 1) % 5 == 0:
            fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
            fake_fname = 'fake_training_sample/fake_images-{0:0=4d}.png'.format(epoch + 1)
            print('Saving', fake_fname)
            save_image(denorm(fake_images), fake_fname)
    torch.save(generator.state_dict(), "generator.pth")
    torch.save(discriminator.state_dict(), "discriminator.pth")
# Train the GAN
train_gan()



Saving fake_training_sample/fake_images-0005.png
Saving fake_training_sample/fake_images-0010.png
Saving fake_training_sample/fake_images-0015.png
Saving fake_training_sample/fake_images-0020.png
Saving fake_training_sample/fake_images-0025.png
Saving fake_training_sample/fake_images-0030.png
Saving fake_training_sample/fake_images-0035.png
Saving fake_training_sample/fake_images-0040.png
Saving fake_training_sample/fake_images-0045.png
Saving fake_training_sample/fake_images-0050.png


In [5]:


# Set random seed for reproducibility
torch.manual_seed(42)

# Hyperparameters
batch_size = 64
learning_rate = 0.003
num_epochs = 15

# MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Model
class DigitRecognizer(nn.Module):
    def __init__(self):
        super(DigitRecognizer, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input image
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = DigitRecognizer()

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{total_step}], Loss: {loss.item():.4f}")

print("Training finished!")

# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Accuracy on the test set: {(100 * correct / total):.2f}%")


Epoch [1/15], Step [100/938], Loss: 0.3914
Epoch [1/15], Step [200/938], Loss: 0.3078
Epoch [1/15], Step [300/938], Loss: 0.4380
Epoch [1/15], Step [400/938], Loss: 0.1774
Epoch [1/15], Step [500/938], Loss: 0.2936
Epoch [1/15], Step [600/938], Loss: 0.4853
Epoch [1/15], Step [700/938], Loss: 0.1584
Epoch [1/15], Step [800/938], Loss: 0.2259
Epoch [1/15], Step [900/938], Loss: 0.2769
Epoch [2/15], Step [100/938], Loss: 0.1586
Epoch [2/15], Step [200/938], Loss: 0.1905
Epoch [2/15], Step [300/938], Loss: 0.3399
Epoch [2/15], Step [400/938], Loss: 0.0947
Epoch [2/15], Step [500/938], Loss: 0.1368
Epoch [2/15], Step [600/938], Loss: 0.2923
Epoch [2/15], Step [700/938], Loss: 0.5561
Epoch [2/15], Step [800/938], Loss: 0.1021
Epoch [2/15], Step [900/938], Loss: 0.3102
Epoch [3/15], Step [100/938], Loss: 0.1538
Epoch [3/15], Step [200/938], Loss: 0.0831
Epoch [3/15], Step [300/938], Loss: 0.1020
Epoch [3/15], Step [400/938], Loss: 0.0685
Epoch [3/15], Step [500/938], Loss: 0.0958
Epoch [3/15

In [6]:
def generate_synthetic_data(num_samples):
    generator = Generator(latent_dim, 28 * 28)
    generator.load_state_dict(torch.load("generator.pth"))
    generator.eval()

    synthetic_data = []

    with torch.no_grad():
        for _ in range(num_samples):
            noise = torch.randn(1, latent_dim)
            generated_image = generator(noise).view(-1, 1, 28, 28)
            synthetic_data.append(generated_image)

    synthetic_data = torch.cat(synthetic_data, dim=0)
    return synthetic_data

# Function to label synthetic data using the trained discriminator
def label_synthetic_data(synthetic_data):
    labels_ = list()
    with torch.no_grad():
        scores = model(synthetic_data)
        for tensor in scores:
            labels_.append(torch.argmax(tensor).item())
    #print(labels_)
    return labels_

# Combine real and synthetic data to create an augmented dataset
def augment_data(num_synthetic_samples):
    # real_data = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    # real_targets = dataset.targets

    synthetic_data = generate_synthetic_data(num_synthetic_samples)

    #synthetic_data = synthetic_data.view(-1, 28*28)  # Flatten the synthetic data to match discriminator's input size

    synthetic_targets = torch.full((num_synthetic_samples,), 10, dtype=torch.int64)  # Use label 10 for synthetic data

    # Create the TensorDataset for real and synthetic data
    # real_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    # synthetic_dataset = torch.utils.data.TensorDataset(synthetic_data, synthetic_targets)

    # Combine the real and synthetic datasets using ConcatDataset
    #augmented_dataset = ConcatDataset([real_dataset, synthetic_dataset])

    # Create DataLoader for the augmented dataset
    #augmented_dataloader = DataLoader(augmented_dataset, batch_size=batch_size, shuffle=True)


    # Label the synthetic data using the trained discriminator
    synthetic_labels = label_synthetic_data(synthetic_data)
    synthetic_labels = torch.tensor(synthetic_labels, dtype=torch.long)
    #synthetic_labels = (synthetic_labels - 0.5) / 0.5
    #synthetic_labels = synthetic_labels.to(torch.long)

    print(synthetic_data.view(-1, 1, 28, 28).size())
    print(synthetic_labels.size())
    # Update the synthetic targets with the predicted labels
    synthetic_dataset = torch.utils.data.TensorDataset(synthetic_data.view(-1, 1, 28, 28), synthetic_labels)

    # Combine the real and labeled synthetic datasets using ConcatDataset
    #augmented_dataset = ConcatDataset([real_dataset, synthetic_dataset])

    # Create DataLoader for the augmented dataset
    #transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    augmented_dataloader = DataLoader(synthetic_dataset, batch_size=batch_size, shuffle=True)
    #print(synthetic_labels[1])
    #print(synthetic_labels)
    return augmented_dataloader

num_synthetic_samples = 1500
#augmented_data, augmented_targets = augment_data(num_synthetic_samples)
aug = augment_data(num_synthetic_samples)

torch.Size([1500, 1, 28, 28])
torch.Size([1500])


In [7]:
model2 = DigitRecognizer()

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model2.parameters(), lr=learning_rate)

# Training loop
total_step = len(aug)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(aug):
            # Convert augmented data to float tensors if needed
        images = images.float()  # Change to .to(torch.float) if you have integer tensors

    # Convert model's parameters to the same data type as the augmented data
        for param in model2.parameters():
            param.data = param.data.to(images.dtype)

        outputs = model2(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    for i, (images, labels) in enumerate(train_loader):
            # Convert augmented data to float tensors if needed
        images = images.float()  # Change to .to(torch.float) if you have integer tensors

    # Convert model's parameters to the same data type as the augmented data
        for param in model2.parameters():
            param.data = param.data.to(images.dtype)

        outputs = model2(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{total_step}], Loss: {loss.item():.4f}")


# Test the model
model2.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in aug:
        outputs = model2(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Accuracy on the test set: {(100 * correct / total):.2f}%")

Epoch [1/15], Step [100/24], Loss: 0.5064
Epoch [1/15], Step [200/24], Loss: 0.2799
Epoch [1/15], Step [300/24], Loss: 0.4024
Epoch [1/15], Step [400/24], Loss: 0.3244
Epoch [1/15], Step [500/24], Loss: 0.2423
Epoch [1/15], Step [600/24], Loss: 0.2242
Epoch [1/15], Step [700/24], Loss: 0.2922
Epoch [1/15], Step [800/24], Loss: 0.2387
Epoch [1/15], Step [900/24], Loss: 0.1081
Epoch [2/15], Step [100/24], Loss: 0.2970
Epoch [2/15], Step [200/24], Loss: 0.1944
Epoch [2/15], Step [300/24], Loss: 0.2548
Epoch [2/15], Step [400/24], Loss: 0.1838
Epoch [2/15], Step [500/24], Loss: 0.1592
Epoch [2/15], Step [600/24], Loss: 0.2140
Epoch [2/15], Step [700/24], Loss: 0.1050
Epoch [2/15], Step [800/24], Loss: 0.1623
Epoch [2/15], Step [900/24], Loss: 0.1229
Epoch [3/15], Step [100/24], Loss: 0.1020
Epoch [3/15], Step [200/24], Loss: 0.1625
Epoch [3/15], Step [300/24], Loss: 0.1015
Epoch [3/15], Step [400/24], Loss: 0.0810
Epoch [3/15], Step [500/24], Loss: 0.1150
Epoch [3/15], Step [600/24], Loss:

In [8]:
print(len(aug))

24
