In [1]:
import torch
import torch.nn as nn
from transformers import XLMTokenizer, RobertaModel


# extract features from texts
class TextFeatureExtractor(nn.Module):
    def __init__(
            self,
            tokenizer="allegro/herbert-klej-cased-tokenizer-v1",
            embed_model="allegro/herbert-klej-cased-v1"
    ):
        super().__init__()

        self.tokenizer = XLMTokenizer.from_pretrained(tokenizer)
        self.embed_model = RobertaModel.from_pretrained(embed_model, return_dict=True)

        self.eval()

    def forward(self, x):
        encoded = self.tokenizer(x, return_tensors='pt', padding=True)
        encoded = {k: v.to(next(self.parameters()).device) for k, v in encoded.items()}
        return self.embed_model(**encoded)['pooler_output'].float()
    
    

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import json


with open("positive_parsed.json") as fp:
    positive = json.load(fp)

with open("negative_parsed.json") as fp:
    negative = json.load(fp)
    

In [3]:
from tqdm import tqdm


encoder = TextFeatureExtractor().cuda().eval()

with torch.no_grad():
    positive_features = {}
    for key, text in tqdm(positive.items()):
        positive_features[key] = encoder(text).cpu().numpy()


    negative_features = {}
    for key, text in tqdm(negative.items()):
        negative_features[key] = encoder(text).cpu().numpy()


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 258/258 [00:01<00:00, 133.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 583/583 [00:03<00:00, 178.23it/s]


In [4]:
import pickle

with open("positive_features.pickle", "wb") as fp:
    pickle.dump(positive_features, fp)

with open("negative_features.pickle", "wb") as fp:
    pickle.dump(negative_features, fp)
    