In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Define the Generator class
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()
        )
        self.img_shape = img_shape

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

# Define the Classifier class
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, img):
        return self.model(img)

# Define the Discriminator class
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

# Hyperparameters
latent_dim = 100
img_shape = (1, 28, 28)
batch_size = 64
lr = 0.0002
epochs = 50

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize models
generator = Generator(latent_dim, img_shape)
classifier = Classifier()
discriminator = Discriminator()
generator.to(device)
classifier.to(device)
discriminator.to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
optimizer_C = optim.Adam(classifier.parameters(), lr=lr)

# Loss functions
adversarial_loss = nn.BCELoss()
classification_loss = nn.CrossEntropyLoss()

# Training loop
for epoch in range(epochs):
    for i, (real_imgs, labels) in enumerate(dataloader):
        batch_size = real_imgs.shape[0]
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        real_imgs = real_imgs.to(device)
        labels = labels.to(device)
        real_labels = real_labels.to(device)
        fake_labels = fake_labels.to(device)


        # Train Classifier
        optimizer_C.zero_grad()
        class_preds = classifier(real_imgs)
        c_loss = classification_loss(class_preds, labels)
        c_loss.backward()
        optimizer_C.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), real_labels)
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z).detach()
        fake_loss = adversarial_loss(discriminator(fake_imgs), fake_labels)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, latent_dim).to(device)
        gen_imgs = generator(z)
        g_loss_adversarial = adversarial_loss(discriminator(gen_imgs), real_labels)
        class_preds_fake = classifier(gen_imgs)
        
        class_preds_fake_sm = class_preds_fake.softmax(-1)
        even = torch.ones_like(class_preds_fake_sm) * 0.1
        g_loss_misclassification = (class_preds_fake_sm - even).abs().max()
        g_loss = g_loss_adversarial + g_loss_misclassification * 2
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch {epoch+1}/{epochs} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f} | C Loss: {c_loss.item():.4f}")


Epoch 1/50 | D Loss: 0.0640 | G Loss: 6.0927 | C Loss: 0.2203
Epoch 2/50 | D Loss: 0.7741 | G Loss: 3.6505 | C Loss: 0.1359
Epoch 3/50 | D Loss: 0.1231 | G Loss: 4.9245 | C Loss: 0.1675
Epoch 4/50 | D Loss: 0.1094 | G Loss: 5.6866 | C Loss: 0.1001
Epoch 5/50 | D Loss: 0.0372 | G Loss: 5.5885 | C Loss: 0.2505
Epoch 6/50 | D Loss: 0.1230 | G Loss: 6.9776 | C Loss: 0.1399
Epoch 7/50 | D Loss: 0.1927 | G Loss: 5.8541 | C Loss: 0.0607
Epoch 8/50 | D Loss: 0.1343 | G Loss: 4.9601 | C Loss: 0.0482
Epoch 9/50 | D Loss: 0.2947 | G Loss: 4.3020 | C Loss: 0.1328
Epoch 10/50 | D Loss: 0.0141 | G Loss: 7.4313 | C Loss: 0.0064
Epoch 11/50 | D Loss: 0.6735 | G Loss: 4.3857 | C Loss: 0.0987
Epoch 12/50 | D Loss: 0.0364 | G Loss: 6.4866 | C Loss: 0.1163
Epoch 13/50 | D Loss: 0.0073 | G Loss: 13.0276 | C Loss: 0.0130
Epoch 14/50 | D Loss: 0.0122 | G Loss: 6.8299 | C Loss: 0.0092
Epoch 15/50 | D Loss: 0.0080 | G Loss: 8.0922 | C Loss: 0.0082
Epoch 16/50 | D Loss: 0.1816 | G Loss: 6.5035 | C Loss: 0.0855


In [2]:

# Hyperparameters
latent_dim = 100
img_shape = (1, 28, 28)
batch_size = 128
lr = 0.002
epochs = 50

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Initialize models
generator = Generator(latent_dim, img_shape)
classifier = Classifier()
discriminator = Discriminator()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
classifier.to(device)
discriminator.to(device)


# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
optimizer_C = optim.Adam(classifier.parameters(), lr=lr)

# Loss functions
adversarial_loss = nn.BCELoss()
classification_loss = nn.CrossEntropyLoss()


In [3]:
import tqdm

In [None]:

# Training loop
for epoch in range(epochs):
    for i, (real_imgs, labels) in tqdm.tqdm(enumerate(dataloader)):
        batch_size = real_imgs.shape[0]
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        real_imgs = real_imgs.to(device)
        labels = labels.to(device)
        real_labels = real_labels.to(device)
        fake_labels = fake_labels.to(device)
        

        # # Train Classifier
        # optimizer_C.zero_grad()
        # class_preds = classifier(real_imgs)
        # c_loss = classification_loss(class_preds, labels)
        # c_loss.backward()
        # optimizer_C.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), real_labels)
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z).detach()
        fake_loss = adversarial_loss(discriminator(fake_imgs), fake_labels)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, latent_dim).to(device)
        gen_imgs = generator(z)
        g_loss_adversarial = adversarial_loss(discriminator(gen_imgs), real_labels)
        class_preds_fake = classifier(gen_imgs)
        # g_loss_misclassification = class_preds_fake.std(-1).mean()
        class_preds_fake_sm = class_preds_fake.softmax(-1)
        # dist to uniform
        even = torch.ones_like(class_preds_fake_sm) * 0.1
        g_loss_misclassification = (class_preds_fake_sm - even).abs().max()
        g_loss = g_loss_adversarial# + g_loss_misclassification * 2
        g_loss.backward()
        optimizer_G.step()

        if i % 500 == 0:
            print(f"Epoch {epoch+1}/{epochs} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f} |")

        


13it [00:00, 68.55it/s]

Epoch 1/50 | D Loss: 0.6767 | G Loss: 0.6099 |


469it [00:09, 51.62it/s]
9it [00:00, 42.10it/s]

Epoch 2/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 48.65it/s]
9it [00:00, 41.89it/s]

Epoch 3/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:07, 59.15it/s]
7it [00:00, 36.08it/s]

Epoch 4/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 50.20it/s]
9it [00:00, 42.59it/s]

Epoch 5/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 50.79it/s]
9it [00:00, 40.19it/s]

Epoch 6/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:08, 57.72it/s]
9it [00:00, 41.93it/s]

Epoch 7/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 49.77it/s]
8it [00:00, 36.17it/s]

Epoch 8/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 50.03it/s]
9it [00:00, 43.31it/s]

Epoch 9/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:08, 58.53it/s]
8it [00:00, 39.15it/s]

Epoch 10/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 49.14it/s]
8it [00:00, 39.35it/s]

Epoch 11/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 50.06it/s]
10it [00:00, 49.40it/s]

Epoch 12/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:08, 54.28it/s]
8it [00:00, 38.58it/s]

Epoch 13/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 47.78it/s]
8it [00:00, 37.94it/s]

Epoch 14/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:08, 55.02it/s]
10it [00:00, 47.10it/s]

Epoch 15/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 50.64it/s]
9it [00:00, 41.72it/s]

Epoch 16/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 49.58it/s]
10it [00:00, 46.82it/s]

Epoch 17/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:07, 59.24it/s]
8it [00:00, 39.49it/s]

Epoch 18/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 49.85it/s]
8it [00:00, 36.75it/s]

Epoch 19/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 50.40it/s]
10it [00:00, 44.71it/s]

Epoch 20/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:07, 60.91it/s]
9it [00:00, 43.00it/s]

Epoch 21/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 49.78it/s]
8it [00:00, 39.13it/s]

Epoch 22/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 48.32it/s]
10it [00:00, 47.01it/s]

Epoch 23/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 50.32it/s]
8it [00:00, 38.43it/s]

Epoch 24/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:10, 46.05it/s]
10it [00:00, 43.16it/s]

Epoch 25/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:09, 49.53it/s]
9it [00:00, 46.26it/s]

Epoch 26/50 | D Loss: 0.0000 | G Loss: 100.0000 |


469it [00:08, 56.33it/s]
8it [00:00, 38.28it/s]

Epoch 27/50 | D Loss: 0.0000 | G Loss: 100.0000 |


264it [00:05, 47.52it/s]


KeyboardInterrupt: 

In [103]:
g_loss_adversarial

tensor(0.0127, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward0>)

In [34]:
class_preds_fake_sm[9]

tensor([1.3334e-14, 9.9997e-01, 2.5651e-13, 3.8565e-07, 3.4910e-07, 2.3707e-11,
        2.1518e-09, 8.2955e-08, 2.5312e-05, 1.2578e-13], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [33]:

import plotly.express as px

px.imshow(fake_imgs[9].cpu().detach().numpy().reshape(28, 28)).show()

In [19]:
class_preds_fake.softmax(-1)[0]

tensor([0.3719, 0.0007, 0.1431, 0.0079, 0.1035, 0.0179, 0.1975, 0.0017, 0.1127,
        0.0432], device='cuda:0', grad_fn=<SelectBackward0>)