In [12]:
import torch
from torch import nn
from torchvision import models

c_dim = 256
m = models.convnext_tiny(from_pretrained=True)
m.classifier = nn.Sequential()
m.fc = torch.nn.Linear(768, c_dim)

In [27]:
class ConvNeXt(nn.Module):
    def __init__(self, c_dim):
        super().__init__()
        self.m = models.convnext_tiny(from_pretrained=True)
        self.m.classifier = nn.Sequential()
        self.fc = torch.nn.Linear(768, c_dim)

    def forward(self, x):
        x = self.m(x)
        x = self.fc(x.reshape((-1, 768)))
        return x

In [43]:
class FPConvNeXt(nn.Module):
    def __init__(self, c_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 3, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(3, 3, kernel_size=4, stride=2, padding=1)

        self.up1 = nn.ConvTranspose2d(3, 3, kernel_size=4, stride=2, padding=1)
        self.up2 = nn.ConvTranspose2d(3, 3, kernel_size=4, stride=2, padding=1)

        self.path0 = nn.Conv2d(3, 3, kernel_size=1)
        self.path1 = nn.Conv2d(3, 3, kernel_size=1)

        self.enc0 = ConvNeXt(c_dim)
        self.enc1 = ConvNeXt(c_dim)
        self.enc2 = ConvNeXt(c_dim)

        self.pool = nn.AvgPool1d(3)

    def forward(self, x):
        x0 = x
        x1 = self.conv1(x0)
        x2 = self.conv2(x1)

        x1 = self.up2(x2) + self.path1(x1)
        x0 = self.up1(x1) + self.path0(x0)

        x2 = self.enc2(x2)
        x1 = self.enc1(x1)
        x0 = self.enc0(x0)

        x = torch.concat([x0.unsqueeze(2), x1.unsqueeze(2), x2.unsqueeze(2)], dim=2)
        x = self.pool(x).squeeze(2)

        return x


f = FPConvNeXt(c_dim)
f(torch.rand(1, 3, 224, 224)).shape

torch.Size([1, 256])

In [24]:
sum(p.numel() for p in m.parameters())

28015456

In [20]:
m(torch.zeros([1, 3, 224 // 4, 224 // 4])).shape

torch.Size([1, 768, 1, 1])

In [17]:
x = m.fc(m(torch.zeros([1, 3, 224 // 2, 224 // 2])).reshape(-1, 768))
x.shape

torch.Size([1, 256])