This notebook shows how to apply the Input level bottleneck to pretrained
4-layer LSTM models on IMDB. Before running this notebook, please make sure
that:

1. All the required packages are installed.

2. Pretrained weight `path/to/informationbottleneck/pretrained/deep_lstm.pt`
 exists.

In [1]:
import torch
import mmcv
from torch.utils.data import DataLoader
from iba.datasets import build_dataset
from iba.models import build_attributor
import os

device = 'cuda:0'

Change the working directory to `path/to/informationbottleneck/`, modify
this if necessary.

In [2]:
# cwd switch from `informationbottleneck/tutorials/` to
# `informationbottleneck/`
os.chdir('..')
cfg_path = 'configs/deep_lstm.py'

In [3]:
cfg = mmcv.Config.fromfile(cfg_path)

In [4]:
dataset = build_dataset(cfg.data['attribution'])
datapoint = next(iter(dataset))

In [5]:
# exam one data point
print("Plain text: {}".format(datapoint['input_text']))
print("Processed text as tensor: {}".format(datapoint['input']))
print("Target class: {}".format(datapoint['target']))
print("File name: {}".format(datapoint['input_name']))
print("Text length: {}".format(datapoint['input_length']))

Plain text: Zentropa has much in common with The Third Man, another noir-like film set among the rubble of postwar Europe. Like TTM, there is much inventive camera work. There is an innocent American who gets emotionally involved with a woman he doesn't really understand, and whose naivety is all the more striking in contrast with the natives.<br /><br />But I'd have to say that The Third Man has a more well-crafted storyline. Zentropa is a bit disjointed in this respect. Perhaps this is intentional: it is presented as a dream/nightmare, and making it too coherent would spoil the effect. <br /><br />This movie is unrelentingly grim--"noir" in more than one sense; one never sees the sun shine. Grim, but intriguing, and frightening.
Processed text as tensor: tensor([13824,    52,    81,    12,  1125,    20,     2,   852,   135,     4,
          164,     0,    23,   293,   769,     2, 15259,     7, 13683,  2278,
            3,    45,     0,     4,    46,    10,    81,  4385,   391,   170,

# Information flow to generate input level attribution map for text data

In [6]:
attributor = build_attributor(cfg.attributor, default_args=dict(device=device))



In [7]:
from iba.datasets import nlp_collate_fn
dataloader = DataLoader(dataset,
                        collate_fn=nlp_collate_fn,
                        **cfg.data['data_loader'])

Estimate the distribution for information bottleneck at the feature map level.

This will take a while.

In [8]:
attributor.estimate(dataloader, cfg.estimation_cfg)
attributor.feat_iba.estimator.mean().shape

torch.Size([256])

## Train Attributor on a sample text
The training pipeline is integrated into *attributor* class

In [9]:
datapoint = next(iter(dataset))
target =  datapoint['target']
input_text = datapoint['input_text']
input_tensor = datapoint['input']

Assume the batch size of the feature iba is `batch_size`,
then expand the the target by `batch_size` times to match the feature iba.

In [10]:
input_tensor = input_tensor.to(device)
feat_iba_batch_size = cfg.attribution_cfg['feat_iba']['batch_size']
target = torch.tensor([target]).expand(
    (feat_iba_batch_size, -1)).to(torch.float32)
target = target.to(device)

attributor.set_text(input_text)
attributor.make_attribution(input_tensor,
                            target,
                            attribution_cfg=cfg.attribution_cfg)

## Display feature mask from IBA (already summed over channels)

We highlight tokens with different colors based on their attribution value,
dark red means the token is very important for model decision, shallower color
means the token is not important for model decision

In [11]:
# tokenizer is needed to divide text into tokens, s
# so we can assign attribution value
from torchtext.data.utils import get_tokenizer
tokenizer = get_tokenizer('basic_english')

In [12]:
attributor.show_feat_mask(tokenizer=tokenizer, show=True)

## Display final input mask (word level) learned from input IB

In [13]:
attributor.show_input_mask(tokenizer=tokenizer, show=True)