In [1]:
import torch
from torch import nn
from torch import optim
from tqdm import tqdm
import pandas as pd
import numpy as np

from weighted_svd_utils import weighted_frobenious, get_svd, rank_r, get_svd_lora, toy_correlated_matrix, toy_two_subspace_matrix, toy_seven_subspace_matrix


In [2]:
def flops_from_svd(U, S, V):
    ret = 2 * (V.shape[0] * V.shape[1] + U.shape[0] * U.shape[1]) + S.shape[0] if U != None and S != None and V != None else 0
    return ret

In [3]:
def find_optimal_rank_allocation_F2(mat1, mat2, flops):
    mat1_flops = 2 * mat1.shape[0] * mat1.shape[1] if mat1 != None else 0.0
    mat2_flops = 2 * mat2.shape[0] * mat2.shape[1] if mat2 != None else 0.0
    u1, s1, v1 = get_svd(mat1) if mat1 != None else (None, None, None)
    u2, s2, v2 = get_svd(mat2) if mat2 != None else (None, None, None)
    svd1_flops = flops_from_svd(u1, s1, v1)
    svd2_flops = flops_from_svd(u2, s2, v2)
    flops_per_sv_1 = 2 * v1.shape[0] + 2 * u1.shape[0] if mat1 != None else 0.0
    flops_per_sv_2 = 2 * v2.shape[0] + 2 * u2.shape[0] if mat2 != None else 0.0

    #F2_loss_1 = torch.diag(s1) if mat1 != None else []
    F2_loss_1 = [s*s for s in s1] if mat1 != None else [] 
    #F2_loss_2 = torch.diag(s2) if mat2 != None else []
    F2_loss_2 = [s*s for s in s2] if mat2 != None else []

    r1 = s1.shape[0] if mat1 != None else 0
    r2 = s2.shape[0] if mat2 != None else 0

    min_F2_loss = torch.inf

    if (mat1 == None and mat2 == None):
        print("Warning! Both matrices empty!")

    # Find the best split given that both matrices are SV-decomposed
    if (mat1 != None and mat2 != None):
        flops_both_svd = svd1_flops + svd2_flops
        F2_loss_both_svd = 0
        i1_both_svd = r1 - 1
        i2_both_svd = r2 - 1
        while flops_both_svd > flops:
            F2_per_flops_1 = F2_loss_1[i1_both_svd] / flops_per_sv_1
            F2_per_flops_2 = F2_loss_2[i2_both_svd] / flops_per_sv_2

            if (F2_per_flops_1 < F2_per_flops_2 or i2_both_svd == 0) and i1_both_svd > 0:
                F2_loss_both_svd += F2_loss_1[i1_both_svd]
                flops_both_svd -= flops_per_sv_1
                i1_both_svd -= 1
            elif i2_both_svd > 0:
                F2_loss_both_svd += F2_loss_2[i2_both_svd]
                flops_both_svd -= flops_per_sv_2
                i2_both_svd -= 1

        r1_optimal = i1_both_svd + 1
        r2_optimal = i2_both_svd + 1
        min_F2_loss = F2_loss_both_svd
        flops1 = r1_optimal * flops_per_sv_1
        flops2 = r2_optimal * flops_per_sv_2

    # Find the best split given that only mat1 is SV-decomposed
    if mat1 != None:
        flops_mat1_svd = svd1_flops + mat2_flops
        F2_loss_mat1_svd = 0
        i1_mat1_svd = r1 - 1
        while flops_mat1_svd > flops and i1_mat1_svd > 0:
            F2_loss_mat1_svd += F2_loss_1[i1_mat1_svd]
            flops_mat1_svd -= flops_per_sv_1
            i1_mat1_svd -= 1

        if i1_mat1_svd <= 1:
            F2_loss_mat1_svd = torch.inf

        if F2_loss_mat1_svd < min_F2_loss:
            r1_optimal = i1_mat1_svd + 1
            r2_optimal = r2
            min_F2_loss = F2_loss_mat1_svd
            flops1 = r1_optimal * flops_per_sv_1
            flops2 = mat2_flops

    # Find the best split given that only mat2 is SV-decomposed
    if mat2 != None:
        flops_mat2_svd = mat1_flops + svd2_flops
        F2_loss_mat2_svd = 0
        i2_mat2_svd = r2 - 1
        while flops_mat2_svd > flops and i2_mat2_svd > 0:
            F2_loss_mat2_svd += F2_loss_2[i2_mat2_svd]
            flops_mat2_svd -= flops_per_sv_2
            i2_mat2_svd -= 1

        if i2_mat2_svd <= 1:
                F2_loss_mat2_svd = torch.inf

        if F2_loss_mat2_svd < min_F2_loss:
            r1_optimal = r1
            r2_optimal = i2_mat2_svd + 1
            min_F2_loss = F2_loss_mat2_svd
            flops1 = mat1_flops
            flops2 = r2_optimal * flops_per_sv_2

    return (min_F2_loss, r1_optimal, r2_optimal, flops1, flops2)


In [4]:
def get_proj_loss_F2(A_tuple_full, group1, group2, flops):
    A_tuple_1 = [A_tuple_full[idx] for idx, row in group1]
    A_tuple_2 = [A_tuple_full[idx] for idx, row in group2]

    mat1 = torch.stack([row for idx, row in group1]) if group1 else None
    mat2 = torch.stack([row for idx, row in group2]) if group2 else None

    F2_loss, r1_optimal, r2_optimal, flops1, flops2 = find_optimal_rank_allocation_F2(mat1, mat2, flops)

    # Check if F2 loss from singular values matches actual F2 of the difference
    lora_1 = get_svd_lora(mat1, r1_optimal) if A_tuple_1 else None
    lora_2 = get_svd_lora(mat2, r2_optimal) if A_tuple_2 else None
    A_partial_1 = torch.stack([row for idx, row in A_tuple_1]) if A_tuple_1 else None
    A_partial_2 = torch.stack([row for idx, row in A_tuple_2]) if A_tuple_2 else None
    F_loss_1 = weighted_frobenious(A_partial_1, lora_1) if A_tuple_1 else 0.0
    F_loss_2 = weighted_frobenious(A_partial_2, lora_2) if A_tuple_2 else 0.0
    F2_loss_actual = F_loss_1 * F_loss_1 + F_loss_2 * F_loss_2

    if torch.abs(F2_loss - F2_loss_actual) > 0.5:
        print(f"Warning! F2 losses dont match: {F2_loss} vs {F2_loss_actual}")

    return F2_loss, r1_optimal, r2_optimal, flops1, flops2



In [5]:
def optimal_row_to_move_F2(A_tuple_full, sender, receiver, flops):
    # current_best_F2_loss, _, _ = get_proj_loss_F2(A_tuple_full, sender, receiver, flops)
    current_best_F2_loss = torch.inf
    current_best_tuple = None
    for tuple in sender:
        s = list(sender)
        r = list(receiver)

        r.append(tuple)
        s.remove(tuple)
        F2_loss, _, _, _, _ = get_proj_loss_F2(A_tuple_full, s, r, flops)
        if F2_loss < current_best_F2_loss:
            current_best_F2_loss = F2_loss
            current_best_tuple = tuple

    return current_best_tuple, current_best_F2_loss

In [6]:
def greedy_splitting_rows_F2(A, flops=None, printdepth = 1):

    # Check if A is a single vector
    if A.shape[0] == 1:
        print("Attempting to split a single vector. Returning vector with zero loss.")
        return 0.0, 2 * A.shape[0] * A.shape[1]

    if flops == None:
        A_flops = 2 * A.shape[0] * A.shape[1]
        flops = 0.5 * A_flops

    print(f"Depth: {printdepth}")

    A_tuple_full = [(i, row) for i, row in enumerate(A)]

    sender_idx = 1 # 1 means group1, 2 means group2
    optimal_group1 = list(A_tuple_full)
    optimal_group2 = []

    current_best_F2_loss, _, _, _, _ = get_proj_loss_F2(A_tuple_full, optimal_group1, optimal_group2, flops)
    single_svd_F2_loss = current_best_F2_loss # For debugging, can remove later
    print(f"Loss from simple SVD: {single_svd_F2_loss}")

    while True:
        current_best_F2_loss_for_direction, _, _, _, _ = get_proj_loss_F2(A_tuple_full, optimal_group1, optimal_group2, flops)
        optimal_group1_for_direction = list(optimal_group1)
        optimal_group2_for_direction = list(optimal_group2)
        groups = [list(optimal_group1), list(optimal_group2)]
        sender = groups[sender_idx - 1]
        receiver = groups[-sender_idx]
        while sender:
            row_to_move, F2_loss = optimal_row_to_move_F2(A_tuple_full, sender, receiver, flops)
            sender.remove(row_to_move)
            receiver.append(row_to_move)
            if F2_loss < current_best_F2_loss_for_direction:
                optimal_group1_for_direction = list(groups[0])
                optimal_group2_for_direction = list(groups[1])
                current_best_F2_loss_for_direction = F2_loss

        print(f"Loss for direction: {current_best_F2_loss_for_direction}")

        if current_best_F2_loss_for_direction < current_best_F2_loss:
            current_best_F2_loss = current_best_F2_loss_for_direction
            optimal_group1 = list(optimal_group1_for_direction)
            optimal_group2 = list(optimal_group2_for_direction)
            sender_idx = 2 if sender_idx == 1 else 1
            continue
        else:
            print("Optimal split found. Proceeding to split sub-matrices.")
            break

    F2_loss, r1_optimal, r2_optimal, flops1, flops2 = get_proj_loss_F2(A_tuple_full, optimal_group1, optimal_group2, flops)
    print(f"Size of group 1: {len(optimal_group1)}\nSize of group 2: {len(optimal_group2)}")
    print(f"Rank of matrix 1: {r1_optimal}\nRank of matrix 2: {r2_optimal}")

    if optimal_group1 and optimal_group2:
        remainderflops = flops - flops1 - flops2 # TODO: Implement efficient usage of remainder flops in another branches
        if remainderflops < 0:
            print("Warning: Used more flops than allocated!")
        if r2_optimal <= 1:
            flops1 += remainderflops
        elif r1_optimal <= 1:
            flops2 += remainderflops
        else:
            flops1 += int(remainderflops * 0.5)
            flops2 += remainderflops - int(remainderflops * 0.5)

        mat1 = torch.stack([row for idx, row in optimal_group1])
        mat2 = torch.stack([row for idx, row in optimal_group2])
        print(f"Splitting sub-matrix 1 of depth = {printdepth} with {flops1} flops")
        F2_loss_1, flops1_actual = greedy_splitting_rows_F2(mat1, flops1, printdepth + 1)
        print(f"Splitting sub-matrix 2 of depth = {printdepth} with {flops2} flops")
        F2_loss_2, flops2_actual = greedy_splitting_rows_F2(mat2, flops2, printdepth + 1)
        total_F2_loss = F2_loss_1 + F2_loss_2
        total_actual_flops = flops1_actual + flops2_actual
        return total_F2_loss, total_actual_flops
    else:
        print("No matrices to split.")
        flops1_actual = flops1
        flops2_actual = flops2
        total_F2_loss = F2_loss
        total_actual_flops = flops1_actual + flops2_actual
        return total_F2_loss, total_actual_flops

In [7]:
A = toy_seven_subspace_matrix(80, 210, 12, True)

In [8]:
A

tensor([[ 1.3759,  1.0100,  1.3270,  ..., -0.6386,  1.5534, -3.6360],
        [-1.4919, -2.6958, -5.4791,  ..., -0.9208,  2.2714,  4.6334],
        [ 2.5923,  2.4036, -1.7378,  ...,  3.2576,  2.2373,  1.9948],
        ...,
        [-6.9214, -0.7655,  8.8818,  ...,  1.5633,  2.7789,  2.2670],
        [-7.2989, -2.7452,  6.1330,  ...,  4.8135, -8.2404, -2.0497],
        [ 1.2732, -0.6490, -0.2096,  ..., -0.3631, -2.1555,  0.5720]])

In [9]:
total_F2_loss, total_actual_flops = greedy_splitting_rows_F2(A)

Depth: 1
Loss from simple SVD: 37750.78125
Loss for direction: 24158.76171875
Loss for direction: 22925.611328125
Loss for direction: 22925.611328125
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 120
Size of group 2: 90
Rank of matrix 1: 23
Rank of matrix 2: 21
Splitting sub-matrix 1 of depth = 1 with 9430 flops
Depth: 2
Loss from simple SVD: 15002.375
Loss for direction: 9047.357421875
Loss for direction: 9047.357421875
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 31
Size of group 2: 89
Rank of matrix 1: 10
Rank of matrix 2: 21
Splitting sub-matrix 1 of depth = 2 with 2276 flops
Depth: 3
Loss from simple SVD: 1492.102294921875
Loss for direction: 1492.1014404296875
Loss for direction: 1492.1014404296875
Optimal split found. Proceeding to split sub-matrices.
Size of group 1: 0
Size of group 2: 31
Rank of matrix 1: 0
Rank of matrix 2: 10
No matrices to split.
Splitting sub-matrix 2 of depth = 2 with 7154 flops
Depth: 3
Loss from simpl

In [10]:
total_F2_loss

tensor(4767.0967)

In [11]:
total_actual_flops

16060.0

In [12]:
svd_F2_loss, svd_r1, _, svd_flops, _ = find_optimal_rank_allocation_F2(A, None, 210*80)

In [13]:
svd_F2_loss

tensor(37750.7812)

In [14]:
svd_r1

28

In [15]:
svd_flops

16240