In [1]:
import gensim
import codecs
import numpy as np
import torch
from torch.nn import init
from torch.nn.parameter import Parameter
from tqdm.notebook import tqdm
from argparse import Namespace
from sklearn.cluster.k_means_ import MiniBatchKMeans



In [2]:
class Sentences:
    def __init__(self, filename: str):
        self.filename = filename
        self.num_lines = sum(1 for line in open(filename))

    def __iter__(self):
        for line in tqdm(
            codecs.open(self.filename, "r", encoding="utf-8"), 
            self.filename, 
            self.num_lines
        ):
            yield line.strip().split()

In [3]:
def read_data_batches(path, batch_size=50, minlength=5):
    """
        Reading batched texts of given min. length
    :param path: path to the text file ``one line -- one normalized sentence''
    :return: batches iterator
    """
    batch = []

    for line in open(path, encoding="utf-8"):
        line = line.strip().split()

        # lines with less than `minlength` words are omitted
        if len(line) >= minlength:
            batch.append(line)
            if len(batch) >= batch_size:
                yield batch
                batch = []

    if len(batch) > 0:
        yield batch

In [4]:
def get_num_batches(path, batch_size=50, minlength=5):
    count = 0
    batch_count = 0
    
    for line in open(path, encoding="utf-8"):

        if len(line) >= minlength:
            batch_count += 1
            if batch_count >= batch_size:
                count += 1
                batch_count = 0
    
    return count

In [5]:
def text2vectors(text, w2v_model, maxlen, vocabulary):
    """
        Token sequence -- to a list of word vectors;
        if token not in vocabulary, it is skipped; the rest of
        the slots up to `maxlen` are replaced with zeroes
    :param text: list of tokens
    :param w2v_model: gensim w2v model
    :param maxlen: max. length of the sentence; the rest is just cut away
    :return:
    """

    acc_vecs = []

    for word in text:
        if word in w2v_model and (vocabulary is None or word in vocabulary):
            acc_vecs.append(w2v_model.wv[word])

    # padding for consistent length with ZERO vectors
    if len(acc_vecs) < maxlen:
        acc_vecs.extend([np.zeros(w2v_model.vector_size)] * (maxlen - len(acc_vecs)))

    return acc_vecs

In [6]:
def read_data_tensors(
    path, 
    batch_size=50, 
    vocabulary=None,
    maxlen=100, 
    pad_value=0, 
    minsentlength=5,
    w2v_model=None,
):
    """
        Data for training the NN -- from text file to word vectors sequences batches
    :param path:
    :param batch_size:
    :param vocabulary:
    :param maxlen:
    :param pad_value:
    :param minsentlength:
    :return:
    """
    for batch in read_data_batches(path, batch_size, minsentlength):
        batch_vecs = []
        batch_texts = []

        for text in batch:
            vectors_as_list = text2vectors(text, w2v_model, maxlen, vocabulary)
            batch_vecs.append(np.asarray(vectors_as_list[:maxlen], dtype=np.float32))
            batch_texts.append(text)

        yield np.stack(batch_vecs, axis=0), batch_texts

In [7]:
def get_centroids(w2v_model, aspects_count):
    """
        Clustering all word vectors with K-means and returning L2-normalizes
        cluster centroids; used for ABAE aspects matrix initialization
    """

    km = MiniBatchKMeans(n_clusters=aspects_count, verbose=0, n_init=100)
    m = []

    for k in w2v_model.wv.vocab:
        m.append(w2v_model.wv[k])

    m = np.matrix(m)

    km.fit(m)
    clusters = km.cluster_centers_

    # L2 normalization
    norm_aspect_matrix = clusters / np.linalg.norm(clusters, axis=-1, keepdims=True)

    return norm_aspect_matrix

In [8]:
class SelfAttention(torch.nn.Module):
    def __init__(self, wv_dim: int, maxlen: int):
        super(SelfAttention, self).__init__()
        self.wv_dim = wv_dim

        # max sentence length -- batch 2nd dim size
        self.maxlen = maxlen
        self.M = Parameter(torch.empty(size=(wv_dim, wv_dim)))
        init.kaiming_uniform(self.M.data)

        # softmax for attending to wod vectors
        self.attention_softmax = torch.nn.Softmax()

    def forward(self, input_embeddings):
        # (b, wv, 1)
        
        input_embeddings = input_embeddings.to(args.device)
        mean_embedding = torch.mean(input_embeddings, (1,)).unsqueeze(2).to(args.device)

        # (wv, wv) x (b, wv, 1) -> (b, wv, 1)
        product_1 = torch.matmul(self.M, mean_embedding)

        # (b, maxlen, wv) x (b, wv, 1) -> (b, maxlen, 1)
        product_2 = torch.matmul(input_embeddings, product_1).squeeze(2)

        results = self.attention_softmax(product_2)

        return results

    def extra_repr(self):
        return f'wv_dim={self.wv_dim}, maxlen={self.maxlen}'

In [27]:
class ABAE(torch.nn.Module):
    """
        The model described in the paper ``An Unsupervised Neural Attention Model for Aspect Extraction''
        by He, Ruidan and  Lee, Wee Sun  and  Ng, Hwee Tou  and  Dahlmeier, Daniel, ACL2017
        https://aclweb.org/anthology/papers/P/P17/P17-1036/
    """

    def __init__(
        self, 
        wv_dim: int = 200, 
        asp_count: int = 30,
        ortho_reg: float = 0.1, 
        maxlen: int = 201, 
        init_aspects_matrix=None
    ):
        """Initializing the model
        
        :param wv_dim: word vector size
        :param asp_count: number of aspects
        :param ortho_reg: coefficient for tuning the ortho-regularizer's influence
        :param maxlen: sentence max length taken into account
        :param init_aspects_matrix: None or init. matrix for aspects
        """
        super(ABAE, self).__init__()
        self.wv_dim = wv_dim
        self.asp_count = asp_count
        self.ortho = ortho_reg
        self.maxlen = maxlen

        self.attention = SelfAttention(wv_dim, maxlen)
        self.linear_transform = torch.nn.Linear(self.wv_dim, self.asp_count)
        self.softmax_aspects = torch.nn.Softmax()
        self.aspects_embeddings = Parameter(torch.empty(size=(wv_dim, asp_count)))

        if init_aspects_matrix is None:
            torch.nn.init.xavier_uniform(self.aspects_embeddings)
        else:
            self.aspects_embeddings.data = torch.from_numpy(init_aspects_matrix.T)

    def get_aspects_importances(self, text_embeddings):
        """Takes embeddings of a sentence as input, returns attention weights

        """

        # compute attention scores, looking at text embeddings average
        attention_weights = self.attention(text_embeddings)

        # multiplying text embeddings by attention scores -- and summing
        # (matmul: we sum every word embedding's coordinate with attention weights)
        weighted_text_emb = torch.matmul(attention_weights.unsqueeze(1),  # (batch, 1, sentence)
                                         text_embeddings  # (batch, sentence, wv_dim)
                                         ).squeeze()

        # encoding with a simple feed-forward layer (wv_dim) -> (aspects_count)
        raw_importances = self.linear_transform(weighted_text_emb)

        # computing 'aspects distribution in a sentence'
        aspects_importances = self.softmax_aspects(raw_importances)

        return attention_weights, aspects_importances, weighted_text_emb

    def forward(self, text_embeddings, negative_samples_texts):

        # negative samples are averaged
        averaged_negative_samples = torch.mean(negative_samples_texts, dim=2)

        # encoding: words embeddings -> sentence embedding, aspects importances
        _, aspects_importances, weighted_text_emb = self.get_aspects_importances(text_embeddings)

        # decoding: aspects embeddings matrix, aspects_importances -> recovered sentence embedding
        recovered_emb = torch.matmul(self.aspects_embeddings, aspects_importances.unsqueeze(2)).squeeze()
        print(f'recovered_emb: {recovered_emb.shape}')

        # loss
        reconstruction_triplet_loss = ABAE._reconstruction_loss(
            weighted_text_emb,
            recovered_emb,
            averaged_negative_samples,
        )
        print(f'reconstruction_triplet_loss: {reconstruction_triplet_loss.shape}')
        
        max_margin = torch.max(reconstruction_triplet_loss, torch.zeros_like(reconstruction_triplet_loss))
        reconstruction_triplet_loss

        return self.ortho * self._ortho_regularizer() + max_margin

    @staticmethod
    def _reconstruction_loss(text_emb, recovered_emb, averaged_negative_emb):

        positive_dot_products = torch.matmul(text_emb.unsqueeze(1), recovered_emb.unsqueeze(2)).squeeze()
        negative_dot_products = torch.matmul(averaged_negative_emb, recovered_emb.unsqueeze(2)).squeeze()
        reconstruction_triplet_loss = torch.sum(1 - positive_dot_products.unsqueeze(1) + negative_dot_products, dim=1)

        return reconstruction_triplet_loss

    def _ortho_regularizer(self):
        return torch.norm(
            torch.matmul(self.aspects_embeddings.t(), self.aspects_embeddings) \
            - torch.eye(self.asp_count).to(args.device))

    def get_aspect_words(self, w2v_model, topn=15):
        words = []

        # getting aspects embeddings
        aspects = self.aspects_embeddings.cpu().detach().numpy()

        # getting scalar products of word embeddings and aspect embeddings;
        # to obtain the ``probabilities'', one should also apply softmax
        words_scores = w2v_model.wv.syn0.dot(aspects)

        for row in range(aspects.shape[1]):
            argmax_scalar_products = np.argsort(- words_scores[:, row])[:topn]
            # print([w2v_model.wv.index2word[i] for i in argmax_scalar_products])
            # print([w for w, dist in w2v_model.similar_by_vector(aspects.T[row])[:topn]])
            words.append([w2v_model.wv.index2word[i] for i in argmax_scalar_products])

        return words

In [28]:
args = Namespace(
    data_json='Electronics_5.json',
    
    w2v_file='Electronics_5.w2v',
    w2v_size=200,
    w2v_window=5,
    w2v_min_count=5,
    w2v_workers=7,
    w2v_sg=1,
    w2v_negative=5,
    w2v_iter=1,
    w2v_max_vocab_size=20000,
    
    batch_size=50,
    aspects_number=40,
    ortho_reg=0.1,
    epochs=1,
    optimizer='adam',
    neg_samples=5,
    maxlen=201,
    
    cuda=True,
    reload_from_files=True,
)

args.device = torch.device("cuda" if args.cuda else "cpu")
    
print("Using CUDA: {}".format(args.cuda))

Using CUDA: True


In [29]:
if args.reload_from_files:
    print("Loading vectorizer")
    pass
else:
    print("Loading dataset and creating vectorizer")
    sentences = Sentences(args.data_json)
    w2v = gensim.models.Word2Vec(
        sentences, 
        size=args.w2v_size, 
        window=args.w2v_window, 
        min_count=args.w2v_min_count, 
        workers=args.w2v_workers, 
        sg=args.w2v_sg,
        negative=args.w2v_negative, 
        iter=args.w2v_iter, 
        max_vocab_size=args.w2v_max_vocab_size,
    )
    w2v.save(args.w2v_file)
    print(f'{args.w2v_file} saved')
    
vectorizer = gensim.models.Word2Vec.load(args.w2v_file)

Loading vectorizer


In [30]:
for word in ["he", "love", "looks", "buy", "laptop"]:
    if word in vectorizer.wv.vocab:
        print(word, [w for w, c in vectorizer.wv.similar_by_word(word=word)])
    else:
        print(word, "not in vocab")

he ['she', 'He', 'his', 'She', 'son', 'husband', 'him', 'dad', 'daughter', 'wife']
love ['LOVE', 'Love', 'loved', '"Love', 'enjoy', 'hate', 'loves', 'appreciate', 'enjoyed', 'like']
looks ['feels', 'Looks', 'look', 'looked', 'sleek', 'sounds', 'matches', 'finish', 'appearance', 'look.']
buy ['purchase', 'buying', 'purchasing', 'sell', 'ordering', 'buy,', 'invest', 'try', 'buy.', 'order']
laptop ['notebook', 'netbook', 'computer', 'laptop,', 'machine', 'laptop.', 'PC', 'desktop', 'tablet', 'pc']


In [31]:
wv_dim = vectorizer.vector_size
y = torch.zeros(args.batch_size, 1).to(args.device)

In [32]:
model = ABAE(
    wv_dim=wv_dim,
    asp_count=args.aspects_number,
    init_aspects_matrix=get_centroids(vectorizer, aspects_count=args.aspects_number)
)
model.to(args.device)

  init.kaiming_uniform(self.M.data)


ABAE(
  (attention): SelfAttention(
    wv_dim=200, maxlen=201
    (attention_softmax): Softmax(dim=None)
  )
  (linear_transform): Linear(in_features=200, out_features=40, bias=True)
  (softmax_aspects): Softmax(dim=None)
)

In [33]:
criterion = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters())

In [34]:
epoch_bar = tqdm(
    desc='training routine', 
    total=args.epochs,
    position=0
)

train_bar = tqdm(
    desc='train',
    total=get_num_batches(args.data_json, args.batch_size, args.maxlen), 
    position=1, 
    leave=True
)


for t in range(args.epochs):

    print("Epoch %d/%d" % (t + 1, args.epochs))

    data_iterator = read_data_tensors(
        args.data_json,
        batch_size=args.batch_size, 
        maxlen=args.maxlen,
        w2v_model=vectorizer,
    )

    for item_number, (x, texts) in enumerate(data_iterator):
        if x.shape[0] < args.batch_size:  # pad with 0 if smaller than batch size
            x = np.pad(x, ((0, args.batch_size - x.shape[0]), (0, 0), (0, 0)))

        x = torch.from_numpy(x).to(args.device)

        # extracting bad samples from the very same batch; not sure if this is OK, so todo
        negative_samples = torch.stack(
            tuple([x[torch.randperm(x.shape[0])[:args.neg_samples]] 
                   for _ in range(args.batch_size)])
        ).to(args.device)

        # prediction
        y_pred = model(x, negative_samples)

        # error computation
        loss = criterion(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if item_number % 1000 == 0:

            print(item_number, "batches, and LR:", optimizer.param_groups[0]['lr'])

            for i, aspect in enumerate(model.get_aspect_words(vectorizer)):
                print(i + 1, " ".join([a for a in aspect]))

            print("Loss:", loss.item())
            print()

        train_bar.update()
    epoch_bar.update()

HBox(children=(FloatProgress(value=0.0, description='training routine', max=1.0, style=ProgressStyle(descripti…

HBox(children=(FloatProgress(value=0.0, description='train', max=33782.0, style=ProgressStyle(description_widt…

Epoch 1/1
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
0 batches, and LR: 0.001
1 25 40 80 20 30 50 200 45 100 15 60 1000 18 12 1.5
2 comments reviews reviews, reviewers questions reviews. posted review, written complained complaining comment reviewer agree experiences
3 Phillips Philips Miller", "Kenneth "Bruce "Bill Johnson", "Stephen "B007WTAJTO", "Dr. "Jerry M. "L. Smith", "Chris
4 sturdy, sturdy. sturdy bulky, durable. flimsy, lightweight, lightweight durable, compact, stylish comfortable, heavy, plastic, durable
5 GB 64 PCI Drive SATA 3.0 SDHC Card MB Desktop Vista XP Flash 2.0 Gigabit
6 "B003ES5ZUU", "B002WE6D44", "B003ELYQGG", mode, determine (you "B002V88HFE", annoying, suspect it.The issue, enables means often, causes
7 Stars", "Amazon "very "it "Five "works "You "There "Do

  if word in w2v_model and (vocabulary is None or word in vocabulary):
  results = self.attention_softmax(product_2)
  aspects_importances = self.softmax_aspects(raw_importances)
  return F.mse_loss(input, target, reduction=self.reduction)
  words_scores = w2v_model.wv.syn0.dot(aspects)


attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([5

attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([5

attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([5

attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([5

attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([5

attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([5

attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([5

attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([5

attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([5

attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([50])
attention weights: torch.Size([50, 201])
text embeddings: torch.Size([50, 201, 200])
aspects importances: torch.Size([50, 40])
weighted_text_emb: torch.Size([50, 200])
recovered_emb: torch.Size([50, 200])
reconstruction_triplet_loss: torch.Size([5

KeyboardInterrupt: 