# Test sampled CRF normed version

* Decision from experiments: limit the difference between crf potentials to be less than 10. Otherwise logsumexp would degrede to max.
    * This could be achieved by normalizing and rescaling: `x = scale * (x - x.mean()) / (x.max() - x.min())`

In [6]:
import torch 
import sklearn

import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns

from torch import nn
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from pandas import DataFrame

from collections import Counter
from tqdm import tqdm 
from data_utils import News20Data
from transformers import BertModel, BertTokenizer
from nltk.corpus import stopwords
from matplotlib.pyplot import figure
from matplotlib import collections as mc
from matplotlib.colors import ListedColormap

from frtorch.structure.linear_chain_crf import LinearChainCRF

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
bert = BertModel.from_pretrained('bert-base-uncased')
bert.to('cuda')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# state_matrix = torch.normal(size=[2000, 768], mean=0.0, std=0.01)
# state_matrix = state_matrix.to('cuda')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [39]:
ckpt_path = '../models/bertnet_0.0.5.0/ckpt-e0.ptmodel.pt'
ckpt = torch.load(ckpt_path)
state_matrix = ckpt['state_matrix']

In [None]:
crf = LinearChainCRF('minmax', 10)

In [4]:
STOPWORDS = stopwords.words('english')
STOPWORDS.extend(['"', "'", '.', ',', '?', '!', '-', '[CLS]', '[SEP]', '[PAD]',
  ':', '@', '/', '[', ']', '(', ')', 'would', 'like'])
STOPWORDS = set(STOPWORDS)
batch_size=20
dataset = News20Data(batch_size=batch_size)

Processing dataset ...
Reading data ...
... 0 seconds
Tokenizing and sorting train data ...
... 110 seconds
Tokenizing and sorting dev data ...
... 36 seconds
Tokenizing and sorting test data ...
... 37 seconds


In [55]:
dev_loader = dataset.val_dataloader()
batches = [batch for batch in dev_loader]

In [81]:
data_loader = dataset.train_dataloader()
batches = []

for i, batch in enumerate(data_loader):
    batches.append(batch)
    if(i == 1000): break

In [86]:
bi = 300
emb = bert(batches[bi]['input_ids'].to('cuda'), batches[bi]['attention_mask'].to('cuda'))[0]

In [87]:
emission = torch.einsum('bij,kj->bik', emb[0:1], state_matrix)
transition = torch.matmul(state_matrix, state_matrix.transpose(0, 1))
lens = batches[bi]['sent_lens'][0:1].to('cuda')
print(lens)

tensor([8], device='cuda:0')


In [88]:
transition.max(), transition.min(), transition.mean(), emission.max(), emission.min(), emission.mean()

(tensor(4.5014, device='cuda:0'),
 tensor(-0.9256, device='cuda:0'),
 tensor(0.2110, device='cuda:0'),
 tensor(4.3941, device='cuda:0', grad_fn=<MaxBackward1>),
 tensor(-2.1628, device='cuda:0', grad_fn=<MinBackward1>),
 tensor(0.6683, device='cuda:0', grad_fn=<MeanBackward0>))

In [89]:
with torch.no_grad():
    alpha, log_z_exact = crf.forward_sum(transition, emission, lens)

In [90]:
log_z_exact

tensor([82.7498], device='cuda:0')

In [101]:
with torch.no_grad():
    log_Z_est = crf.forward_approx(
                     state_matrix, 
                     emission, 
                     lens, 
                     sum_size=100, 
                     proposal='softmax',
                     transition_proposal='none',
                     sample_size=1,
                     return_sampled_idx=False
                     )

In [102]:
log_Z_est

tensor([82.5742], device='cuda:0')