In [12]:
import torch
import torch.nn as nn

class SimpleGenerator(nn.Module):
    def __init__(self):
        super(SimpleGenerator, self).__init__()
        # (A) 3 kênh input, 1 kênh output
        self.convT = nn.ConvTranspose2d(
            in_channels=3,
            out_channels=1,
            kernel_size=2,
            stride=2,
            padding=0,
            bias=False
        )
        self.init_weights()

    def init_weights(self):
        # Weight shape của convT sẽ là [in_channels, out_channels, kH, kW] = [3, 1, 2, 2]
        with torch.no_grad():
            w = torch.zeros((3, 1, 2, 2))
            # Gán kernel cho input channel 0
            w[0, 0] = torch.tensor([[0.29, -0.39],
                                    [-0.25, 0.15]])
            # Gán kernel cho input channel 1
            w[1, 0] = torch.tensor([[0.11, -0.13],
                                    [0.30, 0.34]])
            # Gán kernel cho input channel 2
            w[2, 0] = torch.tensor([[-0.36, -0.27],
                                    [0.46, -0.17]])
            self.convT.weight.copy_(w)

    def forward(self, x):
        return self.convT(x)


# Tạo noise đầu vào z có shape (1, 3, 1, 1): 3 kênh đầu vào.
z = torch.tensor([[[[0.34]], [[0.13]], [[0.23]]]], dtype=torch.float)

gen = SimpleGenerator()
out = gen(z)
print("z.shape =", z.shape)        # mong đợi (1,3,1,1)
print("out.shape =", out.shape)    # mong đợi (1,1,2,2)
print("Output tensor:\n", out)

z.shape = torch.Size([1, 3, 1, 1])
out.shape = torch.Size([1, 1, 2, 2])
Output tensor:
 tensor([[[[ 0.0301, -0.2116],
          [ 0.0598,  0.0561]]]], grad_fn=<ConvolutionBackward0>)
