In [None]:
import argparse
import os
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

In [2]:
os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=250, help="number of epochs")
parser.add_argument("--batch_size", type=int, default=64, help="batch size")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimension of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of image dimension")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--ngf", type=int, default=64, help="size of feature maps in generator")
parser.add_argument("--ndf", type=int, default=64, help="size of feature maps in discriminator")
parser.add_argument("--multiplier", type=float, default=0.6, help="weighting multiplier, which controls the relative contribution of generated data to the classifier training ")
parser.add_argument("--threshold", type=float, default=0.7, help="confidence threshold, which controls the quality of data to be used for classifier training")
parser.add_argument("--datasize", type=float, default=0.1, help="atasize")

opt, unknown = parser.parse_known_args()

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
# ResNet Classifier
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=opt.n_classes):
        super(ResNet, self).__init__()
        self.in_planes = 16
        self.embDim = 128 * block.expansion
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 128, num_blocks[3], stride=2)
        self.linear = nn.Linear(128 * block.expansion, num_classes)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        emb = out.view(out.size(0), -1)
        out = self.linear(emb)
        return out
    def get_embedding_dim(self):
        return self.embDim

def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

In [4]:
#Conditional GAN Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
        
        def block(in_feat, out_feat, kernel, stride, padding, bias=False):
            layers = [nn.ConvTranspose2d(in_feat, out_feat, kernel, stride, padding, bias=False)]
            layers.append(nn.BatchNorm2d(out_feat))
            layers.append(nn.ReLU(True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim + opt.n_classes, opt.ngf * 4, 4, 1, 0, bias=False),
            *block(opt.ngf * 4, opt.ngf * 2, 4, 2, 1, bias=False),
            *block(opt.ngf * 2, opt.ngf, 4, 2, 1, bias=False),
            nn.ConvTranspose2d(opt.ngf,opt.channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        labels = self.label_emb(labels)
        labels = torch.reshape(labels, (labels.size(0), labels.size(1), 1 , 1))
        gen_input = torch.cat((labels, noise), 1)
        img = self.model(gen_input)

        return img

In [5]:
#Conditional GAN Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)
        self.linear_expand = nn.Linear(opt.n_classes, int(opt.img_size**2))
    
        def block(in_feat, out_feat, kernel, stride, padding, bias=False, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, kernel, stride, padding, bias=False)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.channels+1, opt.ndf, 4, 2, 1, bias=False, normalize=False),#(64,7,128,128)
            *block(opt.ndf , opt.ndf * 2, 4, 2, 1, bias=False),
            *block(opt.ndf * 2, opt.ndf * 4, 4, 2, 1, bias=False),
            nn.Conv2d(opt.ndf* 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        labels = self.label_emb(labels)
        labels = self.linear_expand(labels)
        labels = torch.reshape(labels, (labels.size(0), 1, opt.img_size, opt.img_size))
        d_in = torch.cat((img, labels), 1)
        validity = self.model(d_in)
        
        return validity

In [6]:
# Loss functions
adversarial_loss = torch.nn.BCELoss()
criterion = nn.CrossEntropyLoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
netC = ResNet18()

generator.to(device)
discriminator.to(device)
netC.to(device)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optC = torch.optim.Adam(netC.parameters(), lr=opt.lr, betas=(0.5, 0.999), weight_decay = 1e-3)

FloatTensor = torch.FloatTensor
LongTensor = torch.LongTensor

In [7]:
transform = transforms.Compose([
        transforms.Resize(32),          # Resize to the same size
        transforms.RandomCrop(32, padding=4),      # Crop to get square area
        transforms.RandomRotation(10),
        transforms.ToTensor(),            
        transforms.Normalize((0.5, 0.5, 0.5),
                                (0.5, 0.5, 0.5))])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])


# regular data loaders
batch_size = opt.batch_size
trainset = datasets.SVHN("datasets/SVHN", split='train', download = True, transform=transform)
traindataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = datasets.SVHN("datasets/SVHN", split='test', download = True, transform=transform_test)
testloader= torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                           shuffle=False, num_workers=2)

#create subsets
dataSizeConstant = opt.datasize
subTrainSet,_ = torch.utils.data.random_split(trainset, [int(dataSizeConstant*len(trainset)), len(trainset)-int(dataSizeConstant*len(trainset))])

Using downloaded and verified file: datasets/SVHN\train_32x32.mat
Using downloaded and verified file: datasets/SVHN\test_32x32.mat


In [8]:
def make_weights_for_balanced_classes(images, nclasses): 
    count = [0] * nclasses 
    for item in images: 
        count[item[1]] += 1 
    weight_per_class = [0.] * nclasses 
    N = float(sum(count)) 
    for i in range(nclasses): 
        weight_per_class[i] = N/float(count[i]) 
    weight = [0] * len(images) 
    for idx, val in enumerate(images): 
        weight[idx] = weight_per_class[val[1]] 
    return weight 

In [10]:
def gain_sample_w(dataset, batch_size,weights):
    sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights)) 
    loader = DataLoader(dataset, shuffle=False, batch_size=batch_size,sampler=sampler, num_workers=8, pin_memory=True)
    return loader

In [12]:
def validate():
    netC.eval()
    correct = 0
    total = 0
    global gpred_labels, greal_labels
    gpred_labels = torch.empty(0).to(device)
    greal_labels = torch.empty(0).to(device)
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = netC(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            gpred_labels = torch.cat((gpred_labels, torch.flatten(predicted)))
            greal_labels = torch.cat((greal_labels, torch.flatten(labels)))
    
    print('Accuracy of the network on the test images: %f %%' % (100 * correct / total))

In [11]:
imgs = []
for data in subTrainSet:
        img, target = data
        imgs.append((img, target))
weights = make_weights_for_balanced_classes(imgs, 10)
weights = torch.DoubleTensor(weights)

subTrainLoader = gain_sample_w(subTrainSet, batch_size= batch_size, weights=weights)

In [None]:
G_losses = []
D_losses = []
C_losses = []
# ----------
#  Training
# ----------
for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(subTrainLoader):

        batch_size = imgs.shape[0]
        validlabel = Variable(FloatTensor(batch_size, ).fill_(1.0), requires_grad=False).to(device)
        fakelabel = Variable(FloatTensor(batch_size, ).fill_(0.0), requires_grad=False).to(device)
        
        real_imgs = Variable(imgs.type(FloatTensor)).to(device)
        labels = Variable(labels.type(LongTensor)).to(device)
        
        #---------------------------------
        # (1) Update D network: 
        #---------------------------------
        # Train with all-real data batch
        discriminator.zero_grad()
        # Forward pass real batch through D
        output = discriminator(real_imgs, labels).view(-1)
        # Calculate loss on all-real batch
        errD_real = adversarial_loss(output, validlabel)
        # Calculate gradients for D in backward pass
        errD_real.backward()

        # Train with all-fake data batch
        # Generate batch of latent vectors and fake labels
        z = torch.randn(batch_size, opt.latent_dim, 1, 1, device = device)
        gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
        # Generate fake image batch with G
        fake =  generator(z, labels)
        # Discriminate all fake batch with D
        output = discriminator(fake.detach(), labels).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = adversarial_loss(output, fakelabel)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward(retain_graph = True)
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizer_D.step()
        
        #---------------------------------
        # (2) Update G network: 
        #---------------------------------
        generator.zero_grad()
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = discriminator(fake, labels).view(-1)
        # Calculate G's loss based on this output
        errG = adversarial_loss(output, validlabel)
        # Calculate gradients for G
        errG.backward()
        # Update G
        optimizer_G.step()
        
        #---------------------------------
        #(3) Updata C network:
        #---------------------------------
        # train classifier on real data
        
        fake = fake.detach().clone()
        predictions = netC(real_imgs)
        realClassifierLoss = criterion(predictions, labels)
        realClassifierLoss.backward(retain_graph = True)

        optC.step()
        optC.zero_grad()
        
        # train classifier on the synthesized data selected by the discriminator with a confidence being greater than or equal to β.
        x = output.ge(opt.threshold)
        Drealfake = fake[x]
        Dreallabels = labels[x]
        if Drealfake.shape[0]!=0:
            
            predictionsFake = netC(Drealfake)
            fakeClassifierLoss = criterion(predictionsFake, Dreallabels)*opt.multiplier
            fakeClassifierLoss.backward()

            optC.step()
            optC.zero_grad()
        
        if i % 50 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [C loss: %f]"
                % (epoch, opt.n_epochs, i, len(subTrainLoader), errD.item(), errG.item(), realClassifierLoss.item()))
            
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        C_losses.append(realClassifierLoss.item())
        
        batches_done = epoch * len(subTrainLoader) + i
        
    validate()

In [None]:
import matplotlib.pyplot as plt
plt.style.use('classic')
plt.rcParams['figure.facecolor'] = 'white'

plt.figure(figsize=(20,5))
plt.subplot(1, 2, 1)

plt.title("Generator and Discriminator Loss During Training")
line1 =plt.plot(G_losses,'b-')
line2 =plt.plot(D_losses,'r-')
plt.legend(labels=['Generator_loss', ' Discriminator_loss'])
plt.xlabel("iterations")

plt.subplot(1, 2, 2)
line3 =plt.plot(C_losses,'g-')
plt.title('Classifier loss')
plt.legend(labels=['Classifier_loss'])
plt.xlabel("iterations")

plt.savefig('Conditional_classifier_GAN.png',dpi=400,bbox_inches='tight')
plt.show()

In [None]:
torch.save(generator, f'generator_{opt.datasize}.pth') 
torch.save(discriminator, f'discriminator_{opt.datasize}.pth')
torch.save(netC, f'netC_{opt.datasize}.pth')