In [46]:
import torch
from torch import nn
from torch import optim
from tqdm import tqdm
import pandas as pd
import numpy as np

from weighted_svd_utils import weighted_frobenious, get_svd, rank_r, get_svd_lora, toy_correlated_matrix, toy_two_subspace_matrix, toy_seven_subspace_matrix, toy_seven_subspace_matrix_unqual_size


In [47]:
def flops_from_svd(U, S, V):
    ret = 2 * (V.shape[0] * V.shape[1] + U.shape[0] * U.shape[1]) if U != None and S != None and V != None else 0
    return ret

In [48]:
# TODO: More accurate flop calculation everywhere

def find_optimal_rank_allocation_F2(mat1, mat2, flops):
    mat1_flops = 2 * mat1.shape[0] * mat1.shape[1] if mat1 != None else 0.0
    mat2_flops = 2 * mat2.shape[0] * mat2.shape[1] if mat2 != None else 0.0
    u1, s1, v1 = get_svd(mat1) if mat1 != None else (None, None, None)
    u2, s2, v2 = get_svd(mat2) if mat2 != None else (None, None, None)
    svd1_flops = flops_from_svd(u1, s1, v1)
    svd2_flops = flops_from_svd(u2, s2, v2)
    if mat1 != None and mat1.shape[0] == 1:
        svd1_flops = mat1_flops
    if mat2 != None and mat2.shape[0] == 1:
        svd2_flops = mat2_flops

    flops_per_sv_1 = 2 * v1.shape[0] + 2 * u1.shape[0] if mat1 != None else 0.0
    flops_per_sv_2 = 2 * v2.shape[0] + 2 * u2.shape[0] if mat2 != None else 0.0

    #F2_loss_1 = torch.diag(s1) if mat1 != None else []
    F2_loss_1 = [s*s for s in s1] if mat1 != None else [] 
    #F2_loss_2 = torch.diag(s2) if mat2 != None else []
    F2_loss_2 = [s*s for s in s2] if mat2 != None else []

    r1 = s1.shape[0] if mat1 != None else 0
    r2 = s2.shape[0] if mat2 != None else 0

    min_F2_loss = torch.inf

    if (mat1 == None and mat2 == None):
        print("Warning! Both matrices empty!")

    # Find the best split given that both matrices are SV-decomposed
    if (mat1 != None and mat2 != None):
        flops_both_svd = svd1_flops + svd2_flops
        F2_loss_both_svd = 0
        i1_both_svd = r1 - 1
        i2_both_svd = r2 - 1
        while flops_both_svd > flops:
            F2_per_flops_1 = F2_loss_1[i1_both_svd] / flops_per_sv_1
            F2_per_flops_2 = F2_loss_2[i2_both_svd] / flops_per_sv_2

            if (F2_per_flops_1 < F2_per_flops_2 or i2_both_svd == 0) and i1_both_svd > 0:
                F2_loss_both_svd += F2_loss_1[i1_both_svd]
                flops_both_svd -= flops_per_sv_1
                i1_both_svd -= 1
            elif i2_both_svd > 0:
                F2_loss_both_svd += F2_loss_2[i2_both_svd]
                flops_both_svd -= flops_per_sv_2
                i2_both_svd -= 1

        r1_optimal = i1_both_svd + 1
        r2_optimal = i2_both_svd + 1
        min_F2_loss = F2_loss_both_svd
        flops1 = r1_optimal * flops_per_sv_1
        flops2 = r2_optimal * flops_per_sv_2

    # Find the best split given that only mat1 is SV-decomposed
    if mat1 != None:
        flops_mat1_svd = svd1_flops + mat2_flops
        F2_loss_mat1_svd = 0
        i1_mat1_svd = r1 - 1
        while flops_mat1_svd > flops and i1_mat1_svd > -1:
            F2_loss_mat1_svd += F2_loss_1[i1_mat1_svd]
            flops_mat1_svd -= flops_per_sv_1
            i1_mat1_svd -= 1

        if i1_mat1_svd < 0: # TODO: allow for groups of size < 3
            F2_loss_mat1_svd = torch.inf

        if F2_loss_mat1_svd <= min_F2_loss:
            r1_optimal = i1_mat1_svd + 1
            r2_optimal = r2
            min_F2_loss = F2_loss_mat1_svd
            flops1 = r1_optimal * flops_per_sv_1
            flops2 = mat2_flops

    # Find the best split given that only mat2 is SV-decomposed
    if mat2 != None:
        flops_mat2_svd = mat1_flops + svd2_flops
        F2_loss_mat2_svd = 0
        i2_mat2_svd = r2 - 1
        while flops_mat2_svd > flops and i2_mat2_svd > -1:
            F2_loss_mat2_svd += F2_loss_2[i2_mat2_svd]
            flops_mat2_svd -= flops_per_sv_2
            i2_mat2_svd -= 1

        if i2_mat2_svd < 0: # TODO: allow for groups of size < 3
                F2_loss_mat2_svd = torch.inf

        if F2_loss_mat2_svd <= min_F2_loss:
            r1_optimal = r1
            r2_optimal = i2_mat2_svd + 1
            min_F2_loss = F2_loss_mat2_svd
            flops1 = mat1_flops
            flops2 = r2_optimal * flops_per_sv_2

    return (min_F2_loss, r1_optimal, r2_optimal, flops1, flops2)


In [49]:
def get_proj_loss_F2(group1, group2, flops):

    mat1 = torch.stack([row for idx, row in group1]) if group1 else None
    mat2 = torch.stack([row for idx, row in group2]) if group2 else None

    F2_loss, r1_optimal, r2_optimal, flops1, flops2 = find_optimal_rank_allocation_F2(mat1, mat2, flops)

    # Check if F2 loss from singular values matches actual F2 of the difference
    lora_1 = get_svd_lora(mat1, r1_optimal) if group1 else None
    lora_2 = get_svd_lora(mat2, r2_optimal) if group2 else None
    A_partial_1 = torch.stack([row for idx, row in group1]) if group1 else None
    A_partial_2 = torch.stack([row for idx, row in group2]) if group2 else None
    F_loss_1 = weighted_frobenious(A_partial_1, lora_1) if group1 else 0.0
    F_loss_2 = weighted_frobenious(A_partial_2, lora_2) if group2 else 0.0
    F2_loss_1 = F_loss_1 * F_loss_1
    F2_loss_2 = F_loss_2 * F_loss_2
    F2_loss_actual = F2_loss_1 + F2_loss_2

    if torch.abs(F2_loss - F2_loss_actual) > 0.5 and abs(F2_loss/F2_loss_actual - 1) > 0.001:
        print(f"Warning! F2 losses dont match: {F2_loss} vs {F2_loss_actual}")


    return F2_loss_1, F2_loss_2, r1_optimal, r2_optimal, flops1, flops2



In [50]:
def optimal_row_to_move_F2(sender, receiver, flops):
    # current_best_F2_loss, _, _ = get_proj_loss_F2(A_tuple_full, sender, receiver, flops)
    current_best_F2_loss = torch.inf
    current_best_tuple = None
    for tuple in sender:
        s = list(sender)
        r = list(receiver)

        r.append(tuple)
        s.remove(tuple)
        F2_loss_1, F2_loss_2, _, _, _, _ = get_proj_loss_F2(s, r, flops)
        F2_loss = F2_loss_1 + F2_loss_2
        if F2_loss < current_best_F2_loss:
            current_best_F2_loss = F2_loss
            current_best_tuple = tuple

    return current_best_tuple, current_best_F2_loss

In [51]:
def greedy_splitting_rows_F2(row_tuples, flops=None, printdepth = 1):

    # Check if group is vector
    if len(row_tuples) == 1:
        print("Attempting to split a single vector. Returning vector with zero loss.")
        return 0.0, 2 * len(row_tuples) * len((row_tuples[0])[1]), [row_tuples]

    if flops == None:
        A_flops = 2 * len(row_tuples) * len((row_tuples[0])[1])
        flops = 0.5 * A_flops

    print(f"Depth: {printdepth}")

    sender_idx = 1 # 1 means group1, 2 means group2
    optimal_group1 = list(row_tuples)
    optimal_group2 = []

    current_best_F2_loss_1, current_best_F2_loss_2, _, _, _, _ = get_proj_loss_F2(optimal_group1, optimal_group2, flops)
    current_best_F2_loss = current_best_F2_loss_1 + current_best_F2_loss_2
    single_svd_F2_loss = current_best_F2_loss # For debugging, can remove later
    print(f"Loss from simple SVD: {single_svd_F2_loss}")

    while True:
        current_best_F2_loss_for_direction_1, current_best_F2_loss_for_direction_2, _, _, _, _ = get_proj_loss_F2(optimal_group1, optimal_group2, flops)
        current_best_F2_loss_for_direction = current_best_F2_loss_for_direction_1 + current_best_F2_loss_for_direction_2
        optimal_group1_for_direction = list(optimal_group1)
        optimal_group2_for_direction = list(optimal_group2)
        groups = [list(optimal_group1), list(optimal_group2)]
        sender = groups[sender_idx - 1]
        receiver = groups[-sender_idx]
        while sender:
            row_to_move, F2_loss = optimal_row_to_move_F2(sender, receiver, flops)
            sender.remove(row_to_move)
            receiver.append(row_to_move)
            if F2_loss < current_best_F2_loss_for_direction:
                optimal_group1_for_direction = list(groups[0])
                optimal_group2_for_direction = list(groups[1])
                current_best_F2_loss_for_direction = F2_loss

        print(f"Loss for direction: {current_best_F2_loss_for_direction}")

        if current_best_F2_loss_for_direction < current_best_F2_loss:
            current_best_F2_loss = current_best_F2_loss_for_direction
            optimal_group1 = list(optimal_group1_for_direction)
            optimal_group2 = list(optimal_group2_for_direction)
            sender_idx = 2 if sender_idx == 1 else 1
            continue
        else:
            print("Optimal split found. Proceeding to split sub-matrices.")
            break

    F2_loss_1, F2_loss_2, r1_optimal, r2_optimal, flops1, flops2 = get_proj_loss_F2(optimal_group1, optimal_group2, flops)
    F2_loss = F2_loss_1 + F2_loss_2
    print(f"Size of group 1: {len(optimal_group1)}\nSize of group 2: {len(optimal_group2)}")
    print(f"Rank of matrix 1: {r1_optimal}\nRank of matrix 2: {r2_optimal}")

    if optimal_group1 and optimal_group2:
        remainderflops = flops - flops1 - flops2 # TODO: Implement efficient usage of remainder flops in another branches
        if remainderflops < 0:
            print(f"Warning: Used more flops than allocated! {flops1} + {flops2} = {flops1 + flops2} vs {flops}")
        if r2_optimal <= 1:
            flops1 += remainderflops
        elif r1_optimal <= 1:
            flops2 += remainderflops
        else:
            flops1 += int(remainderflops * 0.5)
            flops2 += remainderflops - int(remainderflops * 0.5)

        #mat1 = torch.stack([row for idx, row in optimal_group1])
        #mat2 = torch.stack([row for idx, row in optimal_group2])
        print(f"Splitting sub-matrix 1 of size {len(optimal_group1)} at depth = {printdepth} with {flops1} flops")
        if r1_optimal == 1:
            flops1_actual = 2 * len((optimal_group1[0])[1]) + len(optimal_group1)
            groups1 = optimal_group1
        elif r1_optimal == len(optimal_group1):
            if F2_loss_1 != 0:
                print(f"Warning! Loss is {F2_loss_1} but should be zero!")
            flops1_actual = 2 * len((optimal_group1[0])[1]) * len(optimal_group1)
            groups1 = optimal_group1
        else:
            F2_loss_1, flops1_actual, groups1 = greedy_splitting_rows_F2(optimal_group1, flops1, printdepth + 1)
        print(f"Splitting sub-matrix 2 of size {len(optimal_group2)} at depth = {printdepth} with {flops2} flops")
        if r2_optimal == 1:
            flops2_actual, groups2 = 2 * len((optimal_group2[0])[1]) + len(optimal_group2), optimal_group2
        elif r2_optimal == len(optimal_group2):
            if F2_loss_2 != 0:
                print(f"Warning! Loss is {F2_loss_2} but should be zero!")
            flops2_actual = 2 * len((optimal_group2[0])[1]) * len(optimal_group2)
            groups2 = optimal_group2
        else:
            F2_loss_2, flops2_actual, groups2 = greedy_splitting_rows_F2(optimal_group2, flops2, printdepth + 1)
        total_F2_loss = F2_loss_1 + F2_loss_2
        total_actual_flops = flops1_actual + flops2_actual
        return total_F2_loss, total_actual_flops, groups1 + groups2
    else:
        print("No matrices to split.")
        flops1_actual = flops1
        flops2_actual = flops2
        total_F2_loss = F2_loss
        total_actual_flops = flops1_actual + flops2_actual
        ret_group = optimal_group1 if optimal_group1 else optimal_group2
        ret_group.sort(key=lambda x: x[0])
        return total_F2_loss, total_actual_flops, [ret_group]

In [52]:
# A = toy_seven_subspace_matrix(80, 210, 12, True)
A = toy_seven_subspace_matrix_unqual_size(20, 80, 4, True)
off_center = 10*torch.randn(20)
A_shifted = A + off_center

A_r1 = get_svd_lora(A, 1)
A_shifted_r1 = get_svd_lora(A_shifted, 1)

A_minus_joint = A - A_r1
A_shifted_minus_joint = A_shifted - A_shifted_r1
A_minus_mean = A - torch.mean(A, dim=0)
A_shifted_minus_mean = A_shifted - torch.mean(A_shifted, dim=0)

A_rows = list(enumerate(A))
A_minus_joint_rows = list(enumerate(A_minus_joint))
A_minus_mean_rows = list(enumerate(A_minus_mean))

A_shifted_rows = list(enumerate(A_shifted))
A_shifted_minus_joint_rows = list(enumerate(A_shifted_minus_joint))
A_shifted_minus_mean_rows = list(enumerate(A_shifted_minus_mean))

In [53]:
norms = torch.norm(A, dim=1)
average_length = torch.mean(norms)
print(f'Average length of row vectors A: {average_length}')

print(f'Off-centering magnitude: {torch.norm(off_center)}')

Average length of row vectors A: 14.219156265258789
Off-centering magnitude: 49.08358383178711


In [54]:
total_F2_loss, total_actual_flops, groups = greedy_splitting_rows_F2(A_rows, 20*80)

Depth: 1
Loss from simple SVD: 5488.38720703125


Loss for direction: 4333.79638671875
Loss for direction: 4078.24853515625
Loss for direction: 4078.24853515625
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 36
Size of group 2: 44
Rank of matrix 1: 7
Rank of matrix 2: 6
Splitting sub-matrix 1 of size 36 at depth = 1 with 808 flops
Depth: 2
Loss from simple SVD: 1961.669677734375
Loss for direction: 1730.09912109375
Loss for direction: 1730.09912109375
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 9
Size of group 2: 27
Rank of matrix 1: 4
Rank of matrix 2: 6
Splitting sub-matrix 1 of size 9 at depth = 2 with 238 flops
Depth: 3
Loss from simple SVD: 329.6886901855469
Loss for direction: 324.7070007324219
Loss for direction: 324.7070007324219
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 6
Size of group 2: 3
Rank of matrix 1: 2
Rank of matrix 2: 3
Splitting sub-matrix 1 of size 6 at depth = 3 with 111 flops
Depth: 4
Loss from simple SVD: 324.7070007324219
Loss f

In [55]:
total_F2_loss_minus_joint, total_actual_flops_minus_joint, groups_minus_joint = greedy_splitting_rows_F2(A_minus_joint_rows, 20*80-200)

Depth: 1
Loss from simple SVD: 5488.38818359375
Loss for direction: 4667.1552734375
Loss for direction: 4646.474609375
Loss for direction: 4482.962890625
Loss for direction: 4482.962890625
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 27
Size of group 2: 53
Rank of matrix 1: 4
Rank of matrix 2: 7
Splitting sub-matrix 1 of size 27 at depth = 1 with 377 flops
Depth: 2
Loss from simple SVD: 1673.140869140625
Loss for direction: 1644.368896484375
Loss for direction: 1545.09521484375
Loss for direction: 1545.09521484375
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 11
Size of group 2: 16
Rank of matrix 1: 2
Rank of matrix 2: 3
Splitting sub-matrix 1 of size 11 at depth = 2 with 142 flops
Depth: 3
Loss from simple SVD: 483.59716796875
Loss for direction: 483.59716796875
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 11
Size of group 2: 0
Rank of matrix 1: 2
Rank of matrix 2: 0
No matrices to split.
Splitting sub-mat

In [56]:
total_F2_loss_shifted, total_actual_flops_shifted, groups_shifted = greedy_splitting_rows_F2(A_shifted_rows, 20*80)

Depth: 1
Loss from simple SVD: 5771.80517578125
Loss for direction: 4654.6455078125
Loss for direction: 4654.6455078125
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 50
Size of group 2: 30
Rank of matrix 1: 7
Rank of matrix 2: 6
Splitting sub-matrix 1 of size 50 at depth = 1 with 990 flops
Depth: 2
Loss from simple SVD: 2891.427001953125
Loss for direction: 2491.833984375
Loss for direction: 2491.833984375
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 25
Size of group 2: 25
Rank of matrix 1: 6
Rank of matrix 2: 5
Splitting sub-matrix 1 of size 25 at depth = 2 with 540 flops
Depth: 3
Loss from simple SVD: 1132.40478515625
Loss for direction: 1053.3980712890625
Loss for direction: 1053.3980712890625
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 17
Size of group 2: 8
Rank of matrix 1: 5
Rank of matrix 2: 3
Splitting sub-matrix 1 of size 17 at depth = 3 with 371 flops
Depth: 4
Loss from simple SVD: 763.3182983398

In [57]:
total_F2_loss_shifted_minus_joint, total_actual_flops_shifted_minus_joint, groups_shifted_minus_joint = greedy_splitting_rows_F2(A_shifted_minus_joint_rows, 20*80-200)

Depth: 1
Loss from simple SVD: 5771.80517578125
Loss for direction: 4795.716796875
Loss for direction: 4631.595703125
Loss for direction: 4478.7978515625
Loss for direction: 4478.7978515625
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 30
Size of group 2: 50
Rank of matrix 1: 7
Rank of matrix 2: 5
Splitting sub-matrix 1 of size 30 at depth = 1 with 700 flops
Depth: 2
Loss from simple SVD: 1484.903564453125
Loss for direction: 1260.2025146484375
Loss for direction: 1231.3323974609375
Loss for direction: 1231.3323974609375
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 19
Size of group 2: 11
Rank of matrix 1: 5
Rank of matrix 2: 5
Splitting sub-matrix 1 of size 19 at depth = 2 with 390 flops
Depth: 3
Loss from simple SVD: 810.7728271484375
Loss for direction: 810.7727661132812
Loss for direction: 810.7727661132812
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 0
Size of group 2: 19
Rank of matrix 1: 0
Rank of mat

In [58]:
total_F2_loss_minus_mean, total_actual_flops_minus_mean, groups_minus_mean = greedy_splitting_rows_F2(A_minus_mean_rows, 20*80-120)

Depth: 1
Loss from simple SVD: 6332.521484375


Loss for direction: 4724.13134765625
Loss for direction: 4724.13134765625
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 50
Size of group 2: 30
Rank of matrix 1: 7
Rank of matrix 2: 5
Splitting sub-matrix 1 of size 50 at depth = 1 with 980 flops
Depth: 2
Loss from simple SVD: 3000.599609375
Loss for direction: 2777.8173828125
Loss for direction: 2757.99365234375
Loss for direction: 2755.2392578125
Loss for direction: 2753.83740234375
Loss for direction: 2753.83740234375
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 36
Size of group 2: 14
Rank of matrix 1: 6
Rank of matrix 2: 4
Splitting sub-matrix 1 of size 36 at depth = 2 with 690 flops
Depth: 3
Loss from simple SVD: 1929.2802734375
Loss for direction: 1686.7589111328125
Loss for direction: 1686.7589111328125
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 21
Size of group 2: 15
Rank of matrix 1: 5
Rank of matrix 2: 4
Splitting sub-matrix 1 of size 21 at depth 

In [59]:
total_F2_loss_shifted_minus_mean, total_actual_flops_shifted_minus_mean, groups_shifted_minus_mean = greedy_splitting_rows_F2(A_shifted_minus_mean_rows, 20*80-120)

Depth: 1
Loss from simple SVD: 6332.5224609375
Loss for direction: 4724.130859375
Loss for direction: 4724.130859375
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 50
Size of group 2: 30
Rank of matrix 1: 7
Rank of matrix 2: 5
Splitting sub-matrix 1 of size 50 at depth = 1 with 980 flops
Depth: 2
Loss from simple SVD: 3000.59912109375
Loss for direction: 2777.8173828125
Loss for direction: 2757.99365234375
Loss for direction: 2755.2392578125
Loss for direction: 2753.8369140625
Loss for direction: 2753.8369140625
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 36
Size of group 2: 14
Rank of matrix 1: 6
Rank of matrix 2: 4
Splitting sub-matrix 1 of size 36 at depth = 2 with 690 flops
Depth: 3
Loss from simple SVD: 1929.2799072265625
Loss for direction: 1686.7587890625
Loss for direction: 1686.7587890625
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 21
Size of group 2: 15
Rank of matrix 1: 5
Rank of matrix 2: 4
Spl

In [60]:
print(total_F2_loss)
print(total_F2_loss_minus_joint)
print(total_F2_loss_minus_mean)
print("---")
print(20*80)
print(total_actual_flops)
print(total_actual_flops_minus_joint + 200)
print(total_actual_flops_minus_mean + 120)

tensor(3046.5449)
tensor(3359.6262)
tensor(3582.4277)
---
1600
1571.0
1556.0
1558.0


In [61]:
print(total_F2_loss_shifted)
print(total_F2_loss_shifted_minus_joint)
print(total_F2_loss_shifted_minus_mean)
print("---")
print(20*80)
print(total_actual_flops_shifted)
print(total_actual_flops_shifted_minus_joint + 200)
print(total_actual_flops_shifted_minus_mean + 120)

tensor(3627.8438)
tensor(3273.4287)
tensor(3582.4277)
---
1600
1592.0
1550.0
1558.0
