In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'd:/samples'

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [5]:
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5),std=(0.5, 0.5, 0.5))])

In [6]:
mnist = torchvision.datasets.MNIST(root='d:/MNIST/',
                                   train=True,
                                   transform=transform,
                                   download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [7]:
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)

In [11]:
D = nn.Sequential(
    nn.Linear(image_size,hidden_size),
    nn.LeakyReLU(.2),
    nn.Linear(hidden_size,hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid()).to(device)

G = nn.Sequential(
    nn.Linear(latent_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,image_size),
    nn.Tanh()).to(device)

In [12]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

In [13]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

In [22]:
total_step = len(data_loader)
for epoch in range(num_epochs):
    for idx,(images,_) in enumerate(data_loader):
        images = images.reshape(batch_size,-1).to(device)
        
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        #------------DISCRIMINATOR-------------------------
        outputs = D(images)
        real_loss = criterion(outputs,real_labels)
        
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        fake_loss = criterion(outputs,fake_labels)
        
        d_loss = real_loss + fake_loss
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        #-------------Generator---------------------------
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        g_loss = criterion(outputs,real_labels)
        g_loss.backward()
        g_optimizer.step()
        
        if (idx+1) % 100 == 0:
            print(idx+1, total_step, d_loss.item(), g_loss.item())
            
    if epoch%5==0:
        fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
        save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

100 600 0.012241052463650703 5.728859901428223
200 600 0.01560187991708517 5.394524574279785
300 600 0.01006968505680561 6.615293979644775
400 600 0.007493121549487114 6.205960750579834
500 600 0.015997696667909622 5.115780830383301
600 600 0.019362976774573326 6.4097208976745605
100 600 0.02448098734021187 6.327150821685791
200 600 0.028336849063634872 5.604171276092529
300 600 0.082874596118927 7.147561550140381
400 600 0.0808338075876236 4.517849922180176
500 600 0.34115099906921387 6.846116542816162
600 600 0.0476398840546608 6.767890453338623
100 600 0.17276516556739807 3.8496320247650146
200 600 0.2278597503900528 4.556034088134766
300 600 0.25600844621658325 4.8330559730529785
400 600 0.44862455129623413 5.88062858581543
500 600 0.6797982454299927 3.5120224952697754
600 600 0.7296677231788635 3.0891616344451904
100 600 0.3494322896003723 3.349128484725952
200 600 0.5206717252731323 5.005690097808838
300 600 0.6268661618232727 4.6981682777404785
400 600 0.21782223880290985 4.5921

500 600 0.2684377431869507 4.209503173828125
600 600 0.3678598999977112 5.823183536529541
100 600 0.3849654197692871 3.6394660472869873
200 600 0.41577085852622986 4.256420135498047
300 600 0.23702286183834076 4.056980609893799
400 600 0.39605849981307983 4.16572904586792
500 600 0.577099084854126 2.6691315174102783
600 600 0.5744550228118896 2.7359120845794678
100 600 0.45633307099342346 3.4203689098358154
200 600 0.34383323788642883 4.734043121337891
300 600 0.3679172396659851 4.470970630645752
400 600 0.31219595670700073 4.767660617828369
500 600 0.43386849761009216 3.1205272674560547
600 600 0.6333253383636475 3.382389783859253
100 600 0.5066372156143188 3.496919870376587
200 600 0.33040761947631836 3.483180046081543
300 600 0.3530440628528595 3.658839225769043
400 600 0.3404299020767212 2.5893776416778564
500 600 0.4082452058792114 2.773582696914673
600 600 0.42346179485321045 2.8717477321624756
100 600 0.6381393671035767 3.0425963401794434
200 600 0.46984562277793884 3.3589217662

400 600 0.4081909656524658 2.589768409729004
500 600 0.7332768440246582 1.957647442817688
600 600 0.46015849709510803 2.895503282546997
100 600 0.5547909736633301 2.623548984527588
200 600 0.4659869074821472 2.419830322265625
300 600 0.5224642157554626 2.7527246475219727
400 600 0.4307149648666382 3.294341802597046
500 600 0.35783183574676514 2.535930871963501
600 600 0.46208035945892334 2.8016674518585205
100 600 0.41231781244277954 3.1813485622406006
200 600 0.42454656958580017 4.556445121765137
300 600 0.3976703882217407 2.4108965396881104
400 600 0.36639297008514404 2.3626418113708496
500 600 0.4453986883163452 2.8119254112243652
600 600 0.5366072654724121 3.9463837146759033
100 600 0.38625210523605347 3.6551733016967773
200 600 0.4713834524154663 3.901144504547119
300 600 0.3995468020439148 2.9084084033966064
400 600 0.5220706462860107 4.4046311378479
500 600 0.45284542441368103 2.969169855117798
600 600 0.5105085968971252 3.7193527221679688
100 600 0.5273188352584839 2.2667431831

300 600 0.6240241527557373 2.753700017929077
400 600 0.5668430924415588 2.3514368534088135
500 600 0.6179921627044678 1.5094621181488037
600 600 0.7787387371063232 1.2971781492233276
100 600 0.5596461296081543 1.8670634031295776
200 600 0.5396162271499634 1.567765712738037
300 600 0.6496707201004028 2.2078332901000977
400 600 0.7748672962188721 2.2609000205993652
500 600 0.7148561477661133 2.0183544158935547
600 600 0.9017795324325562 1.7705230712890625
100 600 0.5255138278007507 2.632322072982788
200 600 0.5191242098808289 2.449017286300659
300 600 0.5377441644668579 2.1085712909698486
400 600 0.6011058688163757 2.2717063426971436
500 600 0.46374765038490295 2.9865362644195557
600 600 0.7129180431365967 1.7814042568206787
100 600 0.6628451347351074 1.4624725580215454
200 600 0.5979446172714233 1.9321545362472534
300 600 0.5296754240989685 2.227856397628784
400 600 0.6768769025802612 1.5717779397964478
500 600 0.6336208581924438 1.6926002502441406
600 600 0.5061937570571899 2.367020368

200 600 0.6245473027229309 1.8911477327346802
300 600 0.6250484585762024 1.6594959497451782
400 600 0.5169390439987183 1.9337844848632812
500 600 0.6847496032714844 2.1469717025756836
600 600 0.6590530872344971 2.1933600902557373
100 600 0.5210105180740356 2.3782825469970703
200 600 0.78482985496521 2.0981080532073975
300 600 0.5855755805969238 1.5679110288619995
400 600 0.4560815691947937 1.6975886821746826
500 600 0.5794359445571899 2.092033624649048
600 600 0.523462176322937 2.014965534210205
100 600 0.7840739488601685 2.233156681060791
200 600 0.43409454822540283 2.293928384780884
300 600 0.6679710745811462 1.9925576448440552
400 600 0.4641237258911133 2.119900703430176
500 600 0.6476088762283325 2.0833568572998047
600 600 0.5511729121208191 1.9162194728851318
100 600 0.7501928806304932 1.5880404710769653
200 600 0.528354287147522 2.489081859588623
300 600 0.5651412606239319 2.0790252685546875
400 600 0.7106072902679443 2.108635663986206
500 600 0.5954843759536743 1.398903250694275

100 600 0.8059043884277344 1.441493034362793
200 600 0.5489091277122498 1.9122309684753418
300 600 0.5744408369064331 1.538828730583191
400 600 0.6800681948661804 2.4020285606384277
500 600 0.5342317819595337 1.6301908493041992
600 600 0.6511502861976624 1.975938081741333
100 600 0.6379017233848572 1.468636155128479
200 600 0.7495068907737732 2.066807508468628
300 600 0.5472633838653564 1.5946154594421387
400 600 0.588976263999939 1.9069035053253174
500 600 0.6117066144943237 1.74527907371521
600 600 0.6404586434364319 1.6761984825134277
100 600 0.8216403722763062 1.742174506187439
200 600 0.6643209457397461 1.5419137477874756
300 600 0.5875798463821411 1.7070896625518799
400 600 0.6993129253387451 1.8498018980026245
500 600 0.5248762965202332 2.252140522003174
600 600 0.7152931094169617 1.7848761081695557
100 600 0.48423731327056885 1.756818413734436
200 600 0.5304184556007385 2.2942018508911133
300 600 0.738967776298523 1.9390242099761963
400 600 0.6623252630233765 2.1401400566101074

100 600 0.5761970281600952 1.542922019958496
200 600 0.703243613243103 2.093426465988159
300 600 0.6176347732543945 1.7047940492630005
400 600 0.5850185751914978 1.487375020980835
500 600 0.7270022630691528 1.798017978668213
600 600 0.5452324151992798 1.586493968963623
100 600 0.658220648765564 1.6450690031051636
200 600 0.6541752219200134 1.8619710206985474
300 600 0.646456778049469 1.6302038431167603
400 600 0.5506955981254578 1.8970270156860352
500 600 0.7190794944763184 1.6999106407165527
600 600 0.4509534239768982 1.7432595491409302
100 600 0.6719924211502075 2.154324769973755
200 600 0.7145189642906189 1.2284010648727417
300 600 0.5863392949104309 2.082916021347046
400 600 0.554200291633606 2.5766491889953613
500 600 0.6765995621681213 1.8529303073883057
600 600 0.47587984800338745 1.9713330268859863
100 600 0.5182768702507019 2.2609992027282715
200 600 0.5782467722892761 2.332130193710327
300 600 0.7069176435470581 1.5140316486358643
400 600 0.6805291771888733 1.6664410829544067

In [17]:
images.size()

torch.Size([100])