In [75]:
import torch

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import GPT2Model

  from .autonotebook import tqdm as notebook_tqdm


# Import ASG

In [2]:
NTDBGPT2_lm = AutoModelForCausalLM.from_pretrained('dracoglacius/NTDB-GPT2')
NTDBGPT2_tokenizer = AutoTokenizer.from_pretrained('dracoglacius/NTDB-GPT2')
NTDBGPT2_embed = GPT2Model.from_pretrained('dracoglacius/NTDB-GPT2')

Downloading: 100%|██████████████████████████████████████████████████| 992/992 [00:00<00:00, 411kB/s]
Downloading: 100%|███████████████████████████████████████████████| 374M/374M [00:29<00:00, 13.4MB/s]
Downloading: 100%|██████████████████████████████████████████████████| 421/421 [00:00<00:00, 749kB/s]
Downloading: 100%|███████████████████████████████████████████████| 152k/152k [00:00<00:00, 1.99MB/s]
Downloading: 100%|███████████████████████████████████████████████| 68.0/68.0 [00:00<00:00, 23.3kB/s]
Some weights of the model checkpoint at dracoglacius/NTDB-GPT2 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing 

# Helper Functions

In [227]:
def decode_sequence(sequences):
    return NTDBGPT2_tokenizer.decode(sequences[0])

In [228]:
def get_n_layer_hidden_embeddings(hidden_states, layer_n=12):
    # Default to getting the last layer (12)
    
    # The start sequence embeddings are in the first tuple element
    _start = torch.flatten(hidden_states[0][layer_n], start_dim=-1)
    
    # The rest of the sequence embeddings are obtained
    _hs = [x[layer_n] for x in hidden_states[1:]]
    _hs = torch.concat(_hs, dim=1)
    
    return torch.concat([_start, _hs], dim=1).squeeze(dim=0)

# Testing Scenario Generation

In [215]:
SEARCH_ECODES = ['E885.9','E812.0','E966.0']

ecode_key = SEARCH_ECODES[0]

input_seq = f"<START> {ecode_key} <DSTART>"

seq_ids = NTDBGPT2_tokenizer.encode(input_seq, return_tensors='pt')

#### Generate Scenario from Stem

In [226]:
out_wft = NTDBGPT2_lm.generate(
    seq_ids, 
    do_sample=True,
    min_length=10,
    max_length=12,
    #top_p=0.9, 
    top_k=0,
    return_dict_in_generate=True,
    forced_eos_token_id=NTDBGPT2_tokenizer.eos_token_id,
    #repetition_penalty=3.0,
    #length_penalty=1.0,
    #num_return_seqs=1,
    output_hidden_states=True
)

In [229]:
out_wft.sequences

tensor([[  0,   9,   1,  18,  25,   3, 919,  24,   4,   5,   2]])

In [230]:
len(out_wft.sequences[0])

11

In [231]:
decode_sequence(out_wft.sequences)

'<START> E885.9 <DSTART> 805.4 805.2 <PSTART> 81.66 88.26 88.38 87.03 <END>'

In [232]:
em = get_n_layer_hidden_embeddings(out_wft.hidden_states)

In [284]:
em.shape

torch.Size([10, 768])

#### Shape Discrepency

`<END>` is being treated as `eos` and therefore the output is ignored. The final token's value is therefore thought to contain the "best" representation of the whole sequence.

In [243]:
em

tensor([[ 0.6226,  0.1586, -0.0762,  ...,  0.3346,  0.4395, -0.2872],
        [ 0.1316, -0.3097, -0.6136,  ..., -0.1229,  0.0251,  0.0816],
        [ 0.0113, -0.5432, -1.4481,  ..., -0.1895,  0.4446,  0.2490],
        ...,
        [ 0.6987,  0.1755, -0.5157,  ...,  0.0493, -0.4683, -0.3290],
        [ 0.7893,  0.2013, -0.4699,  ..., -0.0489, -0.4174, -0.2398],
        [ 1.0890,  0.3541, -0.3962,  ..., -0.2973, -0.5481, -0.1599]])

In [247]:
em[:5,:]

tensor([[ 0.6226,  0.1586, -0.0762,  ...,  0.3346,  0.4395, -0.2872],
        [ 0.1316, -0.3097, -0.6136,  ..., -0.1229,  0.0251,  0.0816],
        [ 0.0113, -0.5432, -1.4481,  ..., -0.1895,  0.4446,  0.2490],
        [ 0.5896, -1.0609, -0.2501,  ...,  0.1293,  0.2424, -0.3202],
        [ 0.5839, -0.6312,  0.1603,  ...,  0.2545,  0.4496, -0.2992]])

In [233]:
em.shape

torch.Size([10, 768])

#### Generate Second Scenario

In [276]:
out_wft2 = NTDBGPT2_lm.generate(
    NTDBGPT2_tokenizer.encode('<START> E885.9', return_tensors='pt'), 
    do_sample=True,
    min_length=10,
    max_length=12,
    #top_p=0.9, 
    top_k=0,
    return_dict_in_generate=True,
    forced_eos_token_id=NTDBGPT2_tokenizer.eos_token_id,
    #repetition_penalty=3.0,
    #length_penalty=1.0,
    #num_return_seqs=1,
    output_hidden_states=True
)

In [277]:
decode_sequence(out_wft2.sequences)

'<START> E885.9 <DSTART> 958.4 860 807.04 <PSTART> 88.38 93.23 88.77 99.04 <END>'

In [278]:
em2 = get_n_layer_hidden_embeddings(out_wft2.hidden_states)

In [279]:
em2

tensor([[ 0.6226,  0.1586, -0.0762,  ...,  0.3346,  0.4395, -0.2872],
        [ 0.1316, -0.3097, -0.6136,  ..., -0.1229,  0.0251,  0.0816],
        [ 0.0114, -0.5432, -1.4481,  ..., -0.1895,  0.4446,  0.2490],
        ...,
        [ 0.4292, -0.2142, -0.6003,  ...,  0.0190, -0.6618, -0.4300],
        [ 0.0876, -0.0530, -0.2207,  ...,  0.0445, -0.8208, -0.4123],
        [ 0.1489,  0.2430, -0.7810,  ...,  0.0136, -0.6779, -0.3698]])

In [280]:
em2[:5,:]

tensor([[ 0.6226,  0.1586, -0.0762,  ...,  0.3346,  0.4395, -0.2872],
        [ 0.1316, -0.3097, -0.6136,  ..., -0.1229,  0.0251,  0.0816],
        [ 0.0114, -0.5432, -1.4481,  ..., -0.1895,  0.4446,  0.2490],
        [ 0.2790, -0.7283, -1.2843,  ...,  0.1576,  0.8097,  0.1383],
        [ 0.9445, -0.2814, -1.3363,  ..., -0.4390,  0.0595, -0.1846]])

## Assess Using Cosine Similarity

* Sequence 1: `<START> E885.9 <DSTART> 805.4 805.2 <PSTART> 81.66 88.26 88.38 87.03 <END>`
* Sequence 2: `<START> E885.9 <DSTART> 958.4 860 807.04 <PSTART> 88.38 93.23 88.77 99.04 <END>`

#### Observations

* Only the first three tokens overlap.
* Therefore using cosine similarity, we expect only the first three tokens to be the same

In [281]:
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

In [282]:
cos(em[:5,:], em2[:5,:])

tensor([1.0000, 1.0000, 1.0000, 0.6948, 0.4636])

Results

* As we thought, only the first three tokens are perfectly aligned.
* Diagnostic codes are not considered very similar
* Diagnostic code and `<PSTART>` is considered to be very dissimilar
* However `<PSTART>` and the first procedure code is considered similar

In [283]:
cos(em[:10,:], em2[:10,:])

tensor([ 1.0000,  1.0000,  1.0000,  0.6948,  0.4636, -0.5227,  0.9388,  0.9521,
         0.9006,  0.9390])