In [1]:
from pythonfiles.data_class import *
from pythonfiles.generator import GenModel
from pythonfiles.discriminator import DisModel
from pythonfiles.model_utility import *


from torchsummary import summary
import torch


In [2]:
hparams = {
    'z_shape': 128,
    'gen_final_layer_size': 3,
    'image_input_shape': 64,
    'batch_size': 2,
    'epochs': 5
}


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using the  Device - ", device)

Using the  Device -  cuda:0


In [4]:

gen_model = GenModel(hparams)
dis_model = DisModel(hparams)

dis_model = dis_model.apply(init_weights)
gen_model = gen_model.apply(init_weights)

dis_model = dis_model.to(device)
gen_model = gen_model.to(device)

In [5]:

gen_optimizer = torch.optim.Adadelta(gen_model.parameters())
dis_optimizer = torch.optim.Adadelta(dis_model.parameters())
loss_function = torch.nn.BCELoss(reduction='mean')

In [6]:
from torchsummary import summary
summary(dis_model,(3,64,64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 32, 32]             294
       BatchNorm2d-2            [-1, 6, 32, 32]              12
             PReLU-3            [-1, 6, 32, 32]               1
            Conv2d-4           [-1, 12, 16, 16]           1,164
       BatchNorm2d-5           [-1, 12, 16, 16]              24
             PReLU-6           [-1, 12, 16, 16]               1
            Conv2d-7             [-1, 24, 8, 8]           4,632
       BatchNorm2d-8             [-1, 24, 8, 8]              48
             PReLU-9             [-1, 24, 8, 8]               1
           Conv2d-10             [-1, 48, 2, 2]          18,480
      BatchNorm2d-11             [-1, 48, 2, 2]              96
            PReLU-12             [-1, 48, 2, 2]               1
           Conv2d-13              [-1, 1, 1, 1]             193
          Flatten-14                   

In [7]:
real_dataloader = data_setup(batch_size=hparams['batch_size'])
real_label = torch.ones(hparams['batch_size'], 1,device=device)
fake_label = torch.zeros(hparams['batch_size'], 1,device=device)

In [8]:
dis_fn_real_list=[]
dis_fn_fake_list=[]

dis_loss_list=[]
gen_loss_list=[]

for epoch in range(hparams['epochs']):
    for i, data in enumerate(real_dataloader, 0):

        dis_model.zero_grad()

        #Discriminator update for real data
        data = data[0]
        data = data.to(device)

        real_outputs = dis_model(data)
        dis_fn_real_list.append(torch.mean(real_outputs.detach().cpu()))
        real_loss = loss_function(real_outputs, real_label)
        real_loss.backward()
        running_real_loss = real_loss.item()

        #Discriminator update for fake data

        fake_noise = torch.randn(
            hparams['batch_size'], hparams['z_shape'], 1, 1,device=device)
        fake_image = gen_model(fake_noise)
        fake_outputs = dis_model(fake_image.detach())
        dis_fn_fake_list.append(torch.mean(fake_outputs.detach().cpu()))

        fake_loss = loss_function(fake_outputs, fake_label)
        fake_loss.backward()
        running_fake_loss = fake_loss.item()

        dis_loss_list.append(running_fake_loss+running_real_loss)

        dis_optimizer.step()
        
        #Generator update

        fake_noise = torch.randn(
            hparams['batch_size'], hparams['z_shape'], 1, 1,device=device)
        fake_image = gen_model(fake_noise)
        fake_outputs = dis_model(fake_image.detach())
        
        gen_loss = loss_function(fake_outputs, real_label)
        running_gen_loss = gen_loss.item()
        gen_loss_list.append(running_gen_loss)

        gen_loss.backward()
        gen_optimizer.step()

        

In [9]:
len(dis_fn_fake_list)

25

In [10]:
len(dis_fn_real_list)

25

In [11]:
len(dis_loss_list)

25

In [12]:
len(gen_loss_list)

25

In [13]:
gen_loss_list

[1.3018591403961182,
 0.8219870924949646,
 0.6639832258224487,
 1.9228219985961914,
 2.0108916759490967,
 3.0045371055603027,
 2.7105469703674316,
 2.7447574138641357,
 2.9884958267211914,
 3.725288152694702,
 3.241199493408203,
 3.0003814697265625,
 4.090398788452148,
 2.816197395324707,
 3.771660327911377,
 3.9374473094940186,
 3.363403558731079,
 3.5021839141845703,
 5.235326766967773,
 3.729485511779785,
 4.513854503631592,
 4.693154335021973,
 4.2474751472473145,
 4.851113319396973,
 4.82398796081543]

In [14]:
dis_loss_list

[2.0547125339508057,
 1.5137582421302795,
 1.273125559091568,
 0.8834559917449951,
 0.30197034776210785,
 0.44785988703370094,
 0.6009686514735222,
 0.11387187242507935,
 0.21805571019649506,
 0.14441107772290707,
 0.17371631041169167,
 0.08864334970712662,
 0.06114707142114639,
 0.07172593660652637,
 0.21339410450309515,
 0.03332630358636379,
 0.07110896334052086,
 0.09051743149757385,
 0.030964283272624016,
 0.07766655273735523,
 0.017642113380134106,
 0.021891257259994745,
 0.1025232058018446,
 0.04282413795590401,
 0.02709413692355156]