In [1]:
import numpy as np
import torch
from torch import nn,optim,autograd
from torch.nn import functional as F 
import visdom
import random
from matplotlib import pyplot as plt

In [2]:
h_dim=400
batchsz=512
viz=visdom.Visdom() 

Setting up a new session...


In [3]:
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator,self).__init__()
        
        self.net=nn.Sequential(
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,2)
        )
        
    def forward(self,z):
        output=self.net(z)
        return output
    
    
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator,self).__init__()
        
        self.net=nn.Sequential(
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim,1),
            nn.Sigmoid()
        )
        
    def forward(self,x):
        output=self.net(x)
        return output.view(-1)

In [4]:
def data_generator():
    
    scale =2. 
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2))
    ]
    centers = [(scale * x, scale * y) for x, y in centers]
    while True:
        dataset = []
        for i in range(batchsz):
            point = np.random.randn(2) * .02
            center = random.choice(centers)
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)
        dataset = np.array(dataset, dtype='float32')
        dataset /= 1.414  # stdev
        yield dataset

In [6]:
def weight_init(m):
    
    if isinstance(m,nn.Linear):
        # m.weight.data.normal_(0.0,0.02)
        nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(0)

In [7]:
def main():
    torch.manual_seed(23)
    np.random.seed(23)
    
    G=Generator().cuda()
    D=Discriminator().cuda()
    G.apply(weight_init)
    D.apply(weight_init)
    
    optim_G=optim.Adam(G.parameters(),lr=1e-3,betas=(0.5,0.9))
    optim_D=optim.Adam(D.parameters(),lr=1e-3,betas=(0.5,0.9))
    
    data_iter=data_generator()
    print('batch:',next(data_iter).shape)
    
    viz.line([[0,0]],[0],win='loss',opts=dict(title='loss',legend=['D','G']))
    
    for epoch in range(50000):
        
        # 1.train discriminator for k steps
        for _ in range(5):
            x=next(data_iter)
            xr=torch.from_numpy(x).cuda()
            
            # [b]
            predr=(D(xr))
            # max log(lossr)
            lossr= -(predr.mean())
            
            # train on fake data
            # [b,2]
            z=torch.randn(batchsz,2).cuda()
            # stop gradient on G
            # [b,2]
            xf=G(z).detach() # stop gradient
            # [b]
            predf=(D(xf))
            # min predf
            lossf=(predf.mean())
            
            # aggregate all
            loss_D=lossr+lossf
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()
            
        # 2.train Genarator
        z=torch.randn(batchsz,2).cuda()
        xf=G(z)
        predf=(D(xf))
        # max predf
        loss_G= -(predf.mean())
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()
            
        if epoch % 100 ==0: 
            viz.line([[loss_D.item(),loss_G.item()]],[epoch],win='loss',update='append')
            
            print(loss_D.item(),loss_G.item())

In [8]:
if __name__=='__main__':
    main()

batch: (512, 2)
-0.4008290767669678 -0.3538018465042114
5.960464477539063e-08 -0.9999996423721313
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0
0.0 -1.0


KeyboardInterrupt: 

<Figure size 432x288 with 0 Axes>