In [7]:
import torch
import torch.nn as nn
import loralib as lora
import math
import torch.nn.functional as F

from models import MultiModalDecoder, calc_params
from transformers import GPT2LMHeadModel, GPT2Model, GPT2Config, GPT2Tokenizer

In [2]:
def set_lora(model):
    for layer in model.h:
        layer.attn.c_attn = Conv1D_LoRA(
            model.config.n_embd, model.config.n_embd * 3, 
            r               =   model.config.lora_attn_dim, 
            lora_alpha      =   model.config.lora_attn_alpha, 
            lora_dropout    =   model.config.lora_dropout, 
            merge_weights   =   model.config.merge_weights,
        )
        
        if model.config.add_cross_attention:
            layer.crossattention.c_attn = Conv1D_LoRA(
                model.config.n_embd, model.config.n_embd * 2, 
                r               =   model.config.lora_attn_dim, 
                lora_alpha      =   model.config.lora_attn_alpha, 
                lora_dropout    =   model.config.lora_dropout, 
                merge_weights   =   model.config.merge_weights,
            )

            layer.crossattention.q_attn = Conv1D_LoRA(
                model.config.n_embd, model.config.n_embd, 
                r               =   model.config.lora_attn_dim, 
                lora_alpha      =   model.config.lora_attn_alpha, 
                lora_dropout    =   model.config.lora_dropout, 
                merge_weights   =   model.config.merge_weights,
            )  
        
    return model

class Conv1D(nn.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
    Basically works like a linear layer but the weights are transposed.
    Args:
        nf (`int`): The number of output features.
        nx (`int`): The number of input features.
    """

    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(size_out)
        return x

class Conv1D_LoRA(Conv1D, lora.LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self, 
        in_channels: int, 
        out_channels: int,
        r: int = 0, 
        lora_alpha: int = 1, 
        lora_dropout: float = 0.,
        merge_weights: bool = True,
        **kwargs
    ):
        Conv1D.__init__(self, out_channels, in_channels, **kwargs)
        lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(
                self.weight.new_zeros((r, in_channels))
            )
            self.lora_B = nn.Parameter(
                self.weight.new_zeros((out_channels, r))
            )
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()

    def reset_parameters(self):
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def train(self, mode: bool = True):
        Conv1D.train(self, mode)
        if self.merge_weights and self.merged:
            # Make sure that the weights are not merged
            self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
            self.merged = False
    
    def eval(self):
        Conv1D.eval(self)
        if self.merge_weights and not self.merged:
            # Merge the weights and mark it
            self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
            self.merged = True

    def forward(self, x: torch.Tensor):
        if self.r > 0 and not self.merged:
            size_out = x.size()[:-1] + (self.nf,)
            result = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
            result = result.view(size_out)
            
            if self.r > 0:
                result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
            return result
        else:
            return F.linear(x, self.weight, bias=self.bias)

In [162]:
config = GPT2Config.from_pretrained('gpt2')
config.lora_attn_dim = 8
config.lora_attn_alpha = 8
config.lora_dropout = 0.1
config.merge_weights = False
config.freeze_pretrained_layers = True
config.add_cross_attention = False

In [164]:
model = GPT2Model(config=config)
model.h[0]

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D()
    (c_proj): Conv1D()
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D()
    (c_proj): Conv1D()
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [168]:
# model.lm_head.weight

In [167]:
model.h[0].attn.c_attn.weight

Parameter containing:
tensor([[-0.0351,  0.0222, -0.0084,  ...,  0.0049,  0.0164,  0.0307],
        [-0.0003,  0.0374, -0.0094,  ..., -0.0145,  0.0292,  0.0094],
        [ 0.0148, -0.0177,  0.0074,  ..., -0.0032,  0.0012, -0.0037],
        ...,
        [-0.0043, -0.0013, -0.0250,  ..., -0.0226,  0.0216, -0.0330],
        [-0.0525,  0.0031, -0.0011,  ..., -0.0094, -0.0126,  0.0471],
        [ 0.0233, -0.0420,  0.0165,  ..., -0.0037, -0.0274, -0.0042]],
       requires_grad=True)

In [166]:
if config.add_cross_attention:
    print(model.h[0].crossattention.c_attn.weight)

# GPT2

In [169]:
model = GPT2Model.from_pretrained('gpt2', config=config)
model.h[0]

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D()
    (c_proj): Conv1D()
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D()
    (c_proj): Conv1D()
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [170]:
# model.lm_head.weight

In [171]:
model.h[0].attn.c_attn.weight

Parameter containing:
tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.0513, -0.0584,  0.0250],
        [ 0.0874,  0.1473,  0.2387,  ..., -0.0525, -0.0113, -0.0156],
        [ 0.0039,  0.0695,  0.3668,  ...,  0.1143,  0.0363, -0.0318],
        ...,
        [-0.2592, -0.0164,  0.1991,  ...,  0.0095, -0.0516,  0.0319],
        [ 0.1517,  0.2170,  0.1043,  ...,  0.0293, -0.0429, -0.0475],
        [-0.4100, -0.1924, -0.2400,  ..., -0.0046,  0.0070,  0.0198]],
       requires_grad=True)

In [172]:
if config.add_cross_attention:
    print(model.h[0].crossattention.c_attn.weight)

# LoRA

In [89]:
config = GPT2Config.from_pretrained('gpt2-large')
config.lora_attn_dim = 8
config.lora_attn_alpha = 8
config.lora_dropout = 0.1
config.merge_weights = False
config.freeze_pretrained_layers = True
config.add_cross_attention = True

In [90]:
model = GPT2Model(config=config)
model = set_lora(model)
model.h[0]

GPT2Block(
  (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D_LoRA(
      (lora_dropout): Dropout(p=0.1, inplace=False)
    )
    (c_proj): Conv1D()
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  (crossattention): GPT2Attention(
    (c_attn): Conv1D_LoRA(
      (lora_dropout): Dropout(p=0.1, inplace=False)
    )
    (q_attn): Conv1D_LoRA(
      (lora_dropout): Dropout(p=0.1, inplace=False)
    )
    (c_proj): Conv1D()
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_cross_attn): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D()
    (c_proj): Conv1D()
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [91]:
model.load_state_dict(GPT2Model.from_pretrained('gpt2-large', config=config).state_dict(), strict=False)

Some weights of GPT2Model were not initialized from the model checkpoint at gpt2-large and are newly initialized: ['h.23.crossattention.q_attn.weight', 'h.14.crossattention.c_proj.weight', 'h.13.ln_cross_attn.weight', 'h.29.crossattention.masked_bias', 'h.17.crossattention.q_attn.weight', 'h.1.crossattention.q_attn.bias', 'h.24.crossattention.q_attn.weight', 'h.11.crossattention.masked_bias', 'h.9.crossattention.c_attn.weight', 'h.18.ln_cross_attn.weight', 'h.21.crossattention.c_attn.bias', 'h.5.crossattention.q_attn.weight', 'h.0.ln_cross_attn.weight', 'h.3.crossattention.bias', 'h.7.crossattention.q_attn.weight', 'h.20.crossattention.masked_bias', 'h.33.crossattention.c_proj.bias', 'h.32.crossattention.c_attn.weight', 'h.3.crossattention.q_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.17.crossattention.c_proj.bias', 'h.26.crossattention.masked_bias', 'h.31.crossattention.c_proj.weight', 'h.8.crossattention.c_attn.weight', 'h.14.crossattention.c_attn.bias', 'h.26.crossattention.

_IncompatibleKeys(missing_keys=['h.0.attn.c_attn.lora_A', 'h.0.attn.c_attn.lora_B', 'h.0.crossattention.c_attn.lora_A', 'h.0.crossattention.c_attn.lora_B', 'h.0.crossattention.q_attn.lora_A', 'h.0.crossattention.q_attn.lora_B', 'h.1.attn.c_attn.lora_A', 'h.1.attn.c_attn.lora_B', 'h.1.crossattention.c_attn.lora_A', 'h.1.crossattention.c_attn.lora_B', 'h.1.crossattention.q_attn.lora_A', 'h.1.crossattention.q_attn.lora_B', 'h.2.attn.c_attn.lora_A', 'h.2.attn.c_attn.lora_B', 'h.2.crossattention.c_attn.lora_A', 'h.2.crossattention.c_attn.lora_B', 'h.2.crossattention.q_attn.lora_A', 'h.2.crossattention.q_attn.lora_B', 'h.3.attn.c_attn.lora_A', 'h.3.attn.c_attn.lora_B', 'h.3.crossattention.c_attn.lora_A', 'h.3.crossattention.c_attn.lora_B', 'h.3.crossattention.q_attn.lora_A', 'h.3.crossattention.q_attn.lora_B', 'h.4.attn.c_attn.lora_A', 'h.4.attn.c_attn.lora_B', 'h.4.crossattention.c_attn.lora_A', 'h.4.crossattention.c_attn.lora_B', 'h.4.crossattention.q_attn.lora_A', 'h.4.crossattention.q_at

In [92]:
# model.lm_head.weight

In [93]:
model.wte.weight

Parameter containing:
tensor([[-0.0149, -0.0209,  0.0021,  ...,  0.0336, -0.0005, -0.0090],
        [ 0.0055, -0.0438,  0.0013,  ...,  0.0671,  0.0329, -0.0399],
        [ 0.0585,  0.0603,  0.0302,  ..., -0.1041, -0.0566, -0.0330],
        ...,
        [-0.0108, -0.0908,  0.0624,  ..., -0.0337,  0.0777,  0.0293],
        [ 0.0195, -0.0318,  0.0182,  ...,  0.0191, -0.0454, -0.0139],
        [-0.0419,  0.0848, -0.0512,  ..., -0.0083, -0.0447, -0.0274]],
       requires_grad=True)

In [94]:
model.h[0].attn.c_attn.weight

Parameter containing:
tensor([[ 0.1655,  0.1230,  0.1003,  ..., -0.0081,  0.0106, -0.0183],
        [-0.2344,  0.1413,  0.0706,  ..., -0.0105,  0.0239, -0.0101],
        [ 0.1063, -0.0397,  0.1085,  ..., -0.0042,  0.0183, -0.0080],
        ...,
        [ 0.0020,  0.1257, -0.0798,  ...,  0.0024,  0.0351,  0.0204],
        [-0.1146, -0.0897, -0.0925,  ...,  0.0013,  0.0007, -0.0041],
        [ 0.0222, -0.0171, -0.0463,  ...,  0.0290,  0.0258, -0.0327]])

In [95]:
if config.add_cross_attention:
    print(model.h[0].crossattention.c_attn.weight)

Parameter containing:
tensor([[-0.0149, -0.0345,  0.0357,  ..., -0.0123, -0.0229, -0.0260],
        [-0.0354,  0.0232, -0.0051,  ..., -0.0033,  0.0278,  0.0222],
        [-0.0041,  0.0141,  0.0063,  ..., -0.0153, -0.0293, -0.0105],
        ...,
        [-0.0215,  0.0144,  0.0012,  ...,  0.0281,  0.0139, -0.0212],
        [-0.0192,  0.0033,  0.0161,  ...,  0.0365,  0.0213,  0.0209],
        [ 0.0221, -0.0083, -0.0060,  ...,  0.0314, -0.0185,  0.0160]])


In [96]:
model.h[0].attn.c_attn.lora_A

Parameter containing:
tensor([[-0.0088,  0.0222,  0.0117,  ..., -0.0116, -0.0125,  0.0190],
        [-0.0096,  0.0007,  0.0225,  ..., -0.0149, -0.0054, -0.0150],
        [-0.0007, -0.0119, -0.0269,  ...,  0.0207, -0.0078,  0.0135],
        ...,
        [-0.0053,  0.0191, -0.0040,  ..., -0.0259,  0.0180, -0.0130],
        [-0.0069,  0.0077,  0.0231,  ...,  0.0128, -0.0095,  0.0179],
        [-0.0258,  0.0083,  0.0131,  ..., -0.0107, -0.0210, -0.0210]],
       requires_grad=True)

In [97]:
model.h[0].attn.c_attn.lora_B

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)

In [98]:
if config.add_cross_attention:
    print(model.h[0].crossattention.c_attn.lora_A)

Parameter containing:
tensor([[-2.7540e-02, -1.3705e-02, -2.5356e-02,  ..., -1.5621e-02,
          1.7183e-02,  4.5410e-03],
        [-2.1208e-02,  2.2249e-02,  8.3704e-03,  ...,  2.0344e-02,
          3.9501e-05,  1.4097e-02],
        [ 1.7279e-02,  7.2668e-03, -2.3049e-02,  ..., -1.6981e-02,
         -1.6657e-02, -2.2538e-02],
        ...,
        [-1.2725e-02, -1.8495e-02, -6.5629e-03,  ..., -1.6573e-02,
         -5.2309e-03,  7.3493e-03],
        [ 2.1535e-02,  1.9177e-02,  2.5598e-02,  ..., -1.7152e-02,
          1.9801e-02, -6.6687e-03],
        [ 2.4029e-02,  2.2541e-02, -8.1920e-04,  ...,  2.2911e-02,
         -8.7083e-03, -2.4797e-02]], requires_grad=True)


In [99]:
if config.add_cross_attention:
    print(model.h[0].crossattention.c_attn.lora_B)

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)


In [100]:
if config.add_cross_attention:
    print(model.h[0].crossattention.q_attn.lora_A)

Parameter containing:
tensor([[ 0.0151,  0.0266, -0.0242,  ..., -0.0091, -0.0085,  0.0243],
        [-0.0064,  0.0069,  0.0012,  ..., -0.0268,  0.0254,  0.0026],
        [ 0.0011,  0.0042,  0.0096,  ..., -0.0152, -0.0240, -0.0101],
        ...,
        [ 0.0237, -0.0232, -0.0014,  ...,  0.0060,  0.0163,  0.0102],
        [-0.0102, -0.0076, -0.0234,  ...,  0.0265,  0.0042, -0.0258],
        [-0.0106,  0.0184, -0.0261,  ..., -0.0229, -0.0084, -0.0051]],
       requires_grad=True)


In [101]:
if config.add_cross_attention:
    print(model.h[0].crossattention.q_attn.lora_B)

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)


In [102]:
param = 0
for p in model.parameters():
    if p.requires_grad:
        param += p.numel()

In [103]:
param

659659520

In [104]:
lora.mark_only_lora_as_trainable(model)

In [105]:
param = 0
for p in model.parameters():
    if p.requires_grad:
        param += p.numel()

In [106]:
param

3317760

In [116]:
output = model(torch.LongTensor([[3,2,1,3],[5,1,2312,1]]))

In [118]:
output['last_hidden_state']

tensor([[[-2.6968e-01,  1.9771e-01,  8.4906e-02,  ..., -4.3974e-01,
          -1.7160e-01, -1.4506e+00],
         [-1.8743e-01, -3.9547e-02, -4.5775e-01,  ...,  1.6749e-01,
          -1.7147e-01, -9.6216e-01],
         [ 1.8001e-01, -3.7222e-02,  4.5877e-01,  ..., -2.7406e-01,
           1.0326e-01, -1.3773e+00],
         [-1.1983e-01, -9.5554e-01, -4.7077e-02,  ..., -4.0969e-01,
           1.9316e-01, -2.1377e+00]],

        [[-2.5236e-01, -1.0755e-01, -1.0459e-01,  ..., -4.0032e-01,
           2.6567e-02, -1.8530e+00],
         [-2.7021e-01, -3.1398e-04,  5.7317e-02,  ..., -3.8848e-01,
          -5.2567e-01, -2.0494e+00],
         [ 2.5673e-02, -2.2973e-01,  2.4588e-01,  ..., -2.3121e-01,
           3.8505e-01, -8.1887e-01],
         [-3.3591e-02, -3.3863e-01,  1.6547e-02,  ..., -5.7154e-01,
          -1.5595e-01, -1.8523e+00]]], grad_fn=<ViewBackward>)

# Multimodal

In [8]:
config = GPT2Config.from_pretrained('gpt2')
config.lora_attn_dim = 8
config.lora_attn_alpha = 8
config.lora_dropout = 0.1
config.merge_weights = False
config.freeze_pretrained_layers = True
config.add_cross_attention = True

multimodal_decoder = MultiModalDecoder(config=config)
multimodal_decoder = set_lora(multimodal_decoder)
multimodal_decoder.load_state_dict(
    GPT2Model.from_pretrained('gpt2', config=config).state_dict(), 
    strict=False
)

print('Build a Multimodal Decoder with LoRA')
before_param = calc_params(multimodal_decoder)

# set lora grad
lora.mark_only_lora_as_trainable(multimodal_decoder)

after_param = calc_params(multimodal_decoder)
print('Trainable parameters of a Multimodal Decoder change {} to {}'.format(before_param, after_param))

Some weights of GPT2Model were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.3.ln_cross_attn.bias', 'h.5.crossattention.c_attn.bias', 'h.11.crossattention.bias', 'h.10.crossattention.q_attn.weight', 'h.0.crossattention.c_proj.weight', 'h.1.ln_cross_attn.weight', 'h.3.crossattention.q_attn.bias', 'h.9.ln_cross_attn.weight', 'h.5.crossattention.bias', 'h.0.ln_cross_attn.bias', 'h.7.crossattention.c_attn.bias', 'h.6.ln_cross_attn.weight', 'h.7.crossattention.q_attn.weight', 'h.9.ln_cross_attn.bias', 'h.8.ln_cross_attn.weight', 'h.5.crossattention.c_attn.weight', 'h.0.crossattention.masked_bias', 'h.8.crossattention.c_proj.weight', 'h.11.ln_cross_attn.weight', 'h.1.crossattention.q_attn.bias', 'h.9.crossattention.q_attn.bias', 'h.1.crossattention.c_proj.weight', 'h.6.crossattention.q_attn.bias', 'h.5.crossattention.q_attn.weight', 'h.3.crossattention.masked_bias', 'h.10.crossattention.c_proj.bias', 'h.6.crossattention.q_attn.weight', 'h.11.crossattention.

Build a Multimodal Decoder with LoRA
Trainable parameters of a Multimodal Decoder change 111002880 to 663552


In [12]:
config = GPT2Config.from_pretrained('gpt2')
config.lora_attn_dim = 8
config.lora_attn_alpha = 8
config.lora_dropout = 0.1
config.merge_weights = False
config.freeze_pretrained_layers = True
config.add_cross_attention = False

unimodal_decoder = GPT2Model(config=config)
unimodal_decoder = set_lora(unimodal_decoder)
unimodal_decoder.load_state_dict(
    GPT2Model.from_pretrained('gpt2', config=config).state_dict(), 
    strict=False
)

print('Build a Unimodal Decoder with LoRA')
before_param = calc_params(unimodal_decoder)

# set lora grad
lora.mark_only_lora_as_trainable(unimodal_decoder)

after_param = calc_params(unimodal_decoder)
print('Trainable parameters of a Unimodal Decoder change {} to {}'.format(before_param, after_param))

Build a Unimodal Decoder with LoRA
Trainable parameters of a Unimodal Decoder change 103501056 to 294912


In [13]:
output = unimodal_decoder(torch.LongTensor([[123,52,512],[7345,34,124]]))

In [20]:
multimodal_decoder(
    hidden_states         = output['last_hidden_state'],
    attention_mask        = torch.ones((2,3)),
    encoder_hidden_states = torch.randn(2,4,768)
)

BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[ 0.4528,  1.8306, -0.5822,  ..., -0.1598,  0.6187,  0.1820],
         [ 0.2870,  0.2701, -0.6151,  ..., -0.0373,  0.6674,  0.2033],
         [ 0.3178,  0.7613, -0.2611,  ...,  0.0581,  0.7786, -0.0921]],

        [[ 0.0738,  1.6187,  0.1168,  ..., -0.2524, -0.0491, -0.2290],
         [ 0.4228,  1.5954,  0.2552,  ..., -0.4613,  0.0583, -0.2242],
         [ 0.7592,  1.3706,  0.0426,  ..., -0.6588,  0.1929, -0.0444]]],
       grad_fn=<ViewBackward>), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None)