In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

train_data = datasets.MNIST(root='mnist_data',
                            train=True,
                            download=True,
                            transform=ToTensor(),
                            target_transform=None)

test_data = datasets.MNIST(root='mnist_data',
                           train=False,
                           download=True,
                           transform=ToTensor(),
                           target_transform=None)

BATCH_SIZE = 64

train_dataloader = DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

test_dataloader = DataLoader(
    dataset=test_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
)


In [None]:
import torch
from torch import nn

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)


class HarmonicLayer(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.w_matrix = nn.Parameter(torch.randn(out_features, in_features))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        difference = self.w_matrix.unsqueeze(0) - x.unsqueeze(1)  # shape is now (batch_size, out_features, in_features)
        squared_difference = difference ** 2
        euclidean_distance = torch.sum(squared_difference, dim=2)  # creates (batch_size, out_features) shape
        # print(euclidean_distance.size())
        return euclidean_distance

In [None]:
from torch import nn


class HarmonicLoss(nn.Module):
    def __init__(self, intrinsic_dimensionality: int, harmonic_exponent: float):
        super().__init__()
        self.n = intrinsic_dimensionality ** harmonic_exponent

    def forward(self, x: torch.Tensor, correct_index=torch.Tensor) -> torch.Tensor:
        x = 1 / (x ** self.n)
        x = x / torch.sum(x, dim=1, keepdim=True)
        p_true = x.gather(dim=1, index=correct_index.long()).squeeze(1)
        # print(p_true.size())
        return -torch.log(p_true).mean()



In [None]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.layer1 = nn.Sequential(
            nn.Flatten(),
            HarmonicLayer(in_features=784, out_features=10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layer1(x)

model = Model().to(device=device)

In [None]:
import torch
from torch import nn
import matplotlib.pyplot as plt
from timeit import default_timer as timer

EPOCHS = 10
LEARNING_RATE = 0.001
UPDATE_PARAMETERS_EVERY = 1

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = HarmonicLoss(intrinsic_dimensionality=784, harmonic_exponent=0.5)

train_loss_history = []
test_loss_history = []
train_acc_history = []
test_acc_history = []
current_batch = 0

for epoch in range(EPOCHS):

    print(f"E {epoch + 1:,}/{EPOCHS:,} ")

    for train_batch_index, (x_train, y_train) in enumerate(train_dataloader):
        model.train()
        optimizer.zero_grad()
        x_train, y_train = x_train.to(device=device), y_train.to(device=device)
        # print(f"x shape: {x_train.size()}")
        # print(f"y shape: {y_train.size()}")
        train_outputs = model(x_train)
        # print(f"outputs shape: {train_outputs.size()}")
        loss = loss_fn(train_outputs, y_train.unsqueeze(1))
        # print(y_train.unsqueeze(1))
        train_loss_history.append(loss.item())
        current_batch += 1
        print(f"Batch {train_batch_index + 1:,}/{len(train_dataloader):,} | Loss: {loss.item():,}")
        loss.backward()
        if current_batch == UPDATE_PARAMETERS_EVERY:
            optimizer.step()
            current_batch = 0
        # optimizer.step()

    with torch.inference_mode():
        model.eval()
        for test_batch_index, (x_test, y_test) in enumerate(test_dataloader):
            x_test, y_test = x_test.to(device=device), y_test.to(device=device)
            test_outputs = model(x_test)
            test_loss = loss_fn(test_outputs, y_test.unsqueeze(1))
            test_loss_history.append(test_loss.item())
            print(f"Test batch {test_batch_index + 1:,}/{len(test_dataloader):,} | Loss: {test_loss.item():,}")