In [1]:
import numpy as np
import json
import tqdm
import torch


splits = ['train', 'valid', 'test']

def tape_reader(split):
    data = []
    p = np.load(f'./data/tape/{split}.npz', allow_pickle=True)
    seq_labl = json.load(open(f'./data/{split}.json', 'r'))
    # print(len(p), len(seq_labl))
    # for i in p:
    #     data.append((i))
    #     data.append((i, p[i]))
    for idx in range(len(seq_labl)):
        data.append(
            {
                'seq': seq_labl[idx][0],
                'embed': p[f'{split}-{idx:05d}'],
                'labl': seq_labl[idx][1]
            }
        )

    return data

train = tape_reader('train')
valid = tape_reader('valid')
test = tape_reader('test')

In [3]:
# print(valid[0]['embed'].item()['avg'])
from scipy.stats import stats
anchor = valid[0]
print(anchor['labl'], type(anchor['labl']))
# anchor_embed = anchor['embed'].item()['avg']
# print(anchor_embed.size)
# anchor_embed = anchor['embed'].item()['pooled']
# print(anchor_embed.size)
from torch.nn.functional import normalize


3.8237006664276123 <class 'float'>


In [9]:
def predict(train, valid, test, use_valid=False, use_abs=False, embed_type='pooled'):
    assert embed_type in ['pooled', 'avg'], 'should choose pooled/avg'
    # input data: [{'seq', 'embed', 'labl'}]
    anchor = train[0]
    anchor_embed = anchor['embed'].item()[embed_type]
    anchor_labl = anchor['labl']
    # print(anchor_embed.size())

    test = test[1:]

    if use_valid:
        refs = train[1:] + valid[1:]
    else:
        refs = train[1:]


    # ref_embed_labl = []
    embed_mat = torch.empty(len(refs), anchor_embed.size)
    embed_val = torch.empty(len(refs))

    for idx in range(len(refs)):
        sample = refs[idx]
        # print(embed_mat[idx: idx + 1].shape, type(embed_mat[idx: idx + 1]))
        # print(sample['embed'].item()[embed_type].shape, type(sample['embed'].item()[embed_type]))
        # print(sample['labl'], type(sample['labl']))
        # print(embed_val[idx: idx + 1], type(embed_val[idx: idx + 1]))

        if use_abs:
            # ref_embed_labl.append((sample['embed'].item()[embed_type], sample['labl']))
            # embed_mat[idx: idx + 1] = torch.from_numpy(sample['embed'].item()[embed_type] - anchor_embed.reshape(embed_mat[idx: idx + 1].shape)) # sample['embed'].item()[embed_type]
            embed_mat[idx: idx + 1] = torch.from_numpy((sample['embed'].item()[embed_type]).reshape(embed_mat[idx: idx + 1].shape))
            embed_val[idx: idx + 1] = sample['labl']
        else:
            # ref_embed_labl.append((sample['embed'].item()[embed_type] - anchor_embed, sample['labl'] - anchor['labl']))
            embed_mat[idx: idx + 1] = torch.from_numpy((sample['embed'].item()[embed_type] - anchor_embed).reshape(embed_mat[idx: idx + 1].shape))
            embed_val[idx: idx + 1] = sample['labl'] - anchor_labl


    labls = torch.Tensor([e['labl'] for e in test])
    embed_mat = normalize(embed_mat, p=2, dim=-1)
    # embed_mat = normalize(embed_mat, p=2, dim=-1)
    # print(embed_mat.sum(dim=-1).shape)
    # print(labls.shape)
    # print(embed_mat.shape)

    if use_abs:
        test_embed = [torch.FloatTensor(e['embed'].item()[embed_type]) for e in test]
    else:
        test_embed = [torch.FloatTensor(e['embed'].item()[embed_type] - anchor_embed) for e in test]

    # print(embed_mat[:2], embed_val[:2], embed_val.shape)
    # print(embed_mat.shape, embed_val.shape)
    
    preds = []
    # for idx in tqdm.tqdm(range(len(test))):
    for idx in range(len(test)):
        dot_sim = torch.matmul(embed_mat, normalize(test_embed[idx], p=2, dim=-1))
        # print(torch.softmax(dot_sim, dim=0))
        prob = torch.softmax(dot_sim, dim=0)
        pred_val = torch.dot(prob, embed_val)
        # print(pred_val)
        if use_abs:
            preds.append(pred_val.item())
        else:
            preds.append(pred_val.item() + anchor_labl)
        # if idx == 20:
        #     preds = torch.FloatTensor(preds)
        #     print(preds)
        #     break
        # if idx == 20:
        #     break

    preds = torch.FloatTensor(preds)
    spr = stats.spearmanr(preds.numpy(), labls.numpy())
    return preds, labls, spr[0]

# use_valid=False, use_abs=False
collect = []
for v in [True, False]:
    for a in [True, False]:
        for ebd in ['pooled', 'avg']:
            p, l, spr = predict(train, valid, test, use_valid=v, use_abs=a, embed_type=ebd)
            collect.append((p, l, spr))
            print(f'use_valid={v}, use_abs={a}, embed_type={ebd}, spr={spr}')

use_valid=True, use_abs=True, embed_type=pooled, spr=0.06398364741425792
use_valid=True, use_abs=True, embed_type=avg, spr=0.28939527949895416
use_valid=True, use_abs=False, embed_type=pooled, spr=0.39311592554980107
use_valid=True, use_abs=False, embed_type=avg, spr=0.44173396237417134
use_valid=False, use_abs=True, embed_type=pooled, spr=0.2376535288304412
use_valid=False, use_abs=True, embed_type=avg, spr=0.29011005996109673
use_valid=False, use_abs=False, embed_type=pooled, spr=0.3931455325007132
use_valid=False, use_abs=False, embed_type=avg, spr=0.4413894754063092
