In [14]:
import torch.nn as nn
import torch.nn.functional as F

class CustomFeedForwardNN(nn.Module):

  def __init__(self, input_size, num_classes, hidden_dims, dropout, activation_fn):
    super().__init__()

    # Ensure that hidden_dims is a non-empty list
    assert isinstance(hidden_dims, list) and len(hidden_dims) > 0

    # Initialize a ModuleList to store the hidden layers
    self.hidden_layers = nn.ModuleList()

    # Input layer to first hidden layer
    self.hidden_layers.append(nn.Linear(input_size, hidden_dims[0]))

    # Subsequent hidden layers
    for i in range(1, len(hidden_dims)):
      self.hidden_layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))

    # Set up the nonlinearity to use between layers.
    self.nonlinearity = activation_fn

    # Set up the dropout layer.
    self.dropout = nn.Dropout(dropout)

    # Set up the final transform to a distribution over classes.
    self.output_projection = nn.Linear(hidden_dims[-1], num_classes)



  def forward(self, x):
    
    # Apply the hidden layers, nonlinearity, and dropout.
    for hidden_layer in self.hidden_layers:
      x = hidden_layer(x)
      x = self.nonlinearity(x)
      x = self.dropout(x)
      
    # Output logits
    out = self.output_projection(x)

    return out

In [15]:
import torch.nn as nn

# Define the model parameters
input_size = 28 * 28  # For 28x28 pixel images
hidden_dims = [512, 256, 128]
num_classes = 10  # Number of output classes in FashionMNIST
dropout_rate = 0.2
activation_fn = nn.ReLU()  # Example activation function

# Instantiate the model
model = CustomFeedForwardNN(input_size, num_classes, hidden_dims, dropout_rate, activation_fn)


In [16]:
from torchinfo import summary
summary(model, input_size=(1,input_size))

Layer (type:depth-idx)                   Output Shape              Param #
CustomFeedForwardNN                      [1, 10]                   --
├─ModuleList: 1-7                        --                        (recursive)
│    └─Linear: 2-1                       [1, 512]                  401,920
├─ReLU: 1-2                              [1, 512]                  --
├─Dropout: 1-3                           [1, 512]                  --
├─ModuleList: 1-7                        --                        (recursive)
│    └─Linear: 2-2                       [1, 256]                  131,328
├─ReLU: 1-5                              [1, 256]                  --
├─Dropout: 1-6                           [1, 256]                  --
├─ModuleList: 1-7                        --                        (recursive)
│    └─Linear: 2-3                       [1, 128]                  32,896
├─ReLU: 1-8                              [1, 128]                  --
├─Dropout: 1-9                           [1,

In [11]:
from torchvision.transforms import transforms 
from torchvision.datasets import FashionMNIST

### Loading MINST data

# Define Transformation
transform = transforms.ToTensor()

train_dataset = FashionMNIST(root='./torchvision-data',
                             train=True,
                             transform=transforms,
                             download=True)

test_dataset = FashionMNIST(root='./torchvision-data', 
                            train=False,
                            transform=transforms,
                            download=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./torchvision-data\FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:11<00:00, 2392749.71it/s]


Extracting ./torchvision-data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ./torchvision-data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./torchvision-data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 120976.14it/s]


Extracting ./torchvision-data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ./torchvision-data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./torchvision-data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:03<00:00, 1458406.01it/s]


Extracting ./torchvision-data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ./torchvision-data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./torchvision-data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<?, ?it/s]

Extracting ./torchvision-data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ./torchvision-data\FashionMNIST\raw




