In [1]:
import optuna

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision 
from torchvision import transforms

from torch.utils.data import Dataset, DataLoader

from tqdm.autonotebook import tqdm

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow

import pandas as pd

from sklearn.metrics import accuracy_score

import time

import idlmam

In [2]:
if torch.backends.mps.is_available(): 
    device = torch.device("mps") 
else:
    device = torch.device("cpu")

In [3]:
D = 28 * 28
n = 2 
C = 1 
classes = 10 

In [4]:
class TransposeLinear(nn.Module):                    
    def __init__(self, linearLayer, bias=True):

        """ 
        linearLayer: is the layer that we want to use the transpose of to 
        ➥ produce the output of this layer. So the Linear layer represents 
        ➥ W, and this layer represents W^T. This is accomplished via 
        ➥ weight sharing by reusing the weights of linearLayer 
        bias: if True, we will create a new bias term b that is learned 
        separately from what is in 
        linearLayer. If false, we will not use any bias vector. 
        """ 

        super().__init__() 
        self.weight = linearLayer.weight                

        if bias: 
            self.bias = nn.Parameter(torch.Tensor(
                                    linearLayer.weight.shape[1]))           


        else:
            self.register_parameter('bias', None)      


    def forward(self, x):                              
        return F.linear(x, self.weight.t(), self.bias) 

In [5]:
linearLayer = nn.Linear(D, n, bias=False)

pca_encoder = nn.Sequential(
    nn.Flatten(), linearLayer,
)

pca_decoder = nn.Sequential(
    TransposeLinear(linearLayer, bias=False),
    idlmam.View((-1, 1, 28, 28)),
)

pca_model = nn.Sequential(
    pca_encoder, 
    pca_decoder
)


To make it truly PCA, we need to add the WW⊤ = I constraint, orthogonality. 

In [None]:
nn.init.orthogonal_(linearLayer.weight)