<a href="https://colab.research.google.com/github/abhg86/GGP_Breakthrough/blob/main/A2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
class Generator(nn.Module):
    def __init__(self, g_output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(100, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))


def save_models(G, D, folder):
    os.makedirs(folder, exist_ok=True)  # Ensure the folder exists
    torch.save(G.state_dict(), os.path.join(folder, 'G.pth'))
    torch.save(D.state_dict(), os.path.join(folder, 'D.pth'))


In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm

class WDiscriminator(torch.nn.Module):
    '''Return an output that is not a probability.
    In WGAN we try to compare distances thus real are annotated as 1 and fake as -1.'''
    def __init__(self, d_input_dim):
        super(WDiscriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        # no sigmoid because no probability
        return x

class WGAN():
    def __init__(self, lr=0.00005, n_critic=5, clip_value=0.01, batch_size=64):
        self.mnist_dim = 784
        self.lr = lr
        self.n_critic = n_critic
        self.clip_value = clip_value
        self.batch_size = batch_size
        self.cuda = torch.cuda.is_available()

        self.G = Generator(g_output_dim = self.mnist_dim).cuda()
        self.D = WDiscriminator(self.mnist_dim).cuda()
        self.G = torch.nn.DataParallel(self.G).cuda()
        self.D = torch.nn.DataParallel(self.D).cuda()
        self.G_optimizer = torch.optim.RMSprop(self.G.parameters(), lr = lr)
        self.D_optimizer = torch.optim.RMSprop(self.D.parameters(), lr = lr)
        self.criterion = nn.HingeEmbeddingLoss()

        if (self.cuda):
            self.G = self.G.cuda()
            self.D = self.D.cuda()

    def train(self, train_loader, n_epoch):

        iter = self.get_iter(train_loader)

        for epoch in tqdm.tqdm(range(n_epoch)):
            for _ in range(len(train_loader)//self.n_critic):
                for p in self.D.parameters():
                    p.requires_grad = True

                for c_iter in range(self.n_critic):
                    for p in self.D.parameters():
                        p.data.clamp_(-self.clip_value, self.clip_value)

                    self.D.zero_grad()

                    x = next(iter)
                    if (self.cuda):
                        x = x.cuda()
                    x = x.view(-1, self.mnist_dim)

                    print("x", x.shape)
                    y_real = torch.ones(x.shape[0], 1)
                    y_real.requires_grad = True
                    # again, we are not interested in probabilities so we use -1 to represent fake
                    y_fake = torch.ones(x.shape[0], 1) * -1
                    y_fake.requires_grad = True
                    if (self.cuda):
                        y_real = y_real.cuda()
                        y_fake = y_fake.cuda()

                    D_real = self.D(x)
                    print("D", D_real.shape)
                    # D_real = D_real.mean()
                    D_real = self.criterion(y_real, D_real)
                    D_real.backward()

                    z = torch.randn(x.shape[0], 100).cuda()
                    G_z = self.G(z)
                    D_fake = self.D(G_z)
                    # D_fake = D_fake.mean()
                    print("D2", D_fake.shape)
                    D_fake = self.criterion(y_fake, D_fake)
                    D_fake.backward()
                    D_loss = D_fake + D_real
                    self.D_optimizer.step()

                for p in self.D.parameters():
                    # to avoid computation
                    p.requires_grad = False

                self.G.zero_grad()
                z = torch.randn(x.shape[0], 100)
                if (self.cuda):
                    z = z.cuda()
                G_z = self.G(z)
                D_fake = self.D(G_z)
                G_loss = self.criterion(D_fake, y_real)
                G_loss.backward()
                self.G_optimizer.step()

            if epoch % 10 == 0:
                save_models(self.G, self.D, 'checkpoints')


    def get_iter(self, data_loader):
        for i, (x, _) in enumerate(data_loader):
            yield x

    def generate(self, n):
        z = torch.randn(n, 100)
        if (self.cuda):
            z = z.cuda()
        return self.G(z)



In [4]:
# loading dataset MNIST
from torchvision import datasets, transforms

transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = datasets.MNIST(root='data/MNIST/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='data/MNIST/', train=False, transform=transform, download=False)


train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=2048, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=2048, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.1MB/s]


Extracting data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 482kB/s]


Extracting data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.43MB/s]


Extracting data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.50MB/s]

Extracting data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/MNIST/raw






In [22]:
mnist_dim = 784
G = torch.nn.DataParallel(Generator(g_output_dim = mnist_dim)).cuda()
D = torch.nn.DataParallel(WDiscriminator(mnist_dim)).cuda()

Wgan = WGAN(lr=0.00005, n_critic=5, clip_value=0.01, batch_size=64)
Wgan.train(train_loader, 100)


  0%|          | 0/100 [00:00<?, ?it/s]

x torch.Size([2048, 784])
D torch.Size([2048, 256])
D2 torch.Size([2048, 256])
x torch.Size([2048, 784])
D torch.Size([2048, 256])
D2 torch.Size([2048, 256])
x torch.Size([2048, 784])
D torch.Size([2048, 256])
D2 torch.Size([2048, 256])
x torch.Size([2048, 784])
D torch.Size([2048, 256])
D2 torch.Size([2048, 256])
x torch.Size([2048, 784])
D torch.Size([2048, 256])
D2 torch.Size([2048, 256])
x torch.Size([2048, 784])
D torch.Size([2048, 256])
D2 torch.Size([2048, 256])
x torch.Size([2048, 784])
D torch.Size([2048, 256])
D2 torch.Size([2048, 256])
x torch.Size([2048, 784])
D torch.Size([2048, 256])
D2 torch.Size([2048, 256])
x torch.Size([2048, 784])
D torch.Size([2048, 256])
D2 torch.Size([2048, 256])
x torch.Size([2048, 784])
D torch.Size([2048, 256])
D2 torch.Size([2048, 256])
x torch.Size([2048, 784])
D torch.Size([2048, 256])
D2 torch.Size([2048, 256])
x torch.Size([2048, 784])
D torch.Size([2048, 256])
D2 torch.Size([2048, 256])
x torch.Size([2048, 784])
D torch.Size([2048, 256])


  0%|          | 0/100 [00:06<?, ?it/s]


KeyboardInterrupt: 