In [1]:
import torch
from torchvision.models import resnet18
import numpy

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class SharedEncoder(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

        # rgb image to encoding with 64 channela
        self.conv2d = torch.nn.Conv2d(8, 64, 7, padding=3)

        # downsampling
        self.avg_pooling = torch.nn.AvgPool2d(64, 128, 3, padding=1, stride=2)
        self.max_pooling = torch.nn.MaxPool2d(128, 256, 3, padding=1, stride=2)

        # resnet encoding
        # input: 128 x 128 x 256
        self.resnet = resnet18()
        self.resnet.fc = torch.nn.Identity()    # remove the last fc layer

    def forward(self, x_img: torch.Tensor, x_smpl: torch.Tensor):
        """
        @param x_img: 4(C) x 512(H) x 512(W) rgba image
        @param x_smpl: 4(N) x 512(W) x 512(H) peel map
        """

        # concat
        x = torch.concat(x_img, x_smpl, dim=0)
        x = self.conv2d(x)
        x = self.avg_pooling(x)
        x = self.max_pooling(x)
        x = self.resnet(x)

        return x


In [None]:
class Decoder(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()