<a href="https://colab.research.google.com/github/NguyenDucAnforwork/NLP_training/blob/main/bert_using_lora.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets transformers

In [9]:
#@title load libraries
from datasets import load_dataset
from transformers import BertTokenizer
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
import numpy as np
from random import randrange, shuffle, randint, random
import math

In [15]:
#@title Load dataset
dataset = load_dataset("imdb").shuffle(seed=42)
train_dataset = dataset['train'].shuffle(seed=42).select(range(200))
test_dataset = dataset['test'].shuffle(seed=42).select(range(200))

# Load pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenization and Padding function
def tokenize_and_pad(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=32, return_token_type_ids=True)

# Map datasets to include input_ids and attention_mask
train_tokenized_dataset = train_dataset.map(tokenize_and_pad, batched=True)
test_tokenized_dataset = test_dataset.map(tokenize_and_pad, batched=True)



class IMDbDataset(Dataset):
    def __init__(self, tokenized_dataset):
        self.input_ids = tokenized_dataset['input_ids']
        self.attention_mask = tokenized_dataset['attention_mask']
        self.token_type_ids = tokenized_dataset['token_type_ids']
        self.labels = tokenized_dataset['label']

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        item = {
            'input_ids': torch.tensor(self.input_ids[idx]),
            'attention_mask': torch.tensor(self.attention_mask[idx]),
            'token_type_ids': torch.tensor(self.token_type_ids[idx]),
            'labels': torch.tensor(self.labels[idx])
        }
        return item
# train_dataset = dataset['train'].shuffle(seed=42).select(range(800))

train_dataset = IMDbDataset(train_tokenized_dataset)
test_dataset = IMDbDataset(test_tokenized_dataset)


Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [14]:
#@title BERT from scratch
from dataclasses import dataclass
from typing import Tuple, TypeVar

import torch
import torch.nn as nn
from torch.nn import functional as F


@dataclass
class BertConfig:

    max_seq_length: int = 32
    vocab_size: int = 30522
    n_layers: int = 12
    n_heads: Tuple[int] = (12,) * n_layers
    emb_size: int = 768
    intermediate_size: int = emb_size * 4
    dropout: float = 0.1
    n_classes: int = 2
    layer_norm_eps: float = 1e-12
    pad_token_id: int = 103
    return_pooler_output: bool = False


class BertEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.word_embeddings = nn.Embedding(config.vocab_size, config.emb_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_seq_length, config.emb_size)
        self.token_type_embeddings = nn.Embedding(2, config.emb_size)
        self.LayerNorm = nn.LayerNorm(config.emb_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.dropout)

        # position ids (used in the pos_emb lookup table) that we do not want updated through backpropogation
        self.register_buffer("position_ids", torch.arange(config.max_seq_length).expand((1, -1)))

    def forward(self, input_ids, token_type_ids):
        word_emb = self.word_embeddings(input_ids)
        pos_emb = self.position_embeddings(self.position_ids)
        type_emb = self.token_type_embeddings(token_type_ids)

        emb = word_emb + pos_emb + type_emb
        emb = self.LayerNorm(emb)
        emb = self.dropout(emb)
        return emb


class BertSelfAttention(nn.Module):
    def __init__(self, config, layer_i):
        super().__init__()
        self.config = config
        self.n_heads = config.n_heads[layer_i]
        self.head_size = config.emb_size // self.n_heads
        self.query = nn.Linear(config.emb_size, config.emb_size)
        self.key = nn.Linear(config.emb_size, config.emb_size)
        self.value = nn.Linear(config.emb_size, config.emb_size)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, emb, att_mask):
        B, T, C = emb.shape  # batch size, sequence length, embedding size

        q = self.query(emb).view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        k = self.key(emb).view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        v = self.value(emb).view(B, T, self.n_heads, self.head_size).transpose(1, 2)

        weights = q @ k.transpose(-2, -1) * self.head_size**-0.5

        # set the pad tokens to -inf so that they equal zero after softmax
        if att_mask != None:
            att_mask = (att_mask > 0).unsqueeze(1).repeat(1, att_mask.size(1), 1).unsqueeze(1)
            weights = weights.masked_fill(att_mask == 0, float('-inf'))

        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)

        emb_rich = weights @ v
        emb_rich = emb_rich.transpose(1, 2).contiguous().view(B, T, C)
        return emb_rich


class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.emb_size, config.emb_size)
        self.dropout = nn.Dropout(config.dropout)
        self.LayerNorm = nn.LayerNorm(config.emb_size, eps=config.layer_norm_eps)

    def forward(self, emb_rich, emb):
        x = self.dense(emb_rich)
        x = self.dropout(x)
        x = x + emb
        out = self.LayerNorm(x)
        return out


class BertAttention(nn.Module):
    def __init__(self, config, layer_i):
        super().__init__()
        self.self = BertSelfAttention(config, layer_i)
        self.output = BertSelfOutput(config)

    def forward(self, emb, att_mask):
        emb_rich = self.self(emb, att_mask)
        out = self.output(emb_rich, emb)
        return out


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.emb_size, config.intermediate_size)
        self.gelu = nn.GELU()

    def forward(self, att_out):
        x = self.dense(att_out)
        out = self.gelu(x)
        return out


class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.emb_size)
        self.dropout = nn.Dropout(config.dropout)
        self.LayerNorm = nn.LayerNorm(config.emb_size, eps=config.layer_norm_eps)

    def forward(self, intermediate_out, att_out):
        x = self.dense(intermediate_out)
        x = self.dropout(x)
        x = x + att_out
        out = self.LayerNorm(x)
        return out


class BertLayer(nn.Module):
    def __init__(self, config, layer_i):
        super().__init__()
        self.attention = BertAttention(config, layer_i)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(self, emb, att_mask):
        att_out = self.attention(emb, att_mask)
        intermediate_out = self.intermediate(att_out)
        out = self.output(intermediate_out, att_out)
        return out


class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer = nn.ModuleList([BertLayer(config, layer_i) for layer_i in range(config.n_layers)])

    def forward(self, emb, att_mask):
        for bert_layer in self.layer:
            emb = bert_layer(emb, att_mask)
        return emb


class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.emb_size, config.emb_size)
        self.tanh = nn.Tanh()

    def forward(self, encoder_out):
        pool_first_token = encoder_out[:, 0]
        out = self.dense(pool_first_token)
        out = self.tanh(out)
        return out


class BertModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)

    def forward(self, input_ids, token_type_ids, att_mask):
        emb = self.embeddings(input_ids, token_type_ids)
        out = self.encoder(emb, att_mask)
        pooled_out = self.pooler(out)
        return out, pooled_out


class BertForSequenceClassification(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.emb_size, config.n_classes)


    def forward(self, input_ids, token_type_ids, attention_mask=None):
        _, pooled_out = self.bert(input_ids, token_type_ids, attention_mask)
        pooled_out = self.dropout(pooled_out)
        logits = self.classifier(pooled_out)

        if self.config.return_pooler_output:
            return pooled_out, logits
        return logits

    def reduce_seq_len(self, seq_len):
        """
        Reduces the accepted sequence length of the inputs into the model.
            e.g. BERT normally accepts a maximum of 512 tokens. This can be reduced to a lesser number of tokens
                for a smaller model that requires less compute during training or inference.

        Args:
            seq_len: int = An integer representing the length of the sequences accepted by the model
        """
        assert seq_len <= self.config.max_seq_length, f"Sequence length must be reduced below current length of {self.config.max_seq_length}"
        self.bert.embeddings.position_embeddings.weight = nn.Parameter(self.bert.embeddings.position_embeddings.weight[:seq_len])
        self.bert.embeddings.position_ids = self.bert.embeddings.position_ids[:, :seq_len]
        print(f"Sequence length successfully reduced to {seq_len}.")
        self.config.max_seq_length = seq_len

    @staticmethod
    def adaptive_copy(orig_wei, new_wei):
        """
        Copies the new weights from the pretrained model into the custom model.
        If the dimensions of the new weights are larger then it only copies the
        portions that fit.

            e.g. old_weight_dim = (1 x 64), new_weight_dim = (1 x 512)
                Replaces the old weights with the first 64 elements of the new weights.

        Args:
            orig_wei: torch.tensor = Torch tensor containing the weights from the custom model
            new_wei: torch.tensor = Torch tensor containing the weights from the pretrained model
        """
        n_dim = orig_wei.dim()

        with torch.no_grad():
            if n_dim == 1:
                dim1 = list(orig_wei.shape)[0]
                orig_wei.copy_(new_wei[:dim1])
            elif n_dim == 2:
                dim1, dim2 = list(orig_wei.shape)
                orig_wei.copy_(new_wei[:dim1, :dim2])
            elif n_dim == 3:
                dim1, dim2, dim3 = list(orig_wei.shape)
                orig_wei.copy_(new_wei[:dim1, :dim2, :dim3])

    @classmethod
    def from_pretrained(cls, model_type, config_args=None, adaptive_weight_copy=False):
        """
        Instantiates the BERT model and loads the weights from a compatible hugging face model.

        Args:
            cls: None = Refers to the class itself, similar to how self refers to the instance of the class.
            model_type: str = Model name (hugging face) or local path
                e.g. 'bert-base-uncased' or './path/bert-base-uncased.pth'
            config_args: dict = Dictionary having all or less of the keys found in BertConfig()
                e.g. config_args = dict(max_seq_length=512, vocab_size=30522, n_classes=2)
            adaptive_weights: bool = Boolean that when true, if the weight dimensions are smaller in the custom model,
                                     it will copy over the the portions of the weights that fit. When false
                                     it will throw an error if mismatch in shape weights.

        Returns:
            torch.nn.Module: A pytorch model
        """
        from transformers import BertForSequenceClassification as HFBertForSequenceClassification

        print(f"Loading weights from pretrained model: {model_type}")

        if config_args:
            config = BertConfig(**config_args)
        else:
            config = BertConfig()

        # init custom model
        model = cls(config)
        sd = model.state_dict()
        sd_keys = sd.keys()

        # init huggingface/transformers model
        model_hf = HFBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path=model_type, num_labels=config.n_classes)
        sd_hf = model_hf.state_dict()
        sd_keys_hf = sd_hf.keys()

        # Check that all keys match between the state dictionary of the custom and pretrained model
        assert len(sd_keys_hf) == len(sd_keys), (
            f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}. "
            "Try using transformers==4.28.1 as this version is known to be compatible."
        )

        # Replace weights in the custom model with the weights from the pretrained model
        for k in sd_keys_hf:

            # copy over weights if they are the same shape
            if not adaptive_weight_copy:

                # Check that the shape of the corresponding weights are the same between the two models
                assert sd_hf[k].shape == sd[k].shape, f"Shape mismatch: {k} --> (hf vs custom) ({sd_hf[k].shape} vs {sd[k].shape})"

                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

            # adaptively copy over weights by cropping them if the dimensions are larger
            else:
                with torch.no_grad():
                    cls.adaptive_copy(sd[k], sd_hf[k])
        return model

In [29]:
#@title configurations
config = BertConfig()
dfs = BertForSequenceClassification(config)

ValueError: Parameter config in `BertForSequenceClassification(config)` should be an instance of class `PretrainedConfig`. To create a model from a pretrained model use `model = BertForSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME)`

In [21]:
#@title finetune BERT with GradScaler and Checkpoint Saving
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
import torch
from transformers import BertForSequenceClassification
from tqdm.notebook import tqdm
import os

# Function to save model checkpoint
def save_checkpoint(model, optimizer, epoch, loss, checkpoint_dir="checkpoints"):
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

# Fine-tune BERT
def fine_tune_bert(model, train_dataset, test_dataset, epochs=10, batch_size=8, learning_rate=2e-5, checkpoint_dir="checkpoints", optimizer):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    criterion = CrossEntropyLoss()

    # Initialize GradScaler for mixed precision training
    scaler = torch.cuda.amp.GradScaler()

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(train_loader):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            with torch.cuda.amp.autocast():
                output = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
                loss = criterion(output.logits, labels)

            # Scale loss
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_train_loss}")

        # Save checkpoint
        save_checkpoint(model, optimizer, epoch, avg_train_loss, checkpoint_dir)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(test_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)

            with torch.cuda.amp.autocast():
                output = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
                logits = output.logits
                _, predicted = torch.max(logits, dim=1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy}")

# Create model instance
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# define optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate)

# Fine-tune the model
fine_tune_bert(model, train_dataset, test_dataset, optimizer)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 1/10, Loss: 0.70487060546875
Checkpoint saved at checkpoints/checkpoint_epoch_1.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 2/10, Loss: 0.5891082763671875
Checkpoint saved at checkpoints/checkpoint_epoch_2.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 3/10, Loss: 0.3657722473144531
Checkpoint saved at checkpoints/checkpoint_epoch_3.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 4/10, Loss: 0.17300399780273437
Checkpoint saved at checkpoints/checkpoint_epoch_4.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 5/10, Loss: 0.0790399169921875
Checkpoint saved at checkpoints/checkpoint_epoch_5.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 6/10, Loss: 0.024396018981933595
Checkpoint saved at checkpoints/checkpoint_epoch_6.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 7/10, Loss: 0.016594295501708985
Checkpoint saved at checkpoints/checkpoint_epoch_7.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 8/10, Loss: 0.01260150909423828
Checkpoint saved at checkpoints/checkpoint_epoch_8.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 9/10, Loss: 0.009499883651733399
Checkpoint saved at checkpoints/checkpoint_epoch_9.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 10/10, Loss: 0.007814292907714843
Checkpoint saved at checkpoints/checkpoint_epoch_10.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy: 0.635


In [22]:
#@title LoRA implementation
import torch.nn.utils.parametrize as parametrize

class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )


In [30]:
#@title freeze original parameters and deploy new LoRA parameters
def freeze_parameters_except_lora(model):
    for name, param in model.named_parameters():
        if 'lora' not in name:
            param.requires_grad = False

# Create model instance
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Apply LoRA parameterization to the BERT self-attention layers
for layer in model.bert.encoder.layer:
    parametrize.register_parametrization(layer.attention.self.query, "weight", linear_layer_parameterization(layer.attention.self.query, device))
    parametrize.register_parametrization(layer.attention.self.key, "weight", linear_layer_parameterization(layer.attention.self.key, device))
    parametrize.register_parametrization(layer.attention.self.value, "weight", linear_layer_parameterization(layer.attention.self.value, device))

# Freeze all parameters except those involved in LoRA
freeze_parameters_except_lora(model)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [36]:
#@title finetune BERT with GradScaler and Checkpoint Saving
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
import torch
from transformers import BertForSequenceClassification
from tqdm.notebook import tqdm
import os

def save_checkpoint(model, optimizer, epoch, loss, checkpoint_dir="checkpoints"):
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

def fine_tune_bert(model, train_dataset, test_dataset, epochs=5, batch_size=8, learning_rate=2e-3, checkpoint_dir="checkpoints"):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # Only parameters with requires_grad=True will be optimized
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
    criterion = CrossEntropyLoss()

    scaler = torch.cuda.amp.GradScaler()

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(train_loader):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            with torch.cuda.amp.autocast():
                output = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
                loss = criterion(output.logits, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_train_loss}")

        save_checkpoint(model, optimizer, epoch, avg_train_loss, checkpoint_dir)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(test_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)

            with torch.cuda.amp.autocast():
                output = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
                logits = output.logits
                _, predicted = torch.max(logits, dim=1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy}")

# Create model instance
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Apply LoRA parameterization to the BERT self-attention layers
for layer in model.bert.encoder.layer:
    parametrize.register_parametrization(layer.attention.self.query, "weight", linear_layer_parameterization(layer.attention.self.query, device))
    parametrize.register_parametrization(layer.attention.self.key, "weight", linear_layer_parameterization(layer.attention.self.key, device))
    parametrize.register_parametrization(layer.attention.self.value, "weight", linear_layer_parameterization(layer.attention.self.value, device))

# Freeze all parameters except those involved in LoRA
freeze_parameters_except_lora(model)

# Fine-tune the model
fine_tune_bert(model, train_dataset, test_dataset)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 1/5, Loss: 0.697830810546875
Checkpoint saved at checkpoints/checkpoint_epoch_1.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 2/5, Loss: 0.705623779296875
Checkpoint saved at checkpoints/checkpoint_epoch_2.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 3/5, Loss: 0.70614990234375
Checkpoint saved at checkpoints/checkpoint_epoch_3.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 4/5, Loss: 0.6933837890625
Checkpoint saved at checkpoints/checkpoint_epoch_4.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 5/5, Loss: 0.69838623046875
Checkpoint saved at checkpoints/checkpoint_epoch_5.pth


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy: 0.525


In [19]:
#@title finetune BERT using library (just to verify our implementation)
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from transformers import BertForSequenceClassification, AdamW
from tqdm.notebook import tqdm

# Load the IMDB dataset
dataset = load_dataset("imdb")
train_dataset_without_lora = dataset['train'].shuffle(seed=42).select(range(200))
test_dataset_without_lora = dataset['test'].shuffle(seed=42).select(range(200))

# Load pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenization and Padding function
def tokenize_and_pad(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=32,
        return_token_type_ids=True
    )

# Map datasets to include input_ids, attention_mask, and token_type_ids
train_tokenized_dataset = train_dataset_without_lora.map(tokenize_and_pad, batched=True)
test_tokenized_dataset = test_dataset_without_lora.map(tokenize_and_pad, batched=True)

class IMDbDataset_without_lora(Dataset):
    def __init__(self, tokenized_dataset):
        self.input_ids = tokenized_dataset['input_ids']
        self.attention_mask = tokenized_dataset['attention_mask']
        self.token_type_ids = tokenized_dataset['token_type_ids']
        self.labels = tokenized_dataset['label']

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        item = {
            'input_ids': torch.tensor(self.input_ids[idx]),
            'attention_mask': torch.tensor(self.attention_mask[idx]),
            'token_type_ids': torch.tensor(self.token_type_ids[idx]),
            'labels': torch.tensor(self.labels[idx])
        }
        return item

train_dataset_without_lora = IMDbDataset_without_lora(train_tokenized_dataset)
test_dataset_without_lora = IMDbDataset_without_lora(test_tokenized_dataset)

# Initialize the BERT model_without_lora
model_without_lora = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# Move the model_without_lora to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_without_lora.to(device)

# Create DataLoaders with reduced batch size
train_loader_without_lora = DataLoader(train_dataset_without_lora, batch_size=8, shuffle=True)
test_loader_without_lora = DataLoader(test_dataset_without_lora, batch_size=8)

# Define the optimizer
optimizer = AdamW(model_without_lora.parameters(), lr=2e-5)

# Mixed precision training
scaler = torch.cuda.amp.GradScaler()

# Training loop
num_epochs = 3
model_without_lora.train()

for epoch in range(num_epochs):
    total_loss = 0
    for batch in tqdm(train_loader_without_lora):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        with torch.cuda.amp.autocast():
            outputs = model_without_lora(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader_without_lora)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_train_loss}")

    # Clear memory
    del input_ids, attention_mask, labels, loss
    torch.cuda.empty_cache()

# Evaluation
model_without_lora.eval()

correct = 0
total = 0

with torch.no_grad():
    for batch in tqdm(test_loader_without_lora):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model_without_lora(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        _, predicted = torch.max(logits, dim=1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy of BERT without LoRA: {accuracy}")


Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 1/3, Loss: 0.703927001953125


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 2/3, Loss: 0.60285888671875


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 3/3, Loss: 0.3998068237304688


  0%|          | 0/25 [00:00<?, ?it/s]

Test Accuracy of BERT without LoRA: 0.635
