In [1]:
import torch
from pytorch_lightning import LightningModule
from torch import nn

In [100]:
class DenseNet(LightningModule):
    
    def __init__(self, n_hidden_layers, hyper="B"):
        super().__init__()
        self.n_hidden_layers = n_hidden_layers
        hidden_layers = []
        for _ in range(n_hidden_layers):
            hidden_layers.extend([
                nn.Linear(in_features=512, out_features=512),
                nn.ReLU()
            ])
        self.net = nn.Sequential(
            *hidden_layers,
            nn.Linear(in_features=512, out_features=1)
        )
        self.save_hyperparameters()
    def forward(self, inputs):
        self.n_hidden_layers = 4
        return self.net(inputs)
    
    def train_step(self, batch):
        inputs, targets = batch
        outputs = self(inputs).reshape(-1)
        loss_fn = nn.BCEWithLogitsLoss()
        loss = loss_fn(outputs, targets)
        return loss

In [101]:
model = DenseNet(5)

In [102]:
inputs = torch.rand(16, 512)
targets = torch.randint(0, 2, (16, )).float() # binary targets
batch = (inputs, targets)

In [103]:
model(inputs)

tensor([[-0.0279],
        [-0.0296],
        [-0.0292],
        [-0.0257],
        [-0.0268],
        [-0.0282],
        [-0.0263],
        [-0.0266],
        [-0.0274],
        [-0.0259],
        [-0.0301],
        [-0.0279],
        [-0.0283],
        [-0.0268],
        [-0.0269],
        [-0.0273]], grad_fn=<AddmmBackward>)

In [104]:
model.train_step(batch)

tensor(0.6864, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

In [105]:
model.hparams

"hyper":           B
"n_hidden_layers": 5

In [156]:
import torch 
from torch import nn

class DenseLayer(nn.Module):
    """Simple fully connected dense layer."""
    def __init__(self, in_shape, out_shape):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(in_shape, out_shape), requires_grad=True)
    def forward(self, inputs):
        return self.weights.T.unsqueeze(0) @ inputs

dense = DenseLayer(512, 16)
inputs = torch.randn(8, 512)  # Batchsize 8
ouputs = dense(inputs)
print(outputs.size())
print(list(dense.parameters()))

RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x512 and 8x512)

In [169]:
trans = torch.randn(512, 16)
input = torch.randn(8, 512)
out = input.unsqueeze(2) * trans
out.shape

torch.Size([8, 512, 16])

In [162]:
input.unsqueeze(1).shape

torch.Size([8, 1, 512])

In [145]:
list(dense.parameters())

[Parameter containing:
 tensor([[ 0.3942,  1.0553,  0.5875,  ...,  0.1208, -0.1897, -0.3418],
         [-1.3035,  0.5273,  0.1739,  ..., -1.3229,  0.5408, -0.2280],
         [ 0.1097,  0.2074, -1.5073,  ..., -0.8526, -2.5680, -1.6476],
         ...,
         [-0.5353,  2.3021,  0.3930,  ...,  0.8187, -1.1952, -0.9587],
         [ 0.7563,  0.7897,  1.9017,  ..., -1.0961,  2.1129, -0.1100],
         [-1.4382, -1.6479, -1.8615,  ...,  0.9697,  0.0253,  1.4471]],
        requires_grad=True)]