In [10]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader   
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

torch.manual_seed(48)

<torch._C.Generator at 0x7f8b440e3a70>

In [11]:
def show_tensor_images(image_tensor, num_images = 25, size = (1,28,28)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:25], nrow = 5)
    plt.imshow(image_grid.permute(1,2,0).squeeze())
    plt.show()

In [14]:
def get_generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace = True)
    )

In [15]:
def test_gen_block(in_features, out_features, num_test = 1000):
    block = get_generator_block(in_features, out_features)
    assert len(block) ==3
    assert type(block[0]) == nn.Linear
    assert type(block[1]) == nn.BatchNorm1d
    assert type(block[2]) == nn.ReLU

    test_input = torch.randn(num_test, in_features)
    test_output = block(test_input)
    assert tuple(test_output.shape) == (num_test, out_features)
    assert test_output.std() > 0.55
    assert test_output.std() < 0.65

test_gen_block(25, 12)
test_gen_block(15, 28)
print("Success!")

Success!
