In [119]:
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader

In [120]:
class ResLinBlock(nn.Module):
    def __init__(self, in_features, out_features, expansion_factor, identity_downsample= None) -> None:
        """
        Residual Linear Block for neural networks.

        Args:
            in_features (int): Number of input features.
            out_features (int): Number of output features.
            expansion_factor (int): Expansion factor for output features.
            identity_downsample (nn.Module, optional): Module for downsampling the identity (default: None).

        Attributes:
            in_features (int): Number of input features.
            out_features (int): Number of output features.
            expansion_factor (int): Expansion factor for output features.
            identity_downsample (nn.Module): Module for downsampling the identity.
            relu (ReLU): ReLU activation function.
            lin_block (Sequential): Linear block consisting of multiple linear layers and batch normalization.
            mapping_layer (Linear): Linear layer for feature mapping.
        """
        super(ResLinBlock, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.expansion_factor = expansion_factor
        self.identity_downsample = identity_downsample
        self.relu = nn.ReLU()
        self.lin_block = nn.Sequential(
            # 1
            nn.Linear(in_features, out_features),
            nn.BatchNorm1d(out_features),
            nn.ReLU(),
            
            #2
            nn.Linear(out_features,out_features),
            nn.BatchNorm1d(out_features),
            nn.ReLU(),
            
            #3
            nn.Linear(out_features, out_features*self.expansion_factor),
            nn.BatchNorm1d(out_features*self.expansion_factor),
            )
        self.mapping_layer = nn.Linear(out_features*self.expansion_factor, out_features)
    
    def forward(self, x ):
        """
        Forward pass of the Residual Linear Block.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        identity = x.clone()
        x = self.lin_block(x)
        
        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)
  
        x += identity
        x = self.relu(x)
        if self.in_features!=self.out_features:
            x = self.relu(self.mapping_layer(x))
            
            
        return x

In [121]:
in_channels = 3
out_channels = 3
expansion_factor = 4
identity_downsample = nn.Sequential(nn.Linear(in_channels, out_channels*expansion_factor)
                                                )
def test():
    """
    Test function for ResLinBlock.

    This function creates an instance of ResLinBlock and tests it with random input data.

    Parameters:
    None

    Returns:
    None
    """
    net = ResLinBlock(in_channels, out_channels, expansion_factor, identity_downsample= identity_downsample)
    x = torch.randn(2,in_channels)
    y = net(x)
    print(y.shape)

test()

torch.Size([2, 12])


In [122]:
class ResLinNet(nn.Module):
    def __init__(self, in_features = 150, depths = [64, 64, 128, 128, 256], expansion_factor = 4, outputs = 1) -> None:
        """
        Residual Linear Network constructor.

        Args:
        - in_features (int): Number of input features.
        - depths (list): List of integers representing the number of features in each layer.
        - expansion_factor (int): Expansion factor for linear layers.
        - outputs (int): Number of output features.

        Returns:
        None
        """
        super(ResLinNet, self).__init__()
        self.depths = depths
        self.expansion_factor = expansion_factor
        self.in_features = in_features
        self.in_layer = nn.Linear(in_features, self.depths[0])
        self.reslayers = nn.ModuleList([self.create_block(depths[i], depths[i+1]) for i in range(len(depths)-1)])
        self.relu = nn.ReLU()
        self.out_layer = nn.Linear(512, outputs)
    def create_block(self, in_features, out_features):
        """
        Create a Residual Linear Block.

        Args:
        - in_features (int): Number of input features.
        - out_features (int): Number of output features.

        Returns:
        - block (ResLinBlock): Residual Linear Block instance.
        """
        if in_features==out_features:
            identity_downsample = None 
            layer_expansion_factor = 1  
        else:
            identity_downsample = nn.Sequential(nn.Linear(in_features, out_features*self.expansion_factor ))
            layer_expansion_factor = self.expansion_factor

            
        return ResLinBlock(in_features, out_features, expansion_factor=layer_expansion_factor, identity_downsample=identity_downsample)
    
    def forward(self, x):
        """
        Forward pass through the Residual Linear Network.

        Args:
        - x (Tensor): Input tensor.

        Returns:
        - out (Tensor): Output tensor.
        """
        x = self.relu(self.in_layer(x))
        for i, layer in enumerate(self.reslayers):
            x = layer(x)
            
        
        flattened_tensor = torch.flatten(x)
        out = self.relu(self.out_layer(flattened_tensor))
        return out
        

In [123]:
def test():
    net = ResLinNet()
    x = torch.randn(2,150)
    y = net(x)
    print(y.shape)

test()

torch.Size([1])
