# Read data

In [1]:
data_path = "/home/maksym/da-corpora"

In [2]:
import pickle
from copy import deepcopy

import torch
import pandas as pd
import numpy as np

from transformers import AutoTokenizer, AutoModel

In [3]:
import glob, os

files = []
os.chdir(data_path)
for file in glob.glob("*.train"):
    print(file)
    files.append(open(file, 'r').readlines())

cl-JRC-Acquis.en-et.docs.train
cl-OpenSubtitles.en-et.docs.train
cl-EMEA.en-et.docs.train
cl-Europarl.en-et.docs.train


In [4]:
# reduce files to eglish sentences but keep doc ids

sent_index = deepcopy(files)

for ind, f in enumerate(files):
    for i in range(len(f)):
        sent_index[ind][i] = f[i].split('\t')[0]
        f[i] = f[i].split('\t')[1]

In [5]:
sentences = [line for file in files for line in file]
sent_index = [line for file in sent_index for line in file]

In [6]:
del files

In [7]:
assert len(sentences) == len(sent_index)

In [8]:
#sentences = sentences[0:100]

# Embed sentences

In [9]:
model = AutoModel.from_pretrained('xlm-roberta-base')
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')

In [10]:
class Embedder:
    def __init__(self, model, tokeinzer):    
        self.model = model.cuda()
        self.tokenizer = tokenizer
        
    def embed_batch(self, batch):
        input_ids = self.tokenizer(batch, 
                          return_tensors='pt', 
                          truncation=True, 
                          padding=True, 
                          max_length=100)['input_ids']
        with torch.no_grad():
            res = self.model(input_ids.cuda(), output_hidden_states=True, return_dict=True)
        return res['hidden_states'][7].mean(1).cpu().detach().numpy()


In [11]:
embedder = Embedder(model, tokenizer)

In [12]:
%%time

sent_emb = []

bs = 1000
chunks = [sentences[x:x+bs] for x in range(0, len(sentences), bs)]

for i, chunk in enumerate(chunks):
    print(f"{i} / {len(chunks)}") 
    chunk_emb = embedder.embed_batch(chunk)
    sent_emb.extend(chunk_emb)

0 / 1912
1 / 1912
2 / 1912
3 / 1912
4 / 1912
5 / 1912
6 / 1912
7 / 1912
8 / 1912
9 / 1912
10 / 1912
11 / 1912
12 / 1912
13 / 1912
14 / 1912
15 / 1912
16 / 1912
17 / 1912
18 / 1912
19 / 1912
20 / 1912
21 / 1912
22 / 1912
23 / 1912
24 / 1912
25 / 1912
26 / 1912
27 / 1912
28 / 1912
29 / 1912
30 / 1912
31 / 1912
32 / 1912
33 / 1912
34 / 1912
35 / 1912
36 / 1912
37 / 1912
38 / 1912
39 / 1912
40 / 1912
41 / 1912
42 / 1912
43 / 1912
44 / 1912
45 / 1912
46 / 1912
47 / 1912
48 / 1912
49 / 1912
50 / 1912
51 / 1912
52 / 1912
53 / 1912
54 / 1912
55 / 1912
56 / 1912
57 / 1912
58 / 1912
59 / 1912
60 / 1912
61 / 1912
62 / 1912
63 / 1912
64 / 1912
65 / 1912
66 / 1912
67 / 1912
68 / 1912
69 / 1912
70 / 1912
71 / 1912
72 / 1912
73 / 1912
74 / 1912
75 / 1912
76 / 1912
77 / 1912
78 / 1912
79 / 1912
80 / 1912
81 / 1912
82 / 1912
83 / 1912
84 / 1912
85 / 1912
86 / 1912
87 / 1912
88 / 1912
89 / 1912
90 / 1912
91 / 1912
92 / 1912
93 / 1912
94 / 1912
95 / 1912
96 / 1912
97 / 1912
98 / 1912
99 / 1912
100 / 1912

755 / 1912
756 / 1912
757 / 1912
758 / 1912
759 / 1912
760 / 1912
761 / 1912
762 / 1912
763 / 1912
764 / 1912
765 / 1912
766 / 1912
767 / 1912
768 / 1912
769 / 1912
770 / 1912
771 / 1912
772 / 1912
773 / 1912
774 / 1912
775 / 1912
776 / 1912
777 / 1912
778 / 1912
779 / 1912
780 / 1912
781 / 1912
782 / 1912
783 / 1912
784 / 1912
785 / 1912
786 / 1912
787 / 1912
788 / 1912
789 / 1912
790 / 1912
791 / 1912
792 / 1912
793 / 1912
794 / 1912
795 / 1912
796 / 1912
797 / 1912
798 / 1912
799 / 1912
800 / 1912
801 / 1912
802 / 1912
803 / 1912
804 / 1912
805 / 1912
806 / 1912
807 / 1912
808 / 1912
809 / 1912
810 / 1912
811 / 1912
812 / 1912
813 / 1912
814 / 1912
815 / 1912
816 / 1912
817 / 1912
818 / 1912
819 / 1912
820 / 1912
821 / 1912
822 / 1912
823 / 1912
824 / 1912
825 / 1912
826 / 1912
827 / 1912
828 / 1912
829 / 1912
830 / 1912
831 / 1912
832 / 1912
833 / 1912
834 / 1912
835 / 1912
836 / 1912
837 / 1912
838 / 1912
839 / 1912
840 / 1912
841 / 1912
842 / 1912
843 / 1912
844 / 1912
845 / 1912

1459 / 1912
1460 / 1912
1461 / 1912
1462 / 1912
1463 / 1912
1464 / 1912
1465 / 1912
1466 / 1912
1467 / 1912
1468 / 1912
1469 / 1912
1470 / 1912
1471 / 1912
1472 / 1912
1473 / 1912
1474 / 1912
1475 / 1912
1476 / 1912
1477 / 1912
1478 / 1912
1479 / 1912
1480 / 1912
1481 / 1912
1482 / 1912
1483 / 1912
1484 / 1912
1485 / 1912
1486 / 1912
1487 / 1912
1488 / 1912
1489 / 1912
1490 / 1912
1491 / 1912
1492 / 1912
1493 / 1912
1494 / 1912
1495 / 1912
1496 / 1912
1497 / 1912
1498 / 1912
1499 / 1912
1500 / 1912
1501 / 1912
1502 / 1912
1503 / 1912
1504 / 1912
1505 / 1912
1506 / 1912
1507 / 1912
1508 / 1912
1509 / 1912
1510 / 1912
1511 / 1912
1512 / 1912
1513 / 1912
1514 / 1912
1515 / 1912
1516 / 1912
1517 / 1912
1518 / 1912
1519 / 1912
1520 / 1912
1521 / 1912
1522 / 1912
1523 / 1912
1524 / 1912
1525 / 1912
1526 / 1912
1527 / 1912
1528 / 1912
1529 / 1912
1530 / 1912
1531 / 1912
1532 / 1912
1533 / 1912
1534 / 1912
1535 / 1912
1536 / 1912
1537 / 1912
1538 / 1912
1539 / 1912
1540 / 1912
1541 / 1912
1542

In [13]:
pickle.dump(sent_emb, open("sent_emb.pkl", 'wb'))
pickle.dump(sent_index, open("sent_index.pkl", 'wb'))

In [14]:
sent_emb = pickle.load(open("sent_emb.pkl", 'rb'))
sent_index = pickle.load(open("sent_index.pkl", 'rb'))

# Embed docs

In [27]:
doc2embs = {}

In [28]:
for i, doc_id in enumerate(sent_index):
    doc2embs[doc_id] = []

In [29]:
for i, doc_id in enumerate(sent_index):
    doc2embs[doc_id].append(sent_emb[i])

In [38]:
doc2emb = {}
for doc_id, embs in doc2embs.items():
    doc2emb[doc_id] = np.mean(embs, 0)

In [42]:
doc_emb = list(doc2emb.values())

In [45]:
doc_index = list(doc2emb.keys())

In [47]:
pickle.dump(doc_emb, open("doc_emb.pkl", 'wb'))
pickle.dump(doc_index, open("doc_index.pkl", 'wb'))

In [None]:
doc_emb = pickle.load(open("doc_emb.pkl", 'rb'))
doc_index = pickle.load(open("doc_index.pkl", 'rb'))