# Setup Environment

In [12]:
# External Imports
import time, os, json
import numpy as np
import matplotlib.pyplot as plt
import datasets
import torch
import random
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from pprint import pprint
from scipy.stats import spearmanr
from torch.nn import functional as F


# Internal Imports
import mltoolkit as mltk
from mltoolkit import (
    cfg_reader,
    models,
)
from mltoolkit.utils import (
    strings,
    files,
    display,
)

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Load Config, Tokenizer and Model Checkpoint

In [55]:
# select checkpoint
checkpoints = [
    f'{files.project_root()}/checkpoints/20231012-223600-text-autoencoding/best_model.pt', # no masking, 1-layer enc-dec
    f'{files.project_root()}/checkpoints/20231017-144245-text-autoencoding/best_model.pt', # masking with .1, 1-layer enc-dec
    f'{files.project_root()}/checkpoints/20231017-162325-text-autoencoding/best_model.pt', # masking with .15, 1-layer enc-dec
    f'{files.project_root()}/checkpoints/20231017-221729-text-autoencoding/best_model.pt', # masking with .15, 3-layer enc-dec
]

ckpt_path = checkpoints[3]

# read config
cfg_path = f'{files.project_root()}/cfg/nlp/text_autoencoding/dev_config.yaml'
cfg, keywords = cfg_reader.load(cfg_path)

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(cfg.data['tokenizer_name'])

# assign config values for model
cfg.model['vocab_size'] = len(tokenizer)
cfg.model['pad_token_id'] = tokenizer.pad_token_id

# load model
model = models.TextAutoencoder(cfg).to('cuda:3')
model.load_state_dict(torch.load(ckpt_path))
model.eval()

# display model info
print(f'device is: {model.enc_embeddings.weight.device}')
print()
print(model)

device is: cuda:3

TextAutoencoder(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.05, inplace=False)
        (linear2): Linear(in_features=768, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.05, inplace=False)
        (dropout2): Dropout(p=0.05, inplace=False)
      )
    )
  )
  (enc_embeddings): Embedding(30527, 768, padding_idx=1)
  (dec_embeddings): Embedding(30527, 768, padding_idx=1)
  (deconv): Sequential(
    (0): Sequential(
      (0): Conv1d(1, 16, kernel_size=(33,), stride=(1,), padding=(16,))
      (1): ReLU()
    )
    (1)

# Load Text Data

In [58]:
ds = Dataset.from_csv('/data/john/projects/mltoolkit/data/synthetic_overlap/synthetic_overlap_data.csv')
#ds = datasets.load_dataset('ptb_text_only')['test']
pprint(ds)

Dataset({
    features: ['overlap', 's1', 's2'],
    num_rows: 39263
})


# Encode Some Text

In [59]:
sample_size = 10
sample_indices = np.random.choice(
    np.arange(len(ds)), 
    (sample_size,)
)

sentences = ds[sample_indices]['s1']
enc = model.encode(sentences)

print(enc)
print(enc.shape)

tensor([[-0.0833, -0.0631,  0.0502,  ..., -0.1938,  0.1250,  0.1620],
        [-0.0941,  0.0610, -0.0236,  ..., -0.2886,  0.1770,  0.3198],
        [-0.0519, -0.0501,  0.0676,  ..., -0.1609,  0.0941,  0.1051],
        ...,
        [-0.1627,  0.0996, -0.0095,  ..., -0.3014,  0.2292,  0.2978],
        [-0.0732, -0.0144,  0.0468,  ..., -0.1544,  0.0203,  0.0842],
        [-0.1123,  0.0230, -0.0161,  ..., -0.1778,  0.1749,  0.1980]],
       device='cuda:3')
torch.Size([10, 768])


# Decode the Encodings

In [60]:
dec = model.decode(enc)

for i, s in enumerate(dec):
    print(f'\nx:\t{sentences[i]}\nx_hat:\t{s}')




x:	After being shot down, he was rescued by a U.S. submarine, also a congressman, ambassador, and CIA director.
x_hat:	after being shot down, he was rescued by a u. s. submarine, also a craftsman, and for the playstation, etc.

x:	Moore questioned why the bartender didn't write 'sexy n****' on the check, despite believing she only checked Facebook comments after getting caught.
x_hat:	moore questioned why the bartender didn't write'sexy n * * * *'for the checkout, and his last time, and his mother had been.

x:	West Ham's promising start to the season has been overshadowed by a significant dip in form and booing of captain Kevin Nolan.
x_hat:	west ham's promising start to the season has been assassinated by a significant intruder in form and booze of the st. louis.

x:	The men used a deodorant can and lighter to set an Australian quokka on fire, as shown in a horrifying video.
x_hat:	the men used a deodorant can and lighter to set an australian traveler propeller on fire, and in a vas

# Test on My Own Writing

In [17]:
my_sents = [
    'can you tell me how to find the gas station?',
    'i really cannot stand washing the dishes.',
]

my_enc = model.encode(my_sents)
my_dec = model.decode(my_enc)

print(my_dec)

['can you tell me how to find the gas station?', 'i really cannot stand washing the dishes.']


# Test on Overlap Dataset

In [27]:
# read in dataset
overlap_ds = Dataset.from_csv('/data/john/projects/mltoolkit/data/synthetic_overlap/synthetic_overlap_data.csv')

# get encodings
s1_enc = model.encode(overlap_ds[:10]['s1'])
s2_enc = model.encode(overlap_ds[:10]['s2'])
overlap_enc = model.encode(overlap_ds[:10]['overlap'])

# get intermediate encodings and decodings
middle_enc = (s1_enc + s2_enc)/2
middle_dec = model.decode(middle_enc)

# get euclidean distances between middle encodings and overlap encodings
l2_dist = torch.sum((middle_enc * overlap_enc)**2, dim=-1)
print(f'l2: {l2_dist}')

# compare cosine similarity
overlap_cos_sim = F.cosine_similarity(middle_enc, overlap_enc, dim=-1)
print(f'cosine similarities: {overlap_cos_sim}')

# compare decodings with overlap
for i, s in enumerate(middle_dec):
    print(f'\ny:\t{overlap_ds[i]["overlap"]}\ny_hat:\t{s}')



l2: tensor([0.0113, 0.0232, 0.2107, 0.8175, 0.0220, 0.0885, 0.0406, 0.0616, 0.0214,
        0.0262], device='cuda:3')
cosine similarities: tensor([0.8394, 0.8285, 0.8796, 0.8536, 0.8639, 0.8192, 0.8628, 0.8421, 0.8409,
        0.8029], device='cuda:3')

x:	She said: 'Mike and I have a great sex life.
x_hat:	once lakese's life might can of he or moapeency because hayden are slower for their demeanor life is.

x:	Travellers will learn about the ship in a series of lectures in addition to the dive .
x_hat:	following the expansion factor sidedrs the tenants column of time trapping the thrill the on damp mph ( s. ) were located the battleship to the west.

x:	The crew ordered an evacuation after the Boeing 737-800 came to a stop, and passengers slid down the plane’s inflatable slides to safety.
x_hat:	taped the illumination'messedjeeels attendants deployed as is crew including tiny legions, sides'the stepbrothered us in the united states, the latter who had been a victim for them.

x:	Jaili

# What Does $dec(enc(dec(\frac{enc(s_1) + enc(s_2)}{2})))$ Look Like?

In [31]:
double_enc = model.encode(middle_dec)
double_dec = model.decode(double_enc)

double_cos_sim = F.cosine_similarity(double_enc, middle_enc, dim=-1)

print(f'cos_sim(double_enc, middle_enc): {double_cos_sim}')

# compare double decodings with original decodings
for i, s in enumerate(middle_dec):
    print(f'\nx:\t{double_dec[i]}\nx_hat:\t{s}')

cos_sim(double_enc, middle_enc): tensor([0.9722, 0.9816, 0.9773, 0.9879, 0.9695, 0.9753, 0.9776, 0.9811, 0.9754,
        0.9816], device='cuda:3')

x:	once lakese's life might can of he or moapeency because hayden are animals, their struggle is not the source.
x_hat:	once lakese's life might can of he or moapeency because hayden are slower for their demeanor life is.

x:	following the expansion factorrandrs the cadets column of time trapping the thrill the on damp mph ( s. ) were located the west of the pacific.
x_hat:	following the expansion factor sidedrs the tenants column of time trapping the thrill the on damp mph ( s. ) were located the battleship to the west.

x:	taped the verses'messedjeeels attendants deployed as is crew including tiny legions, sides'the booted by the bar for a week, the two @ - @ yard touchdown in the morning.
x_hat:	taped the illumination'messedjeeels attendants deployed as is crew including tiny legions, sides'the stepbrothered us in the united states, the 

# Evaluate Embeddings On STS

In [56]:

# load dataset
sts = datasets.load_dataset('mteb/sts16-sts').with_format('torch')
a = sts['test']['sentence1']
b = sts['test']['sentence2']
labels = sts['test']['score']

# encode sentences
enc_a = model.encode(a)
enc_b = model.encode(b)

# compute similarity scores
sim_scores = F.cosine_similarity(enc_a, enc_b, dim=-1)
dist_scores = torch.sum((enc_a-enc_b)**2, dim=-1)

# compute spearman correlation for cosine similarity scores and negative distance scores
rho_cos = spearmanr(sim_scores.cpu().numpy(), labels)
rho_dist = spearmanr(-dist_scores.cpu().numpy(), labels)

print(f'spearman score (cos): {rho_cos.statistic:.2f}')
print(f'spearman score (dist): {rho_dist.statistic:.2f}')

spearman score (cos): 0.17
spearman score (dist): 0.20


In [48]:
for ss, ds, l in zip(sim_scores, dist_scores, labels):
    print(f'sim score: {ss:.2f}, dist score: {ds:.2f}, label: {l}')

sim score: 0.98, dist score: 0.13, label: 3.0
sim score: 0.77, dist score: 0.48, label: 3.0
sim score: 0.81, dist score: 0.42, label: 0.0
sim score: 0.80, dist score: 0.76, label: 0.0
sim score: 0.77, dist score: 0.27, label: 2.0
sim score: 0.85, dist score: 0.81, label: 0.0
sim score: 0.72, dist score: 0.32, label: 5.0
sim score: 0.80, dist score: 0.19, label: 3.0
sim score: 0.82, dist score: 1.11, label: 2.0
sim score: 0.86, dist score: 0.84, label: 0.0
sim score: 0.77, dist score: 0.50, label: 2.0
sim score: 0.74, dist score: 0.26, label: 5.0
sim score: 0.78, dist score: 0.20, label: 1.0
sim score: 0.88, dist score: 0.35, label: 4.0
sim score: 0.98, dist score: 0.02, label: 4.0
sim score: 0.88, dist score: 0.76, label: 1.0
sim score: 0.79, dist score: 1.10, label: 1.0
sim score: 0.82, dist score: 0.47, label: 1.0
sim score: 0.67, dist score: 0.28, label: 4.0
sim score: 0.86, dist score: 0.76, label: 1.0
sim score: 0.87, dist score: 0.24, label: 5.0
sim score: 0.73, dist score: 0.23,