In [4]:
import pandas as pd
import json
import os

In [5]:
datapath = "data/train/"

def get_tweets():
    return pd.read_csv(os.path.join(datapath, "tweets-train-dev.tsv"), sep="\t", names=["id","tweet"])

def get_qrels():
    conn_names = ["tweet_id", "tweet_num", "claim_id", "claim_num"]
    train_conns = pd.read_csv(os.path.join(datapath, "qrels-train.tsv"), sep="\t", names=conn_names)
    dev_conns = pd.read_csv(os.path.join(datapath, "qrels-dev.tsv"), sep="\t", names=conn_names)
    return train_conns, dev_conns

def get_claims():
    # claimpaths = [os.path.join(datapath, "vclaims", f"{claim_id}.json") for claim_id in claim_ids]
    claimpath = os.path.join(datapath, "vclaims")
    claimpaths = [os.path.join(claimpath, f) for f in os.listdir(claimpath)]
    def load_claim(path):
        with open(path) as f:
            return json.load(f)
    
    claims = [load_claim(path) for path in claimpaths]
    return pd.DataFrame(claims, columns=["title","subtitle","author","date","vclaim_id","vclaim"])

In [6]:
tweets = get_tweets()
train_conns, dev_conns = get_qrels()
claims = get_claims()#[*train_conns.claim_id.to_list(), *dev_conns.claim_id.to_list()])

In [77]:
len(claims), len(train_conns), len(dev_conns)

(13825, 999, 200)

In [76]:
[c for c in claims.vclaim if 'wombat' in c.lower()]

['Wombats are herding animals and inviting them into their burrows in order to escape the wildfires in Australia.',
 'In January 2021, scientists discovered that the exceptionally rectangular droppings of wombats are the result of an uniquely evolved gastrointestinal system and not a cube-shaped anus as was previously proposed.']

In [68]:
tweets.head()

Unnamed: 0,id,tweet
0,tweet-sno-0,How are butterflies surviving the #AustralianF...
1,tweet-sno-1,Trump needs to immediately divest from his bus...
2,tweet-sno-2,A number of fraudulent text messages informing...
3,tweet-sno-3,Fact check: The U.S. Army is NOT contacting an...
4,tweet-sno-4,The US drone attack on #Soleimani caught on ca...


In [58]:
print(tweets[tweets.id == train_conns.tweet_id[1]].tweet.to_numpy()[0])
print('----')
print(claims[claims.vclaim_id == train_conns.claim_id[1]].vclaim.to_numpy()[0])

Trump needs to immediately divest from his businesses and comply with the emoluments clause. Iran could threaten Trump hotels *worldwide* and he could provoke war over the loss of revenue from skittish guests.  His business interests should not be driving military decisions. — Ilhan Omar (@IlhanMN) January 6, 2020
----
In January 2020, U.S. Rep. Ilhan Omar advised Iran to attack Trump-branded hotels in the world, thus committing treason.


In [64]:
dev_conns.claim_num.min()

1

In [78]:
from sentence_transformers import SentenceTransformer

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model = SentenceTransformer("sentence-transformers/sentence-t5-large")

Downloading:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.02k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.39k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/461 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/670M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.92k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.15M [00:00<?, ?B/s]

In [79]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: T5EncoderModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Dense({'in_features': 1024, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (3): Normalize()
)

In [107]:
sentences = ["what is 2+2?", "2"]
embeddings = model.encode(sentences)
print(embeddings)

[[-0.01550025  0.01297178  0.00584198 ... -0.02914877  0.01487338
   0.06720603]
 [-0.00851915 -0.0226616   0.0019202  ... -0.03072338  0.0023656
   0.03755347]]


In [108]:
embeddings.shape

(2, 768)

In [109]:
np.dot(embeddings[0], embeddings[1])

0.73845553

In [123]:
np.array([model.tokenizer.encode(sentences[0])]).astype(int).shape

(1, 7)

In [152]:
import torch

device = torch.device("cuda")

inpt = dict(
    input_ids = torch.Tensor(np.array(model.tokenizer.encode(sentences[0]))).view(1,-1).int().to(device)
)
inpt["attention_mask"] = torch.Tensor(np.ones(inpt["input_ids"].shape)).view(1,-1).int().to(device)

In [153]:
for k, v in inpt.items():
    print(k, v.shape)

input_ids torch.Size([1, 7])
attention_mask torch.Size([1, 7])


In [154]:
model(inpt)

{'input_ids': tensor([[ 125,   19,  204, 1220,  357,   58,    1]], device='cuda:0',
        dtype=torch.int32),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]], device='cuda:0', dtype=torch.int32),
 'token_embeddings': tensor([[[-0.0928,  0.0343, -0.4449,  ..., -0.0307, -0.0905,  0.0133],
          [-0.0167, -0.0861, -0.2169,  ..., -0.0458, -0.0277, -0.1231],
          [-0.0129, -0.1005,  0.3698,  ..., -0.0945,  0.0824,  0.0355],
          ...,
          [ 0.0079,  0.1546,  0.2256,  ..., -0.2101,  0.0292,  0.0914],
          [-0.0702,  0.1371, -0.2202,  ..., -0.0265, -0.2608,  0.0448],
          [-0.0512,  0.1165,  0.0685,  ...,  0.0782,  0.0275, -0.0778]]],
        device='cuda:0', grad_fn=<MulBackward0>),
 'sentence_embedding': tensor([[-1.5500e-02,  1.2972e-02,  5.8420e-03,  8.6660e-03,  1.8796e-02,
           3.4288e-02,  2.9658e-02,  4.9993e-02, -2.0949e-03,  1.0967e-02,
           2.8683e-02,  1.5631e-02,  3.6811e-02,  6.2192e-02, -4.3116e-02,
          -2.7372e-02, -4.2383e-02

In [122]:
model.forward?

[0;31mSignature:[0m [0mmodel[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0minput[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Defines the computation performed at every call.

Should be overridden by all subclasses.

.. note::
    Although the recipe for forward pass needs to be defined within
    this function, one should call the :class:`Module` instance afterwards
    instead of this since the former takes care of running the
    registered hooks while the latter silently ignores them.
[0;31mFile:[0m      ~/anaconda3/envs/torch/lib/python3.7/site-packages/torch/nn/modules/container.py
[0;31mType:[0m      method


In [132]:
model.tokenize(["try 1", "try 2 try 3 try 4"])

{'input_ids': tensor([[653, 209,   1,   0,   0,   0,   0],
         [653, 204, 653, 220, 653, 314,   1]]),
 'attention_mask': tensor([[1, 1, 1, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1]])}

In [157]:
embs = model.encode(claims.vclaim.to_list())

In [158]:
embs.shape

(13825, 768)