In [1]:
import torch
import torch.nn as  nn
from models.involution import * 
from models.unet import * 


In [2]:
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # Normalize across the channel dimension
        return F.normalize(x, dim=1) * self.scale * self.gamma.view(1, -1, 1, 1)


In [3]:
class BottleNeck_Long(nn.Module):
    def __init__(self, in_channels, out_channels, inv_kernel = 7):
        super(BottleNeck_Long, self).__init__()
        self.invblock = nn.Sequential(
            Involution_CUDA(in_channels, kernel_size = inv_kernel, stride = 1),
            RMSNorm(in_channels),
            nn.GELU()
        )

        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels = in_channels,
                      out_channels = out_channels,
                      kernel_size = 1,
                      stride = 1,
                      padding = 0,
                      bias = False),
            RMSNorm(out_channels),
        )   

        self.mapping=nn.Sequential(
            nn.Conv2d(in_channels = in_channels,
                      out_channels = out_channels,
                      kernel_size = 1,
                      padding = 0),
            RMSNorm(out_channels))
        
        self.gelu = nn.GELU()

    def forward(self, x):
        x1 = self.invblock(x)
        x1 = self.conv_block(x1)

        if x.shape[1] != x1.shape[1]:
            x = self.mapping(x)

        return self.gelu(x1 + x)


In [13]:
class UNET_Long(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 1, 
                 features = [64, 128, 256, 512], device = 'cpu'):
        super().__init__()
        self.in_channels = in_channels
        
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        self.warmup_conv = nn.Conv2d(in_channels = in_channels, 
                                     out_channels = features[0],
                                     kernel_size = 3,
                                     padding = 1)
        
        self.final_layer = nn.Conv2d(in_channels = features[0],
                                     out_channels = out_channels,
                                     kernel_size = 1)
        
        self.bottle_neck = BottleNeck_Long(in_channels = features[-1],
                                      out_channels = features[-1] * 2)

        in_channels_temp = features[0]

        # down
        for i in range(len(features)):
            in_channels_temp = features[i]
            self.downs.append(
                BottleNeck_Long(in_channels_temp,
                           out_channels = in_channels_temp))
            
            if i < len(features) - 1:
                # Instead of using AvgPool2D
                self.downs.append(nn.Conv2d(in_channels = in_channels_temp,
                                        out_channels = features[i + 1],
                                        kernel_size = 1,
                                        stride = 2,
                                        padding = 1))
                


        # ups
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(
                in_channels = feature * 2,
                out_channels = feature, 
                kernel_size = 2, stride = 2))
            self.ups.append(BottleNeck_Long(in_channels = feature * 2, 
                                       out_channels = feature))

        self.device = device

        # print(self.downs)
        
    def forward(self,x):
        skip_connections=[]
        print("Encoder")
        x = self.warmup_conv(x)
        print(x.shape)

    
        for i in range(0, len(self.downs)):
            x = self.downs[i](x)

            if i % 2 == 0:
                skip_connections.append(x)
                print("Involution Block:", x.shape)
            else:
                print("Downsample:", x.shape)


            
        x = self.bottle_neck(x)
        skip_connections = skip_connections[::-1]
        print(f"After Bottleneck: ", x.shape)

        for i in range(0, len(self.ups), 2):

            x = self.ups[i](x)
            print(f"Upsample: ", x.shape)
            skip_connection = skip_connections[i//2]
            if x.shape != skip_connection.shape:
                x = TF.resize(x, skip_connection.shape[2::], antialias=True)

            concat_skip = torch.cat((skip_connection,x),dim=1)
            x = self.ups[i+1](concat_skip)
            print(f"Involution Block: ", x.shape)

        return self.final_layer(x)


In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet_long = UNET_Long(in_channels = 3, out_channels = 3,
                      features = [32, 64, 128, 256], device = device)

unet_long.to(device)

# print(unet_long)
x = torch.randn(1, 3, 256, 256).to(device)
unet_long(x)

torch.Size([1, 32, 256, 256])
Involution Block: torch.Size([1, 32, 256, 256])
Downsample: torch.Size([1, 64, 129, 129])
Involution Block: torch.Size([1, 64, 129, 129])
Downsample: torch.Size([1, 128, 66, 66])
Involution Block: torch.Size([1, 128, 66, 66])
Downsample: torch.Size([1, 256, 34, 34])
Involution Block: torch.Size([1, 256, 34, 34])
After Bottleneck:  torch.Size([1, 512, 34, 34])


tensor([[[[-0.3114, -0.1726, -0.0837,  ..., -0.2129,  0.6130,  0.2561],
          [ 0.1200,  0.4699, -0.3126,  ...,  0.0853, -0.1210, -0.2007],
          [ 0.4664, -0.3636,  0.1720,  ..., -0.1691,  0.3157, -0.1408],
          ...,
          [ 0.0162, -0.3941, -0.0520,  ..., -0.3603, -0.2271, -0.6019],
          [-0.2806, -0.8263, -0.5221,  ...,  0.0626, -0.1143, -0.6053],
          [ 0.3929,  0.3507,  0.5398,  ..., -0.9595, -0.4424,  0.4979]],

         [[-0.3466, -0.6634, -0.6091,  ...,  0.7380, -0.6423, -0.5715],
          [-0.8005, -1.0138, -0.8860,  ..., -0.3679, -0.1191, -0.0067],
          [-0.4197,  0.2794, -0.2652,  ..., -0.0132,  0.2775, -0.4289],
          ...,
          [-0.5381, -0.8163, -0.2908,  ..., -0.4168, -0.2932, -0.1758],
          [ 0.0974, -0.3481,  0.2250,  ..., -0.4825, -0.5164,  0.2487],
          [-0.2580, -0.2641, -0.2939,  ...,  0.2769,  0.4559, -0.3663]],

         [[ 0.5749, -0.6137,  0.0210,  ..., -0.5880,  0.1212, -0.2306],
          [-0.1594,  0.3680,  