In [1]:
from utils import tokenize, to_cuda, Logger, plot_results, HuggingMetric, freeze_model

import numpy as np

import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import datasets

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from tqdm.notebook import tqdm
import pickle

from IPython.display import clear_output
from typing import Tuple, List

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

%load_ext autoreload
%autoreload 2

# Global Params

In [2]:
MAX_LEN = 128


# Train Loop 

In [3]:
class ColaDataset(Dataset):
    def __init__(self, data) -> None:
        self.input_ids = torch.tensor(data['input_ids']).squeeze()
        self.attention_mask = torch.tensor(data['attention_mask']).squeeze()
        self.labels = torch.tensor(data['label']).squeeze()
        
    def __len__(self) -> int:
        return self.input_ids.shape[0]
    
    def __getitem__(self, idx: int) -> dict:
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.labels[idx]
        }

In [4]:
def train_epoch(
    model: nn.Module, 
    optim: torch.optim, 
    train_loader: DataLoader, 
    valid_loader: DataLoader,
    logger: Logger,
    scheduler: torch.optim.lr_scheduler=None,
    log_step: int=None,
    valid_step: int=None,
    N_cycles: int=0,
    writer = None,
    epoch: int = None
)-> None:
    writer.add_scalar('N_cycles', N_cycles[0], epoch)
    
    if log_step is None:
        log_step = train_loader.__len__() // 20
    
    if valid_step is None:
        valid_step = train_loader.__len__() // 6
    
    train_metric, train_loss = [], []
        
    global_step = epoch * train_loader.__len__()
    for step, batch in tqdm(enumerate(train_loader)):
        model.train()
        batch = to_cuda(batch)
        out = model(**batch, N_cycles=[*N_cycles, global_step])
        
        loss = out['loss']
        loss.backward()
        
        optim.step()
        scheduler.step()
        
        optim.zero_grad()
        
        writer.add_scalar('LR', scheduler.get_last_lr()[0], global_step)
        writer.add_scalar('Epoch', epoch, global_step)

        train_metric.append(
            metric(
                out['logits'].detach().cpu(),
                batch['labels'].cpu()
            )
        )
        train_loss.append(loss.detach().cpu().numpy())
        global_step += 1
        
        if step % log_step == 0:
            logger.log({
                'train_loss': np.mean(train_loss),
                'train_metric': np.mean(train_metric)
            })
            writer.add_scalar('Train/loss', np.mean(train_loss), global_step)
            writer.add_scalar('Train/corr', np.mean(train_metric), global_step)

            
            train_metric, train_loss = [], []
                    
        if step % valid_step == 0:
            with torch.no_grad():
                model.eval()
                metrics, losses = [], []
                for batch in tqdm(valid_loader):
                    batch = to_cuda(batch)
                    out = model(**batch, N_cycles=[*N_cycles, global_step])
                    
                    loss = out['loss']
                    
                    metrics.append(
                        metric(
                            out['logits'].detach().cpu(),
                            batch['labels'].cpu()
                            
                        )
                    )
                    losses.append(loss.detach().cpu().numpy())
                
                logger.log({
                        'test_loss': np.mean(losses),
                        'test_metric': np.mean(metrics)
                })
                writer.add_scalar('Test/loss', np.mean(losses), global_step)
                writer.add_scalar('Test/corr', np.mean(metrics), global_step)

                print('Step=', step, ' test_loss=', np.mean(losses), ' test_metric=', np.mean(metrics))
                model.train()
                              


In [7]:
from IterativeBert import BertForSequenceClassificationOur


def prepare_model(model_name: str, strategy: str, bs: int=32, drop: float=0.1):
    if 'distilbert' in model_name:
        dropout = {'dropout': drop}
    else:
        dropout = {'hidden_dropout_prob': drop}
    
    model = BertForSequenceClassificationOur.from_pretrained('bert-base-uncased', strategy=strategy).cuda()
    
    # Example in PyTorch
    for param in model.parameters():
        param.requires_grad = True

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    dataset_train_tokenized = dataset_train.map(
        tokenize, fn_kwargs={'tokenizer': tokenizer, 'MAX_LEN': MAX_LEN}, num_proc=2,
    )
    dataset_valid_tokenized = dataset_valid.map(
        tokenize, fn_kwargs={'tokenizer': tokenizer, 'MAX_LEN': MAX_LEN}, num_proc=2,
    )
    
    train_data = ColaDataset(dataset_train_tokenized)
    train_loader = DataLoader(train_data, batch_size=bs, num_workers=8, shuffle=True, pin_memory=True, drop_last=True)

    valid_data = ColaDataset(dataset_valid_tokenized)
    valid_loader = DataLoader(valid_data, batch_size=bs * 2, num_workers=8, pin_memory=True)
    
    return model, train_loader, valid_loader

In [8]:
dataset_train = datasets.load_dataset('glue', 'cola', split='train')
dataset_valid = datasets.load_dataset('glue', 'cola', split='validation')

matthew = datasets.load_metric('matthews_correlation')

# Expirement

In [13]:
from torch.utils.tensorboard import SummaryWriter

NUM_EPOCH = 5

distilbert_loggers = dict()

def schedule(epoch):
    s = [10] * 7
    return s[epoch]
    # elif 5 <= epoch:
    #     return 3


for lr in [2e-5]:
    exp_path = f'./runs/lr={lr}_testrun_scheduler={"".join([str(schedule(i)) for i in range(NUM_EPOCH)])}'
    writer = SummaryWriter(log_dir=exp_path)

    model, train_loader, valid_loader = prepare_model("bert-base-uncased", strategy='last_update', bs=64)

    weight_decay = 1e-3
    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optim, 
                                                step_size=5000, 
                                                gamma=0.9)
    metric = HuggingMetric(matthew)
    logger_distilbert = Logger(f'distilbert lr = {lr}')
    for epoch in range(NUM_EPOCH):
        print(f'Epoch {epoch} started...')
        train_epoch(
            model,
            optim,
            train_loader,
            valid_loader,
            logger_distilbert,
            scheduler=scheduler,
            N_cycles=[schedule(epoch), writer],
            writer=writer,
            epoch=epoch
        )
    
        
        distilbert_loggers[lr] = logger_distilbert
        # plot_results(distilbert_loggers.values())

Some weights of BertForSequenceClassificationOur were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 0 started...


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Step= 0  test_loss= 0.72210056  test_metric= 0.00868297279821629


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 22  test_loss= 0.61795557  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 44  test_loss= 0.6074059  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 66  test_loss= 0.61971873  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 88  test_loss= 0.61476374  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 110  test_loss= 0.6109324  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 132  test_loss= 0.6128113  test_metric= 0.0
Epoch 1 started...


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Step= 0  test_loss= 0.6110126  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 22  test_loss= 0.60744214  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 44  test_loss= 0.6080037  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 66  test_loss= 0.6036393  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 88  test_loss= 0.6056828  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 110  test_loss= 0.608114  test_metric= 0.015173476154046314


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 132  test_loss= 0.6120595  test_metric= 0.0
Epoch 2 started...


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Step= 0  test_loss= 0.608235  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 22  test_loss= 0.60759836  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 44  test_loss= 0.60510373  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 66  test_loss= 0.6128465  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 88  test_loss= 0.6094266  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 110  test_loss= 0.61106426  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

Step= 132  test_loss= 0.6095667  test_metric= 0.0
Epoch 3 started...


0it [00:00, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Step= 0  test_loss= 0.608668  test_metric= 0.0


  0%|          | 0/9 [00:00<?, ?it/s]

KeyboardInterrupt: 