In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torch.nn.functional import binary_cross_entropy_with_logits as bincre_wlogits

import torchvision
from torch.utils.data import DataLoader
from torch.autograd.variable import Variable

device = torch.device("cuda:0" if torch.cuda.is_available()
                      else "cpu")
torch.manual_seed(1)


<torch._C.Generator at 0x2a269cf54d0>

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_set = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST('./data', train=False, download=True, transform=transform)

In [3]:
# Random sample
def sample_Z(m, n):
    return torch.randn(m, n)

# Re-implement 
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / (in_dim / 2.)**0.5  # Using standard Python sqrt calculation
    return torch.randn(*size) * xavier_stddev

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.W1 = nn.Parameter(xavier_init([784, 128]))
        self.b1 = nn.Parameter(torch.zeros(128))
        self.W2 = nn.Parameter(xavier_init([128, 1]))
        self.b2 = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        x = torch.relu(x @ self.W1 + self.b1)
        logits = x @ self.W2 + self.b2
        prob = torch.sigmoid(logits)
        return prob, logits


class Generator(nn.Module):
    def __init__(self, zdim):
        super(Generator, self).__init__()
        self.W1 = nn.Parameter(xavier_init([zdim, 128]))
        self.b1 = nn.Parameter(torch.zeros(128))
        self.W2 = nn.Parameter(xavier_init([128, 784]))
        self.b2 = nn.Parameter(torch.zeros(784))

    def forward(self, z):
        z = torch.relu(z @ self.W1 + self.b1)
        log_prob = z @ self.W2 + self.b2
        prob = torch.sigmoid(log_prob)
        return prob

In [5]:
m = 50
k = 1
zdim = 100

def train(iterations, loss_criterion=""):
    generator = Generator(zdim=zdim)
    discriminator = Discriminator()
    D_optimizer = torch.optim.Adam(params=discriminator.parameters(), lr=0.001)
    G_optimizer = torch.optim.Adam(params=generator.parameters(), lr=0.001)
    
    train_loader = DataLoader(train_set, batch_size=m, shuffle=True)

    if loss_criterion == "logistic":
        print("Running with logistic loss")

    for it in range(iterations):
        
        # Update discriminator (k times)
        for jt in range(k):  
            (batch_images, batch_label)  = next(iter(train_loader))
            batch_images = batch_images.view(batch_images.size(0), -1)  # Flatten the image tensor if needed
            Z = sample_Z(1, zdim)  # Sample noise

            D_optimizer.zero_grad()
            x = generator(Z)

            D_real, D_logit_real = discriminator(batch_images)
            D_fake, D_logit_fake = discriminator(x)    

            # print(D_real.shape, D_fake.shape, D_logit_real.shape, D_logit_fake.shape)

            if loss_criterion == "logistic":
                pass # D_loss_real = -torch.mean(bincre_wlogits(D_logits))
            else:
                D_loss = -torch.mean(torch.log(D_real) + torch.log(torch.ones(D_fake.shape) - D_fake))

            D_optimizer.step()

        # Update generator
        Z = sample_Z(1, zdim)  # Sample noise
        G_optimizer.zero_grad()
        G_fake, G_logit_fake = discriminator(generator(Z))

        if loss_criterion == "logistic":
            pass
        else:
            G_loss = -torch.mean(torch.log(1 - D_fake))
            
        G_optimizer.step()

        if it % 2000 == 0:
            print(f"Iteration {it}: D_loss {D_loss}, G_loss {G_loss}")

    print(f"Final: D_loss {D_loss}, G_loss {G_loss}")
    return discriminator, generator

In [6]:
# Train with std loss
D, G = train(iterations=5_000)

Iteration 0: D_loss 1.5805743932724, G_loss 1.0986809730529785


KeyboardInterrupt: 

In [None]:
# Train with logistic loss
D, G = train(iterations=20_000, loss_criterion="logistic")