In [None]:
import torch
import torchvision
# import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import numpy as np
import pandas as pd
import torch.nn as nn
# import torch.nn.functional as F

import torch.optim as optim

# Remove all the warnings
import warnings
warnings.filterwarnings('ignore')

# Set env CUDA_LAUNCH_BLOCKING=1
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Retina display
%config InlineBackend.figure_format = 'retina'

try:
    from einops import rearrange
except ImportError:
    %pip install einops
    from einops import rearrange

In [None]:
if os.path.exists('dog.jpg'):
    print('dog.jpg exists')
else:
    !curl -o dog.jpg https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg

In [None]:
# gradient descent
def factorize(A: torch.tensor, r: int):
    n = A.shape[0]
    U = torch.randn(n, r, requires_grad = True, device = device)
    V = torch.randn(r, n, requires_grad = True, device = device)

    optimizer = optim.Adam([U, V], lr = 0.01)
    mask = ~torch.isnan(A)

    for i in range(1000):
        diff_matrix = (U @ V) - A
        diff_vector = diff_matrix[mask]
        loss = torch.norm(diff_vector)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return U, V, loss

In [None]:
img = torchvision.io.read_image("dog.jpg")
print(img.shape)

In [None]:
img = torch.tensor(img, dtype=torch.float)
img = img.mean(dim=0, keepdim=False)
print(img.shape)

In [None]:
crop = torchvision.transforms.functional.crop(img, 600, 800, 500, 500)
crop.shape

In [None]:
factorize(crop, 5)

In [None]:
# wals method
def fac_wals(A: torch.tensor, r: int):
    n = A.shape[0]
    U = torch.randn(n, r, requires_grad = True, device = device)
    V = torch.randn(r, n, requires_grad = True, device = device)

    optimizer1 = optim.Adam([U], lr = 0.01)
    optimizer2 = optim.Adam([V], lr = 0.01)
    mask = ~torch.isnan(A)

    for i in range(1000):
        # fix V, update U
        diff_matrix = (U @ V) - A
        diff_vector = diff_matrix[mask]
        loss = torch.norm(diff_vector)
        optimizer1.zero_grad()
        loss.backward()
        optimizer1.step()

        # fix U, update V
        diff_matrix = (U @ V) - A
        diff_vector = diff_matrix[mask]
        loss = torch.norm(diff_vector)
        optimizer2.zero_grad()
        loss.backward()
        optimizer2.step()

    return U, V, loss