In [1]:
import os

os.chdir("/root/dev/playground/knowledges/peft")
os.getcwd()

'/root/dev/playground/knowledges/peft'

In [2]:
!gpustat

[1m[37maa4173e0a5f2              [m  Mon Jul  8 14:34:55 2024  [1m[30m535.129.03[m
[36m[0][m [34mNVIDIA GeForce RTX 4090[m |[31m 28°C[m, [32m  0 %[m | [36m[1m[33m    2[m / [33m24564[m MB |
[36m[1][m [34mNVIDIA GeForce RTX 4090[m |[31m 32°C[m, [32m  0 %[m | [36m[1m[33m    2[m / [33m24564[m MB |
[36m[2][m [34mNVIDIA GeForce RTX 4090[m |[31m 32°C[m, [32m  0 %[m | [36m[1m[33m    2[m / [33m24564[m MB |
[36m[3][m [34mNVIDIA GeForce RTX 4090[m |[31m 32°C[m, [32m  0 %[m | [36m[1m[33m    2[m / [33m24564[m MB |
[36m[4][m [34mNVIDIA GeForce RTX 4090[m |[31m 32°C[m, [32m  0 %[m | [36m[1m[33m    2[m / [33m24564[m MB |
[36m[5][m [34mNVIDIA GeForce RTX 4090[m |[31m 32°C[m, [32m  0 %[m | [36m[1m[33m    2[m / [33m24564[m MB |


In [3]:
import torch

torch.cuda.set_device(5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Explore LoKr

In [6]:
from typing import Tuple

def factorization(dimension: int, factor: int = -1) -> Tuple[int, int]:
    """Factorizes the provided number into the product of two numbers

    Args:
        dimension (`int`): The number that needs to be factorized.
        factor (`int`, optional):
            Factorization divider. The algorithm will try to output two numbers, one of each will be as close to the
            factor as possible. If -1 is provided, the decomposition algorithm would try to search dividers near the
            square root of the dimension. Defaults to -1.

    Returns:
        Tuple[`int`, `int`]: A tuple of two numbers, whose product is equal to the provided number. The first number is
        always less than or equal to the second.
    """

    if factor > 0 and (dimension % factor) == 0:
        m = factor
        n = dimension // factor
        return m, n
    if factor == -1:
        factor = dimension
    m, n = 1, dimension
    length = m + n
    while m < n:
        new_m = m + 1
        while dimension % new_m != 0:
            new_m += 1
        new_n = dimension // new_m
        if new_m + new_n > length or new_m > factor:
            break
        else:
            m, n = new_m, new_n
    if m > n:
        n, m = m, n
    return m, n

In [7]:
decompose_factor = 2

In [15]:
# p = 128, q = 256
# u_p, u_q for C // v_p for B (v_p, r) // v_q for A (r, v_q)
u_p, v_p = factorization(128, 2)
u_q, v_q = factorization(256, 2)
print(u_p, v_p, u_q, v_q)

u_p, v_p = factorization(128)
u_q, v_q = factorization(256)
print(u_p, v_p, u_q, v_q)

2 64 2 128
8 16 16 16


In [20]:
# in case of linear
in_dim, out_dim = 128, 256

u_p, v_p = factorization(128, 2)
u_q, v_q = factorization(256, 2)

# shape = ((out_l, out_k), (in_m, in_n))
shape = ((u_q, v_q), (u_p, v_p))

In [21]:
r = 128

In [23]:
decompose_both = False

In [26]:
use_w1 = not (decompose_both and r < max(u_q, u_p) / 2)
print("r: ", r, "\tmax(u_q, u_p) / 2: ", max(u_q, u_p) / 2)
print(use_w1)

r:  128 	max(u_q, u_p) / 2:  1.0
True


In [22]:
use_w2 = not (r < max(v_q, v_p) / 2)
print("r: ", r, "\tmax(v_q, v_p) / 2: ", max(v_q, v_p) / 2)
print(use_w2)

r:  128 	max(v_q, v_p) / 2:  64.0
True


In [27]:
def create_adapter_parameters(
    self,
    adapter_name: str,
    r: int,
    shape,
    use_w1: bool,
    use_w2: bool,
    use_effective_conv2d: bool,
):
    if use_w1:
        self.lokr_w1[adapter_name] = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # (u_q, u_p) -> C
    else:
        self.lokr_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0][0], r)) # (u_q, r)
        self.lokr_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][0])) # (r, u_p)

    if len(shape) == 4:
        # Conv2d
        if use_w2:
            self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *shape[2:]))
        elif use_effective_conv2d:
            self.lokr_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3]))
            self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0][1]))  # b, 1-mode
            self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1]))  # d, 2-mode
        else:
            self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r))
            self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1] * shape[2] * shape[3]))
    else:
        # Linear
        if use_w2:
            self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1])) # (v_q, v_p)
        else:
            self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r)) # (v_q, r) -> A
            self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1])) # (r, v_p) -> B