In [1]:
import ruamel.yaml as yaml
import argparse
from pathlib import Path
import torch
from torch.utils.data import DataLoader, RandomSampler, TensorDataset, SequentialSampler
from transformers import AutoTokenizer
from datasets import load_dataset

from src.model import DiffNetwork
from src.training_logger import TrainLogger
from src.metrics import accuracy

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
def get_ds(tokenizer) -> TensorDataset:
    
    ds = load_dataset("glue", "sst2", cache_dir="cache")
    return ds.map(
        lambda x: tokenizer(x["sentence"], padding="max_length", max_length=128, truncation=True),
        batched=True,
        load_from_cache_file=False,
        desc="Running tokenizer on dataset"
    )
    

def get_ds_part(ds, part) -> TensorDataset:
    _ds = ds[part]
    return TensorDataset(
        torch.tensor(_ds["input_ids"], dtype=torch.long),
        torch.tensor(_ds["token_type_ids"], dtype=torch.long),
        torch.tensor(_ds["attention_mask"], dtype=torch.long),
        torch.tensor(_ds["label"], dtype=torch.float)
    )

def batch_fn(batch):
    input_ids, token_type_ids, attention_masks, labels = [torch.stack(l) for l in zip(*batch)]
    x = {
        "input_ids": input_ids,
        "token_type_ids": token_type_ids,
        "attention_mask": attention_masks
    }
    return x, labels

In [4]:
with open("cfg.yml", "r") as f:
    cfg = yaml.safe_load(f)    
args = argparse.Namespace(**cfg["train_config"])

In [15]:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)

tokenizer = AutoTokenizer.from_pretrained(args.model_name)

ds = get_ds(tokenizer)

pred_fn = lambda y_hat: (torch.sigmoid(y_hat) > .5).long()
loss_fn = lambda y_hat, y: torch.nn.BCEWithLogitsLoss()(y_hat.flatten(), y)
metrics = {
    "acc": lambda x, y: accuracy(pred_fn(x), y),
    "balanced_acc": lambda x, y: accuracy(pred_fn(x), y, balanced=True)
}

ds_train = get_ds_part(ds, "train")
train_loader = DataLoader(ds_train, sampler=RandomSampler(ds_train), batch_size=args.batch_size, collate_fn=batch_fn)
ds_eval = get_ds_part(ds, "validation")
eval_loader = DataLoader(ds_eval, sampler=SequentialSampler(ds_eval), batch_size=args.batch_size, collate_fn=batch_fn)

Reusing dataset glue (cache/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

Running tokenizer on dataset:   0%|          | 0/68 [00:00<?, ?ba/s]

Running tokenizer on dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Running tokenizer on dataset:   0%|          | 0/2 [00:00<?, ?ba/s]

In [21]:
logger_name = "_".join([
    "diff_pruning",
    args.model_name.split('/')[-1],
    str(args.batch_size),
    str(args.learning_rate)
])
train_logger = TrainLogger(
    log_dir = args.log_dir,
    logger_name = logger_name,
    logging_step = args.logging_step
)

trainer = DiffNetwork(1, args.model_name)
trainer.to(DEVICE)

Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DiffNetwork(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=Tru

In [22]:
trainer.fit(
    train_loader,
    eval_loader,
    train_logger,
    loss_fn,
    metrics,
    args.num_epochs_finetune,
    args.num_epochs_fixmask,
    args.diff_pruning,
    args.alpha_init,
    args.concrete_lower,
    args.concrete_upper,
    args.structured_diff_pruning,
    args.sparsity_pen,
    args.fixmask_pct,
    args.weight_decay,
    args.learning_rate,
    args.learning_rate_alpha,
    args.adam_epsilon,
    args.warmup_steps,
    args.gradient_accumulation_steps,
    args.max_grad_norm,
    args.output_dir
)

Epoch 0, model_state: ModelState.DIFFPRUNING, :   0%|          | 0/8 [00:00<?, ?it/s]
training - step 0, loss:     nan, loss without l0 pen:     nan:   0%|          | 0/1053 [00:00<?, ?it/s][A
training - step 0, loss: 1.32117, loss without l0 pen: 0.77659:   0%|          | 0/1053 [00:00<?, ?it/s][A
training - step 0, loss: 1.32117, loss without l0 pen: 0.77659:   0%|          | 1/1053 [00:00<10:30,  1.67it/s][A
training - step 1, loss: 1.24439, loss without l0 pen: 0.69982:   0%|          | 1/1053 [00:01<10:30,  1.67it/s][A
training - step 1, loss: 1.24439, loss without l0 pen: 0.69982:   0%|          | 2/1053 [00:01<09:57,  1.76it/s][A
training - step 2, loss: 1.24618, loss without l0 pen: 0.70161:   0%|          | 2/1053 [00:01<09:57,  1.76it/s][A
training - step 2, loss: 1.24618, loss without l0 pen: 0.70161:   0%|          | 3/1053 [00:01<09:31,  1.84it/s][A
training - step 3, loss: 1.23201, loss without l0 pen: 0.68744:   0%|          | 3/1053 [00:02<09:31,  1.84it/s][A
tr

KeyboardInterrupt: 