In [None]:
import numpy as np
import math
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import TensorDataset, DataLoader

In [None]:
from keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [None]:
x_train = (x_train/255.0).astype(np.float32)
y_train = (y_train).astype(np.float32)
x_test = (x_test/255.0).astype(np.float32)
y_test = (y_test).astype(np.float32)

x_train_tensor = torch.tensor(x_train)
y_train_tensor = torch.tensor(y_train)
x_test_tensor = torch.tensor(x_test)
y_test_tensor = torch.tensor(y_test)

train_dataset = TensorDataset(x_train_tensor,y_train_tensor)
test_dataset = TensorDataset(x_test_tensor,y_test_tensor)

train_dataloader = DataLoader(train_dataset,batch_size=1,shuffle=True)
test_dataloader = DataLoader(test_dataset,batch_size=1)

In [None]:
class Discriminator(torch.nn.Module):
  
  def __init__(self):
    super(Discriminator, self).__init__()
    self.conv1 = torch.nn.Conv2d(1,4,5)
    self.conv2 = torch.nn.Conv2d(4,8,5)

    self.pool = torch.nn.MaxPool2d(2,2)

    self.flatten = torch.nn.Flatten(start_dim=0)
    self.fc3 = torch.nn.Linear(4*4*8,256)
    self.fc4 = torch.nn.Linear(256,10)
    self.fc5 = torch.nn.Linear(10,1)

    self.dropout = torch.nn.Dropout(0.25)

  def forward(self,x):
    x = F.relu(self.conv1(x))
    x = self.pool(x)
    x = F.relu(self.conv2(x))
    x = self.pool(x)
    x = self.flatten(x)
    x = F.relu(self.fc3(x))
    x = self.dropout(x)
    x = F.relu(self.fc4(x))
    x = torch.sigmoid(self.fc5(x))
    return x

In [None]:
class Generator(torch.nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.fc1 = torch.nn.Linear(100,256)
    self.fc2 = torch.nn.Linear(256,512)
    self.fc3 = torch.nn.Linear(512,784)
  
  def forward(self):
    x = torch.normal(torch.zeros(100),torch.ones(100))
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = torch.tanh(self.fc3(x))
    x = x.view(1,28,28)
    return x


In [None]:
dis = Discriminator()
gen = Generator()

criterion = torch.nn.BCELoss()
optimizer_dis = optim.Adam(dis.parameters(),lr=0.0001)
optimizer_gen = optim.Adam(gen.parameters(),lr=0.0001)
real_label = torch.tensor([1],dtype=torch.float32)
fake_label = torch.tensor([0],dtype=torch.float32)
dis_loss = []
gen_loss = []

for epochs in range(200):
  print(epochs)
  img = gen().detach().reshape((28,28))
  plt.imsave("./images/"+str(epochs)+".png",img)
  c=1
  for data in train_dataloader:
    x_real,label = data
    x_fake = gen()

    #Discriminator train
    optimizer_dis.zero_grad()
    x_real_label = dis(x_real)
    x_fake_label = dis(x_fake)
    loss_dis = criterion(x_real_label,real_label)+criterion(x_fake_label,fake_label)
    loss_dis.backward()
    optimizer_dis.step()

    #Generator train
    optimizer_gen.zero_grad()
    x_fake_label = dis(x_fake.detach())
    loss_gen = criterion(x_fake_label,real_label)
    loss_gen.backward()

    optimizer_gen.step()
    if(c%1000==0):
      dis_loss.append(loss_dis)
      gen_loss.append(loss_gen)
    c+=1

0
1
2
3
4
5


In [None]:
img = gen().detach().reshape((28,28))
plt.imshow(img)

NameError: ignored