In [23]:
import torch
from torch import nn
from torch import optim
from tqdm import tqdm
import pandas as pd
import numpy as np

from weighted_svd_utils import weighted_frobenious, get_svd, rank_r, get_svd_lora, toy_correlated_matrix, toy_two_subspace_matrix, toy_seven_subspace_matrix, toy_seven_subspace_matrix_unqual_size


In [24]:
def flops_from_svd(U, S, V):
    ret = 2 * (V.shape[0] * V.shape[1] + U.shape[0] * U.shape[1]) if U != None and S != None and V != None else 0
    return ret

In [25]:
# TODO: More accurate flop calculation everywhere

def find_optimal_rank_allocation_F2(mat1, mat2, flops):
    mat1_flops = 2 * mat1.shape[0] * mat1.shape[1] if mat1 != None else 0.0
    mat2_flops = 2 * mat2.shape[0] * mat2.shape[1] if mat2 != None else 0.0
    u1, s1, v1 = get_svd(mat1) if mat1 != None else (None, None, None)
    u2, s2, v2 = get_svd(mat2) if mat2 != None else (None, None, None)
    svd1_flops = flops_from_svd(u1, s1, v1)
    svd2_flops = flops_from_svd(u2, s2, v2)
    if mat1 != None and mat1.shape[0] == 1:
        svd1_flops = mat1_flops
    if mat2 != None and mat2.shape[0] == 1:
        svd2_flops = mat2_flops

    flops_per_sv_1 = 2 * v1.shape[0] + 2 * u1.shape[0] if mat1 != None else 0.0
    flops_per_sv_2 = 2 * v2.shape[0] + 2 * u2.shape[0] if mat2 != None else 0.0

    #F2_loss_1 = torch.diag(s1) if mat1 != None else []
    F2_loss_1 = [s*s for s in s1] if mat1 != None else [] 
    #F2_loss_2 = torch.diag(s2) if mat2 != None else []
    F2_loss_2 = [s*s for s in s2] if mat2 != None else []

    r1 = s1.shape[0] if mat1 != None else 0
    r2 = s2.shape[0] if mat2 != None else 0

    min_F2_loss = torch.inf

    if (mat1 == None and mat2 == None):
        print("Warning! Both matrices empty!")

    # Find the best split given that both matrices are SV-decomposed
    if (mat1 != None and mat2 != None):
        flops_both_svd = svd1_flops + svd2_flops
        F2_loss_both_svd = 0
        i1_both_svd = r1 - 1
        i2_both_svd = r2 - 1
        while flops_both_svd > flops:
            F2_per_flops_1 = F2_loss_1[i1_both_svd] / flops_per_sv_1
            F2_per_flops_2 = F2_loss_2[i2_both_svd] / flops_per_sv_2

            if (F2_per_flops_1 < F2_per_flops_2 or i2_both_svd == 0) and i1_both_svd > 0:
                F2_loss_both_svd += F2_loss_1[i1_both_svd]
                flops_both_svd -= flops_per_sv_1
                i1_both_svd -= 1
            elif i2_both_svd > 0:
                F2_loss_both_svd += F2_loss_2[i2_both_svd]
                flops_both_svd -= flops_per_sv_2
                i2_both_svd -= 1

        r1_optimal = i1_both_svd + 1
        r2_optimal = i2_both_svd + 1
        min_F2_loss = F2_loss_both_svd
        flops1 = r1_optimal * flops_per_sv_1
        flops2 = r2_optimal * flops_per_sv_2

    # Find the best split given that only mat1 is SV-decomposed
    if mat1 != None:
        flops_mat1_svd = svd1_flops + mat2_flops
        F2_loss_mat1_svd = 0
        i1_mat1_svd = r1 - 1
        while flops_mat1_svd > flops and i1_mat1_svd > 0:
            F2_loss_mat1_svd += F2_loss_1[i1_mat1_svd]
            flops_mat1_svd -= flops_per_sv_1
            i1_mat1_svd -= 1

        if i1_mat1_svd < 1:
            F2_loss_mat1_svd = torch.inf

        if F2_loss_mat1_svd <= min_F2_loss:
            r1_optimal = i1_mat1_svd + 1
            r2_optimal = r2
            min_F2_loss = F2_loss_mat1_svd
            flops1 = r1_optimal * flops_per_sv_1
            flops2 = mat2_flops

    # Find the best split given that only mat2 is SV-decomposed
    if mat2 != None:
        flops_mat2_svd = mat1_flops + svd2_flops
        F2_loss_mat2_svd = 0
        i2_mat2_svd = r2 - 1
        while flops_mat2_svd > flops and i2_mat2_svd > 0:
            F2_loss_mat2_svd += F2_loss_2[i2_mat2_svd]
            flops_mat2_svd -= flops_per_sv_2
            i2_mat2_svd -= 1

        if i2_mat2_svd < 1:
                F2_loss_mat2_svd = torch.inf

        if F2_loss_mat2_svd <= min_F2_loss:
            r1_optimal = r1
            r2_optimal = i2_mat2_svd + 1
            min_F2_loss = F2_loss_mat2_svd
            flops1 = mat1_flops
            flops2 = r2_optimal * flops_per_sv_2

    return (min_F2_loss, r1_optimal, r2_optimal, flops1, flops2)


In [26]:
def get_proj_loss_F2(group1, group2, flops):

    mat1 = torch.stack([row for idx, row in group1]) if group1 else None
    mat2 = torch.stack([row for idx, row in group2]) if group2 else None

    F2_loss, r1_optimal, r2_optimal, flops1, flops2 = find_optimal_rank_allocation_F2(mat1, mat2, flops)

    # Check if F2 loss from singular values matches actual F2 of the difference
    lora_1 = get_svd_lora(mat1, r1_optimal) if group1 else None
    lora_2 = get_svd_lora(mat2, r2_optimal) if group2 else None
    A_partial_1 = torch.stack([row for idx, row in group1]) if group1 else None
    A_partial_2 = torch.stack([row for idx, row in group2]) if group2 else None
    F_loss_1 = weighted_frobenious(A_partial_1, lora_1) if group1 else 0.0
    F_loss_2 = weighted_frobenious(A_partial_2, lora_2) if group2 else 0.0
    F2_loss_actual = F_loss_1 * F_loss_1 + F_loss_2 * F_loss_2

    if torch.abs(F2_loss - F2_loss_actual) > 0.5 and abs(F2_loss/F2_loss_actual - 1) > 0.001:
        print(f"Warning! F2 losses dont match: {F2_loss} vs {F2_loss_actual}")

    return F2_loss, r1_optimal, r2_optimal, flops1, flops2



In [27]:
def optimal_row_to_move_F2(sender, receiver, flops):
    # current_best_F2_loss, _, _ = get_proj_loss_F2(A_tuple_full, sender, receiver, flops)
    current_best_F2_loss = torch.inf
    current_best_tuple = None
    for tuple in sender:
        s = list(sender)
        r = list(receiver)

        r.append(tuple)
        s.remove(tuple)
        F2_loss, _, _, _, _ = get_proj_loss_F2(s, r, flops)
        if F2_loss < current_best_F2_loss:
            current_best_F2_loss = F2_loss
            current_best_tuple = tuple

    return current_best_tuple, current_best_F2_loss

In [28]:
def greedy_splitting_rows_F2(row_tuples, flops=None, printdepth = 1):

    # Check if group is vector
    if len(row_tuples) == 1:
        print("Attempting to split a single vector. Returning vector with zero loss.")
        return 0.0, 2 * len(row_tuples) * len((row_tuples[0])[1]), [row_tuples]

    if flops == None:
        A_flops = 2 * len(row_tuples) * len((row_tuples[0])[1])
        flops = 0.5 * A_flops

    print(f"Depth: {printdepth}")

    sender_idx = 1 # 1 means group1, 2 means group2
    optimal_group1 = list(row_tuples)
    optimal_group2 = []

    current_best_F2_loss, _, _, _, _ = get_proj_loss_F2(optimal_group1, optimal_group2, flops)
    single_svd_F2_loss = current_best_F2_loss # For debugging, can remove later
    print(f"Loss from simple SVD: {single_svd_F2_loss}")

    while True:
        current_best_F2_loss_for_direction, _, _, _, _ = get_proj_loss_F2(optimal_group1, optimal_group2, flops)
        optimal_group1_for_direction = list(optimal_group1)
        optimal_group2_for_direction = list(optimal_group2)
        groups = [list(optimal_group1), list(optimal_group2)]
        sender = groups[sender_idx - 1]
        receiver = groups[-sender_idx]
        while sender:
            row_to_move, F2_loss = optimal_row_to_move_F2(sender, receiver, flops)
            sender.remove(row_to_move)
            receiver.append(row_to_move)
            if F2_loss < current_best_F2_loss_for_direction:
                optimal_group1_for_direction = list(groups[0])
                optimal_group2_for_direction = list(groups[1])
                current_best_F2_loss_for_direction = F2_loss

        print(f"Loss for direction: {current_best_F2_loss_for_direction}")

        if current_best_F2_loss_for_direction < current_best_F2_loss:
            current_best_F2_loss = current_best_F2_loss_for_direction
            optimal_group1 = list(optimal_group1_for_direction)
            optimal_group2 = list(optimal_group2_for_direction)
            sender_idx = 2 if sender_idx == 1 else 1
            continue
        else:
            print("Optimal split found. Proceeding to split sub-matrices.")
            break

    F2_loss, r1_optimal, r2_optimal, flops1, flops2 = get_proj_loss_F2(optimal_group1, optimal_group2, flops)
    print(f"Size of group 1: {len(optimal_group1)}\nSize of group 2: {len(optimal_group2)}")
    print(f"Rank of matrix 1: {r1_optimal}\nRank of matrix 2: {r2_optimal}")

    if optimal_group1 and optimal_group2:
        remainderflops = flops - flops1 - flops2 # TODO: Implement efficient usage of remainder flops in another branches
        if remainderflops < 0:
            print(f"Warning: Used more flops than allocated! {flops1} + {flops2} = {flops1 + flops2} vs {flops}")
        if r2_optimal <= 1:
            flops1 += remainderflops
        elif r1_optimal <= 1:
            flops2 += remainderflops
        else:
            flops1 += int(remainderflops * 0.5)
            flops2 += remainderflops - int(remainderflops * 0.5)

        #mat1 = torch.stack([row for idx, row in optimal_group1])
        #mat2 = torch.stack([row for idx, row in optimal_group2])
        print(f"Splitting sub-matrix 1 of size {len(optimal_group1)} at depth = {printdepth} with {flops1} flops")
        if r1_optimal > 1:
            F2_loss_1, flops1_actual, groups1 = greedy_splitting_rows_F2(optimal_group1, flops1, printdepth + 1)
        else:
            F2_loss_1, flops1_actual, groups1 = 0.0, 2 * len((optimal_group1[0])[1]) + len(optimal_group1), optimal_group1
        print(f"Splitting sub-matrix 2 of size {len(optimal_group2)} at depth = {printdepth} with {flops2} flops")
        if r2_optimal > 1:
            F2_loss_2, flops2_actual, groups2 = greedy_splitting_rows_F2(optimal_group2, flops2, printdepth + 1)
        else:
            F2_loss_2, flops2_actual, groups2 = 0.0, 2 * len((optimal_group2[0])[1]) + len(optimal_group2), optimal_group2
        total_F2_loss = F2_loss_1 + F2_loss_2
        total_actual_flops = flops1_actual + flops2_actual
        return total_F2_loss, total_actual_flops, groups1 + groups2
    else:
        print("No matrices to split.")
        flops1_actual = flops1
        flops2_actual = flops2
        total_F2_loss = F2_loss
        total_actual_flops = flops1_actual + flops2_actual
        ret_group = optimal_group1 if optimal_group1 else optimal_group2
        ret_group.sort(key=lambda x: x[0])
        return total_F2_loss, total_actual_flops, [ret_group]

In [29]:
# A = toy_seven_subspace_matrix(80, 210, 12, True)
A = toy_seven_subspace_matrix_unqual_size(50, 80, 3, True)

A_rows = list(enumerate(A))

In [30]:
A

tensor([[-3.5255,  3.5839, -0.2200,  ...,  2.4400, -0.5140, -0.2197],
        [ 1.5199, -1.7891, -0.8587,  ...,  0.0415,  0.5110, -3.6876],
        [-0.7784,  0.7236, -1.1616,  ...,  0.5156,  0.4050, -0.7047],
        ...,
        [ 5.3366,  4.7646, -7.3674,  ..., -3.7187,  2.4843,  0.7762],
        [ 2.2597, -4.4653,  0.2376,  ...,  6.4451,  2.9782,  1.8185],
        [ 1.1547,  1.8041, -0.5035,  ..., -1.8724, -2.0218, -5.9161]])

In [31]:
total_F2_loss, total_actual_flops, groups = greedy_splitting_rows_F2(A_rows)

Depth: 1
Loss from simple SVD: 6129.34375
Loss for direction: 4203.85009765625
Loss for direction: 4102.3896484375
Loss for direction: 4102.3896484375
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 37
Size of group 2: 43
Rank of matrix 1: 11
Rank of matrix 2: 11
Splitting sub-matrix 1 of size 37 at depth = 1 with 1934 flops
Depth: 2
Loss from simple SVD: 1940.3934326171875
Loss for direction: 1695.4700927734375
Loss for direction: 1318.014892578125
Loss for direction: 1318.014892578125
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 22
Size of group 2: 15
Rank of matrix 1: 8
Rank of matrix 2: 6
Splitting sub-matrix 1 of size 22 at depth = 2 with 1153 flops
Depth: 3
Loss from simple SVD: 881.4962768554688
Loss for direction: 881.4962768554688
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 22
Size of group 2: 0
Rank of matrix 1: 8
Rank of matrix 2: 0
No matrices to split.
Splitting sub-matrix 2 of size 15 at depth 

In [32]:
for group in groups:
    print(f"Size of group: {len(group)}")
    for row in group:
        print(f"Row idx in group: {row[0]}")

Size of group: 22
Row idx in group: 5
Row idx in group: 7
Row idx in group: 60
Row idx in group: 61
Row idx in group: 62
Row idx in group: 63
Row idx in group: 64
Row idx in group: 65
Row idx in group: 66
Row idx in group: 67
Row idx in group: 68
Row idx in group: 69
Row idx in group: 70
Row idx in group: 71
Row idx in group: 72
Row idx in group: 73
Row idx in group: 74
Row idx in group: 75
Row idx in group: 76
Row idx in group: 77
Row idx in group: 78
Row idx in group: 79
Size of group: 15
Row idx in group: 30
Row idx in group: 31
Row idx in group: 32
Row idx in group: 33
Row idx in group: 34
Row idx in group: 35
Row idx in group: 36
Row idx in group: 37
Row idx in group: 38
Row idx in group: 39
Row idx in group: 40
Row idx in group: 41
Row idx in group: 42
Row idx in group: 43
Row idx in group: 44
Size of group: 14
Row idx in group: 2
Row idx in group: 4
Row idx in group: 9
Row idx in group: 20
Row idx in group: 21
Row idx in group: 22
Row idx in group: 23
Row idx in group: 24
Row id

In [33]:
groups

[[(5,
   tensor([-0.0552,  1.0613,  0.1623, -0.8906,  0.6300, -1.7071, -0.2869,  1.2017,
           -0.7861, -1.0423, -0.3417,  0.2348, -2.2136,  0.6383, -0.8885, -1.6155,
           -1.4515, -2.0146, -1.9422, -0.1110, -0.7664, -0.5075,  0.9989,  0.5055,
            1.5191, -1.3595, -0.4680, -0.4318, -0.0081, -0.6060,  0.6152, -0.8297,
           -0.6847,  1.9538, -0.1784, -0.2062,  1.7839, -0.1963,  1.0637, -1.6847,
            0.6636,  0.1178,  0.8497,  0.6996,  0.1070, -0.1907, -0.8738, -0.5575,
           -0.5849, -0.3180])),
  (7,
   tensor([ 0.6702, -3.0882,  1.3666,  1.2812, -0.3196,  0.6010, -1.4105,  1.3855,
           -0.7166,  0.0128,  2.6610,  1.9699,  3.4171, -1.6650,  3.3668,  4.2533,
            0.0363, -0.1216,  1.7665,  0.8234, -0.4439,  1.6215, -0.2282, -1.1640,
           -0.0913,  0.9248,  2.9240,  0.3085,  0.7140, -0.6795, -0.8228,  3.3377,
            0.9799, -4.2597,  0.1406,  0.2117, -1.6329,  0.3897, -2.8691,  2.4396,
            0.0418, -0.5949, -0.1049, -2.77

In [34]:
total_F2_loss

tensor(2802.5923)

In [35]:
total_actual_flops

3916.0

In [36]:
svd_F2_loss, svd_r1, _, svd_flops, _ = find_optimal_rank_allocation_F2(A, None, 80*50)

In [37]:
svd_F2_loss

tensor(6129.3438)

In [38]:
svd_r1

15

In [39]:
svd_flops

3900