In [16]:
%load_ext autoreload
%autoreload 2
%cd ..

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/Users/akkirr/Desktop/IT


In [20]:
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from pathlib import Path

import torch
from torch import nan_to_num
from torchvision import transforms
from torch.utils.data import DataLoader

import numpy as np
from PIL import Image
import requests

from datasets import load_dataset
from torchvision.utils import save_image
from torch.optim import Adam

In [21]:
from mylib import *
from lora import *

In [34]:
class Attention(nn.Module):
    def __init__(self):
        super().__init__()
        self.QKV = nn.Linear(1, 1)
        self.C = nn.Linear(1, 1)
        self.lrelu = nn.LeakyReLU()

    def forward(self, x):
        return self.C(self.lrelu(self.QKV(x)))


class TimeEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_proj = nn.Linear(1, 1)
        self.lrelu = nn.LeakyReLU()

    def forward(self, x):
        return self.lrelu(self.time_proj(x))


class LoraInjected(nn.Module):
    def __init__(self):
        super().__init__()
        self.src_linear = nn.Linear(1, 1)
        self.A = nn.Linear(1, 1)
        self.B = nn.Linear(1, 1)
        self.dropout = nn.Dropout1d()

    def forward(self, x):
        return self.src_linear(x) + self.dropout(self.B(self.A(x)))


class A(nn.Module):
    def __init__(self):
        super().__init__()
        self.just_linear = nn.Linear(1, 1)
        self.attn = Attention()
        self.time_embedder = TimeEmbedding()

    def forward(self, x):
        return self.attn(self.just_linear(x) + self.time_embedder(x))

In [35]:
from torch.optim import Adam
from copy import deepcopy

a = A()
optim = Adam(a.parameters())

In [33]:
[x[2] + '.' + x[3] for x in list(find_modules(
        a,
        ["Attention"],
        [nn.Linear],
        [LoraInjected],
    ))]

				 Attention (<class '__main__.LoraInjected'>,) False
attn. Attention [<class '__main__.LoraInjected'>] False
				 Attention (<class '__main__.LoraInjected'>,) False
	 QKV LoraInjectedLinear
				 LoraInjectedLinear (<class 'torch.nn.modules.linear.Linear'>,) False
	 C LoraInjectedLinear
				 LoraInjectedLinear (<class 'torch.nn.modules.linear.Linear'>,) False
	 lrelu LeakyReLU
				 LeakyReLU (<class 'torch.nn.modules.linear.Linear'>,) False
				 LoraInjectedLinear (<class '__main__.LoraInjected'>,) False
attn.QKV LoraInjectedLinear [<class '__main__.LoraInjected'>] False
				 LoraInjectedLinear (<class '__main__.LoraInjected'>,) False
	 src_linear Linear
				 Linear (<class 'torch.nn.modules.linear.Linear'>,) True
		*
	 lora_down Linear
				 Linear (<class 'torch.nn.modules.linear.Linear'>,) True
		*
	 lora_up Linear
				 Linear (<class 'torch.nn.modules.linear.Linear'>,) True
		*
	 dropout_layer Dropout1d
				 Dropout1d (<class 'torch.nn.modules.linear.Linear'>,) False
				 Linear (

['attn.src_linear',
 'attn.lora_down',
 'attn.lora_up',
 'attn.src_linear',
 'attn.lora_down',
 'attn.lora_up']

In [32]:
inject_lora(
        a,
        2, 0,
        ["Attention"],
        [nn.Linear],
        [LoraInjected],
        verbose=True
    )

				 Attention (<class '__main__.LoraInjected'>,) False
attn. Attention [<class '__main__.LoraInjected'>] False
				 Attention (<class '__main__.LoraInjected'>,) False
	 QKV Linear
				 Linear (<class 'torch.nn.modules.linear.Linear'>,) True
		*
	 C Linear
				 Linear (<class 'torch.nn.modules.linear.Linear'>,) True
		*
	 lrelu LeakyReLU
				 LeakyReLU (<class 'torch.nn.modules.linear.Linear'>,) False
				 Linear (<class '__main__.LoraInjected'>,) False
attn.QKV Linear [<class '__main__.LoraInjected'>] False
				 Linear (<class '__main__.LoraInjected'>,) False
				 Linear (<class '__main__.LoraInjected'>,) False
attn.C Linear [<class '__main__.LoraInjected'>] False
				 Linear (<class '__main__.LoraInjected'>,) False
				 LeakyReLU (<class '__main__.LoraInjected'>,) False
attn.lrelu LeakyReLU [<class '__main__.LoraInjected'>] False
				 LeakyReLU (<class '__main__.LoraInjected'>,) False
Injected lora (1x2x1) in attn.QKV
Injected lora (1x2x1) in attn.C


In [None]:
import torchvision
model = torchvision.models.resnet18()

inject_lora(
        model,
        2, 0,
        ["BasicBlock"],
        [nn.Conv2d],
        [LoraInjectedConv2d],
        verbose=True
    )

				 BasicBlock(
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
) (<class 'lora.lora.LoraInjectedConv2d'>,) False
layer1.0. BasicBlock [<class 'lora.lora.LoraInjectedConv2d'>] False
				 BasicBlock(
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
) (<class 'lora.lora.LoraInjectedConv2d'>,) False
	 conv1 Conv2d
				 Conv2d(6

TypeError: 'BasicBlock' object does not support item assignment

In [None]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  