In [1]:
import os, time
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable


In [2]:
# G(z)
class generator_weight(nn.Module):
    # initializers
    def __init__(self):
        super(generator_weight, self).__init__()
        self.w_11 = nn.Parameter(torch.tensor(torch.rand(8,4), requires_grad=True))
        self.w_12 = nn.Parameter(torch.tensor(torch.rand(4,8), requires_grad=True))

        self.w_21 = nn.Parameter(torch.tensor(torch.rand(8,4), requires_grad=True))
        self.w_22 = nn.Parameter(torch.tensor(torch.rand(4,8), requires_grad=True))

        self.w_31 = nn.Parameter(torch.tensor(torch.rand(16,8), requires_grad=True))
        self.w_32 = nn.Parameter(torch.tensor(torch.rand(8,16), requires_grad=True))

        self.w_41 = nn.Parameter(torch.tensor(torch.rand(28,16), requires_grad=True))
        self.w_42 = nn.Parameter(torch.tensor(torch.rand(16,28), requires_grad=True))

    # forward method
    def forward(self, input, label):
        x = self.w_11 @ input @ self.w_12
        y = self.w_21 @ label @ self.w_22
        z = torch.cat([x.unsqueeze(1), y.unsqueeze(1)], dim=1)
        z = torch.mean(z, dim=1)
        z = self.w_31 @ z @ self.w_32
        z = self.w_41 @ z @ self.w_42
        return z

In [4]:
def normal_init(m, mean, std):
    if isinstance(m, nn.Linear):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [5]:

class discriminator(nn.Module):
    # initializers
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1_1 = nn.Linear(784, 1024)
        self.fc1_2 = nn.Linear(10, 1024)
        self.fc2 = nn.Linear(2048, 512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, 256)
        self.fc3_bn = nn.BatchNorm1d(256)
        self.fc4 = nn.Linear(256, 1)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input, label):
        x = F.leaky_relu(self.fc1_1(input), 0.2)
        y = F.leaky_relu(self.fc1_2(label), 0.2)
        x = torch.cat([x, y], 1)
        x = F.leaky_relu(self.fc2_bn(self.fc2(x)), 0.2)
        x = F.leaky_relu(self.fc3_bn(self.fc3(x)), 0.2)
        x = F.sigmoid(self.fc4(x))

        return x


In [6]:
def lable2mat4(y):
    matrix = torch.nn.functional.pad(y, (0, 6), value=0)
    matrix = matrix.view(4, 4)
    return matrix

In [7]:
temp_z_ = torch.rand(10, 4,4)
fixed_z_ = temp_z_
fixed_y_ = torch.zeros(10, 1)
for i in range(9):
    fixed_z_ = torch.cat([fixed_z_, temp_z_], 0)
    temp = torch.ones(10,1) + i
    fixed_y_ = torch.cat([fixed_y_, temp], 0)

In [8]:

fixed_z_ = Variable(fixed_z_.cuda(), volatile=True)
fixed_y_label_ = torch.zeros(100, 10)
fixed_y_label_.scatter_(1, fixed_y_.type(torch.LongTensor), 1)
fixed_y_label_ = Variable(fixed_y_label_.cuda(), volatile=True)
fixed_y_label__mat_ =  torch.stack([lable2mat4(y) for y  in fixed_y_label_])

  fixed_z_ = Variable(fixed_z_.cuda(), volatile=True)
  fixed_y_label_ = Variable(fixed_y_label_.cuda(), volatile=True)


In [9]:
fixed_z_.shape

torch.Size([100, 4, 4])

In [10]:

def show_result(num_epoch, show = False, save = False, path = 'result.png'):

    G.eval()
    test_images = G(fixed_z_, fixed_y_label__mat_)
    G.train()

    size_figure_grid = 10
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)

    for k in range(10*10):
        i = k // 10
        j = k % 10
        ax[i, j].cla()
        ax[i, j].imshow(test_images[k].cpu().data.view(28, 28).numpy(), cmap='gray')

    label = 'Epoch {0}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')
    plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()


In [11]:

def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
    x = range(len(hist['D_losses']))

    y1 = hist['D_losses']
    y2 = hist['G_losses']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()


In [12]:

# training parameters
batch_size = 128
lr = 0.0002
train_epoch = 50


In [13]:

# data_loader
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, ), std=(0.5, ))
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)


In [14]:

# network
G = generator_rezhghi()
D = discriminator()
D.weight_init(mean=0, std=0.02)
G.cuda()
D.cuda()


  self.w_11 = nn.Parameter(torch.tensor(torch.rand(8,4), requires_grad=True))
  self.w_12 = nn.Parameter(torch.tensor(torch.rand(4,8), requires_grad=True))
  self.w_21 = nn.Parameter(torch.tensor(torch.rand(8,4), requires_grad=True))
  self.w_22 = nn.Parameter(torch.tensor(torch.rand(4,8), requires_grad=True))
  self.w_31 = nn.Parameter(torch.tensor(torch.rand(16,8), requires_grad=True))
  self.w_32 = nn.Parameter(torch.tensor(torch.rand(8,16), requires_grad=True))
  self.w_41 = nn.Parameter(torch.tensor(torch.rand(28,16), requires_grad=True))
  self.w_42 = nn.Parameter(torch.tensor(torch.rand(16,28), requires_grad=True))


discriminator(
  (fc1_1): Linear(in_features=784, out_features=1024, bias=True)
  (fc1_2): Linear(in_features=10, out_features=1024, bias=True)
  (fc2): Linear(in_features=2048, out_features=512, bias=True)
  (fc2_bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc3_bn): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)

In [15]:

# Binary Cross Entropy loss
BCE_loss = nn.BCELoss()

# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))


In [16]:

# results save folder
if not os.path.isdir('MNIST_cGAN_results'):
    os.mkdir('MNIST_cGAN_results')
if not os.path.isdir('MNIST_cGAN_results/Fixed_results'):
    os.mkdir('MNIST_cGAN_results/Fixed_results')


In [17]:

train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []
train_hist['per_epoch_ptimes'] = []
train_hist['total_ptime'] = []


In [18]:
! pip install torchmetrics

[31mERROR: Operation cancelled by user[0m[31m
[0m

In [19]:
from torchmetrics.image import StructuralSimilarityIndexMeasure
ssim = StructuralSimilarityIndexMeasure(data_range=1.0)


In [None]:

print('training start!')
start_time = time.time()
for epoch in range(train_epoch):
    D_losses = []
    G_losses = []
    ssim_scores = []

    # learning rate decay
    if (epoch+1) == 30:
        G_optimizer.param_groups[0]['lr'] /= 10
        D_optimizer.param_groups[0]['lr'] /= 10
        print("learning rate change!")

    if (epoch+1) == 40:
        G_optimizer.param_groups[0]['lr'] /= 10
        D_optimizer.param_groups[0]['lr'] /= 10
        print("learning rate change!")

    epoch_start_time = time.time()
    for x_, y_ in train_loader:
        # train discriminator D
        D.zero_grad()
        mini_batch = x_.size()[0]
        if(mini_batch < 128):
            continue
        y_real_ = torch.ones(mini_batch)
        y_fake_ = torch.zeros(mini_batch)
        y_label_ = torch.zeros(mini_batch, 10)
        y_label_.scatter_(1, y_.view(mini_batch, 1), 1)

        x_mat_ = x_
        x_ = x_.view(-1, 28 * 28)
        x_, y_label_, y_real_, y_fake_ = Variable(x_.cuda()), Variable(y_label_.cuda()), Variable(y_real_.cuda()), Variable(y_fake_.cuda())
        D_result = D(x_, y_label_).squeeze()
        D_real_loss = BCE_loss(D_result, y_real_)

        z_ = torch.rand((128, 4,4))
        y_ = (torch.rand(mini_batch, 1) * 10).type(torch.LongTensor)
        y_label_ = torch.zeros(mini_batch, 10)
        y_label_.scatter_(1, y_.view(mini_batch, 1), 1)

        z_, y_label_ = Variable(z_.cuda()), Variable(y_label_.cuda())
        y_lable_mat_ =  torch.stack([lable2mat4(y) for y  in y_label_])

        G_result = G(z_, y_lable_mat_)
        G_result = G_result.view(-1, 28 * 28)
        D_result = D(G_result, y_label_).squeeze()
        D_fake_loss = BCE_loss(D_result, y_fake_)
        D_fake_score = D_result.data.mean()

        D_train_loss = D_real_loss + D_fake_loss

        D_train_loss.backward()
        D_optimizer.step()

        D_losses.append(D_train_loss.data)

        # train generator G
        G.zero_grad()

        z_ = torch.rand((128, 4,4))
        y_ = (torch.rand(mini_batch, 1) * 10).type(torch.LongTensor)
        y_label_ = torch.zeros(mini_batch, 10)
        y_label_.scatter_(1, y_.view(mini_batch, 1), 1)
        z_, y_label_ = Variable(z_.cuda()), Variable(y_label_.cuda())
        y_lable_mat_ =  torch.stack([lable2mat4(y) for y  in y_label_])

        G_result_mat_ = G(z_, y_lable_mat_)
        G_result = G_result_mat_.view(-1, 28 * 28)
        D_result = D(G_result, y_label_).squeeze()
        G_train_loss = BCE_loss(D_result, y_real_)
        G_train_loss.backward()
        G_optimizer.step()

        G_losses.append(G_train_loss.data)

        ssim_score = (float)(ssim( G_result_mat_.unsqueeze(1).cpu() , x_mat_.cpu() ))
        ssim_scores.append(ssim_score)

    epoch_end_time = time.time()
    per_epoch_ptime = epoch_end_time - epoch_start_time


    print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f, ssim_score:%f' % ((epoch + 1), train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_losses)),
                                                              torch.mean(torch.FloatTensor(G_losses)),torch.mean(torch.FloatTensor(ssim_scores))))
    fixed_p = 'MNIST_cGAN_results/Fixed_results/MNIST_cGAN_' + str(epoch + 1) + '.png'
    show_result((epoch+1), save=True, path=fixed_p)
    train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
    train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
    train_hist['per_epoch_ptimes'].append(per_epoch_ptime)

end_time = time.time()
total_ptime = end_time - start_time
train_hist['total_ptime'].append(total_ptime)


training start!
[1/50] - ptime: 41.00, loss_d: 0.710, loss_g: 1.341, ssim_score:-0.000000
[2/50] - ptime: 33.18, loss_d: 0.180, loss_g: 2.562, ssim_score:-0.000000
[3/50] - ptime: 34.41, loss_d: 0.073, loss_g: 3.505, ssim_score:-0.000000
[4/50] - ptime: 33.21, loss_d: 0.052, loss_g: 3.980, ssim_score:-0.000000
[5/50] - ptime: 33.34, loss_d: 0.021, loss_g: 4.742, ssim_score:-0.000000
[6/50] - ptime: 32.94, loss_d: 0.033, loss_g: 4.547, ssim_score:-0.000000
[7/50] - ptime: 33.38, loss_d: 0.011, loss_g: 5.292, ssim_score:-0.000000
[8/50] - ptime: 34.32, loss_d: 0.019, loss_g: 5.693, ssim_score:-0.000003
[9/50] - ptime: 33.44, loss_d: 0.014, loss_g: 5.323, ssim_score:-0.000039
[10/50] - ptime: 33.59, loss_d: 0.015, loss_g: 5.408, ssim_score:0.000033
[11/50] - ptime: 33.94, loss_d: 0.014, loss_g: 5.784, ssim_score:0.000011
[12/50] - ptime: 34.18, loss_d: 0.005, loss_g: 6.181, ssim_score:-0.000001
[13/50] - ptime: 33.82, loss_d: 0.003, loss_g: 7.062, ssim_score:-0.000003
[14/50] - ptime: 33.

In [None]:

print("Avg one epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), train_epoch, total_ptime))
print("Training finish!... save training results")
torch.save(G.state_dict(), "MNIST_cGAN_results/generator_param.pkl")
torch.save(D.state_dict(), "MNIST_cGAN_results/discriminator_param.pkl")
with open('MNIST_cGAN_results/train_hist.pkl', 'wb') as f:
    pickle.dump(train_hist, f)

show_train_hist(train_hist, save=True, path='MNIST_cGAN_results/MNIST_cGAN_train_hist.png')


In [None]:

images = []
for e in range(train_epoch):
    img_name = 'MNIST_cGAN_results/Fixed_results/MNIST_cGAN_' + str(e + 1) + '.png'
    images.append(imageio.imread(img_name))
imageio.mimsave('MNIST_cGAN_results/generation_animation.gif', images, fps=5)
