In [10]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import BertTokenizer
from loralib.layers import Linear
from loralib.utils import mark_only_lora_as_trainable

upload dataset

In [11]:
dataset = load_dataset("ag_news")
print(dataset["train"][0])

{'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.", 'label': 2}


BERT pre-training

In [12]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # BERT分词器分词

def encode_with_bert(example):
    tokens = tokenizer(
        example["text"],
        truncation=True,
        padding='max_length',
        max_length=128,
    )
    return {
        "input_ids": tokens["input_ids"],
        "attention_mask": tokens["attention_mask"],
        "label": example["label"]
    }

encoded_dataset = dataset.map(encode_with_bert, batched=True) # Hugging Face Datasets库的map函数批量处理数据

encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) # 设置数据格式为PyTorch张量，并指定需要的列


Prepare DataLoader

In [13]:
train_loader = DataLoader(encoded_dataset["train"], batch_size=32, shuffle=True)
test_loader = DataLoader(encoded_dataset["test"], batch_size=32)

Define RNN model with LoRA classification head

In [17]:
class CustomTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.self_attn.q_proj = Linear(d_model, d_model, r=8)
        self.self_attn.v_proj = Linear(d_model, d_model, r=8)
        self.linear1 = nn.Linear(d_model, d_model * 4)
        self.dropout = nn.Dropout(0.1)
        self.linear2 = nn.Linear(d_model * 4, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)

    def forward(self, src, src_mask=None):
        q = self.self_attn.q_proj(src)
        v = self.self_attn.v_proj(src)
        attn_output, _ = self.self_attn(q, src, v, attn_mask=src_mask)
        src = src + self.dropout1(attn_output)
        src = self.norm1(src)
        ff_output = self.linear2(self.dropout(torch.relu(self.linear1(src))))
        src = src + self.dropout2(ff_output)
        src = self.norm2(src)
        return src

In [18]:
# Define Transformer-based model with LoRA on fc layer
class TextTransformerWithLoRA(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.layers = nn.ModuleList([
            CustomTransformerEncoderLayer(embed_dim, num_heads) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, input_ids, attention_mask):
        x = self.embedding(input_ids)
        x = x.masked_fill(attention_mask.unsqueeze(-1) == 0, 0)
        for layer in self.layers:
            x = layer(x)
        return self.fc(x[:, 0, :])


In [19]:
# Initialize model
model = TextTransformerWithLoRA(
    vocab_size=tokenizer.vocab_size,
    embed_dim=128,
    num_heads=4,
    num_layers=2,
    num_classes=4
)
# Load pretrained weights from IMDB except fc
state_dict = torch.load("Transformer_imdb.pth")
filtered_dict = {k: v for k, v in state_dict.items() if not k.startswith("fc")}
model.load_state_dict(filtered_dict, strict=False)

# freeze all except LoRA
mark_only_lora_as_trainable(model)

# Move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 7. Define optimizer and loss
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
criterion = nn.CrossEntropyLoss()


Train model

In [20]:
for epoch in range(20):
    model.train()
    total_loss = 0
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")

Epoch 1, Loss: 1.0820
Epoch 2, Loss: 0.8639
Epoch 3, Loss: 0.7935
Epoch 4, Loss: 0.7462
Epoch 5, Loss: 0.7094
Epoch 6, Loss: 0.6829
Epoch 7, Loss: 0.6623
Epoch 8, Loss: 0.6465
Epoch 9, Loss: 0.6327
Epoch 10, Loss: 0.6223
Epoch 11, Loss: 0.6132
Epoch 12, Loss: 0.6033
Epoch 13, Loss: 0.5948
Epoch 14, Loss: 0.5857
Epoch 15, Loss: 0.5751
Epoch 16, Loss: 0.5661
Epoch 17, Loss: 0.5574
Epoch 18, Loss: 0.5499
Epoch 19, Loss: 0.5418
Epoch 20, Loss: 0.5345


Evaluation

In [21]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        outputs = model(input_ids, attention_mask)
        predicted = torch.argmax(outputs, dim=1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {correct / total:.2%}")



Test Accuracy: 80.78%
