In [1]:
import math

In [2]:
math.log2(128)

7.0

In [3]:
math.log2(32)

5.0

In [5]:
math.log2(224)

7.807354922057604

In [6]:
def scaling_steps(dim_small, dim_big):
    pow_2_dif = math.log2(dim_big) - math.log2(dim_small)
    return math.ceil(pow_2_dif)


In [7]:
scaling_steps(32, 224)

3

In [32]:
224 // 2**3

28

In [10]:
from pathlib import Path

path_tensor = Path("data/CelebAMask-HQ/descriptions_embedded/0/0.pt")

In [11]:
import torch
from torch import nn
tensor = torch.load(path_tensor)

In [13]:
tensor.shape

torch.Size([384])

In [22]:
class ConditionalTextReshaping(nn.Module):
    def __init__(self, embedding_dim, target_dim, max_start_dim=16, hidden=64):
        super().__init__()
        scaling_steps = self._num_scaling_steps(max_start_dim, target_dim)
        self.start_dim = int(target_dim / 2**scaling_steps)
        self.embedding_in = nn.Linear(embedding_dim, self.start_dim*self.start_dim)
        upscaling_list = [
            nn.ConvTranspose2d(
                1,
                hidden,
                2,
                stride=2,
            )
        ]
        for i in range(scaling_steps - 1):
            upscaling_list.append(
                nn.ConvTranspose2d(
                    hidden,
                    hidden,
                    2,
                    stride=2,
                )
            )
        self.upscaling = nn.Sequential(*upscaling_list)
    
    def _num_scaling_steps(self, dim_small, dim_big):
        pow_2_dif = math.log2(dim_big) - math.log2(dim_small)
        return math.ceil(pow_2_dif)

    def forward(self, x: torch.Tensor):
        x = self.embedding_in(x)
        x = x.view(-1, self.start_dim, self.start_dim)
        x = x.unsqueeze(1)
        x = self.upscaling(x)
        return x

In [23]:
layer_test = ConditionalTextReshaping(384, 64 ,hidden=5)

In [28]:
output = layer_test(tensor.unsqueeze(0))
output.shape

torch.Size([1, 5, 64, 64])

In [29]:
tmp = tensor.unsqueeze(0)
tmp.shape

torch.Size([1, 384])

In [30]:
batch = torch.cat((tmp, tmp, tmp), 0)
batch.shape

torch.Size([3, 384])

In [31]:
batch_output = layer_test(batch)
batch_output.shape

torch.Size([3, 5, 64, 64])