In [None]:
import torchvision.datasets as datasets
import string
from typing import Any, Callable, Dict, List, Optional, Tuple
from torchvision.datasets.utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity
import shutil

from PIL import Image
import os
import os.path
import numpy as np
#https://pytorch.org/vision/stable/_modules/torchvision/datasets/mnist.html#EMNIST

class EMNIST(datasets.MNIST):
    url = 'https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip'
    md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
    splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
    # Merged Classes assumes Same structure for both uppercase and lowercase version
    _merged_classes = {'c', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 's', 'u', 'v', 'w', 'x', 'y', 'z'}
    _all_classes = set(string.ascii_lowercase+string.ascii_uppercase)
    classes_split_dict = {
        'byclass': sorted(list(_all_classes)),
        'bymerge': sorted(list(_all_classes - _merged_classes)),
        'balanced': sorted(list(_all_classes - _merged_classes)),
        'letters': ['N/A'] + list(string.ascii_lowercase),
        'digits': list(string.digits),
        'mnist': list(string.digits),
    }

    def __init__(self, root: str, split: str, **kwargs: Any) -> None:
        self.split = verify_str_arg(split, "split", self.splits)
        self.training_file = self._training_file(split)
        self.test_file = self._test_file(split)
        super(EMNIST, self).__init__(root, **kwargs)
        self.classes = self.classes_split_dict[self.split]

    @staticmethod
    def _training_file(split) -> str:
        return 'training_{}.pt'.format(split)

    @staticmethod
    def _test_file(split) -> str:
        return 'test_{}.pt'.format(split)

    @property
    def _file_prefix(self) -> str:
        return f"emnist-{self.split}-{'train' if self.train else 'test'}"

    @property
    def images_file(self) -> str:
        return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")

    @property
    def labels_file(self) -> str:
        return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")

    def _load_data(self):
        return read_image_file(self.images_file), read_label_file(self.labels_file)

    def _check_exists(self) -> bool:
        return all(check_integrity(file) for file in (self.images_file, self.labels_file))

    def download(self) -> None:
        """Download the EMNIST data if it doesn't exist already."""

        if self._check_exists():
            return

        os.makedirs(self.raw_folder, exist_ok=True)

        download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
        gzip_folder = os.path.join(self.raw_folder, 'gzip')
        for gzip_file in os.listdir(gzip_folder):
            if gzip_file.endswith('.gz'):
                extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
        shutil.rmtree(gzip_folder)


In [None]:
import torch
import torchvision.transforms as tran
import torchvision.utils as vutils
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mimg
from IPython.display import clear_output


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
workers = 2
batch_size = 64
#image_size = 64
nc = 1 # Number of channels in the training images
nz = 620 # Size of z latent vector
ngf = 128 # Size of feature maps in generator
ndf = 128 # Size of feature maps in discriminator
lr = 4e-4 # Learning rate for optimizers
beta1 = 0.5 # eBta1 hyperparam for Adam optimizers
ngpu = 1

transforms = tran.Compose(
    [
     #tran.Resize(image_size),
     tran.RandomRotation(15),
     #tran.CenterCrop(25),
     #tran.RandomPerspective(distortion_scale=0.15, p=0.6),
     #tran.RandomAffine(degrees=3, translate=None, scale=None, shear=2),
     lambda img: tran.functional.rotate(img, angle=-90),
     lambda img: tran.functional.hflip(img),
     tran.ToTensor(),
     tran.Normalize(
         [0.5 for _ in range(nc)], 
         [0.5 for _ in range(nc)]
         )
    ]
)

#dataset and dataloader
dataset = datasets.EMNIST(root="/dataset/", split="balanced", train=True, transform=transforms, download=True)

# inds = list(range(0, len(dataset),3))
# set_1 = torch.utils.data.Subset(dataset, inds)
dataloader = DataLoader(dataset, batch_size = batch_size, shuffle=True, drop_last=True)

print(device)

In [None]:
# idx = (dataset.targets==26) | (dataset.targets==27) | (dataset.targets==28) | (dataset.targets==29) | (dataset.targets==30) | (dataset.targets==31) | (dataset.targets==32) | (dataset.targets==33) | (dataset.targets==34) | (dataset.targets==35) | (dataset.targets==36) | (dataset.targets==37) | (dataset.targets==38) | (dataset.targets==39) | (dataset.targets==40) | (dataset.targets==41) |  (dataset.targets==42) |  (dataset.targets==43) |  (dataset.targets==44) |  (dataset.targets==45) |  (dataset.targets==46) |  (dataset.targets==47) |  (dataset.targets==48) |  (dataset.targets==49) |  (dataset.targets==50) |  (dataset.targets==51) |  (dataset.targets==52)
# dataset.targets = dataset.targets[idx]
# dataset.data = dataset.data[idx]
# dataset.classes = dataset.classes[idx]

In [None]:
l = dataset.classes
l.sort()
print("No of classes: ",len(l))
print("List of all classes")
print(l)

In [None]:
from torchvision.utils import make_grid

def show_batch(dl):
    for images, labels in dl:
        fig, ax = plt.subplots(figsize=(20, 20))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images, nrow=20).permute(1, 2, 0))
        break
        
show_batch(dataloader)

#inconsistency bw dataloader and dataset class

In [None]:
classes_num = 47

class Generator(nn.Module):
    def __init__(self, params):
        super().__init__()

        self.label_emb = nn.Embedding(classes_num, classes_num)

        self.tconv1 = nn.ConvTranspose2d(nz + classes_num, ngf * 8, kernel_size=3, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(ngf * 8)

        self.tconv2 = nn.ConvTranspose2d(ngf * 8, ngf * 8, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(ngf * 8)

        self.tconv3 = nn.ConvTranspose2d(ngf * 8, ngf * 2, 4, 2, 0, bias=False)
        self.bn3 = nn.BatchNorm2d(ngf * 2)

        self.tconv4 = nn.ConvTranspose2d(ngf*2, nc, 4, 2, 1, bias=False)

    def forward(self, x, labels):
        c = self.label_emb(labels)
        c = c.unsqueeze(2).unsqueeze(3)

        # print(c.size())
        # print(x.size())

        x = torch.cat([x, c], 1)
        x = F.relu(self.bn1(self.tconv1(x)))
        x = F.relu(self.bn2(self.tconv2(x)))
        x = F.relu(self.bn3(self.tconv3(x)))

        # x = F.relu(self.bn5(self.tconv5(x)))
        x = torch.tanh(self.tconv4(x))

        return x


netG = Generator(ngpu).to(device)

class Discriminator(nn.Module):
    def __init__(self, params):
        super().__init__()

        # meta data (label)
        self.label_emb = nn.Embedding(classes_num, classes_num)

        self.conv1 = nn.Conv2d(nc, ndf*2, 5, 2, 1, bias=False)

        self.conv2 = nn.Conv2d(ndf*2, ndf * 2, 5, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(ndf * 2)

        self.conv3 = nn.Conv2d(ndf * 2, classes_num, 5, 2, 0, bias=False)


        self.fc1 = nn.Linear(classes_num*2, classes_num)
        self.fc2 = nn.Linear(classes_num, 1)

    def forward(self, x, labels):
        x = F.leaky_relu(self.conv1(x), 0.2, True)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2, True)
        x = F.leaky_relu(self.conv3(x))
        x = torch.flatten(x, 1)

        c = self.label_emb(labels)
        x = torch.cat([x, c], 1)
        x = F.leaky_relu(self.fc1(x))
        x = F.sigmoid(self.fc2(x))

        return x

netD = Discriminator(ngpu).to(device)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netD.apply(weights_init)
netG.apply(weights_init)

In [None]:
# netD.load_state_dict(torch.load("../input/modelsgw3/discriminator_gw_3.pth"))
# netG.load_state_dict(torch.load("../input/modelsgw3/generator_gw_3.pth"))

In [None]:
#test gen
noise = torch.randn(10, nz, 1, 1, device=device).to(device)
fake_labels = Variable(torch.LongTensor(np.random.randint(0, classes_num, 10))).to(device)
fake = netG(noise, fake_labels)
s=fake.detach().cpu().numpy()
plt.imshow(s[0][0])
print(s[0][0].shape)

#test disc
fake_labels = Variable(torch.LongTensor(np.random.randint(0, classes_num, batch_size))).to(device)
noise2 = torch.randn(batch_size, nz, 1, 1, device=device)
fake = netG(noise2, fake_labels)
output = netD(fake, fake_labels).view(-1)

In [None]:
fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)
real_label = 1.
fake_label = 0.

criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=0.001)
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

img_list = []
G_losses = []
D_losses = []
D_xs = []
iters = 0

In [None]:
netG.load_state_dict(torch.load("../input/gwbase3/generator_gw_8.pth"))
netD.load_state_dict(torch.load("../input/gwbase3/discriminator_gw_8.pth"))

In [None]:
num_epochs =2

for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        clear_output(wait=True)
        
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        #print((data[1]))
        netD.zero_grad()
        # Format batch
#         print("debug", data[1])
        real_cpu = data[0].to(device)
        real_labels = data[1].to(device)
        real_labels=real_labels
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu, real_labels).view(-1)
        # Calculate loss on all-real batch
        # print(output[0])
        # print(label[0])
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake_labels = Variable(torch.LongTensor(np.random.randint(0, classes_num, batch_size))).to(device)
        # Generate fake image batch with G
        fake = netG(noise, fake_labels)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach(), fake_labels).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake, fake_labels).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        
        # print training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            
            plt.figure(figsize=(10,5))
            plt.title("Generator and Discriminator Loss")
            plt.plot(G_losses,label="Gen")
            plt.plot(D_losses,label="Disc")
            plt.xlabel("iterations")
            plt.ylabel("Loss")
            plt.legend()
            plt.show()

            plt.figure(figsize=(10,5))
            plt.title("D(x) ")
            plt.plot(D_xs)
            plt.xlabel("iterations")
            plt.ylabel("D(x)")
            plt.show()
    
        
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        D_xs.append(D_x)
        
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise, fake_labels).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        iters += 1

In [181]:
def generate_digit_from_label(label, seed):
  fake_label = torch.tensor([label]).cuda()
  with torch.no_grad():
    fake_ = netG(seed, fake_label).detach().cpu()
  return(fake_.squeeze())


fig, axes = plt.subplots(1,9, figsize = (20,8))
seed = torch.randn(1, nz, 1, 1, device=device)

def generate_noisy_seed(seed):
    seed = seed + (0.2**0.4)*torch.randn(1, nz, 1, 1, device=device)
    return(seed)
    
# seed = torch.randn(1, nz, 1, 1, device=device)

# axes[0].set_title('label 0')
axes[0].imshow(1-generate_digit_from_label(16, generate_noisy_seed(seed)), cmap="gray")
axes[0].axis('off')
# axes[1].set_title('label 1')
axes[1].imshow(1-generate_digit_from_label(14, generate_noisy_seed(seed)), cmap="gray")
axes[1].axis('off')
# axes[2].set_title('label 2')
axes[2].imshow(1-generate_digit_from_label(23, generate_noisy_seed(seed)), cmap="gray")
axes[2].axis('off')
# axes[3].set_title('label 3')
axes[3].imshow(1-generate_digit_from_label(14, generate_noisy_seed(seed)), cmap="gray")
axes[3].axis('off')
# axes[4].set_title('label 4')
axes[4].imshow(1-generate_digit_from_label(27, generate_noisy_seed(seed)), cmap="gray")
axes[4].axis('off')
# axes[5].set_title('label 5')
axes[5].imshow(1-generate_digit_from_label(10, generate_noisy_seed(seed)), cmap="gray")
axes[5].axis('off')
# axes[6].set_title('label 6')
axes[6].imshow(1-generate_digit_from_label(29, generate_noisy_seed(seed)), cmap="gray")
axes[6].axis('off')
# axes[7].set_title('label 7')
axes[7].imshow(1-generate_digit_from_label(14, generate_noisy_seed(seed)), cmap="gray")
axes[7].axis('off')
# axes[8].set_title('label 8')
axes[8].imshow(1-generate_digit_from_label(13, generate_noisy_seed(seed)), cmap="gray")
axes[8].axis('off')
# axes[9].set_title('label 9')
# axes[9].imshow(1-generate_digit_from_label(26, seed), cmap="gray")

In [None]:
for i in range(classes_num):
    plt.imshow(1-generate_digit_from_label(i, seed), cmap='gray')
    plt.show()

In [None]:
torch.save(netG.state_dict(), "generator_gw_8.pth")
torch.save(netD.state_dict(), "discriminator_gw_8.pth")

In [None]:
from IPython.display import FileLink
FileLink(r'./discriminator_gw_8.pth')

In [None]:
FileLink(r'./discriminator_gw_8.pth')

In [None]:
def generate_and_save_image_with_multiple_digits(digits, seed):
  gen_list=[]
  for dig in str(digits):
    gen_list.append(generate_digit_from_label(int(dig), seed))
  vis = np.concatenate((gen_list), axis=1)
  vis = 255-vis
  mimg.imsave("generated"+str(digits)+".png", vis, cmap='gray')
  return(vis)

['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'D', 'E', 'F', 'G', 'H', 'N', 'Q', 'R', 'T', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

In [None]:
plt.imshow(generate_and_save_image_with_multiple_digits(12110, seed), cmap='gray')
plt.axis("off")
plt.show()
plt.imshow(generate_and_save_image_with_multiple_digits(410, seed), cmap='gray')
plt.axis("off")
plt.show()
plt.imshow(generate_and_save_image_with_multiple_digits(6789, seed), cmap='gray')
plt.axis("off")
plt.show()