In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.nn.utils import spectral_norm

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define hyperparameters
batch_size = 64
z_size = 100
img_size = 28
num_epochs = 30
lr = 0.0002
n_heads = 8

# Define transform for FashionMNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

class_num =9

# Define FashionMNIST dataset
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Define SAGANGenerator class
class GeneratorSNBlock(nn.Module):
    # A generator block to upsample the input by a factor of 2
    def __init__(self, in_channels, out_channels):
        super(GeneratorSNBlock, self).__init__()
        self.relu = nn.ReLU()
        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")

        self.conv_module = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels, out_channels, 3, 1, padding=1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            spectral_norm(nn.Conv2d(out_channels, out_channels, 3, 1, padding=1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

        self.residual_conv = spectral_norm(nn.Conv2d(in_channels, out_channels, 1, 1))

    def forward(self, x):
        identity = x

        # Upsample and SN Conv
        x = self.upsample(x)
        x = self.conv_module(x)

        # Residual connection
        return x + self.residual_conv(self.upsample(identity))

class SAGANGenerator(nn.Module):
    def __init__(self, z_dim, in_channels, n_heads=1):
        super(SAGANGenerator, self).__init__()
        self.z_dim = z_dim
        self.in_channels = in_channels
        self.n_heads = n_heads

        self.z_linear = spectral_norm(
            nn.Linear(self.z_dim, 4 * 4 * self.in_channels, bias=False)
        )

        self.block1 = GeneratorSNBlock(self.in_channels, 256)  # 8 x 8
        self.block2 = GeneratorSNBlock(256, 128)  # 16 x 16
        self.block3 = GeneratorSNBlock(128, 64)  # 32 x 32
        self.block4 = GeneratorSNBlock(64, 32)  # 64 x 64
        self.block5 = GeneratorSNBlock(32, 16)  # 128 x 128

        self.attn1 = nn.MultiheadAttention(64, num_heads=self.n_heads)
        self.alpha = nn.Parameter(torch.tensor(0.0), requires_grad=True)

        self.last = nn.Sequential(
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 3, 3, 1, padding=1),
        )

    def forward(self, z):
        # Reshape z
        z = self.z
        # Define the generator and discriminator models

class Discriminator(nn.Module):
    def __init__(self, discriminator_layer_size, img_size, class_num):
        super(Discriminator, self).__init__()

        self.label_emb = nn.Embedding(class_num, class_num)
        self.img_size = img_size

        self.model = nn.Sequential(
            nn.Linear(self.img_size * self.img_size + class_num, discriminator_layer_size[0]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(discriminator_layer_size[0], discriminator_layer_size[1]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(discriminator_layer_size[1], discriminator_layer_size[2]),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(discriminator_layer_size[2], 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        # Reshape image
        x = x.view(-1, self.img_size * self.img_size)

        # One-hot vector to embedding vector
        c = self.label_emb(labels)

        # Concat image & label
        x = torch.cat([x, c], 1)

        # Discriminator out
        out = self.model(x)

        return out.squeeze()


generator = SAGANGenerator(z_size, 16, n_heads=8).to(device)
discriminator = Discriminator(z_size, img_size, class_num).to(device)

# Define the loss function and optimizers
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)

# Train the models
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Train the discriminator
        d_optimizer.zero_grad()

        # Train on real images
        real_images = images.to(device)
        real_labels = labels.to(device)
        real_validity = discriminator(real_images, real_labels)
        real_loss = criterion(real_validity, torch.ones_like(real_validity))

        # Train on fake images
        noise = torch.randn(images.size(0), z_size).to(device)
        fake_images = generator(noise, real_labels)
        fake_validity = discriminator(fake_images, real_labels)
        fake_loss = criterion(fake_validity, torch.zeros_like(fake_validity))

        # Combine losses and backpropagate
        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_optimizer.step()

        # Train the generator
        g_optimizer.zero_grad()

        # Train on fake images
        noise = torch.randn(images.size(0), z_size).to(device)
        fake_images = generator(noise, real_labels)
        fake_validity = discriminator(fake_images, real_labels)
        g_loss = criterion(fake_validity, torch.ones_like(fake_validity))

        # Backpropagate
        g_loss.backward()
        g_optimizer.step()

    # Print training stats
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

# Save the trained models
torch.save(generator.state_dict(), 'generator.pt')
torch.save(discriminator.state_dict(), 'discriminator.pt')

ModuleNotFoundError: No module named 'torch'