# Introduction

This notebook illustrates how to use `XLM-T` models for encoding a dataset from a text file into tweet embeddings.

# Installs and imports

In [None]:
!pip install --upgrade pip
!pip install sentencepiece
!pip install transformers

In [None]:
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import AutoModelForSequenceClassification
from torch.utils.data import DataLoader
import numpy as np

# Data

In [None]:
def preprocess(corpus):
  outcorpus = []
  for text in corpus:
    new_text = []
    for t in text.split(" "):
        t = '@user' if t.startswith('@') and len(t) > 1 else t
        t = 'http' if t.startswith('http') else t
        new_text.append(t)
    new_text = " ".join(new_text)
    outcorpus.append(new_text)
  return outcorpus

In [None]:
!wget https://raw.githubusercontent.com/cardiffnlp/xlm-t/main/data/sentiment/all/test_text.txt

--2021-04-26 22:29:21--  https://raw.githubusercontent.com/cardiffnlp/xlm-t/main/data/sentiment/all/test_text.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 654172 (639K) [text/plain]
Saving to: ‘test_text.txt’


2021-04-26 22:29:21 (35.6 MB/s) - ‘test_text.txt’ saved [654172/654172]



In [None]:
dataset_path = './test_text.txt'
dataset = open(dataset_path).read().split('\n')

In [None]:
# this is a dataset in 8 different languages
for example in [0,870,1740,2610,3480,4350,5220,6090]:
  print(dataset[example])

نوال الزغبي (الشاب خالد ليس عالمي) هههههههه أتفرجي على ها الفيديو يا مبتدئة http vía @user
Trying to have a conversation with my dad about vegetarianism is the most pointless infuriating thing ever #caveman 
Royal: le président n'aime pas les pauvres? "c'est n'importe quoi" http …
@user korrekt! Verstehe sowas nicht...
CONGRESS na ye party kabhi bani hoti na india ka partition hota nd na hi humari country itni khokhli hoti   @ 
@user @user Ma Ferrero? il compagno Ferrero? ma il suo partito esiste ancora? allora stiamo proprio frecati !!!
todos os meus favoritos na prova de eliminação #MasterChefBR
@user jajajaja dale, hacete la boluda vos jajaja igual a vos nunca se te puede tomar en serio te mando un abrazo desde Perú!


# Model

In [None]:
CUDA = True # set to true if using GPU (Runtime -> Change runtime Type -> GPU)
BATCH_SIZE = 32
MODEL = "cardiffnlp/twitter-xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
config = AutoConfig.from_pretrained(MODEL)
model = AutoModel.from_pretrained(MODEL)
if CUDA:
  model = model.to('cuda')
_ = model.eval()

## Encode

In [None]:
def encode(text, cuda=True):
  text = preprocess(text)
  encoded_input = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
  if cuda:
    encoded_input.to('cuda')
    output = model(**encoded_input)
    embeddings = output[0].detach().cpu().numpy()
  else:
    output = model(**encoded_input)
    embeddings = output[0].detach().numpy()
  
  embeddings = np.max(embeddings, axis=1)
  #embeddings = np.mean(embeddings, axis=1) 
  return embeddings

In [None]:
dl = DataLoader(dataset, batch_size=BATCH_SIZE)
all_embeddings = np.zeros([len(dataset), 768])
for idx,batch in enumerate(dl):
  print('Batch ',idx+1,' of ',len(dl))
  text = preprocess(batch)
  embeddings = encode(text, cuda=CUDA)
  a = idx*BATCH_SIZE
  b = (idx+1)*BATCH_SIZE
  all_embeddings[a:b,:]=embeddings

## Cosine similarity and retrieval of all embeddings

In [None]:
norms = np.linalg.norm(all_embeddings, axis=-1)
all_embeddings_unit = all_embeddings/norms[:,None]
all_embeddings_sim = np.dot(all_embeddings_unit, all_embeddings_unit.T)

In [None]:
def get_most_sim(sim):
  s = np.argsort(sim)
  s = s[::-1] # invert sort order
  return s

In [None]:
query = 1111
a = 870  # english text from
b = 1740 # english text to
tmp_sim = all_embeddings_sim[a:b,query]
tmp_data = dataset[a:b]
s = get_most_sim(tmp_sim)

In [None]:
print('QUERY: ', dataset[query])

QUERY:  This means they believe it to be a legitimate non-violent movement based on a concern for human rights in #Palestine. #queensu #ygk 


In [None]:
print(' ----- Most similar ----- ')
too_much = 10
for i in s:
  print(tmp_sim[i], tmp_data[i])
  if too_much < 0:
    break
  too_much-=1

print(' ----- Least similar ----- ')
too_much = 10
for i in s[::-1]:
  print(tmp_sim[i], tmp_data[i])
  if too_much < 0:
    break
  too_much-=1

 ----- Most similar ----- 
0.9999999999999998 This means they believe it to be a legitimate non-violent movement based on a concern for human rights in #Palestine. #queensu #ygk 
0.964109671884958 @user aint in support with Israel nor Palestine! Hope this fire is settled soon & there's no more massacre in #Palestine either... 
0.9612606761750646 Israel deems comatose Gaza man who needs treatment in West Bank  a security threat. #Palestine  via @user 
0.9593051201529168 #latestnews 4 #newmexico #politics + #nativeamerican + #Israel + #Palestine  -  Protesting Rise Of Alt-Right At... 
0.9588319060541266 UK Govt reject criticism on Libya saying its involvement saved lives-... wishing UK to enjoy post Gadafi Libya fate. #UK #libya 
0.9583803569594294 @user Megyn, Please interview Halderman from the Univ of Michigan re:discrepancy in the results in counties with e-voting machines. 
0.9579723960580191 Saakashvili is pushing his own agenda here.The Ukrainian economy is growing, although corru