In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt 
from tqdm import tqdm
import math

In [21]:
transform = transforms.ToTensor()
mnist_data = datasets.MNIST(root='./data', train=True, download = True, transform = transform)
data_loader = torch.utils.data.DataLoader(dataset=mnist_data, batch_size=64, shuffle=True)

In [24]:
data_loader.dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

In [None]:
#custom linear layer 
class CustomLinearLayer(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool=True):
        super().__init__() #calls the constructor of the parent class 
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        #wraps this tensor so that PyTorch knows it should be a learnable parameter 
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    def reset_parameters(self):
        
        fan_in = self.in_features
        fan_out = self.out_features
        limit = math.sqrt(6.0 / (fan_in + fan_out))
        with torch.no_grad():
            self.weight.uniform_(-limit, limit)
            if self.bias is not None:
                self.bias.zero_()    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        out = x.matmul(self.weight.t())
        if self.bias is not None:
            out = out + self.bias
        return out    
        

In [None]:
class CustomReLU(nn.Module):
    def __init__(self, inplace: bool = False):
        super().__init__()
        self.inplace = inplace

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.inplace:
            x[:] = torch.where(x > 0, x, torch.zeros_like(x))
            return x
        return torch.where(x > 0, x, torch.zeros_like(x))

class CustomSigmoid(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return 1.0 / (1.0 + torch.exp(-x))    