### Fine tune RoBerta with LoRA

In [None]:
!pip install -q transformers 
!pip install -q peft
!pip install -q evaluate
!pip install wandb

### 1. Injecting LoRA to the RoBerta model

In [2]:
from peft import LoraConfig, TaskType

lora_config = LoraConfig(
    # r: the dimension of the A and B
    # lora_alpha: a scaling factor -> the relative significance of the weights 
    #                                 in ‘A’ and ‘B’ in relation to the model’s 
    #                                 original parameters
    task_type=TaskType.SEQ_CLS, r=8, lora_alpha=8, lora_dropout=0.1
)

In [3]:
import torch
from transformers import RobertaForSequenceClassification

# get RoBerta model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RobertaForSequenceClassification.from_pretrained('roberta-base').to(device)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# insert the ‘A’ and ‘B’ matrices into our model by invoking the ‘get_peft_model’ function.
from peft import get_peft_model

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 887,042 || all params: 125,534,212 || trainable%: 0.7066137476531099


### 2. Get Datasets

In [5]:
from nli_dataset import NliDataset

# get dataset
train_dataset = NliDataset(csv_file="Data/train_data.csv", max_length=256)
val_dataset = NliDataset(csv_file="Data/test_data.csv", max_length=256)


In [6]:
from torch.utils.data import DataLoader
train_loader = DataLoader(
                        dataset=train_dataset,
                        batch_size=64,
                        num_workers=4,
                        prefetch_factor=2,
                        shuffle=True,
                        drop_last=False
                    )

val_loader = DataLoader(
                    dataset=val_dataset,
                    batch_size=64,
                    num_workers=4,
                    prefetch_factor=2,
                    drop_last=False
                )

### Train

In [7]:
from torch.optim import lr_scheduler
from transformers import AdamW

# Define optimizer and loss function
optimizer = AdamW(model.parameters(),
                betas=(0.9, 0.98),  # according to RoBERTa paper
                lr=1e-4, 
                weight_decay=5e-2)

linear_sl = lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=1)
cos_sl = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
scheduler = lr_scheduler.SequentialLR(optimizer, schedulers=[linear_sl, cos_sl], milestones=[2])

loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)



In [8]:
import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project="NLI-Roberta-LoRA",

    # track hyperparameters and run metadata
    config={
        'learning_rate': 1e-4,
        'optimizer': 'AdamW',
        'batch_size': 64,
        'epochs': 20
    }
)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwziyi1169[0m ([33mziyiwang[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

# Training loop
total_best_val_acc = 0
for epoch in range(20):
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = loss_fn(outputs.logits, labels)
        train_loss += loss.item()
        # loss = outputs.loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
        optimizer.step()
    
    scheduler.step()
    print(f'Epoch {epoch + 1}, Train Loss: {train_loss / len(train_loader)}')
    
    # Validation loop
    model.eval()
    val_loss = 0
    # Initialize lists to store true labels and predicted labels
    true_labels_list = []
    preds_list = []
    for batch in tqdm(val_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = loss_fn(outputs.logits, labels)
            val_loss += loss.item()

            preds = torch.argmax(outputs.logits, dim=1)

            # Append true labels and predicted labels to the lists
            true_labels_list.extend(labels.cpu().numpy())
            preds_list.extend(preds.cpu().numpy())

    # Calculate accuracy and F1-score
    # Calculate evaluation metrics
    val_acc = accuracy_score(true_labels_list, preds_list)
    f1 = f1_score(true_labels_list, preds_list, average='weighted')
    precision = precision_score(true_labels_list, preds_list, average='weighted')
    recall = recall_score(true_labels_list, preds_list, average='weighted')

    wandb.log({"Epoch":epoch + 1, 
                "Train_Loss": train_loss / len(train_loader),
                "Val_Loss": val_loss / len(val_loader),
                "Val_Acc": val_acc,
                "F1": f1,
                "Precision": precision,
                "Recall": recall
            })

    if val_acc > total_best_val_acc:
        total_best_val_acc = val_acc
        torch.save(model.state_dict(), 'roberta_lora/best_model.pth')
    print(f'Validation: Epoch {epoch + 1}, Loss: {val_loss / len(val_loader)}, Acc: {val_acc}, F1: {f1}, Precision: {precision}, Recall: {recall}')

100%|██████████| 337/337 [07:17<00:00,  1.30s/it]


Epoch 1, Train Loss: 0.6952152584709824


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation: Epoch 1, Loss: 0.6921639470493093, Acc: 0.5227314900723696, F1: 0.35889217829243614, Precision: 0.27324821071327987, Recall: 0.5227314900723696


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 2, Train Loss: 0.6270558160207392


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 2, Loss: 0.4885168313980103, Acc: 0.8055297828910745, F1: 0.8052545073313016, Precision: 0.8057039665600675, Recall: 0.8055297828910745


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 3, Train Loss: 0.5025826085038284


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 3, Loss: 0.43710631868418526, Acc: 0.8456114306921507, F1: 0.8454459611042332, Precision: 0.8457757283630717, Recall: 0.8456114306921507


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 4, Train Loss: 0.46784948754381356


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 4, Loss: 0.4226447364863227, Acc: 0.8537762107997773, F1: 0.8535497221539312, Precision: 0.8542039962594918, Recall: 0.8537762107997773


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 5, Train Loss: 0.44903793807199516


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 5, Loss: 0.4170047104358673, Acc: 0.8613843013546112, F1: 0.8609720252186983, Precision: 0.8628411657435872, Recall: 0.8613843013546112


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 6, Train Loss: 0.43815467477906


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 6, Loss: 0.40860189234509187, Acc: 0.8649100018556318, F1: 0.8648715255668282, Precision: 0.8648867829771859, Recall: 0.8649100018556318


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 7, Train Loss: 0.42838973187197565


100%|██████████| 85/85 [00:28<00:00,  3.02it/s]


Validation: Epoch 7, Loss: 0.4061557068544276, Acc: 0.8708480237520876, F1: 0.8704836531730088, Precision: 0.8722624564284023, Recall: 0.8708480237520876


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 8, Train Loss: 0.42127351723721895


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 8, Loss: 0.4018365828429951, Acc: 0.8714047133048803, F1: 0.8711481963156166, Precision: 0.8721991266533288, Recall: 0.8714047133048803


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 9, Train Loss: 0.41333380279625914


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 9, Loss: 0.3973268515923444, Acc: 0.8743737242531082, F1: 0.8741979535699185, Precision: 0.874795854890158, Recall: 0.8743737242531082


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 10, Train Loss: 0.40893154761557765


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 10, Loss: 0.3958342629320481, Acc: 0.8753015401744294, F1: 0.8752633788515324, Precision: 0.8752870981426379, Recall: 0.8753015401744294


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 11, Train Loss: 0.4072410853397245


100%|██████████| 85/85 [00:28<00:00,  3.04it/s]


Validation: Epoch 11, Loss: 0.39510438372107115, Acc: 0.8758582297272222, F1: 0.8758342931581997, Precision: 0.8758370133206729, Recall: 0.8758582297272222


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 12, Train Loss: 0.40662630057122656


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 12, Loss: 0.3948725987883175, Acc: 0.8754871033586936, F1: 0.875470343333397, Precision: 0.8754671752548672, Recall: 0.8754871033586936


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 13, Train Loss: 0.4054595755540299


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 13, Loss: 0.394584180151715, Acc: 0.8773427352013361, F1: 0.8772787836597492, Precision: 0.8773724945937686, Recall: 0.8773427352013361


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 14, Train Loss: 0.40485344300991705


100%|██████████| 85/85 [00:28<00:00,  3.04it/s]


Validation: Epoch 14, Loss: 0.39473073447451873, Acc: 0.8756726665429578, F1: 0.8755430921828143, Precision: 0.8759044336143595, Recall: 0.8756726665429578


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 15, Train Loss: 0.4075680754481862


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 15, Loss: 0.39348207256373235, Acc: 0.8764149192800148, F1: 0.8763174547053145, Precision: 0.8765341736281764, Recall: 0.8764149192800148


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 16, Train Loss: 0.4029840239785192


100%|██████████| 85/85 [00:28<00:00,  3.02it/s]


Validation: Epoch 16, Loss: 0.39517158164697536, Acc: 0.8778994247541287, F1: 0.8777185137514236, Precision: 0.8783849384827278, Recall: 0.8778994247541287


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 17, Train Loss: 0.40590131574639227


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 17, Loss: 0.38971853782148924, Acc: 0.8793839302282427, F1: 0.8792888071448264, Precision: 0.879508941312897, Recall: 0.8793839302282427


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 18, Train Loss: 0.4041649450711044


100%|██████████| 85/85 [00:28<00:00,  3.03it/s]


Validation: Epoch 18, Loss: 0.3955311003853293, Acc: 0.8734459083317869, F1: 0.8730439679322117, Precision: 0.8751762981607469, Recall: 0.8734459083317869


100%|██████████| 337/337 [07:16<00:00,  1.30s/it]


Epoch 19, Train Loss: 0.3997571692976117


100%|██████████| 85/85 [00:27<00:00,  3.04it/s]


Validation: Epoch 19, Loss: 0.3898716975660885, Acc: 0.8784561143069215, F1: 0.8784723652512237, Precision: 0.8785061183017059, Recall: 0.8784561143069215


100%|██████████| 337/337 [07:17<00:00,  1.30s/it]


Epoch 20, Train Loss: 0.4008827855926593


100%|██████████| 85/85 [00:27<00:00,  3.04it/s]

Validation: Epoch 20, Loss: 0.388645177378374, Acc: 0.8791983670439785, F1: 0.8791295195138855, Precision: 0.8792444567064556, Recall: 0.8791983670439785





In [10]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
F1,▁▇██████████████████
Precision,▁▇██████████████████
Recall,▁▇▇▇████████████████
Train_Loss,█▆▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
Val_Acc,▁▇▇▇████████████████
Val_Loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Epoch,20.0
F1,0.87913
Precision,0.87924
Recall,0.8792
Train_Loss,0.40088
Val_Acc,0.8792
Val_Loss,0.38865
