In [57]:
import torch
import torch.nn as nn
import loralib as lora
import math
from transformers import GPT2LMHeadModel, GPT2Model, GPT2Config, GPT2Tokenizer

In [97]:
def set_lora(model):
    for layer in model.transformer.h:
        layer.attn.c_attn = Conv1D_LoRA(
            model.config.n_embd, model.config.n_embd * 3, 
            r               =   model.config.lora_attn_dim, 
            lora_alpha      =   model.config.lora_attn_alpha, 
            lora_dropout    =   model.config.lora_dropout, 
            merge_weights   =   model.config.merge_weights,
        )
        
        if model.config.add_cross_attention:
            layer.crossattention.c_attn = Conv1D_LoRA(
                model.config.n_embd, model.config.n_embd * 2, 
                r               =   model.config.lora_attn_dim, 
                lora_alpha      =   model.config.lora_attn_alpha, 
                lora_dropout    =   model.config.lora_dropout, 
                merge_weights   =   model.config.merge_weights,
            )

            layer.crossattention.q_attn = Conv1D_LoRA(
                model.config.n_embd, model.config.n_embd, 
                r               =   model.config.lora_attn_dim, 
                lora_alpha      =   model.config.lora_attn_alpha, 
                lora_dropout    =   model.config.lora_dropout, 
                merge_weights   =   model.config.merge_weights,
            )  
        
    return model

class Conv1D(nn.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
    Basically works like a linear layer but the weights are transposed.
    Args:
        nf (`int`): The number of output features.
        nx (`int`): The number of input features.
    """

    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(size_out)
        return x

class Conv1D_LoRA(Conv1D, lora.LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self, 
        in_channels: int, 
        out_channels: int,
        r: int = 0, 
        lora_alpha: int = 1, 
        lora_dropout: float = 0.,
        merge_weights: bool = True,
        **kwargs
    ):
        Conv1D.__init__(self, out_channels, in_channels, **kwargs)
        lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(
                self.weight.new_zeros((r, in_channels))
            )
            self.lora_B = nn.Parameter(
                self.weight.new_zeros((out_channels, r))
            )
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()

    def reset_parameters(self):
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def train(self, mode: bool = True):
        Conv1D.train(self, mode)
        if self.merge_weights and self.merged:
            # Make sure that the weights are not merged
            self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
            self.merged = False
    
    def eval(self):
        Conv1D.eval(self)
        if self.merge_weights and not self.merged:
            # Merge the weights and mark it
            self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
            self.merged = True

    def forward(self, x: torch.Tensor):
        if self.r > 0 and not self.merged:
            return F.Conv1d(
                x, 
                self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
                self.bias, self.stride, self.padding, self.dilation, self.groups
            )
        return Conv1D.forward(self, x)

In [136]:
config = GPT2Config.from_pretrained('gpt2')
config.lora_attn_dim = 8
config.lora_attn_alpha = 8
config.lora_dropout = 0.1
config.merge_weights = False
config.freeze_pretrained_layers = True
config.add_cross_attention = False

In [137]:
model = GPT2LMHeadModel(config=config)
model.transformer.h[0]

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D()
    (c_proj): Conv1D()
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D()
    (c_proj): Conv1D()
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [138]:
model.lm_head.weight

Parameter containing:
tensor([[ 3.4167e-03, -3.9373e-02, -7.9190e-04,  ...,  3.7932e-02,
         -2.2393e-02, -1.8256e-02],
        [ 6.2204e-03,  2.1795e-02,  1.6698e-03,  ...,  1.2985e-02,
         -2.4400e-02, -1.3343e-02],
        [-6.9144e-03, -2.8654e-02,  3.1401e-02,  ...,  1.5245e-02,
          2.6001e-02, -5.7864e-04],
        ...,
        [ 4.2912e-02,  2.5906e-02, -1.0334e-02,  ..., -1.6790e-02,
          1.6977e-02, -3.6761e-02],
        [ 1.1425e-02, -1.7903e-05,  5.6524e-03,  ...,  2.7493e-02,
         -1.1056e-02,  1.7742e-02],
        [ 6.8175e-03, -1.0026e-02, -7.7820e-03,  ...,  2.8702e-03,
          3.3572e-03, -3.0546e-02]], requires_grad=True)

In [139]:
model.transformer.h[0].attn.c_attn.weight

Parameter containing:
tensor([[-0.0045, -0.0025, -0.0126,  ...,  0.0088,  0.0105,  0.0001],
        [-0.0025, -0.0063,  0.0039,  ...,  0.0008,  0.0022,  0.0093],
        [ 0.0102,  0.0097,  0.0131,  ..., -0.0013,  0.0159,  0.0037],
        ...,
        [ 0.0159,  0.0258, -0.0146,  ...,  0.0278, -0.0046, -0.0224],
        [-0.0096,  0.0167, -0.0005,  ..., -0.0083, -0.0074, -0.0398],
        [ 0.0262, -0.0165, -0.0196,  ..., -0.0243,  0.0262,  0.0109]],
       requires_grad=True)

In [140]:
if config.add_cross_attention:
    print(model.transformer.h[0].crossattention.c_attn.weight)

# GPT2

In [141]:
model = GPT2LMHeadModel.from_pretrained('gpt2', config=config)
model.transformer.h[0]

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D()
    (c_proj): Conv1D()
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D()
    (c_proj): Conv1D()
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [142]:
model.lm_head.weight

Parameter containing:
tensor([[-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453],
        [ 0.0403, -0.0486,  0.0462,  ...,  0.0861,  0.0025,  0.0432],
        [-0.1275,  0.0479,  0.1841,  ...,  0.0899, -0.1297, -0.0879],
        ...,
        [-0.0445, -0.0548,  0.0123,  ...,  0.1044,  0.0978, -0.0695],
        [ 0.1860,  0.0167,  0.0461,  ..., -0.0963,  0.0785, -0.0225],
        [ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207]],
       requires_grad=True)

In [143]:
model.transformer.h[0].attn.c_attn.weight

Parameter containing:
tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.0513, -0.0584,  0.0250],
        [ 0.0874,  0.1473,  0.2387,  ..., -0.0525, -0.0113, -0.0156],
        [ 0.0039,  0.0695,  0.3668,  ...,  0.1143,  0.0363, -0.0318],
        ...,
        [-0.2592, -0.0164,  0.1991,  ...,  0.0095, -0.0516,  0.0319],
        [ 0.1517,  0.2170,  0.1043,  ...,  0.0293, -0.0429, -0.0475],
        [-0.4100, -0.1924, -0.2400,  ..., -0.0046,  0.0070,  0.0198]],
       requires_grad=True)

In [144]:
if config.add_cross_attention:
    print(model.transformer.h[0].crossattention.c_attn.weight)

# LoRA

In [145]:
model = GPT2LMHeadModel(config=config)
model = set_lora(model)
model.transformer.h[0]

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D_LoRA(
      (lora_dropout): Dropout(p=0.1, inplace=False)
    )
    (c_proj): Conv1D()
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D()
    (c_proj): Conv1D()
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [146]:
model.load_state_dict(GPT2LMHeadModel.from_pretrained('gpt2', config=config).state_dict(), strict=False)

_IncompatibleKeys(missing_keys=['transformer.h.0.attn.c_attn.lora_A', 'transformer.h.0.attn.c_attn.lora_B', 'transformer.h.1.attn.c_attn.lora_A', 'transformer.h.1.attn.c_attn.lora_B', 'transformer.h.2.attn.c_attn.lora_A', 'transformer.h.2.attn.c_attn.lora_B', 'transformer.h.3.attn.c_attn.lora_A', 'transformer.h.3.attn.c_attn.lora_B', 'transformer.h.4.attn.c_attn.lora_A', 'transformer.h.4.attn.c_attn.lora_B', 'transformer.h.5.attn.c_attn.lora_A', 'transformer.h.5.attn.c_attn.lora_B', 'transformer.h.6.attn.c_attn.lora_A', 'transformer.h.6.attn.c_attn.lora_B', 'transformer.h.7.attn.c_attn.lora_A', 'transformer.h.7.attn.c_attn.lora_B', 'transformer.h.8.attn.c_attn.lora_A', 'transformer.h.8.attn.c_attn.lora_B', 'transformer.h.9.attn.c_attn.lora_A', 'transformer.h.9.attn.c_attn.lora_B', 'transformer.h.10.attn.c_attn.lora_A', 'transformer.h.10.attn.c_attn.lora_B', 'transformer.h.11.attn.c_attn.lora_A', 'transformer.h.11.attn.c_attn.lora_B'], unexpected_keys=[])

In [147]:
model.lm_head.weight

Parameter containing:
tensor([[-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453],
        [ 0.0403, -0.0486,  0.0462,  ...,  0.0861,  0.0025,  0.0432],
        [-0.1275,  0.0479,  0.1841,  ...,  0.0899, -0.1297, -0.0879],
        ...,
        [-0.0445, -0.0548,  0.0123,  ...,  0.1044,  0.0978, -0.0695],
        [ 0.1860,  0.0167,  0.0461,  ..., -0.0963,  0.0785, -0.0225],
        [ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207]],
       requires_grad=True)

In [148]:
model.transformer.wte.weight

Parameter containing:
tensor([[-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453],
        [ 0.0403, -0.0486,  0.0462,  ...,  0.0861,  0.0025,  0.0432],
        [-0.1275,  0.0479,  0.1841,  ...,  0.0899, -0.1297, -0.0879],
        ...,
        [-0.0445, -0.0548,  0.0123,  ...,  0.1044,  0.0978, -0.0695],
        [ 0.1860,  0.0167,  0.0461,  ..., -0.0963,  0.0785, -0.0225],
        [ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207]],
       requires_grad=True)

In [149]:
model.transformer.h[0].attn.c_attn.weight

Parameter containing:
tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.0513, -0.0584,  0.0250],
        [ 0.0874,  0.1473,  0.2387,  ..., -0.0525, -0.0113, -0.0156],
        [ 0.0039,  0.0695,  0.3668,  ...,  0.1143,  0.0363, -0.0318],
        ...,
        [-0.2592, -0.0164,  0.1991,  ...,  0.0095, -0.0516,  0.0319],
        [ 0.1517,  0.2170,  0.1043,  ...,  0.0293, -0.0429, -0.0475],
        [-0.4100, -0.1924, -0.2400,  ..., -0.0046,  0.0070,  0.0198]])

In [150]:
if config.add_cross_attention:
    print(model.transformer.h[0].crossattention.c_attn.weight)

In [151]:
model.transformer.h[0].attn.c_attn.lora_A

Parameter containing:
tensor([[ 0.0356,  0.0222,  0.0258,  ..., -0.0347, -0.0331,  0.0075],
        [ 0.0211,  0.0300, -0.0234,  ...,  0.0281, -0.0070,  0.0213],
        [-0.0181, -0.0181,  0.0236,  ..., -0.0028, -0.0094, -0.0066],
        ...,
        [-0.0266, -0.0205,  0.0049,  ..., -0.0171, -0.0159,  0.0267],
        [-0.0302,  0.0024,  0.0270,  ...,  0.0355,  0.0327, -0.0252],
        [ 0.0339, -0.0072, -0.0233,  ..., -0.0155, -0.0185, -0.0092]],
       requires_grad=True)

In [152]:
model.transformer.h[0].attn.c_attn.lora_B

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)

In [153]:
if config.add_cross_attention:
    print(model.transformer.h[0].crossattention.c_attn.lora_A)

In [154]:
if config.add_cross_attention:
    print(model.transformer.h[0].crossattention.c_attn.lora_B)

In [155]:
if config.add_cross_attention:
    print(model.transformer.h[0].crossattention.q_attn.lora_A)

In [156]:
if config.add_cross_attention:
    print(model.transformer.h[0].crossattention.q_attn.lora_B)

In [157]:
param = 0
for p in model.parameters():
    if p.requires_grad:
        param += p.numel()

In [158]:
param

103501056

In [159]:
lora.mark_only_lora_as_trainable(model)

In [160]:
param = 0
for p in model.parameters():
    if p.requires_grad:
        param += p.numel()

In [161]:
param

294912