# DeepART Dev Notebook

In [4]:
import torch

import logging
from abc import ABC, abstractmethod
from typing import List, Tuple, Dict, Any

In [None]:
# model: torch.nn.Module
class LocalUpdate(ABC):

    def __init__(self):
        self.logger = logging.getLogger(
            f"{__name__}-{self.__class__.__name__}"
        )

    @abstractmethod
    def update(
        self,
        x: torch.Tensor,
        w: torch.Tensor,
    ):
        pass

In [None]:
class Oja(LocalUpdate):

    def __init__(self,
        eta=0.1
    ):
        super().__init__()
        self.eta = eta

    def update(self,
        x: torch.Tensor,
        w: torch.Tensor,
    ):
        # Allocate weight update for each sample
        d_ws = torch.zeros(x.size(0), *w.shape)
        for idx, x in enumerate(x):
            # Allocate weight update
            d_w = torch.zeros(w.shape)
            # Compute the outputs
            y = torch.matmul(w, x.unsqueeze(1))
            # Iterate over the outputs
            for i in range(y.shape[0]):
                for j in range(x.shape[0]):
                    # Oja's rule
                    d_w[i, j] = self.eta * y[i] * (x[j] - y[i] * w[i, j])
            d_ws[idx] = d_w

        return torch.mean(d_ws, dim=0)