In [1]:
import argparse
from typing import Optional, Union

import pandas as pd
import numpy as np
import torch
import torch.nn as nn

from dataclasses import dataclass

import datasets
from datasets import Dataset

from sklearn.metrics import log_loss

from transformers import (
    AutoTokenizer,
    AutoConfig,
    EarlyStoppingCallback,
    AutoModelForCausalLM,
    AutoModelForMultipleChoice,
    TrainingArguments,
    Trainer,
    RobertaForMultipleChoice,
    AutoModelForSequenceClassification,
    LlamaModel,
    LlamaForSequenceClassification,
    BitsAndBytesConfig,
    get_polynomial_decay_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    TrainerCallback,
    AutoModel
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy

from peft import (
    get_peft_config,
    PeftModel,
    PeftConfig,
    get_peft_model,
    LoraConfig,
    TaskType,
    prepare_model_for_kbit_training
)
import os

import random

In [3]:
MODEL = 'google/gemma-2-9b-it'

config = AutoConfig.from_pretrained(MODEL, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True, truncation_side = 'left')

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

model = AutoModel.from_pretrained(MODEL,
                                 config=config,
                                 quantization_config=bnb_config,
                                 torch_dtype=torch.bfloat16,
                                 device_map="auto",
                                 trust_remote_code=True,
                                 attn_implementation='eager')
peft_config = LoraConfig(
            task_type=TaskType.SEQ_CLS,  # For sequence classification
            inference_mode=False,
            r=32,
            lora_alpha=64,
            lora_dropout=0.1,
            #bias = 'none',
            target_modules=['q_proj','k_proj','v_proj','o_proj'] #,
        )
model = get_peft_model(model, peft_config)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
model

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): Gemma2Model(
      (embed_tokens): Embedding(256000, 3584, padding_idx=0)
      (layers): ModuleList(
        (0-41): 42 x Gemma2DecoderLayer(
          (self_attn): Gemma2Attention(
            (q_proj): Linear4bit(
              (lora_dropout): ModuleDict(
                (default): Dropout(p=0.1, inplace=False)
              )
              (lora_A): ModuleDict(
                (default): Linear(in_features=3584, out_features=32, bias=False)
              )
              (lora_B): ModuleDict(
                (default): Linear(in_features=32, out_features=4096, bias=False)
              )
              (lora_embedding_A): ParameterDict()
              (lora_embedding_B): ParameterDict()
              (base_layer): Linear4bit(in_features=3584, out_features=4096, bias=False)
            )
            (k_proj): Linear4bit(
              (lora_dropout): ModuleDict(
                (default): Dropout(p=0.1, inplac

In [11]:
t = tokenizer("Hello Who are you",return_tensors="pt")

In [12]:
t

{'input_ids': tensor([[   2, 4521, 7702,  708,  692]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [13]:
output = model(**t)

In [14]:
output.last_hidden_state.shape

torch.Size([1, 5, 3584])

In [35]:
output.loss['logits'].shape

torch.Size([1, 2, 256000])

In [38]:
output.logits.shape

torch.Size([1, 2, 256000])

In [39]:
output.loss['logits']

tensor([[[-26.2500,  17.5000,  21.2500,  ..., -10.4375,  -3.5469, -14.0625],
         [-24.8750,   1.1641, -15.9375,  ..., -14.4375, -11.8125, -25.7500]]],
       device='cuda:0', grad_fn=<ToCopyBackward0>)

In [40]:
output.logits

tensor([[[-26.2500,  17.5000,  21.2500,  ..., -10.4375,  -3.5469, -14.0625],
         [-24.8750,   1.1641, -15.9375,  ..., -14.4375, -11.8125, -25.7500]]],
       device='cuda:0', grad_fn=<ToCopyBackward0>)

In [16]:
model.config.hidden_size

3584

In [5]:
input_ids = torch.randint(0, 100, (32, 50))  # 示例输入
attention_mask = torch.ones((32, 50))  # 示例attention mask

# 确保输入张量在正确的设备上
input_ids = input_ids.to(next(model.parameters()).device)
attention_mask = attention_mask.to(next(model.parameters()).device)

In [6]:
class GemmaLSTM(nn.Module):
    def __init__(self, gemma, lstm_hidden_size, num_classes):
        super(GemmaLSTM, self).__init__()
        self.gemma = gemma
        self.lstm = nn.LSTM(gemma.config.hidden_size, lstm_hidden_size, batch_first=True)
        self.classifier = nn.Linear(lstm_hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        output = self.gemma(input_ids, attention_mask)
        hidden_states = output.last_hidden_state
        print(hidden_states)
        lstm_output, _ = self.lstm(hidden_states)
        lstm_output = lstm_output[:,-1,:]
        logits = self.classifier(lstm_output)
        return logits

In [7]:
custom = GemmaLSTM(model, 1024, 3)

In [8]:
t

NameError: name 't' is not defined

In [6]:
custom(input_ids,attention_mask)

tensor([[[ 1.1406,  2.8438,  0.3203,  ...,  0.3535,  0.0669,  3.7500],
         [ 1.2656,  3.2656,  0.4004,  ...,  0.3711,  0.0369,  3.6875],
         [ 1.1172,  3.0312,  0.6602,  ..., -0.4629,  0.2988,  3.4531],
         ...,
         [ 0.3906,  1.5781, -1.3828,  ..., -1.6953,  1.2500,  2.8750],
         [ 0.0173,  2.0312, -1.2422,  ..., -1.0859,  1.2031,  3.0938],
         [ 0.1260,  2.1719, -1.2109,  ..., -1.2109,  1.6250,  2.8281]],

        [[ 1.2031,  3.5781,  0.7305,  ..., -0.1182,  0.1631,  4.2500],
         [ 1.1719,  3.7031,  0.8047,  ..., -0.2334,  0.0215,  4.3438],
         [ 0.9180,  3.2969,  1.3516,  ..., -0.3672,  0.5586,  3.5469],
         ...,
         [ 2.5156,  5.0312, -1.0000,  ..., -2.2031,  2.2969,  3.2812],
         [ 2.0156,  6.0312, -1.2422,  ..., -1.6250,  2.2188,  3.0938],
         [ 1.7188,  6.7812, -1.5469,  ..., -2.6562,  2.4844,  4.1250]],

        [[ 1.3203,  3.5625,  0.6094,  ...,  0.0215, -0.0214,  4.2500],
         [ 1.3047,  3.8906,  0.8008,  ..., -0

RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cuda:0 and parameter tensor at cpu