# Bert Partition

## Imports

In [1]:
# ! pip install transformers[torch] datasets evaluate wandb minio tqdm scipy

In [2]:
import datasets
import pickle as pkl

from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    BertForMaskedLM,
    get_scheduler,
    TrainingArguments,
    Trainer
)
import evaluate
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD
# from maskedtensor import masked_tensor

from tqdm.auto import trange, tqdm
import pandas as pd
from collections import defaultdict
import io
from PairsDataset import PairsDataset

import wandb

from IPython.display import clear_output
import logging
import ema_swa_utils

### Runtime parameters

In [3]:
TRY_NAME = "naive_cosine_with_pretrained_bert_with_ema_0.999"

In [4]:
SEQ_LEN = 64
BATCH_SIZE = 16
MLM_PROB = 0.15

#DATA_PATH = '/content/drive/MyDrive/nnlp/bert/biblioteka_prikluchenij_both_agr.csv'
DATA_PATH = "data/train_dataset.csv"
TEST_PATH = "data/tda_test.csv"
MODEL_NAME = 'DeepPavlov/rubert-base-cased'
WEIGHTS_PATH = "ckpt/pretrained_bert_epoch_9.999976796259556.pt"

USE_SWA = False

USE_EMA = True
EMA_DECAY = 0.999

# whether to log the layers being changed (happens once per notebook restart)
LOG_LAYERS = True

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [6]:
num_epochs = 15
learning_rate = 5e-4

n_mlm = 1
n_cosine = 10
division_layer = 3
weight_mlm = 1
MEAN_OVER_CHANGED = True

In [7]:
DESCRIPTION = \
f'''
Model: based on rubert, additionally pretrained for 10 epochs;
Checkpoint: {WEIGHTS_PATH};
Context: {SEQ_LEN};
Batch size: {BATCH_SIZE};

Loss: classic cosine embedding with steps along the MLM gradient;
Loss only over diff in tokenization: {MEAN_OVER_CHANGED};
N_MLM: {n_mlm};
N_Cosine: {n_cosine};
Division_layer: {division_layer};
weight_mlm: {weight_mlm};
weight_cosine: polynomial decay, sum to 1;

LR_SCHEDULER: MultiStep;
Initial learning rate: {learning_rate};
Steps: 5, 
Decay: 0.5,
Epochs: {num_epochs};

Additional parameters and notes:
EMA: {USE_EMA};
EMA_DECAY: {EMA_DECAY};

SWA: {USE_SWA};

'''

### Logging

In [8]:
logging.basicConfig(filename=f"logs/{TRY_NAME}_grad_log", filemode="w")
logger = logging.getLogger(__name__)
logger.setLevel("INFO")

In [9]:
# minio handler to use remote data -- implements get and put methods with pickling option (view file)

from MinioHandler import MinioHandler

minio = MinioHandler()

In [10]:
# from google.colab import drive
# drive.mount('/content/drive')

In [11]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mxenomirant[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [12]:
wandb.init(
    project='grammar-bert-model1',
    entity='grammar-bert'
)

[34m[1mwandb[0m: Currently logged in as: [33mxenomirant[0m ([33mgrammar-bert[0m). Use [1m`wandb login --relogin`[0m to force relogin


### Data Preparation

In [13]:
# df = pd.read_csv(DATA_PATH, index_col = 0)
# df = df.drop(columns=['Unnamed: 0'])

In [14]:
# df = df[df.was_changed].reset_index(drop=True)

In [15]:
# df

In [16]:
# tqdm.pandas()

# idx_init = df.initial.progress_apply(lambda x: x.replace(' ', ''))
# idx_pol = df.polypers.progress_apply(lambda x: x.replace(' ', ''))
# idx = -(idx_init == idx_pol)
# df['was_changed'] = idx

In [17]:
# df.to_csv(DATA_PATH, index=False)

### Train test splitting

In [18]:
# from sklearn.model_selection import train_test_split

# TEST_SIZE = 0.1

In [19]:
# df = pd.read_csv(DATA_PATH, index_col = 0)

In [20]:
# df

In [21]:
# train, test = train_test_split(df, test_size=TEST_SIZE, stratify = df["was_changed"])

In [22]:
# train.to_csv("data/train_bpa.csv")
# test.to_csv("data/test_bpa.csv")

### Pick items from test for TDA and homology computation

In [23]:
# df = pd.read_csv("data/test_bpa.csv", index_col = 0)

# df = df[df.was_changed]

In [24]:
# tda_data = df.sample(n = 250, random_state=42)

In [25]:
# tda_data.to_csv("tda_test.csv")

In [26]:
# minio.put_object(tda_data, save_name="data/tda_test.pkl", pickle=True)

In [27]:
# put everything to minio -- also possible to use default minio functions from Minio class

# minio.minio.fput_object(file_path="data/test_dataset.csv", bucket_name="public",
#                       object_name="ModularLM/data/test_dataset.csv")

## Dataset and collator

In [28]:
def collate_func(batch):
    batch = [data_collator.torch_call(item) for item in zip(*batch)]
    return batch

In [29]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

tokenizer.pad_token = '[SEP]'
tokenizer.eos_token = '[SEP]'
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=MLM_PROB)

## MLM Train

In [30]:
# dt = datasets.Dataset.from_csv(DATA_PATH)
# dt = dt.remove_columns(['polypers', 'was_changed']).rename_column('initial', 'text')

In [31]:
# N_samples = 10**5

In [32]:
# def tokenize_function(example):
#     return tokenizer(example['text'], truncation=True)

# tok_dt = dt.select(range(N_samples)).map(tokenize_function, batched=True)
# tok_dt = tok_dt.train_test_split(test_size=100,
#                          shuffle=True,
#                          seed=42)

In [33]:
# training_args = TrainingArguments(
#     report_to = 'wandb',
#     output_dir='part1-model',
#     learning_rate=1e-3,
#     per_device_train_batch_size=16,
#     num_train_epochs=1,
#     # evaluation_strategy='steps',
#     # eval_steps=20,
#     logging_steps=20,
#     logging_first_step=True
# )

In [34]:
# model = BertForMaskedLM.from_pretrained(MODEL_NAME)
# model.to(device)
# pass

In [35]:
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=tok_dt['train'],
#     # eval_dataset=tok_dt['test'],
#     tokenizer=tokenizer,
#     data_collator=data_collator
# )

In [36]:
# trainer.train()

## Model

In [37]:
def save_gradients(model, division_layer):
    layers = {}
    for name, param in model.named_parameters():
        # division layer passed == division layer + 1 as is inside train
        if name.startswith(f'bert.encoder.layer.{division_layer}'):
            break
        if (param.requires_grad) and param.grad is not None:
            layers[name] = param.grad.detach().clone()
    if LOG_LAYERS:
        logger.info(f"Saved layers: {str(layers.keys())}")
    return layers

In [38]:
def change_gradients(*, model, layers, 
                     division_layer,
                     weight_mlm=0.5, 
                     weight_cos=1, ):

    global LOG_LAYERS
    
    for name, param in model.named_parameters():
        # division layer passed == division layer + 1 as is inside train
        if name.startswith(f'bert.encoder.layer.{division_layer}'):
            break
        if name in layers:
            param.grad = weight_cos * param.grad + weight_mlm * layers[name]
            if LOG_LAYERS:
                logger.info(f"Changed layer: {name}")
                logger.info(f"gradients changed. {(weight_cos * param.grad).norm(), (weight_mlm * layers[name]).norm()}\n")
    LOG_LAYERS = False

In [39]:
class CosWeightDecay:
    '''
    Cosine Weight with decaying step sizes after each multiplication
    '''
    def __init__(self, init_state=1, step=0.5):
        self.cur_state = init_state
        self.step = step

    def __mul__(self, other):
        res = self.cur_state * other
        self.cur_state = self.cur_state * self.step
        return res

    def __repr__(self):
        return str(self.cur_state)

In [40]:
class CosWeightSum2One:
    '''
    Cosine Weight summing to 1 over 10 steps (must be subject to change in case other step size is required)
    This looks quite dumb...
    '''
    def __init__(self, init_coef = 100, steps: int = 10, linear = False):
        self.counter = -1
        if linear:
            self.steps = init_coef*np.ones(steps)
            return None
        self.steps = init_coef*np.arange(2, steps+2)**-1.5
        return None

    def __mul__(self, other):
        res = self.steps[self.counter] * other
        return res
    
    def step(self):
        self.counter+=1
        return None

    def __repr__(self):
        return str(self.steps[self.counter])

    def reset(self):
        self.counter = -1
        return None

    @property
    def weight(self):
        return self.steps[self.counter]

In [41]:
class CosLoss:
    def __init__(self, vector=None, alpha=0):
        self.loss = nn.CosineEmbeddingLoss()
        self.target = torch.ones(BATCH_SIZE).to(model.device)
        self.alpha = alpha
        self.vector = vector

    def __call__(self, hid_ref, hid_cur, target):
        cos_loss = self.loss(hid_ref, hid_cur, target)
        if self.vector is not None:
            cos_loss += self.alpha * self.loss(self.vector, hid_ref - hid_cur,
                                               self.target)
        return cos_loss

In [42]:
def train(model, criteria, optimizer, lr_scheduler, data, hom_data, n_epochs=1,
          n_cosine=10, division_layer=4, weight_mlm=1,
          weight_cos=1, save_every_epoch=3, test_every=5000):

    # global mlm_losses, cosine_losses
    # change global loss tracking to local only -- for now it seems unnecessary
    global tda_save_dict

    tq_epoch = trange(n_epochs, desc='Epochs: ')
    tq_batch = tqdm(total=len(data))

    # target for cosine loss
    target = -torch.ones(BATCH_SIZE).to(model.device)
    grads = None
    
    # just initialization -- first few batches make no difference for tracking
    cos_loss = 0
    mlm_loss = 0
    hom_computed = 0
    # TODO -- optimize for gradual saving of dict intead of accumulation


    # save necessary features to dict
    def save_tda_features(hom_data: torch.utils.data.DataLoader, save_dict: dict):

        nonlocal hom_computed
        hom_computed+=1
        base = []
        polypers = []
        
        with torch.no_grad():
            model.eval()
            for i, batch in tqdm(enumerate(hom_data)):
                # base embeddings after layer
                pred_base = model(**{k: v.to(model.device) for k, v in batch[0].items()},
                                  output_hidden_states=True)
                hid_ref = torch.mean(pred_base.hidden_states[division_layer], dim=1)
                base.extend(hid_ref.detach().cpu().numpy())

                # polypers embeddings after layer
                pred_new = model(**{k: v.to(model.device) for k, v in batch[1].items()},
                                  output_hidden_states=True)
                hid_cur = torch.mean(pred_new.hidden_states[division_layer], dim=1)
                polypers.extend(hid_cur.detach().cpu().numpy())
        
        save_dict[hom_computed]["base"] = base
        save_dict[hom_computed]["polypers"] = polypers
        return save_dict

    def gradient_norm():
        grads = [
        param.grad.detach().flatten()
        for param in model.parameters()
        if param.grad is not None
        ]
        norm = torch.cat(grads).norm()
        return norm
        
    #########################################################
    # training loop 
    #########################################################
    for epoch in tq_epoch:
        tq_batch.reset()
        cosine_losses = [cos_loss]
        mlm_losses = [mlm_loss]

        for i, batch in enumerate(data):
            # save data for TDA
            if i % test_every == 0:
                tda_save_dict = save_tda_features(hom_data, tda_save_dict)
                # also a point of optimization
                minio.put_object(tda_save_dict, 
                                 save_name=f"data/TDA_FEATURES/tda_save_dict_{TRY_NAME}.pkl", 
                                 pickle=True)
                model.train()
            # pred on base text    
            pred = model(**{k: v.to(model.device) for k, v in batch[0].items()},
                         output_hidden_states=True, )
            
            # once upon 10 steps, compute mlm loss
            if i % n_cosine == 0:
                pred.loss.backward(retain_graph=True)
                grads = save_gradients(
                    model=model, 
                    division_layer=division_layer
                )
                mlm_grad_norm = gradient_norm()

                if (USE_EMA | USE_SWA):
                    avg_model.update_parameters(model)                    
                
                optimizer.zero_grad()
                
                mlm_losses.append(pred.loss.detach().cpu())
                # reset cosine weight every cycle
                weight_cos.reset()
                # take lr_step also every cycle
                lr_scheduler.step()
            
            # compute cosine anyway
            # pred on polypers text
            pred_new = model(**{k: v.to(model.device) for k, v in batch[1].items()},
                             output_hidden_states=True)
            
            hid_ref = pred.hidden_states[division_layer]
            hid_cur = pred_new.hidden_states[division_layer]
            
            # look for changed ids only
            if MEAN_OVER_CHANGED:
                mask = (batch[0]["input_ids"] - batch[1]["input_ids"]) != 0
                mask = mask.unsqueeze(-1).expand(-1, -1, 768) 
                # masked tensors don't support loss calculations as filling with 0 stops differentiation
                # hid_ref = masked_tensor(hid_ref, mask.to(model.device), requires_grad=True).to_tensor(value=0)
                # hid_cur = masked_tensor(hid_cur, mask.to(model.device), requires_grad=True).to_tensor(value=0)
                hid_ref = hid_ref*mask.to(model.device)
                hid_cur = hid_cur*mask.to(model.device)

            
            hid_ref = torch.mean(hid_ref, dim=1)
            hid_cur = torch.mean(hid_cur, dim=1)

            cos_loss = criteria(hid_ref, hid_cur, target)
            cos_loss.backward()
            cos_grad_norm = gradient_norm()
            
            # as indexing in weight_cos starts from -1, take a step before applying            
            weight_cos.step()

            # new_grads = {}
            # for name, parameter in model.named_parameters():
            #     if name in grads:
            #         new_grads[name] = parameter.grad.clone()

            change_gradients(model=model, 
                             layers=grads, 
                             division_layer=division_layer,
                             weight_mlm=weight_mlm, 
                             weight_cos=weight_cos)
            
            # with open("grad_log", "a") as f:
            #     for name, parameter in model.named_parameters():
            #         if name in new_grads:
            #             f.writelines(f"{name} - {(new_grads[name] - parameter.grad).norm()}")
                
            optimizer.step()
            optimizer.zero_grad()                    

            cosine_losses.append(cos_loss.detach().cpu())

            cos_loss = (sum(cosine_losses[-30:]) / len(cosine_losses[-30:])).item()
            mlm_loss = (sum(mlm_losses[-30:]) / len(mlm_losses[-30:])).item()

            # TODO: possibly send not every batch -- however, it acts asynchrously, so doesn't seem to make much difference
            wandb.log({"MLM loss": mlm_loss,
                       "Cosine loss": cos_loss,
                       "MLM grad norm": mlm_grad_norm,
                       "Cos grad norm": cos_grad_norm,
                       "Cos_Weight": weight_cos.weight,
                       "learning_rate": lr_scheduler.get_last_lr()[0]
                      })
            tq_batch.set_postfix({
                    'MLM loss': mlm_loss,
                    'Cosine loss': cos_loss
                })

            tq_batch.update(1)

        if epoch % save_every_epoch == 0:
            # Note -- we don't save the model class, only the weights
            print("Saving model checkpoint...")
            buffer = io.BytesIO()
            torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'mlm_loss': pred.loss,
                    'cos_loss': cos_loss
                            }, f=buffer)
            # TODO -- add custom hash to model instead of value
            minio.put_object(buffer.getvalue(), 
                             save_name=f"ckpt/{TRY_NAME}/model_epoch_{epoch}.pt")
    
    model.eval()

In [43]:
dt = PairsDataset(tokenizer, path=DATA_PATH)
dl = DataLoader(dt, batch_size=BATCH_SIZE, shuffle=True,
                collate_fn=collate_func, drop_last=True)

In [44]:
dt_tda = PairsDataset(tokenizer, path=TEST_PATH)
dl_tda = DataLoader(dt_tda, batch_size=BATCH_SIZE, shuffle=True,
                collate_fn=collate_func, drop_last=True)

In [45]:
model = BertForMaskedLM.from_pretrained(MODEL_NAME)
model.to(device)
pass

  return self.fget.__get__(instance, owner)()


Load weights from the last checkpoint

In [46]:
ckpt = minio.get_object(WEIGHTS_PATH, type="model")
model_dict = torch.load(ckpt)

In [47]:
model_dict.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict'])

In [48]:
model.load_state_dict(model_dict["model_state_dict"])

<All keys matched successfully>

In [49]:
model.train();

### Compute SVD on dataset

In [50]:
# from SVDmatrix import GetSVD

# get_svd = GetSVD(model=model, dataloader=dl, division_layer=3)

In [51]:
# get_svd.get_matrix()

In [52]:
# len(get_svd.matrix)

In [53]:
# get_svd.compute_svd()

In [54]:
# minio.put_object(get_svd.matrix, save_name="3rd_layer_diff_matrix.pkl", pickle=True)

In [55]:
# matrix = minio.get_object("3rd_layer_diff_matrix.pkl", unpickle=True)

In [56]:
# len(matrix)

In [57]:
# minio.put_object(get_svd.svd, save_name="3rd_layer_diff_svd.pkl", pickle=True)

In [58]:
# different solution (when library internal tobytes interface is implemented --
# this variant is prefered)

# minio.put_object(get_svd.svd[0].tobytes(), save_name="svd_test", pickle=False)

### Training parameters

In [59]:
# scale by number of steps in a cycle
weight_mlm /= n_cosine

weight_cos = CosWeightSum2One(init_coef=1)
#weight_cos = CosWeightSum2One(init_coef=100)

In [60]:
for name, param in model.named_parameters():
    param.requires_grad = name.startswith(f"bert.encoder.layer.{division_layer}")

In [61]:
# vec = torch.normal(0.5,
#                    0.1,
#                    size=(768, ),
#                    requires_grad=False).repeat(BATCH_SIZE, 1)

In [62]:
optimizer = SGD(model.parameters(), lr=learning_rate)
# criterion = CosLoss(alpha=0.5, vec=vec)
criterion = nn.CosineEmbeddingLoss()

num_training_steps = int(num_epochs * len(dl) / n_cosine * n_mlm)

In [63]:
# number of lr sheduler updates
num_training_steps

50460

### Learning rate & averaging

In [64]:
if not (USE_SWA):
    
    # lr_scheduler = get_scheduler(
    #                 name="cosine", optimizer=optimizer, num_warmup_steps=0,
    #                 num_training_steps=num_training_steps
    #                 )

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                        milestones=[len(dl)*i for i in range(1, num_epochs)][:5],
                                                        gamma=0.5,)
    
    # # Note: Averaged model is applicable to custom modules -- not only full models, 
    # so it can be used for module training as well
    if USE_EMA:
        avg_model = ema_swa_utils.AveragedModel(model,  
                                                multi_avg_fn=ema_swa_utils.get_ema_multi_avg_fn(decay=EMA_DECAY)).to(device)
elif USE_SWA:
    pass
    

### Training

In [65]:
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

In [66]:
print(DESCRIPTION)


Model: based on rubert, additionally pretrained for 10 epochs;
Checkpoint: ckpt/pretrained_bert_epoch_9.999976796259556.pt;
Context: 64;
Batch size: 16;

Loss: classic cosine embedding with steps along the MLM gradient;
Loss only over diff in tokenization: True;
N_MLM: 1;
N_Cosine: 10;
Division_layer: 3;
weight_mlm: 1;
weight_cosine: polynomial decay, sum to 1;

LR_SCHEDULER: MultiStep;
Initial learning rate: 0.0005;
Steps: 5, 
Decay: 0.5,
Epochs: 15;

Additional parameters and notes:
EMA: True;
EMA_DECAY: 0.999;

SWA: False;




In [67]:
logger.info(DESCRIPTION)

In [None]:
tda_save_dict = defaultdict(dict)

train(model=model, criteria=criterion, optimizer=optimizer,
      lr_scheduler=lr_scheduler, data=dl, hom_data=dl_tda,
      n_epochs=num_epochs,
      n_cosine=n_cosine, division_layer=division_layer+1,
      weight_mlm=weight_mlm, weight_cos=weight_cos,
      save_every_epoch=3, test_every=5000)

if (USE_EMA | USE_SWA):
    torch.optim.swa_utils.update_bn(dl, avg_model)

logger.info("Saving model...")
buffer = io.BytesIO()
torch.save({
        'model_state_dict': avg_model.state_dict(),
                }, f=buffer)

minio.put_object(buffer.getvalue(), 
                 save_name=f"ckpt/trained_models/{TRY_NAME}/model.pt")

minio.put_object(DESCRIPTION, 
                save_name=f"ckpt/trained_models/{TRY_NAME}/DESCRIPTION.txt", pickle=True)

Epochs:   0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/33640 [00:00<?, ?it/s]

0it [00:00, ?it/s]

ModularLM/data/TDA_FEATURES/tda_save_dict_naive_cosine_with_pretrained_bert_with_ema_0.999.pkl: |####################| 1.42 MB/1.42 MB 100% [elapsed: 00:00 left: 00:00, 4645.71 MB/sec]

In [None]:
# import gc
# gc.collect()
# torch.cuda.empty_cache()