In [1]:
from __future__ import absolute_import, division, unicode_literals

import sys
import os
import io
import numpy as np
import logging

sys.path.insert(0, '..')
import senteval

In [2]:
import pytorch_lightning as pl
import torch

In [12]:
pwd

'/mounts/Users/cisintern/jabbar/git/barlowtwins/pl_model/senteval'

In [13]:
# Set PATHs
PATH_TO_SENTEVAL = '../'
PATH_TO_DATA = '/mounts/Users/cisintern/jabbar/git/SentEval/data'
PATH_TO_VEC = '/mounts/Users/cisintern/jabbar/git/SentEval/data/fasttext/glove.840B.300d.txt'
# PATH_TO_VEC = '../data/fasttext/crawl-300d-2M.vec'

In [14]:
# Create dictionary
def create_dictionary(sentences, threshold=0):
    words = {}
    for s in sentences:
        for word in s:
            words[word] = words.get(word, 0) + 1

    if threshold > 0:
        newwords = {}
        for word in words:
            if words[word] >= threshold:
                newwords[word] = words[word]
        words = newwords
    words['<s>'] = 1e9 + 4
    words['</s>'] = 1e9 + 3
    words['<p>'] = 1e9 + 2

    sorted_words = sorted(words.items(), key=lambda x: -x[1])  # inverse sort
    id2word = []
    word2id = {}
    for i, (w, _) in enumerate(sorted_words):
        id2word.append(w)
        word2id[w] = i

    return id2word, word2id

In [15]:
# Get word vectors from vocabulary (glove, word2vec, fasttext ..)
def get_wordvec(path_to_vec, word2id):
    word_vec = {}

    with io.open(path_to_vec, 'r', encoding='utf-8') as f:
        # if word2vec or fasttext file : skip first line "next(f)"
        for line in f:
            word, vec = line.split(' ', 1)
            if word in word2id:
                word_vec[word] = np.fromstring(vec, sep=' ')

    logging.info('Found {0} words with word vectors, out of \
        {1} words'.format(len(word_vec), len(word2id)))
    return word_vec

In [16]:
# SentEval prepare and batcher
def prepare(params, samples):
    _, params.word2id = create_dictionary(samples)
    params.word_vec = get_wordvec(PATH_TO_VEC, params.word2id)
    params.wvec_dim = 300
    return

In [17]:
def batcher(params, batch):
    batch = [sent if sent != [] else ['.'] for sent in batch]
    embeddings = []

    for sent in batch:
        sentvec = []
        for word in sent:
            if word in params.word_vec:
                sentvec.append(params.word_vec[word])
        if not sentvec:
            vec = np.zeros(params.wvec_dim)
            sentvec.append(vec)
        sentvec = np.mean(sentvec, 0)
        embeddings.append(sentvec)

    embeddings = np.vstack(embeddings)
    return embeddings

In [18]:
# Set params for SentEval
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
                                 'tenacity': 3, 'epoch_size': 2}

# Set up logger
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)

In [19]:
se = senteval.engine.SE(params_senteval, batcher, prepare)
transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
                  'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
                  'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
                  'Length', 'WordContent', 'Depth', 'TopConstituents',
                  'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
                  'OddManOut', 'CoordinationInversion']

In [20]:
results = se.eval(transfer_tasks)
print(results)

2021-08-19 14:17:13,458 : ***** Transfer task : STS12 *****


2021-08-19 14:17:19,623 : Found 7801 words with word vectors, out of         8130 words
2021-08-19 14:17:19,766 : MSRpar : pearson = 0.4250, spearman = 0.4514
2021-08-19 14:17:19,871 : MSRvid : pearson = 0.6621, spearman = 0.6750
2021-08-19 14:17:19,947 : SMTeuroparl : pearson = 0.4913, spearman = 0.5880
2021-08-19 14:17:20,055 : surprise.OnWN : pearson = 0.5703, spearman = 0.6106
2021-08-19 14:17:20,120 : surprise.SMTnews : pearson = 0.4627, spearman = 0.3388
2021-08-19 14:17:20,121 : ALL (weighted average) : Pearson = 0.5319,             Spearman = 0.5495
2021-08-19 14:17:20,121 : ALL (average) : Pearson = 0.5223,             Spearman = 0.5328

2021-08-19 14:17:20,122 : ***** Transfer task : STS13 (-SMT) *****




KeyboardInterrupt: 

In [17]:
senteval.engine.SE?

[0;31mInit signature:[0m [0msenteval[0m[0;34m.[0m[0mengine[0m[0;34m.[0m[0mSE[0m[0;34m([0m[0mparams[0m[0;34m,[0m [0mbatcher[0m[0;34m,[0m [0mprepare[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m      <no docstring>
[0;31mFile:[0m           ~/git/SentEval/senteval/engine.py
[0;31mType:[0m           type
[0;31mSubclasses:[0m     


In [22]:
models_dir = '/mounts/Users/cisintern/jabbar/git/barlowtwins/pl_model/lightning_tb_logs/'

In [25]:
ckpt = models_dir+'bert_small_bs128_lr0001_ngpu2_20mil_nodiv_lambda0.0005_maxmeancls/checkpoints/epoch=0-step=62499.ckpt'

In [24]:
model = pl.LightningModule

['epoch=0-step=62499.ckpt']