In [1]:
import torch
import imp
#import TransformerTrainer
import MyTransformer
import PruningTrainer
imp.reload(PruningTrainer)
imp.reload(MyTransformer)
from PruningTrainer import BaseDataModule, MyPruningTrainer
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor
import random
import numpy as np
import utils
from torch import nn

In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
%env CUBLAS_WORKSPACE_CONFIG :16:8

env: CUBLAS_WORKSPACE_CONFIG=:16:8


In [3]:
DEVICE = "cuda"
BATCH_SIZE = 64
MAX_LEN = 50

torch.manual_seed(SEED)
torch.set_deterministic(True)

data_module = BaseDataModule(
    batch_size=BATCH_SIZE,
    device = DEVICE,
    data_path="./data/eng_rus.txt",
    seed=SEED
)

data_module.prepare_data()

In [4]:
model_params = {
    "src_vocab_size": data_module.src_vocab_len,
    "trg_vocab_size": data_module.trg_vocab_len,
    "d_model": 512,
    "n_enc_layers": 6,
    "n_dec_layers": 6,
    "n_enc_heads": 8,
    "n_dec_heads": 8,
    "enc_dropout": 0.1,
    "dec_dropout": 0.1
}

In [5]:
model = MyTransformer.Transformer(**model_params)
checkpoint = torch.load("models/transformer_model.pt")
model.load_state_dict(checkpoint)
model.to(DEVICE)
pass

In [6]:
plmodel = MyPruningTrainer(
    model, data_module.src_pad_idx, data_module.trg_pad_idx, 1e-4
)
plmodel.to(DEVICE)
pass

In [8]:
N_EPOCHS = 10
CLIP = 1
plmodel.lr = 1e-4

tb_logger = pl_loggers.TensorBoardLogger('./logs/')
lr_monitor = LearningRateMonitor(logging_interval='step')
early_stop_callback = EarlyStopping(
   monitor='total_val_loss',
   min_delta=0.01,
   patience=2,
   verbose=False,
   mode='mean'
)
trainer = Trainer(
    max_epochs=N_EPOCHS,
    gradient_clip_val=CLIP,
    progress_bar_refresh_rate=1,
    callbacks=[early_stop_callback, lr_monitor], 
    logger=tb_logger,
    log_every_n_steps=20
)
data_module.setup('fit')
trainer.fit(plmodel, data_module)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | model     | Transformer      | 57.1 M
2 | pruner    | Pruner           | 180   
-----------------------------------------------
57.1 M    Trainable params
36        Non-trainable params
57.1 M    Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

In [None]:
import Pruner
import imp
imp.reload(Pruner)
from Pruner import Pruner

pruner = 

In [9]:
plmodel.pruner.enc_attn_weights[0]
plmodel.pruner.get_probs()

In [72]:
a = plmodel.pruner.get_all_gates()["enc_gates"]
for i in range(100):
    a += plmodel.pruner.get_all_gates()["enc_gates"]

a

tensor([[65., 73., 76., 80., 72., 69., 70., 78.],
        [71., 70., 74., 72., 66., 75., 82., 81.],
        [82., 78., 73., 75., 77., 74., 69., 72.],
        [74., 71., 76., 77., 81., 79., 67., 79.],
        [76., 77., 82., 69., 77., 75., 69., 68.],
        [78., 83., 73., 77., 73., 80., 82., 70.]], device='cuda:0')

In [22]:
 with torch.no_grad():
    mask = plmodel.pruner.get_all_gates()["enc_gates"]
    d_size = plmodel.pruner.enc_attn_weights[0].shape[0]
    for m, w in zip(mask, plmodel.pruner.enc_attn_weights):
        w_new = w.view(d_size, 8, -1) * m.view(1, -1, 1)
        w.copy_(w_new.view(d_size, d_size))

In [40]:
(plmodel.model.encoder.layers[0].attn.fc_out.weight.view(d_size, 8, -1).permute(1, 0 , -1)[6] == 0).sum()

tensor(32768, device='cuda:0')

In [30]:
(plmodel.model.encoder.layers[0].attn.fc_out.weight == 0).sum()

tensor(32768, device='cuda:0')

In [13]:
plmodel.pruner.get_total_sparsity_rate()

{'enc_gates': 0.8333333333333334, 'dec_gates': 0.7708333333333334}