## If 7 kernel

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the custom padding function for periodic padding in the longitude direction
def periodic_padding(input_tensor, padding_size):
    # padding_size is the number of pixels you want to pad on each side
    # For left-right padding (longitude), we pad the width (dimension 3 in PyTorch)
    
    # Apply periodic padding along the width (longitude)
    left_pad = input_tensor[:, :, :, -padding_size:]  # Take the last 'padding_size' elements for left padding
    right_pad = input_tensor[:, :, :, :padding_size]  # Take the first 'padding_size' elements for right padding

    # Concatenate left and right padding with the original tensor
    padded_tensor = torch.cat([left_pad, input_tensor, right_pad], dim=3)
    
    return padded_tensor

# Example input: (batch_size=1, channels=3, height=32, width=32)
input_tensor = torch.randn(1, 3, 32, 32)

# For a 7x7 kernel, padding of 3 is required to maintain spatial dimensions.
# Apply periodic padding for the left-right boundaries (padding=3 on each side)
padded_tensor = periodic_padding(input_tensor, padding_size=3)

# Apply zero padding (or no padding) on the top-bottom direction (height)
# In this case, we add padding of 3 pixels at the top and bottom
padded_tensor = F.pad(padded_tensor, (0, 0, 3, 3), mode='reflect', value=0)

# Define a Conv2d layer with a 7x7 kernel
conv_layer = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=7, padding=0)

# Apply the Conv2d layer to the padded input tensor
output_tensor = conv_layer(padded_tensor)

# Print the shape of the output tensor
output_tensor.shape


torch.Size([1, 6, 32, 32])

## if 3 Kernel

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the custom padding function for periodic padding in the longitude direction
def periodic_padding(input_tensor, padding_size):
    # padding_size is the number of pixels you want to pad on each side
    # For left-right padding (longitude), we pad the width (dimension 3 in PyTorch)
    
    # Apply periodic padding along the width (longitude)
    left_pad = input_tensor[:, :, :, -padding_size:]  # Take the last 'padding_size' elements for left padding
    right_pad = input_tensor[:, :, :, :padding_size]  # Take the first 'padding_size' elements for right padding

    # Concatenate left and right padding with the original tensor
    padded_tensor = torch.cat([left_pad, input_tensor, right_pad], dim=3)
    
    return padded_tensor

# Example input: (batch_size=1, channels=3, height=32, width=32)
input_tensor = torch.randn(1, 3, 32, 32)

# For a 7x7 kernel, padding of 3 is required to maintain spatial dimensions.
# Apply periodic padding for the left-right boundaries (padding=3 on each side)
padded_tensor = periodic_padding(input_tensor, padding_size=1)

# Apply zero padding (or no padding) on the top-bottom direction (height)
# In this case, we add padding of 3 pixels at the top and bottom
padded_tensor = F.pad(padded_tensor, (0, 0, 1, 1), mode='reflect', value=0)

# Define a Conv2d layer with a 7x7 kernel
conv_layer = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, padding=0)

# Apply the Conv2d layer to the padded input tensor
output_tensor = conv_layer(padded_tensor)

# Print the shape of the output tensor
output_tensor.shape


torch.Size([1, 6, 32, 32])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):
    def __init__(self, input_channels, init_dim):
        super(MyModel, self).__init__()
        
        # Define Conv2d without internal padding (since we'll apply custom padding)
        self.init_conv = nn.Conv2d(input_channels, init_dim, kernel_size=7, padding=0)
    
    def periodic_padding(self, x, padding_size):
        # Apply periodic padding (left-right)
        left_pad = x[:, :, :, -padding_size:]  # Last 'padding_size' columns
        right_pad = x[:, :, :, :padding_size]  # First 'padding_size' columns
        x = torch.cat([left_pad, x, right_pad], dim=3)
        
        # Apply zero or constant padding (top-bottom)
        x = F.pad(x, (0, 0, padding_size, padding_size), mode='constant', value=0)
        return x

    def forward(self, x):
        # Apply custom periodic padding before the convolution
        x = self.periodic_padding(x, padding_size=3)
        x = self.init_conv(x)  # Apply the convolution
        return x

# Example usage:
input_tensor = torch.randn(1, 3, 32, 32)  # Batch size = 1, channels = 3, height/width = 32x32
model = MyModel(input_channels=3, init_dim=64)
output = model(input_tensor)

output.shape


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange

# Custom function for periodic padding (left-right periodic)
def periodic_padding(input_tensor, padding_size):
    left_pad = input_tensor[:, :, :, -padding_size:]  # Last 'padding_size' columns
    right_pad = input_tensor[:, :, :, :padding_size]  # First 'padding_size' columns

    # Concatenate left and right periodic padding
    padded_tensor = torch.cat([left_pad, input_tensor, right_pad], dim=3)
    return padded_tensor

# Upsample block with periodic padding applied before Conv2d
def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),  # Upsample
        nn.Conv2d(dim, default(dim_out, dim), kernel_size=3, padding=0),  # No padding here
        nn.Lambda(lambda x: periodic_padding(x, padding_size=1))  # Apply custom periodic padding
    )

# Downsample block with periodic padding applied before Conv2d
def Downsample(dim, dim_out=None):
    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2),  # Reshape for downsampling
        nn.Conv2d(dim * 4, default(dim_out, dim), kernel_size=1),  # No padding here
        nn.Lambda(lambda x: periodic_padding(x, padding_size=1))  # Apply custom periodic padding
    )
