In [228]:
# SVDrop: A method to drop out singular vectors from a representation learned by a neural network

# Method consists of:

# 1. Calculate SVD of features from DNN
# 2. Premultiply classifier by the following V^-1 DropoutVector V
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.parameter import Parameter, UninitializedParameter
import torch.nn.init as init
import math
from torch.nn import functional as F
from random import randint

class SVDropClassifier(nn.Module):
    """
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor a """

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.V = None
        self.Lambda = None

        self.mu_R = None
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        self.reset_mask()
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        init.uniform_(self.bias, -bound, bound)
    
    def set_singular(self, V: Tensor, mu: Tensor) -> None:
        self.V = V                           # Right singular vectors of R
        self.s_values = Lambda               # List of singular values of R
        self.V_inv = torch.linalg.pinv(V)    # Pseudoinverse of V
        self.mu_R = mu
        self.Lambda = torch.diagflat(self.mask)
        
    def reset_singular(self) -> None:
        self.V = None
        self.Lambda = None
        self.V_inv = None
        self.mu_R = None
        self.s_values = None

    def dropout_dim(self, indices=None):
        if indices is None: # Randomly drop one index
            indices =  [randint(0, self.in_features)] # Choose a random dimension
        self.mask[indices] = 0
        self.Lambda = torch.diagflat(self.mask)

    def reset_mask(self):
        self.mask = torch.ones(self.in_features)
        self.Lambda = torch.diagflat(self.mask)
        
    def forward(self, input: Tensor) -> Tensor:
        if self.V is not None: # I want to remove some of my right singular directions!
            new_weights = (((self.V @ self.Lambda) @ self.V_inv) @ self.weight.T).T
            new_bias = (-self.mu_R*(new_weights - self.weight)).sum() + self.bias
            return F.linear(input, new_weights, new_bias)
        else: # I'm just a regular linear layer
            return F.linear(input, self.weight, self.bias)

    def extra_repr(self) -> str:
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'

a = SVDropClassifier(100,1)
data = torch.randn(1000,100)
print(a(data)[0:2])
S,V,D = torch.linalg.svd(data)
mu = data.mean(dim=0)
print(a.Lambda.shape)
a.set_singular(D,mu)
print(a(data)[0:2])
a.dropout_dim([0,1])
print(a(data)[0:2])
a.reset_mask()
print(a(data)[0:2])
a.dropout_dim()
print(a(data)[0:2])
print(a.mask)

tensor([[-0.2663],
        [-0.1317]], grad_fn=<SliceBackward0>)
torch.Size([100, 100])
tensor([[-0.2663],
        [-0.1317]], grad_fn=<SliceBackward0>)
tensor([[-0.1623],
        [-0.2971]], grad_fn=<SliceBackward0>)
tensor([[-0.2663],
        [-0.1317]], grad_fn=<SliceBackward0>)
tensor([[-0.2711],
        [-0.1303]], grad_fn=<SliceBackward0>)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
