In [1]:
# 导入相关包
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time

In [2]:
class G_model(nn.Module):
    def __init__(self):
        super(G_model,self).__init__()
        self.main = nn.Sequential(
            # input: 100 x 1 x 
            # !!! stride and padding
            # output: 512 x 4 x 4
            nn.ConvTranspose2d(in_channels=100,out_channels=512,kernel_size=4,stride=1,padding=0,bias=False),
            # BN
            nn.BatchNorm2d(num_features=512),
            # Relu
            nn.ReLU(inplace=True),
            
            
            # input: 512 x 4 x4 
            # output : 256 x 8 x 8
            nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=4,stride=2,padding=1,bias=False),
            # BN
            nn.BatchNorm2d(num_features=256),
            # Relu
            nn.ReLU(inplace=True),
            
            # input: 256 x 8 x8
            # output: 128 x 16 x 16
            nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            
            # input : 128 x 16 x 16 
            # output: 64 x 32 x 32
            nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=4,stride=2,padding=1,bias=False),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            
            # input: 64 x 32 x 32
            # output: 3 x 64 x 64
            nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=4,stride=2,padding=1,bias=False),
            # tanh
            nn.Tanh()
            
        
        )
    
    def forward(self,input):
        return self.main(input)

In [3]:
device =  torch.device("cuda:0")

In [4]:
netG = G_model().to(device)
netG.load_state_dict(torch.load('./models/g_model.pt'))
netG.eval()

G_model(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

In [5]:
def randomGenerate():
    fixed_noise = torch.randn(16, 100, 1, 1, device=device)
    fake_imgs = netG(fixed_noise).detach().cpu().numpy()
    fig = plt.figure(figsize=(10, 10))
    for i in range(fake_imgs.shape[0]):
        plt.subplot(4, 4, i+1)
        img =  np.transpose(fake_imgs[i],(1,2,0))
        img =(img+1 )/ 2 * 255
        img = img.astype('int')
        plt.imshow(img)
        plt.axis('off')
    plt.show()

In [6]:
from IPython.display import clear_output

In [7]:
count = 0
while True:
    randomGenerate()
    clear_output(wait=True)
    time.sleep(0.1)
    # if count > 20:
    #     break
    count+=1


KeyboardInterrupt: 