In [7]:
import time 
b = time.time()
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
import torch.jit as jit 
import numpy as np 
import matplotlib.pyplot as plt
from tqdm import tqdm
a = time.time()
print(f'Imports complete in {a-b} seconds.')

out_image_shape = (3, 160, 160) # Channels - Height - Width
n_latent_vars = 100

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.input = nn.Linear(n_latent_vars, 20*20*30)
        self.convT1 = nn.ConvTranspose2d(30, 20, kernel_size = (2,2), stride=(2,2), padding=0)
        self.convT2 = nn.ConvTranspose2d(20, 10, kernel_size = (2,2), stride=(2,2), padding=0)
        self.convT3 = nn.ConvTranspose2d(10, out_image_shape[0], kernel_size = (2,2), stride=(2,2), padding=0)
        print('Generator created.')
        self.init_weights()

    def forward(self, x):
        x = F.relu(self.input(x))
        #print(x.shape)
        x = x.view(1,30,20,20)
        #print(x.shape)
        x = F.relu(self.convT1(x))
        #print('1=>', x.shape)
        x = F.relu(self.convT2(x))
        #print('2=>',x.shape)
        x = torch.sigmoid(self.convT3(x))
        #print('3=>',x.shape)
        return x

    def init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.ConvTranspose2d):
                nn.init.kaiming_uniform_(module.weight)
            elif isinstance(module, nn.Linear):
                nn.init.kaiming_uniform_(module.weight)
            else:
                pass
        print('Generator weights initialized as per kaiming uniform criterion.')
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, 3)
        self.conv2 = nn.Conv2d(10, 100, 5)
        self.conv3 = nn.Conv2d(100, 1, 5)
        self.lin1 = nn.Linear(22500, 500)
        self.out = nn.Linear(500, 1)
        print(f'Discriminator created.')
        self.init_weights()

    def get_flattened(self, shape_):
        prod = 1
        for element in shape_:
            prod = prod*element
        return prod
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        #print(f'1=> {x.shape}')
        x = F.relu(self.conv2(x))
        #print(f'2=> {x.shape}')
        x = F.relu(self.conv3(x))
        #print(f'3=> {x.shape}')
        x = x.view(-1,22500)
        x = F.relu(self.lin1(x))
        x = torch.sigmoid(self.out(x))
        return x

    def init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_uniform_(module.weight)
            elif isinstance(module, nn.Linear):
                nn.init.kaiming_uniform_(module.weight)
            else:
                pass
        print('Discriminator weights initialized as per kaiming uniform criterion.')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
cpu = torch.device('cpu')
print(f'Running on {device}')
gen_tracer = torch.rand(100).to(device)
dis_tracer = torch.rand(1, *out_image_shape).to(device)
gen = Generator().to(device)
dis = Discriminator().to(device)
traced_gen = jit.trace(gen, gen_tracer).to(device)
traced_dis = jit.trace(dis, dis_tracer).to(device)

n = 4000
b = time.time()
for _ in tqdm(range(n)):
    gen_out = gen(gen_tracer)
    dis_out = dis(gen_out)
a = time.time()
eagertime = a-b

b = time.time()
for _ in tqdm(range(n)):
    gen_out = traced_gen(gen_tracer)
    dis_out = traced_dis(gen_out)
a = time.time()
statictime = a-b

print(f'Static vs Eager execution time = {statictime/eagertime}')

Imports complete in 0.00013113021850585938 seconds.
Running on cuda:0
Generator created.
Generator weights initialized as per kaiming uniform criterion.
Discriminator created.
Discriminator weights initialized as per kaiming uniform criterion.


100%|██████████| 4000/4000 [00:25<00:00, 155.65it/s]
100%|██████████| 4000/4000 [00:25<00:00, 154.71it/s]

Static vs Eager execution time = 1.0061345510181274



