In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import os

In [3]:
cuda = True
cnt = 0
lr = 1e-4
batch_size = 4

nc = 3 # number of channels
nz = 64 # size of latent vector
ngf = 64 # decoder (generator) filter factor
ndf = 64 # encoder filter factor
h_dim = 128 # discriminator hidden size
lam = 1 # regulization coefficient

In [4]:
class Decoder(nn.Module):
        def __init__(self):
            super(Decoder, self).__init__()
            self.main = [
                # input is Z, going into a convolution
                nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
                nn.BatchNorm2d(ngf * 8),
                nn.ReLU(True),
                # state size. (ngf*8) x 4 x 4
                nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 4),
                nn.ReLU(True),
                # state size. (ngf*4) x 8 x 8
                nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 2),
                nn.ReLU(True),
                # state size. (ngf*2) x 16 x 16
                nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf),
                nn.ReLU(True),
                # state size. (ngf) x 32 x 32
                nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
                nn.Sigmoid()
                # state size. (nc) x 64 x 64
            ]
            for idx, module in enumerate(self.main):
                self.add_module(str(idx), module)

        def forward(self, x):
            for layer in self.main:
                x = layer(x)
            return x

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.main = [
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, nz, 4, 1, 0, bias=False),
        ]

        for idx, module in enumerate(self.main):
            self.add_module(str(idx), module)

    def forward(self, x):
        for layer in self.main:
            x = layer(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = [
            nn.Linear(nz, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        ]
        for idx, module in enumerate(self.main):
            self.add_module(str(idx), module)
    def forward(self, x):
        for layer in self.main:
            x = layer(x)
        return x
    

In [5]:
number_classes = 15
current_directory = os.getcwd()

In [None]:
for x in range(1, number_classes):
    data_path = current_directory + '\\swedish_data\\' + str(x) + '\\'
    print(data_path)
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    
    dataset = datasets.ImageFolder(data_path, transform)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=20, drop_last=True)
    out_dir = data_path + 'out'
    os.mkdir(out_dir)
    
    Q = Encoder()
    P = Decoder()
    D = Discriminator()

    if cuda:
        Q = Q.cuda()
        P = P.cuda()
        D = D.cuda()

    def reset_grad():
        Q.zero_grad()
        P.zero_grad()
        D.zero_grad()

    Q_solver = optim.Adam(Q.parameters(), lr=lr)
    P_solver = optim.Adam(P.parameters(), lr=lr)
    D_solver = optim.Adam(D.parameters(), lr=lr*0.1)
    
    for it in range(100):
        for batch_idx, batch_item in enumerate(data_loader):
            #X = sample_X(mb_size)
            """ Reconstruction phase """
            X = Variable(batch_item[0])
            if cuda:
                X = X.cuda()

            z_sample = Q(X)

            X_sample = P(z_sample)
            recon_loss = F.mse_loss(X_sample, X)

            recon_loss.backward()
            P_solver.step()
            Q_solver.step()
            reset_grad()

            """ Regularization phase """
            # Discriminator
            for _ in range(5):
                z_real = Variable(torch.randn(batch_size, nz))
                if cuda:
                    z_real = z_real.cuda()

                z_fake = Q(X).view(batch_size,-1)

                D_real = D(z_real)
                D_fake = D(z_fake)

                #D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))
                D_loss = -(torch.mean(D_real) - torch.mean(D_fake))

                D_loss.backward()
                D_solver.step()

                # Weight clipping
                for p in D.parameters():
                    p.data.clamp_(-0.01, 0.01)

                reset_grad()

            # Generator
            z_fake = Q(X).view(batch_size,-1)
            D_fake = D(z_fake)

            #G_loss = -torch.mean(torch.log(D_fake))
            G_loss = -torch.mean(D_fake)

            G_loss.backward()
            Q_solver.step()
            reset_grad()

            if batch_idx % 10 == 0:
                print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}'
                      .format(batch_idx, D_loss.item(), G_loss.item(), recon_loss.item()))
                torch.save(Q, out_dir + "\\Q_latest.pth")
                torch.save(P, out_dir + "\\P_latest.pth")
                torch.save(D, out_dir + "\\D_latest.pth")

            # Print and plot every now and then
            if batch_idx % 100 == 0:

                z_real = z_real.unsqueeze(2).unsqueeze(3) # add 2 dimensions
                if cnt % 2 == 0:
                    samples = P(z_real) # Generated
                else:
                    samples = X_sample # Reconstruction
                #samples = X_sample
                if cuda:
                    samples = samples.cpu()
                samples = samples.data.numpy()[:16]

                fig = plt.figure(figsize=(4, 4))
                gs = gridspec.GridSpec(4, 4)
                gs.update(wspace=0.05, hspace=0.05)

                for i, sample in enumerate(samples):
                    ax = plt.subplot(gs[i])
                    plt.axis('off')
                    ax.set_xticklabels([])
                    ax.set_yticklabels([])
                    ax.set_aspect('equal')
                    sample = np.swapaxes(sample,0,2)
                    plt.imshow(sample)


                if not os.path.exists(out_dir):
                    os.makedirs(out_dir)

                plt.savefig('{}/{}.png'
                            .format(out_dir,str(cnt).zfill(3)), bbox_inches='tight')
                cnt += 1
                plt.close(fig)

C:\Users\Parth Pankaj Tiwary\Desktop\examples-master\waae\swedish_data\1\
Iter-0; D_loss: 0.0001283; G_loss: -0.4966; recon_loss: 0.1274


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Iter-10; D_loss: 0.009878; G_loss: -0.5073; recon_loss: 0.09019
Iter-0; D_loss: 0.01469; G_loss: -0.512; recon_loss: 0.07822
Iter-10; D_loss: 0.01574; G_loss: -0.5141; recon_loss: 0.06801
Iter-0; D_loss: 0.009383; G_loss: -0.5068; recon_loss: 0.07664
Iter-10; D_loss: -0.00136; G_loss: -0.4964; recon_loss: 0.05659
Iter-0; D_loss: 0.004748; G_loss: -0.502; recon_loss: 0.06798
Iter-10; D_loss: 0.007554; G_loss: -0.5054; recon_loss: 0.05252
Iter-0; D_loss: 0.002973; G_loss: -0.5006; recon_loss: 0.05115
Iter-10; D_loss: -0.002648; G_loss: -0.4942; recon_loss: 0.04905
Iter-0; D_loss: 0.001428; G_loss: -0.4979; recon_loss: 0.03982
Iter-10; D_loss: 0.002143; G_loss: -0.4995; recon_loss: 0.03065
Iter-0; D_loss: 0.00334; G_loss: -0.4999; recon_loss: 0.03815
Iter-10; D_loss: 0.0004434; G_loss: -0.4973; recon_loss: 0.02675
Iter-0; D_loss: -0.0002833; G_loss: -0.4964; recon_loss: 0.03034
Iter-10; D_loss: -0.0006623; G_loss: -0.4969; recon_loss: 0.02408
Iter-0; D_loss: 0.001038; G_loss: -0.4978; rec

Iter-10; D_loss: -0.0003376; G_loss: -0.4967; recon_loss: 0.003917
Iter-0; D_loss: 3.105e-05; G_loss: -0.4969; recon_loss: 0.005287
Iter-10; D_loss: 0.0001315; G_loss: -0.4975; recon_loss: 0.004105
Iter-0; D_loss: -0.000497; G_loss: -0.4969; recon_loss: 0.004085
Iter-10; D_loss: -0.0005092; G_loss: -0.4972; recon_loss: 0.003821
Iter-0; D_loss: -0.000347; G_loss: -0.4973; recon_loss: 0.004724
Iter-10; D_loss: -0.0006266; G_loss: -0.4972; recon_loss: 0.003995
Iter-0; D_loss: -0.0007555; G_loss: -0.4971; recon_loss: 0.003855
Iter-10; D_loss: -0.0002241; G_loss: -0.4974; recon_loss: 0.005077
Iter-0; D_loss: -0.0004602; G_loss: -0.4975; recon_loss: 0.0043
Iter-10; D_loss: -4.938e-05; G_loss: -0.4979; recon_loss: 0.004736
Iter-0; D_loss: -0.001349; G_loss: -0.4965; recon_loss: 0.004249
Iter-10; D_loss: -0.0007824; G_loss: -0.4968; recon_loss: 0.004797
Iter-0; D_loss: -0.0004979; G_loss: -0.497; recon_loss: 0.004856
Iter-10; D_loss: -0.000125; G_loss: -0.4976; recon_loss: 0.003703
Iter-0; D_l

Iter-0; D_loss: 0.0004381; G_loss: -0.5002; recon_loss: 0.02259
Iter-10; D_loss: 0.0005351; G_loss: -0.5; recon_loss: 0.02341
Iter-0; D_loss: -0.0009683; G_loss: -0.5; recon_loss: 0.02535
Iter-10; D_loss: -0.0001058; G_loss: -0.4997; recon_loss: 0.02471
Iter-0; D_loss: -0.002843; G_loss: -0.4979; recon_loss: 0.01957
Iter-10; D_loss: -0.0003577; G_loss: -0.4996; recon_loss: 0.02317
Iter-0; D_loss: 0.0006011; G_loss: -0.5003; recon_loss: 0.02174
Iter-10; D_loss: -6.098e-05; G_loss: -0.5002; recon_loss: 0.02396
Iter-0; D_loss: -0.0002211; G_loss: -0.5004; recon_loss: 0.02477
Iter-10; D_loss: -0.0005088; G_loss: -0.4997; recon_loss: 0.02858
Iter-0; D_loss: -2.801e-05; G_loss: -0.4998; recon_loss: 0.02062
Iter-10; D_loss: -0.0003877; G_loss: -0.5004; recon_loss: 0.02221
Iter-0; D_loss: 0.0009018; G_loss: -0.5009; recon_loss: 0.02028
Iter-10; D_loss: -0.001181; G_loss: -0.5; recon_loss: 0.02466
Iter-0; D_loss: -0.0003709; G_loss: -0.4995; recon_loss: 0.0215
Iter-10; D_loss: -8.523e-06; G_los

Iter-10; D_loss: -0.0002711; G_loss: -0.4999; recon_loss: 0.005188
Iter-0; D_loss: -0.0009394; G_loss: -0.4992; recon_loss: 0.005359
Iter-10; D_loss: -0.000649; G_loss: -0.4994; recon_loss: 0.005381
Iter-0; D_loss: -0.0005788; G_loss: -0.4996; recon_loss: 0.004466
Iter-10; D_loss: 0.0006517; G_loss: -0.4998; recon_loss: 0.006017
Iter-0; D_loss: -0.0005425; G_loss: -0.4998; recon_loss: 0.004339
Iter-10; D_loss: -7.48e-05; G_loss: -0.5001; recon_loss: 0.005124
Iter-0; D_loss: -0.0003549; G_loss: -0.4999; recon_loss: 0.004559
Iter-10; D_loss: -0.0009453; G_loss: -0.4996; recon_loss: 0.005507
Iter-0; D_loss: -0.000959; G_loss: -0.4993; recon_loss: 0.004247
Iter-10; D_loss: -0.0009578; G_loss: -0.4995; recon_loss: 0.00548
Iter-0; D_loss: -0.0002481; G_loss: -0.4996; recon_loss: 0.004619
Iter-10; D_loss: -0.001021; G_loss: -0.4993; recon_loss: 0.006359
Iter-0; D_loss: -0.00062; G_loss: -0.4992; recon_loss: 0.005819
Iter-10; D_loss: -0.0001479; G_loss: -0.5; recon_loss: 0.004683
Iter-0; D_los

Iter-10; D_loss: 0.000133; G_loss: -0.4963; recon_loss: 0.004327
Iter-0; D_loss: 9.733e-05; G_loss: -0.4964; recon_loss: 0.00515
Iter-10; D_loss: -0.000183; G_loss: -0.4958; recon_loss: 0.004527
Iter-0; D_loss: -0.0002408; G_loss: -0.4963; recon_loss: 0.004701
Iter-10; D_loss: -0.0001723; G_loss: -0.496; recon_loss: 0.00515
Iter-0; D_loss: 0.0001196; G_loss: -0.4964; recon_loss: 0.005012
Iter-10; D_loss: -0.0003027; G_loss: -0.4961; recon_loss: 0.004053
Iter-0; D_loss: 0.0001623; G_loss: -0.4962; recon_loss: 0.005099
Iter-10; D_loss: -0.0004659; G_loss: -0.4959; recon_loss: 0.004323
Iter-0; D_loss: -0.000645; G_loss: -0.4959; recon_loss: 0.004033
Iter-10; D_loss: -0.0003343; G_loss: -0.496; recon_loss: 0.005517
Iter-0; D_loss: 1.422e-05; G_loss: -0.4966; recon_loss: 0.006509
Iter-10; D_loss: -0.0006765; G_loss: -0.4958; recon_loss: 0.004877
Iter-0; D_loss: -0.0006924; G_loss: -0.496; recon_loss: 0.005024
Iter-10; D_loss: -0.000477; G_loss: -0.496; recon_loss: 0.004858
Iter-0; D_loss: -

Iter-0; D_loss: -0.0009177; G_loss: -0.5007; recon_loss: 0.03712
Iter-10; D_loss: 0.0007477; G_loss: -0.5008; recon_loss: 0.03815
Iter-0; D_loss: 6.515e-05; G_loss: -0.501; recon_loss: 0.03311
Iter-10; D_loss: 0.000545; G_loss: -0.5009; recon_loss: 0.03163
Iter-0; D_loss: -0.0004584; G_loss: -0.5002; recon_loss: 0.03586
Iter-10; D_loss: -0.0004922; G_loss: -0.5007; recon_loss: 0.03476
Iter-0; D_loss: -0.00149; G_loss: -0.5003; recon_loss: 0.02793
Iter-10; D_loss: -0.001318; G_loss: -0.4996; recon_loss: 0.02578
Iter-0; D_loss: -0.001464; G_loss: -0.5002; recon_loss: 0.02753
Iter-10; D_loss: -0.002145; G_loss: -0.4986; recon_loss: 0.02581
Iter-0; D_loss: -0.001708; G_loss: -0.4996; recon_loss: 0.0281
Iter-10; D_loss: -0.001363; G_loss: -0.4999; recon_loss: 0.02523
Iter-0; D_loss: 0.001025; G_loss: -0.5003; recon_loss: 0.0234
Iter-10; D_loss: -0.0003969; G_loss: -0.5011; recon_loss: 0.02341
Iter-0; D_loss: 0.0005327; G_loss: -0.502; recon_loss: 0.0212
Iter-10; D_loss: -0.0003549; G_loss: 

Iter-0; D_loss: 9.537e-06; G_loss: -0.5001; recon_loss: 0.003484
Iter-10; D_loss: -0.001213; G_loss: -0.4986; recon_loss: 0.003948
Iter-0; D_loss: -0.001411; G_loss: -0.4988; recon_loss: 0.003219
Iter-10; D_loss: -0.001005; G_loss: -0.4995; recon_loss: 0.00302
Iter-0; D_loss: 0.0001152; G_loss: -0.5; recon_loss: 0.002883
Iter-10; D_loss: -0.0004669; G_loss: -0.4997; recon_loss: 0.002993
Iter-0; D_loss: -0.001182; G_loss: -0.4996; recon_loss: 0.002831
Iter-10; D_loss: -0.0007913; G_loss: -0.4993; recon_loss: 0.003275
Iter-0; D_loss: 0.0001572; G_loss: -0.4996; recon_loss: 0.002645
Iter-10; D_loss: 0.0001907; G_loss: -0.5001; recon_loss: 0.003027
Iter-0; D_loss: -0.0002598; G_loss: -0.5002; recon_loss: 0.002961
Iter-10; D_loss: -7.117e-05; G_loss: -0.4998; recon_loss: 0.003683
Iter-0; D_loss: -0.0008948; G_loss: -0.4994; recon_loss: 0.003159
Iter-10; D_loss: -0.000206; G_loss: -0.4998; recon_loss: 0.002835
Iter-0; D_loss: -0.0003042; G_loss: -0.4998; recon_loss: 0.002194
Iter-10; D_loss:

Iter-10; D_loss: -0.001104; G_loss: -0.5009; recon_loss: 0.008588
Iter-0; D_loss: -0.001906; G_loss: -0.5008; recon_loss: 0.007352
Iter-10; D_loss: -0.003235; G_loss: -0.4997; recon_loss: 0.005889
Iter-0; D_loss: -0.0004398; G_loss: -0.501; recon_loss: 0.007292
Iter-10; D_loss: -3.35e-05; G_loss: -0.501; recon_loss: 0.007487
Iter-0; D_loss: 0.001135; G_loss: -0.5032; recon_loss: 0.009266
Iter-10; D_loss: -6.849e-05; G_loss: -0.5013; recon_loss: 0.006768
Iter-0; D_loss: 0.00102; G_loss: -0.501; recon_loss: 0.008853
Iter-10; D_loss: -0.001381; G_loss: -0.5003; recon_loss: 0.007702
Iter-0; D_loss: -0.001917; G_loss: -0.4997; recon_loss: 0.006833
Iter-10; D_loss: -0.001148; G_loss: -0.5005; recon_loss: 0.005759
Iter-0; D_loss: 0.0006082; G_loss: -0.5015; recon_loss: 0.006352
Iter-10; D_loss: -0.0004352; G_loss: -0.5005; recon_loss: 0.006798
Iter-0; D_loss: -0.001148; G_loss: -0.5017; recon_loss: 0.007473
Iter-10; D_loss: 9.775e-06; G_loss: -0.5015; recon_loss: 0.005546
Iter-0; D_loss: -5.2

Iter-0; D_loss: 0.004821; G_loss: -0.506; recon_loss: 0.07168
Iter-10; D_loss: 0.005461; G_loss: -0.5069; recon_loss: 0.05672
Iter-0; D_loss: 0.002835; G_loss: -0.5041; recon_loss: 0.05895
Iter-10; D_loss: 0.002769; G_loss: -0.5039; recon_loss: 0.05175
Iter-0; D_loss: 0.0025; G_loss: -0.504; recon_loss: 0.05058
Iter-10; D_loss: 0.001924; G_loss: -0.5027; recon_loss: 0.0541
Iter-0; D_loss: 0.002284; G_loss: -0.5036; recon_loss: 0.04807
Iter-10; D_loss: 0.002851; G_loss: -0.503; recon_loss: 0.03627
Iter-0; D_loss: -0.001912; G_loss: -0.4989; recon_loss: 0.03341
Iter-10; D_loss: -0.002346; G_loss: -0.4989; recon_loss: 0.03497
Iter-0; D_loss: 0.001643; G_loss: -0.502; recon_loss: 0.03876
Iter-10; D_loss: 0.00077; G_loss: -0.5018; recon_loss: 0.03466
Iter-0; D_loss: -0.0006942; G_loss: -0.4999; recon_loss: 0.02969
Iter-10; D_loss: -0.0006259; G_loss: -0.5004; recon_loss: 0.02797
Iter-0; D_loss: -0.001068; G_loss: -0.4997; recon_loss: 0.02716
Iter-10; D_loss: 0.0008076; G_loss: -0.5015; reco

Iter-0; D_loss: -0.0004386; G_loss: -0.5007; recon_loss: 0.003881
Iter-10; D_loss: -0.0003961; G_loss: -0.5006; recon_loss: 0.003616
Iter-0; D_loss: -0.0005752; G_loss: -0.5007; recon_loss: 0.004316
Iter-10; D_loss: -0.001098; G_loss: -0.5001; recon_loss: 0.004361
Iter-0; D_loss: -0.0004091; G_loss: -0.5007; recon_loss: 0.004476
Iter-10; D_loss: -0.0001675; G_loss: -0.501; recon_loss: 0.003748
Iter-0; D_loss: -0.0001357; G_loss: -0.5006; recon_loss: 0.003922
Iter-10; D_loss: 0.0001173; G_loss: -0.5004; recon_loss: 0.004599
Iter-0; D_loss: -0.000256; G_loss: -0.501; recon_loss: 0.004501
Iter-10; D_loss: -0.0003771; G_loss: -0.5006; recon_loss: 0.005062
Iter-0; D_loss: 0.0002708; G_loss: -0.5009; recon_loss: 0.004305
Iter-10; D_loss: -0.0001433; G_loss: -0.5003; recon_loss: 0.005375
Iter-0; D_loss: -0.0006819; G_loss: -0.5002; recon_loss: 0.004145
Iter-10; D_loss: -0.000315; G_loss: -0.5001; recon_loss: 0.0047
Iter-0; D_loss: -0.000942; G_loss: -0.4999; recon_loss: 0.004094
Iter-10; D_lo