In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


In [2]:
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, img_shape):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, int(np.prod(img_shape))),
            nn.Tanh()
        )
        self.img_shape = img_shape

    def forward(self, z, labels):
        gen_input = torch.cat((z, self.label_emb(labels)), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *self.img_shape)
        return img


In [3]:
class Discriminator(nn.Module):
    def __init__(self, num_classes, img_shape):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)) + num_classes, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        d_input = torch.cat((img.view(img.size(0), -1), self.label_emb(labels)), -1)
        validity = self.model(d_input)
        return validity


In [4]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
latent_dim = 100
num_classes = 10  # For example, MNIST has 10 digit classes
img_shape = (1, 28, 28)
batch_size = 64
lr = 0.0002

# Initialize models
generator = Generator(latent_dim, num_classes, img_shape).to(device)
discriminator = Discriminator(num_classes, img_shape).to(device)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# Load dataset (MNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
num_epochs = 50

for epoch in range(num_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        batch_size = imgs.shape[0]
        real_imgs = imgs.to(device)
        labels = labels.to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        z = torch.randn(batch_size, latent_dim).to(device)
        gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
        fake_imgs = generator(z, gen_labels)
        
        real_validity = discriminator(real_imgs, labels)
        fake_validity = discriminator(fake_imgs.detach(), gen_labels)
        
        real_loss = adversarial_loss(real_validity, torch.ones_like(real_validity))
        fake_loss = adversarial_loss(fake_validity, torch.zeros_like(fake_validity))
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        fake_validity = discriminator(fake_imgs, gen_labels)
        g_loss = adversarial_loss(fake_validity, torch.ones_like(fake_validity))

        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")


Epoch [1/50] | D Loss: 0.1475 | G Loss: 3.0319
Epoch [2/50] | D Loss: 0.2876 | G Loss: 3.4285
Epoch [3/50] | D Loss: 0.4485 | G Loss: 1.9518
Epoch [4/50] | D Loss: 0.7276 | G Loss: 3.8470
Epoch [5/50] | D Loss: 0.2845 | G Loss: 2.2963
Epoch [6/50] | D Loss: 0.1322 | G Loss: 3.6823
Epoch [7/50] | D Loss: 0.2187 | G Loss: 3.4079
Epoch [8/50] | D Loss: 0.0678 | G Loss: 4.4234
Epoch [9/50] | D Loss: 0.0982 | G Loss: 5.1831
Epoch [10/50] | D Loss: 0.0744 | G Loss: 4.1677
Epoch [11/50] | D Loss: 0.2181 | G Loss: 3.7302
Epoch [12/50] | D Loss: 0.1228 | G Loss: 2.9255
Epoch [13/50] | D Loss: 0.0995 | G Loss: 4.9800
Epoch [14/50] | D Loss: 0.2097 | G Loss: 3.0810
Epoch [15/50] | D Loss: 0.0802 | G Loss: 3.6244
Epoch [16/50] | D Loss: 0.1530 | G Loss: 3.1879
Epoch [17/50] | D Loss: 0.1684 | G Loss: 3.4257
Epoch [18/50] | D Loss: 0.2246 | G Loss: 2.6944
Epoch [19/50] | D Loss: 0.1221 | G Loss: 3.8075
Epoch [20/50] | D Loss: 0.2353 | G Loss: 3.0551
Epoch [21/50] | D Loss: 0.2824 | G Loss: 4.0538
E

In [1]:
def generate_counterfactuals(generator, labels, latent_dim, img_shape, device):
    z = torch.randn(labels.shape[0], latent_dim).to(device)
    gen_labels = labels.to(device)
    fake_imgs = generator(z, gen_labels)
    return fake_imgs

# Generate images for validation
labels = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])  # Digits 0-9 for counterfactuals
generated_imgs = generate_counterfactuals(generator, labels, latent_dim, img_shape, device)

# Plot generated images
fig, axes = plt.subplots(1, 10, figsize=(10, 2))
for i, img in enumerate(generated_imgs):
    axes[i].imshow(img.cpu().detach().numpy().squeeze(), cmap="gray")
    axes[i].set_title(f"Label {labels[i].item()}")
    axes[i].axis("off")
plt.show()


NameError: name 'torch' is not defined

In [None]:
import torch.nn.functional as F

# Load trained classifier
classifier = torchvision.models.resnet18(pretrained=True)  # Example classifier
classifier.fc = nn.Linear(classifier.fc.in_features, num_classes)
classifier.to(device)
classifier.eval()

# Validate generated images
with torch.no_grad():
    outputs = classifier(generated_imgs.repeat(1, 3, 1, 1))  # Convert grayscale to 3 channels
    predictions = torch.argmax(F.softmax(outputs, dim=1), dim=1)

print("Generated Labels:", labels.tolist())
print("Classifier Predictions:", predictions.tolist())
