In [None]:
# Tied Linear Autoencoder for Learned Image Transform (8x8)
#
# Provided by ChatGPT after being prompted with:
#
# I would like to know if I can achieve such learned transform using
# a 3-layers autoencoder, where the weights between the input layer and
# the center layer define the forward transform, and the weights between
# the center layer and the output layer define the inverse transform.
#
# Untested!!

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

# Load example image (grayscale)
img = Image.open('example.png').convert('L')
img = img.resize((256,256))
img_np = np.array(img)/255.0

# extract 8x8 patches
patches=[]
for i in range(0,256-7,1):
    for j in range(0,256-7,1):
        p=img_np[i:i+8,j:j+8].reshape(-1)
        patches.append(p)
patches=np.array(patches,dtype=np.float32)

dataset=TensorDataset(torch.tensor(patches))
loader=DataLoader(dataset,batch_size=512,shuffle=True)

device='cuda' if torch.cuda.is_available() else 'cpu'

class TiedAE(nn.Module):
    def __init__(self,dim):
        super().__init__()
        self.encoder=nn.Linear(dim,dim,bias=False)
    def forward(self,x):
        W=self.encoder.weight
        z=F.linear(x,W)
        xh=F.linear(z,W.t())
        return xh,z

def orth_loss(W):
    I=torch.eye(W.size(0),device=W.device)
    return torch.norm(W@W.t()-I)**2

model=TiedAE(64).to(device)
opt=optim.Adam(model.parameters(),lr=1e-3)
lam_ortho=1e-3
lam_l1=1e-4

for epoch in range(5):
    for (batch,) in loader:
        batch=batch.to(device)
        xh,z=model(batch)
        loss=F.mse_loss(xh,batch)+lam_ortho*orth_loss(model.encoder.weight)+lam_l1*z.abs().mean()
        opt.zero_grad()
        loss.backward()
        opt.step()

# visualize learned basis
W=model.encoder.weight.detach().cpu().numpy()
fig,axs=plt.subplots(8,8,figsize=(8,8))
for i in range(64):
    axs[i//8][i%8].imshow(W[i].reshape(8,8),cmap='gray')
    axs[i//8][i%8].axis('off')
plt.tight_layout()
plt.savefig('/mnt/data/basis.png')
