<a href="https://colab.research.google.com/github/NguyenDucAnforwork/NLP_training/blob/main/BERT_with_LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install datasets transformers tqdm

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/547.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━[0m [32m266.2/547.8 kB[0m [31m9.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m38.9 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
Collecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.3-py3-none-an

In [14]:
from datasets import load_dataset

# Load the IMDB dataset
dataset = load_dataset("imdb")

# Select a subset of 5000 examples
train_dataset = dataset['train'].shuffle(seed=42).select(range(800))
test_dataset = dataset['test'].shuffle(seed=42).select(range(200))

# Check the data format
print(train_dataset[0])


{'text': 'There is no relation at all between Fortier and Profiler but the fact that both are police series about violent crimes. Profiler looks crispy, Fortier looks classic. Profiler plots are quite simple. Fortier\'s plot are far more complicated... Fortier looks more like Prime Suspect, if we have to spot similarities... The main character is weak and weirdo, but have "clairvoyance". People like to compare, to judge, to evaluate. How about just enjoying? Funny thing too, people writing Fortier looks American but, on the other hand, arguing they prefer American series (!!!). Maybe it\'s the language, or the spirit, but I think this series is more English than American. By the way, the actors are really good and funny. The acting is not superficial at all...', 'label': 1}


In [15]:
from transformers import BertTokenizer

# Load the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

# Remove the original text column
train_dataset = train_dataset.remove_columns(["text"])
test_dataset = test_dataset.remove_columns(["text"])

# Set the format for PyTorch
train_dataset.set_format("torch")
test_dataset.set_format("torch")


Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [18]:
#@title without LoRA
import torch
import torch.nn as nn
from transformers import BertForSequenceClassification

# Initialize the BERT model
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

from torch.utils.data import DataLoader
from transformers import AdamW
from tqdm.notebook import tqdm

# Create DataLoaders with reduced batch size
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

# Define the optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Mixed precision training
scaler = torch.cuda.amp.GradScaler()

# Training loop
num_epochs = 3
model.train()

for epoch in range(num_epochs):
    total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        with torch.cuda.amp.autocast():
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_train_loss}")

    # Clear memory
    del input_ids, attention_mask, labels, loss
    torch.cuda.empty_cache()




Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1/3, Loss: 0.6941873168945313


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2/3, Loss: 0.4596139144897461


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3/3, Loss: 0.23722649574279786


  0%|          | 0/25 [00:00<?, ?it/s]

KeyError: 'labels'

In [19]:
# Evaluation
model.eval()

correct = 0
total = 0

with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        _, predicted = torch.max(logits, dim=1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy of BERT without LoRA: {accuracy}")

  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy of BERT without LoRA: 0.8


In [21]:
model_path_no_lora = "./bert_no_lora_checkpoint"
model.save_pretrained(model_path_no_lora)

# Save tokenizer associated with this model if needed
tokenizer.save_pretrained(model_path_no_lora)

('./bert_no_lora_checkpoint/tokenizer_config.json',
 './bert_no_lora_checkpoint/special_tokens_map.json',
 './bert_no_lora_checkpoint/vocab.txt',
 './bert_no_lora_checkpoint/added_tokens.json')

In [16]:
import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel, BertConfig

class LoRALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank=8):
        super(LoRALayer, self).__init__()
        self.rank = rank
        self.low_rank_layer_1 = nn.Linear(input_dim, rank, bias=False)
        self.low_rank_layer_2 = nn.Linear(rank, output_dim, bias=False)

    def forward(self, x):
        return self.low_rank_layer_2(self.low_rank_layer_1(x))

class BertWithLoRA(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.lora = LoRALayer(config.hidden_size, config.hidden_size)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # Freeze the BERT parameters
        for param in self.bert.parameters():
            param.requires_grad = False

        self.init_weights()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        sequence_output = outputs[1]
        lora_output = self.lora(sequence_output)
        pooled_output = sequence_output + lora_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))

        return (loss, logits) if loss is not None else logits


In [6]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [7]:
# Initialize the original BERT model
original_model = BertModel.from_pretrained("bert-base-uncased")

# Initialize the BertWithLoRA model
config = BertConfig.from_pretrained("bert-base-uncased")
config.num_labels = 2
lora_model = BertWithLoRA.from_pretrained("bert-base-uncased", config=config)

# Freeze BERT parameters in LoRA model
for param in lora_model.bert.parameters():
    param.requires_grad = False

# Count the parameters
original_params = count_parameters(original_model)
lora_params = count_parameters(lora_model)

# Print the number of trainable parameters
print(f"Number of trainable parameters in original BERT model: {original_params}")
print(f"Number of trainable parameters in BertWithLoRA model: {lora_params}")


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertWithLoRA were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'lora.low_rank_layer_1.weight', 'lora.low_rank_layer_2.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Number of trainable parameters in original BERT model: 109482240
Number of trainable parameters in BertWithLoRA model: 13826


In [22]:
#@title with LoRA
from torch.utils.data import DataLoader
from transformers import AdamW
from tqdm.notebook import tqdm

# Assuming train_loader and test_loader are already defined DataLoader objects
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

# Initialize the model_with_lora
config = BertConfig.from_pretrained("bert-base-uncased")
config.num_labels = 2
model_with_lora = BertWithLoRA(config=config)

# Move the model_with_lora to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_with_lora.to(device)

# Define the optimizer for LoRA and classifier parameters only
optimizer = AdamW(filter(lambda p: p.requires_grad, model_with_lora.parameters()), lr=5e-5)

# Mixed precision training
scaler = torch.cuda.amp.GradScaler()

# Training loop
num_epochs = 10
model_with_lora.train()

for epoch in range(num_epochs):
    total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        with torch.cuda.amp.autocast():
            outputs = model_with_lora(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs[0]

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_train_loss}")

    # Evaluation after each epoch (optional)
    model_with_lora.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(test_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model_with_lora(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs[0]

            # Calculate predictions
            _, predicted = torch.max(logits, dim=-1)  # Use dim=-1 to find max along the last dimension

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f"Test Accuracy after epoch {epoch+1}: {accuracy}")

    model_with_lora.train()  # Set the model_with_lora back to train mode

    # Clear memory
    del input_ids, attention_mask, labels, loss
    torch.cuda.empty_cache()




  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1/10, Loss: 0.6988143920898438


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy after epoch 1: 0.52


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2/10, Loss: 0.70306640625


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy after epoch 2: 0.48


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3/10, Loss: 0.6984735107421876


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy after epoch 3: 0.48


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 4/10, Loss: 0.7048410034179687


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy after epoch 4: 0.52


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 5/10, Loss: 0.698941650390625


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy after epoch 5: 0.52


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 6/10, Loss: 0.70212890625


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy after epoch 6: 0.5


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 7/10, Loss: 0.6989669799804688


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy after epoch 7: 0.48


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 8/10, Loss: 0.7037411499023437


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy after epoch 8: 0.5


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 9/10, Loss: 0.6966702270507813


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy after epoch 9: 0.52


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 10/10, Loss: 0.695791015625


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy after epoch 10: 0.52


In [23]:
model_path_lora = "./bert_no_lora_checkpoint"
model_with_lora.save_pretrained(model_path_lora)

# Save tokenizer associated with this model if needed
tokenizer.save_pretrained(model_path_lora)

('./bert_no_lora_checkpoint/tokenizer_config.json',
 './bert_no_lora_checkpoint/special_tokens_map.json',
 './bert_no_lora_checkpoint/vocab.txt',
 './bert_no_lora_checkpoint/added_tokens.json')

In [None]:
model.eval()

correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs[1]
        _, predicted = torch.max(logits, dim=1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy}")
