In [1]:
import numpy as np
import torch 
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.manual_seed(0)
np.random.seed(0)

In [3]:
## Data
input = torch.randn(128, 20)

In [4]:
## Linear Layer 
m = nn.Linear(20, 30)

In [64]:
# U, S, Vh = torch.linalg.svd(m.weight, full_matrices=False)

# rank = 20
# torch.dist(U[:, :rank] @ torch.diag(S[:rank]) @ Vh[:rank, :], m.weight)

tensor(2.4411e-06, grad_fn=<DistBackward0>)

In [65]:
## Low-rank Linear Layer
class LowRankLinearLayer(nn.Module):
    def __init__(self, initial_weight, rank):
        super(LowRankLinearLayer, self).__init__()
        U, S, Vh = torch.linalg.svd(initial_weight, full_matrices=False)
        self.U = nn.Parameter(U[:, :rank])
        self.S = nn.Parameter(S[:rank])
        self.Vh = nn.Parameter(Vh[:rank, :])

    def forward(self, x):
        return torch.einsum('mn, bn -> bm', self.U @ torch.diag(self.S) @ self.Vh, x)

m2 = LowRankLinearLayer(m.weight, rank=20)

In [63]:
torch.dist(m(input), m2(input))

tensor(6.5319, grad_fn=<DistBackward0>)

In [None]:
## Model Parameters
learning_rate = 1e-3
epochs = 5

## Loss Function
loss_function = nn.MSELoss()

## Optimizer 
optimizer = torch.optim.Adam(m.parameters(), lr=learning_rate)
optimizer2 = torch.optim.Adam(m2.parameters(), lr=learning_rate)

## Training Loop
def train_loop(dataloader, model, loss_function, optimizer):
    size = len(dataloader.dataset)
    for batch, (t, x) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(t)
        loss = loss_function(pred, x)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(t)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(dataloader, m, loss_function, optimizer)
print("Done!")

In [None]:
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(dataloader, m2, loss_function, optimizer)
print("Done!")