In [1]:
import wandb
import torch
import gc
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoConfig,
    BitsAndBytesConfig,
)
from peft import get_peft_model, LoraConfig
import bitsandbytes as bnb



In [2]:
def nested_dict_from_flat(flat_dict):
    nested_dict = {}
    for key, value in flat_dict.items():
        keys = key.split(".")
        d = nested_dict
        for k in keys[:-1]:
            d = d.setdefault(k, {})
        d[keys[-1]] = value
    return nested_dict


wandb.init(project="qlora_classification")
configuration = nested_dict_from_flat({k: v for k, v in wandb.config.as_dict().items()})

# Clear the GPU
torch.cuda.empty_cache()
gc.collect()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mviveknayak2210[0m ([33mhpml-nyu[0m). Use [1m`wandb login --relogin`[0m to force relogin


386

In [3]:
import pandas as pd
from torch.utils.data import Dataset
from transformers import BertTokenizer


class MultilabelDataset(Dataset):
    def __init__(self, pandas_df, tokenizer, max_length=1024):
        self.data = pandas_df
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.label_cols = [
            "Computer Science",
            "Physics",
            "Mathematics",
            "Statistics",
            "Quantitative Biology",
            "Quantitative Finance",
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]

        # Join title and abstract as specified
        text = f"{row['TITLE']}: {row['ABSTRACT']}"

        # Tokenize the combined text without truncation
        full_inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            return_token_type_ids=True,
            truncation=False,
        )

        # Tokenize with possible truncation
        truncated_inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            return_token_type_ids=True,
            truncation=True,
        )

        # Check if text was truncated
        if len(full_inputs["input_ids"]) > self.max_length:
            truncated_tokens = full_inputs["input_ids"][self.max_length :]
            truncated_text = self.tokenizer.decode(
                truncated_tokens, skip_special_tokens=True
            )
            print(
                f"Text at index {index} was truncated. Truncated text: {truncated_text}"
            )

        # Extract the one-hot encoded labels for the given row
        labels = row[self.label_cols].values.astype(int).tolist()

        return {
            "input_ids": torch.tensor(truncated_inputs["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(
                truncated_inputs["attention_mask"], dtype=torch.long
            ),
            "token_type_ids": torch.tensor(
                truncated_inputs["token_type_ids"], dtype=torch.long
            ),
            "labels": torch.tensor(
                labels, dtype=torch.float
            ),  # Convert labels to tensor here
        }

In [4]:
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import AutoTokenizer

model_name = "facebook/opt-1.3b"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Assuming you have the MultilabelDataset defined as before

# Load the CSV data
df = pd.read_csv(
    "/scratch/vgn2004/fine_tuning/datasets/paper_abstract_topic_prediction.csv"
)

# Split the data into training and validation sets (80% train, 20% validation)
train_df, val_df = train_test_split(df, test_size=0.2)

# Create datasets
train_dataset = MultilabelDataset(train_df, tokenizer)
val_dataset = MultilabelDataset(val_df, tokenizer)

# Create data loaders
batch_size = 4  # Adjust as per your requirements

train_dataloader = DataLoader(
    train_dataset, sampler=RandomSampler(train_dataset), batch_size=batch_size
)

validation_dataloader = DataLoader(
    val_dataset, sampler=SequentialSampler(val_dataset), batch_size=batch_size
)

In [5]:
def find_lora_target_modules(model):
    """Find all linear layer names in the model. reference from qlora paper."""
    cls = bnb.nn.Linear4bit

    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            # last layer is not add to lora_module_names
            if "lm_head" in name or "score" in name:
                continue
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    return sorted(lora_module_names)


quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_use_double_quant=True,
)
configuration = AutoConfig.from_pretrained(model_name, num_labels=6)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    config=configuration,
    device_map="auto",
    quantization_config=quantization_config,
)

lora_config = LoraConfig(
    r=64,
    lora_alpha=16,
    lora_dropout=0.4,
    bias="none",
    task_type="SEQ_CLS",
    target_modules=find_lora_target_modules(model),
)
model = get_peft_model(model, lora_config)

Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-1.3b and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
model

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): OPTForSequenceClassification(
      (model): OPTModel(
        (decoder): OPTDecoder(
          (embed_tokens): Embedding(50272, 2048, padding_idx=1)
          (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
          (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (layers): ModuleList(
            (0-23): 24 x OPTDecoderLayer(
              (self_attn): OPTAttention(
                (k_proj): Linear4bit(
                  in_features=2048, out_features=2048, bias=True
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.4, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=2048, out_features=64, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=64, out_features=2048, bias=False)
   

In [None]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, hamming_loss
from transformers import get_linear_schedule_with_warmup
from bitsandbytes.optim import AdamW
from tqdm import tqdm  # Import tqdm

optimizer = AdamW(
    params=model.parameters(),
    lr=2e-4,
    is_paged=True,
    optim_bits=8,
)
loss_fn = torch.nn.BCEWithLogitsLoss()
device = "cuda"

scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * 3
)


def train(epoch):
    model.train()
    total_loss = 0

    progress_bar = tqdm(
        train_dataloader, desc=f"Epoch {epoch} Training", position=0, leave=True
    )  # Wrap the dataloader with tqdm
    for step, batch in enumerate(progress_bar):
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()

        outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"])
        logits = outputs.logits

        loss = loss_fn(logits, batch["labels"].type_as(logits))
        loss.backward()

        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

        progress_bar.set_postfix(
            {"loss": f"{loss.item():.4f}"}
        )  # Update tqdm progress bar with loss

        if step > 0 and step % 50 == 0:  # Adjust the logging frequency
            print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")
        if step > 0 and step % 500 == 0:
            evaluate()

    print(f"Epoch {epoch}, Training Loss: {total_loss/len(train_dataloader):.4f}")


def evaluate():
    print(f"Evaluating...")
    LABEL_NAMES = [
        "Computer Science",
        "Physics",
        "Mathematics",
        "Statistics",
        "Quantitative Biology",
        "Quantitative Finance",
    ]
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    progress_bar = tqdm(
        validation_dataloader, desc="Validating", position=0, leave=True
    )  # Wrap the dataloader with tqdm
    with torch.no_grad():
        for step, batch in enumerate(progress_bar):
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"])
            logits = outputs.logits

            loss = loss_fn(logits, batch["labels"].type_as(logits))
            total_loss += loss.item()

            preds = torch.sigmoid(logits) > 0.5
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch["labels"].cpu().numpy())

    # Calculate metrics
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    print(f"ALL_PREDS: {all_preds[:5]}")
    print(f"ALL_LABELS: {all_labels[:5]}")

    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(
        all_labels, all_preds, average="micro"
    )  # using micro average for multi-label
    hamming = hamming_loss(all_labels, all_preds)

    # Print metrics
    print(f"Validation Loss: {total_loss/len(validation_dataloader):.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Micro F1 Score: {f1:.4f}")
    print(f"Hamming Loss: {hamming:.4f}")

    # Print some sample predictions
    for i in range(2):  # printing 2 samples
        print("\nSample:", i + 1)
        print(
            "Predicted labels:",
            [LABEL_NAMES[j] for j in range(len(LABEL_NAMES)) if all_preds[i][j] == 1],
        )
        print(
            "Actual labels:",
            [LABEL_NAMES[j] for j in range(len(LABEL_NAMES)) if all_labels[i][j] == 1],
        )

    model.train()


# Train & Evaluate for 3 epochs as an example
num_epochs = 20
for epoch in range(num_epochs):
    train(epoch)

Epoch 0 Training:   1%|          | 51/4195 [01:14<1:39:54,  1.45s/it, loss=0.4661]

Epoch 0, Step 50, Loss: 0.4661


Epoch 0 Training:   2%|▏         | 101/4195 [02:26<1:38:45,  1.45s/it, loss=0.1499]

Epoch 0, Step 100, Loss: 0.1499


Epoch 0 Training:   4%|▎         | 151/4195 [03:38<1:37:33,  1.45s/it, loss=0.1826]

Epoch 0, Step 150, Loss: 0.1826


Epoch 0 Training:   5%|▍         | 201/4195 [04:51<1:36:33,  1.45s/it, loss=0.0842]

Epoch 0, Step 200, Loss: 0.0842


Epoch 0 Training:   6%|▌         | 251/4195 [06:03<1:35:13,  1.45s/it, loss=0.1909]

Epoch 0, Step 250, Loss: 0.1909


Epoch 0 Training:   7%|▋         | 301/4195 [07:16<1:33:59,  1.45s/it, loss=0.2163]

Epoch 0, Step 300, Loss: 0.2163


Epoch 0 Training:   8%|▊         | 351/4195 [08:28<1:32:44,  1.45s/it, loss=0.1591]

Epoch 0, Step 350, Loss: 0.1591


Epoch 0 Training:  10%|▉         | 401/4195 [09:41<1:31:34,  1.45s/it, loss=0.1525]

Epoch 0, Step 400, Loss: 0.1525


Epoch 0 Training:  11%|█         | 451/4195 [10:53<1:30:23,  1.45s/it, loss=0.1039]

Epoch 0, Step 450, Loss: 0.1039


Epoch 0 Training:  12%|█▏        | 500/4195 [12:05<1:29:07,  1.45s/it, loss=0.3169]

Epoch 0, Step 500, Loss: 0.3169
Evaluating...


Validating: 100%|██████████| 1049/1049 [09:23<00:00,  1.86it/s]
Epoch 0 Training:  12%|█▏        | 501/4195 [21:29<175:00:16, 170.55s/it, loss=0.3169]

ALL_PREDS: [[ True False False False False False]
 [ True False False False False False]
 [ True False False False False False]
 [ True False False False False False]
 [False False  True False False False]]
ALL_LABELS: [[1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]]
Validation Loss: 0.1874
Accuracy: 0.6687
Micro F1 Score: 0.8132
Hamming Loss: 0.0755

Sample: 1
Predicted labels: ['Computer Science']
Actual labels: ['Computer Science']

Sample: 2
Predicted labels: ['Computer Science']
Actual labels: ['Computer Science']


Epoch 0 Training:  13%|█▎        | 551/4195 [22:41<1:27:53,  1.45s/it, loss=0.1486]   

Epoch 0, Step 550, Loss: 0.1486


Epoch 0 Training:  14%|█▍        | 601/4195 [23:54<1:26:42,  1.45s/it, loss=0.1956]

Epoch 0, Step 600, Loss: 0.1956


Epoch 0 Training:  16%|█▌        | 651/4195 [25:06<1:25:30,  1.45s/it, loss=0.1102]

Epoch 0, Step 650, Loss: 0.1102


Epoch 0 Training:  17%|█▋        | 701/4195 [26:19<1:24:19,  1.45s/it, loss=0.0964]

Epoch 0, Step 700, Loss: 0.0964


Epoch 0 Training:  18%|█▊        | 751/4195 [27:31<1:23:05,  1.45s/it, loss=0.8945]

Epoch 0, Step 750, Loss: 0.8945


Epoch 0 Training:  19%|█▉        | 801/4195 [28:44<1:21:54,  1.45s/it, loss=0.1287]

Epoch 0, Step 800, Loss: 0.1287


Epoch 0 Training:  20%|██        | 851/4195 [29:56<1:20:40,  1.45s/it, loss=0.0449]

Epoch 0, Step 850, Loss: 0.0449


Epoch 0 Training:  21%|██▏       | 901/4195 [31:08<1:19:34,  1.45s/it, loss=0.0955]

Epoch 0, Step 900, Loss: 0.0955


Epoch 0 Training:  23%|██▎       | 951/4195 [32:21<1:18:20,  1.45s/it, loss=0.1003]

Epoch 0, Step 950, Loss: 0.1003


Epoch 0 Training:  24%|██▍       | 1000/4195 [33:33<1:17:19,  1.45s/it, loss=0.1613]

Epoch 0, Step 1000, Loss: 0.1613
Evaluating...


Validating: 100%|██████████| 1049/1049 [09:22<00:00,  1.86it/s]
Epoch 0 Training:  24%|██▍       | 1001/4195 [42:56<151:07:34, 170.34s/it, loss=0.1613]

ALL_PREDS: [[ True False False False False False]
 [ True False False False False False]
 [ True False False False False False]
 [ True False False False False False]
 [False False  True False False False]]
ALL_LABELS: [[1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]]
Validation Loss: 0.1971
Accuracy: 0.6589
Micro F1 Score: 0.7919
Hamming Loss: 0.0792

Sample: 1
Predicted labels: ['Computer Science']
Actual labels: ['Computer Science']

Sample: 2
Predicted labels: ['Computer Science']
Actual labels: ['Computer Science']


Epoch 0 Training:  25%|██▌       | 1051/4195 [44:09<1:15:54,  1.45s/it, loss=0.1001]   

Epoch 0, Step 1050, Loss: 0.1001


Epoch 0 Training:  26%|██▌       | 1101/4195 [45:21<1:14:41,  1.45s/it, loss=0.1549]

Epoch 0, Step 1100, Loss: 0.1549


Epoch 0 Training:  27%|██▋       | 1151/4195 [46:33<1:13:41,  1.45s/it, loss=0.0549]

Epoch 0, Step 1150, Loss: 0.0549


Epoch 0 Training:  29%|██▊       | 1201/4195 [47:46<1:12:18,  1.45s/it, loss=0.3044]

Epoch 0, Step 1200, Loss: 0.3044


Epoch 0 Training:  30%|██▉       | 1251/4195 [48:58<1:11:03,  1.45s/it, loss=0.1148]

Epoch 0, Step 1250, Loss: 0.1148


Epoch 0 Training:  31%|███       | 1301/4195 [50:11<1:09:54,  1.45s/it, loss=0.0555]

Epoch 0, Step 1300, Loss: 0.0555


Epoch 0 Training:  32%|███▏      | 1351/4195 [51:23<1:08:39,  1.45s/it, loss=0.5327]

Epoch 0, Step 1350, Loss: 0.5327


Epoch 0 Training:  33%|███▎      | 1401/4195 [52:36<1:07:25,  1.45s/it, loss=0.2284]

Epoch 0, Step 1400, Loss: 0.2284


Epoch 0 Training:  35%|███▍      | 1451/4195 [53:48<1:06:12,  1.45s/it, loss=0.2325]

Epoch 0, Step 1450, Loss: 0.2325


Epoch 0 Training:  36%|███▌      | 1500/4195 [55:00<1:05:02,  1.45s/it, loss=0.3650]

Epoch 0, Step 1500, Loss: 0.3650
Evaluating...


Validating: 100%|██████████| 1049/1049 [09:23<00:00,  1.86it/s]
Epoch 0 Training:  36%|███▌      | 1501/4195 [1:04:24<127:39:18, 170.59s/it, loss=0.3650]

ALL_PREDS: [[ True False False False False False]
 [ True False False False False False]
 [ True False False False False False]
 [ True False False False False False]
 [False False  True False False False]]
ALL_LABELS: [[1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]]
Validation Loss: 0.1792
Accuracy: 0.6632
Micro F1 Score: 0.8089
Hamming Loss: 0.0735

Sample: 1
Predicted labels: ['Computer Science']
Actual labels: ['Computer Science']

Sample: 2
Predicted labels: ['Computer Science']
Actual labels: ['Computer Science']


Epoch 0 Training:  37%|███▋      | 1551/4195 [1:05:37<1:03:50,  1.45s/it, loss=0.1251]   

Epoch 0, Step 1550, Loss: 0.1251


Epoch 0 Training:  38%|███▊      | 1601/4195 [1:06:49<1:02:39,  1.45s/it, loss=0.2229]

Epoch 0, Step 1600, Loss: 0.2229


Epoch 0 Training:  39%|███▉      | 1651/4195 [1:08:01<1:01:22,  1.45s/it, loss=0.0612]

Epoch 0, Step 1650, Loss: 0.0612


Epoch 0 Training:  41%|████      | 1701/4195 [1:09:14<1:00:15,  1.45s/it, loss=0.1982]

Epoch 0, Step 1700, Loss: 0.1982


Epoch 0 Training:  42%|████▏     | 1751/4195 [1:10:26<58:57,  1.45s/it, loss=0.3137]  

Epoch 0, Step 1750, Loss: 0.3137


Epoch 0 Training:  43%|████▎     | 1801/4195 [1:11:39<57:48,  1.45s/it, loss=0.2625]

Epoch 0, Step 1800, Loss: 0.2625


Epoch 0 Training:  44%|████▍     | 1851/4195 [1:12:51<56:41,  1.45s/it, loss=0.1335]

Epoch 0, Step 1850, Loss: 0.1335


Epoch 0 Training:  45%|████▌     | 1901/4195 [1:14:04<55:20,  1.45s/it, loss=0.1459]

Epoch 0, Step 1900, Loss: 0.1459


Epoch 0 Training:  47%|████▋     | 1951/4195 [1:15:16<54:12,  1.45s/it, loss=0.1074]

Epoch 0, Step 1950, Loss: 0.1074


Epoch 0 Training:  48%|████▊     | 2000/4195 [1:16:29<52:57,  1.45s/it, loss=0.1125]

Epoch 0, Step 2000, Loss: 0.1125
Evaluating...


Validating:   9%|▉         | 94/1049 [00:50<08:33,  1.86it/s]