In [51]:
import torch
import imp
import MyTransformer
import PruningTrainer
imp.reload(PruningTrainer)
imp.reload(MyTransformer)
from PruningTrainer import MyPruningTrainer
from DataModule import BaseDataModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor
import random
import numpy as np
import utils
from torch import nn
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import display, Image

In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
%env CUBLAS_WORKSPACE_CONFIG :16:8

env: CUBLAS_WORKSPACE_CONFIG=:16:8


In [3]:
DEVICE = "cuda"
BATCH_SIZE = 64
MAX_LEN = 50

torch.manual_seed(SEED)
torch.set_deterministic(True)

data_module = BaseDataModule(
    batch_size=BATCH_SIZE,
    device = DEVICE,
    data_path="./data/eng_rus.txt",
    seed=SEED
)

data_module.prepare_data()

In [4]:
model_params = {
    "src_vocab_size": data_module.src_vocab_len,
    "trg_vocab_size": data_module.trg_vocab_len,
    "d_model": 512,
    "n_enc_layers": 6,
    "n_dec_layers": 6,
    "n_enc_heads": 8,
    "n_dec_heads": 8,
    "enc_dropout": 0.1,
    "dec_dropout": 0.1
}

In [8]:
model = MyTransformer.Transformer(**model_params)
checkpoint = torch.load("models/transformer_model_pruned.pt")
model.load_state_dict(checkpoint)
model.to(DEVICE)
pass

In [9]:
plmodel = MyPruningTrainer(
    model, data_module.src_pad_idx, data_module.trg_pad_idx, 1e-4
)
plmodel.to(DEVICE)
pass

In [7]:
N_EPOCHS = 6
CLIP = 1
plmodel.lr = 1e-4

tb_logger = pl_loggers.TensorBoardLogger('./logs/')
lr_monitor = LearningRateMonitor(logging_interval='step')
early_stop_callback = EarlyStopping(
   monitor='total_val_loss',
   min_delta=0.01,
   patience=2,
   verbose=False,
   mode='mean'
)
trainer = Trainer(
    max_epochs=N_EPOCHS,
    gradient_clip_val=CLIP,
    progress_bar_refresh_rate=1,
    callbacks=[lr_monitor], 
    logger=tb_logger,
    log_every_n_steps=20
)
data_module.setup('fit')
with torch.autograd.set_detect_anomaly(True):
    trainer.fit(plmodel, data_module)

GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | model     | Transformer      | 57.1 M
2 | pruner    | Pruner           | 57.1 M
-----------------------------------------------
27.1 M    Trainable params
30.0 M    Non-trainable params
57.1 M    Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [10]:
#torch.save(plmodel.model.state_dict(), 'models/transformer_model_pruned.pt')

In [36]:
plmodel.pruner.get_total_sparsity_rate()

0.604


### В модели осталось меньше 2/3 от изначального числа голов
### Посчитаем теперь bleu на модели с удаленными головами

In [37]:
utils.calculate_bleu(
    data = data_module.test_iter, 
    src_field = data_module.src_field, 
    trg_field = data_module.trg_field,
    model = plmodel.model,
    device=DEVICE
)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7500.0), HTML(value='')))




0.24293770822074123

## 0.242 bleu у срезанной модели против 0.261 у оригинальной

## Визуализируем как выключались головы по мере обучения

In [11]:
all_data = []
with open("./data/probs.txt", 'r') as file:
    for line in file.readlines():
        all_data.extend([float(item) for item in line.split()])

all_data = np.array(all_data).reshape([-1, 6, 8])

In [33]:
for i, data in enumerate(all_data):
    fig,ax = plt.subplots(figsize=(8, 5.5))
    fig.suptitle(f"batch: {i*10}", fontsize=16)
    heatmap = sns.heatmap(data, vmin=0, vmax=1, cmap="YlGnBu_r")
    plt.xlabel('head', fontsize=14)
    plt.ylabel('layer', fontsize=14)
    ax.set_xticklabels(range(1, 9))
    ax.set_yticklabels(range(1, 7))
    plt.savefig(f"images/image{i}.png")
    plt.close()

In [34]:
from PIL import Image as PilImage

im1 = PilImage.open("./images/image0.png")
imgs = (PilImage.open(f"./images/image{i}.png") for i in range(377))

im1.save(fp="./gifs/pruning.gif", format='GIF', append_images=imgs,
         save_all=True, duration=50, loop=0)

In [45]:
!gifsicle -O3 --colors 256 --lossy=30 -o ./temp/test2.gif ./temp/test.gif


In [53]:
with open('./gifs/pruning.gif','rb') as f:
    display(Image(data=f.read(), format='png'))
