# Temporal U-Net with Squeezeformer Blocks:

In [2]:
import torch
import torch.nn as nn

class SqueezeformerBlock(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.mha = nn.MultiheadAttention(dim, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.scale = nn.Parameter(torch.ones(1))

    def forward(self, x):
        # Multi-head attention
        x = x + self.mha(x, x, x)[0]
        x = self.norm1(x) * self.scale
        
        # Feed-forward
        x = x + self.ff(x)
        x = self.norm2(x) * self.scale
        return x

class Squeezeformer(nn.Module):
    def __init__(self, num_blocks, dim, num_heads):
        super().__init__()
        self.blocks = nn.ModuleList([
            SqueezeformerBlock(dim, num_heads) for _ in range(num_blocks)
        ])
        self.downsample = nn.Conv1d(dim, dim, kernel_size=3, stride=2, padding=1, groups=dim)
        self.upsample = nn.ConvTranspose1d(dim, dim, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        skip = x
        for block in self.blocks[:len(self.blocks)//2]:
            x = block(x)
        x = self.downsample(x.transpose(1, 2)).transpose(1, 2)
        for block in self.blocks[len(self.blocks)//2:]:
            x = block(x)
        x = self.upsample(x.transpose(1, 2)).transpose(1, 2)
        x = x + skip
        return x

# Example usage
if __name__ == "__main__":
    # Define the input dimensions and model hyperparameters
    input_dim = 128
    sequence_length = 100
    num_blocks = 4
    num_heads = 4

    # Create a random input tensor
    input_tensor = torch.randn(1, sequence_length, input_dim)

    # Initialize the Squeezeformer model
    model = Squeezeformer(num_blocks, input_dim, num_heads)

    # Pass the input through the model
    output_tensor = model(input_tensor)

    # Print the output shape
    print("Output shape:", output_tensor.shape)

Output shape: torch.Size([1, 100, 128])


# Depthwise Separable Convolution Subsampling:


In [3]:
import torch
import torch.nn as nn

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()
        self.depthwise = nn.Conv1d(in_channels, in_channels, kernel_size, stride, padding=kernel_size//2, groups=in_channels)
        self.pointwise = nn.Conv1d(in_channels, out_channels, kernel_size=1)
    
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class SubsamplingBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.conv2 = DepthwiseSeparableConv(out_channels, out_channels, kernel_size=3, stride=2)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

# Example usage
if __name__ == "__main__":
    input_dim = 128
    sequence_length = 100
    
    input_tensor = torch.randn(1, input_dim, sequence_length)
    model = SubsamplingBlock(input_dim, 256)
    output_tensor = model(input_tensor)
    print("Output shape:", output_tensor.shape)

Output shape: torch.Size([1, 256, 25])


# 

Unified Activations with Squeezeformer Block: