### Imports

In [1]:
from sentence_transformers import SentenceTransformer
import torch, json, glob
from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pandas as pd
import numpy as np
import os
from pathlib import Path

In [3]:
import tqdm
from tqdm import tqdm

In [28]:
import pickle

In [4]:
import ijson

In [12]:
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')

### Model -> LaBSE

In [40]:
model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
model = model.to(torch.float16)
model.eval() 

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)

In [14]:
tweet_files = glob.glob('../Dataset/TwiBot-22/tweet_*.json')

In [19]:
# Data structures to accumulate
sum_embeds   = defaultdict(lambda: torch.zeros(768, device=device))
tweet_counts = defaultdict(int)


In [8]:
DATA_DIR = "../Dataset/TwiBot-22"

DATA_DIR = Path(DATA_DIR)
# make sure the data directory exists
if not DATA_DIR.exists():
    raise FileNotFoundError(f"Data directory {DATA_DIR} does not exist.")

def load_json_records(fname):
    """Load a JSON file of array- or line- delimited records."""
    path = DATA_DIR / fname
    with open(path, 'r', encoding='utf-8') as f:
        # if the file is a single large JSON array:
        try:
            data = json.load(f)
        except json.JSONDecodeError:
            # fallback: one JSON object per line
            f.seek(0)
            data = [json.loads(line) for line in f]
    return data

user_dicts = load_json_records('user.json')
users_df = pd.DataFrame(user_dicts)

ordered_uids = users_df['id'].astype(str).tolist()

In [42]:
# ─── Hyperparameters & Files ─────────────────────────────────────────────────
BATCH_SIZE           = 512
CHECKPOINT_INTERVAL  = 1_000_000        # save every 1M tweets
CHECKPOINT_FILE      = 'tweet_feats_checkpoint.pkl'
TWEET_GLOB_PATTERN   = '../Dataset/TwiBot-22/tweet_*.json'

### Sample

In [43]:
import time
sample_texts = []
for fn in glob.glob('../Dataset/TwiBot-22/tweet_*.json'):
    print(f"Loading {fn}")
    with open(tweet_files[0], 'r') as f:
        for tw in ijson.items(f, 'item'):
            text = tw.get('text','').strip()
            if text:
                sample_texts.append(text)
            if len(sample_texts) >= 1000:
                break
    break

# 2) Time the batch encode
t0 = time.time()
_  = model.encode(sample_texts,
                  convert_to_tensor=True,
                  batch_size=BATCH_SIZE,
                  show_progress_bar=False)
t1 = time.time()

throughput = len(sample_texts) / (t1 - t0)
print(f"Batch size: {BATCH_SIZE}")
print(device)
print(f"Encoded {len(sample_texts)} tweets in {t1-t0:.2f}s → {throughput:.1f} tweets/sec")

Loading ../Dataset/TwiBot-22/tweet_1.json
Batch size: 512
mps
Encoded 1000 tweets in 1.77s → 563.6 tweets/sec


### Main Code

In [34]:
if os.path.exists(CHECKPOINT_FILE):
    with open(CHECKPOINT_FILE, 'rb') as f:
        data = pickle.load(f)
        sum_embeds, tweet_counts, processed = (data['sum_embeds'],
                                               data['tweet_counts'],
                                               data['processed'])
        sum_embeds = sum_embeds.to(device)
    print(f"Resuming from checkpoint: {processed} tweets processed so far.")
else:
    sum_embeds   = defaultdict(lambda: torch.zeros(768, device=device))
    tweet_counts = defaultdict(int)
    processed    = 0

In [35]:
batch_uids, batch_texts = [], []
def flush_batch():
    """Encode batch_texts, accumulate into sum_embeds/tweet_counts,
       advance processed counter, and checkpoint if needed."""
    global processed
    if not batch_texts:
        return

    # 2a) Batch-encode
    embs = model.encode(
        batch_texts,
        convert_to_tensor=True,
        batch_size=BATCH_SIZE,
        show_progress_bar=False
    )
    # 2b) Accumulate
    for uid, emb in zip(batch_uids, embs):
        sum_embeds[uid]   += emb
        tweet_counts[uid] += 1

    # 2c) Update processed count & clear buffers
    processed += len(batch_texts)
    batch_uids.clear()
    batch_texts.clear()

    # 2d) Checkpoint?
    if processed % CHECKPOINT_INTERVAL < BATCH_SIZE:
        with open(CHECKPOINT_FILE, 'wb') as f:
            pickle.dump({
                'sum_embeds':   sum_embeds,
                'tweet_counts': tweet_counts,
                'processed':    processed
            }, f)
        print(f"Checkpoint saved at {processed} tweets.")


In [None]:
# ─── 3) Stream & process with tqdm ───────────────────────────────────────────
tweet_files = glob.glob(TWEET_GLOB_PATTERN)
for fn in tweet_files:
    with open(fn, 'r') as f:
        # ijson.items streams each JSON object in the top-level array
        for tw in tqdm(ijson.items(f, 'item'),
                       desc=f"Streaming {os.path.basename(fn)}",
                       leave=False):
            text = tw.get('text','').strip()
            if not text:
                continue

            batch_uids.append(tw['author_id'])
            batch_texts.append(text)

            if len(batch_texts) >= BATCH_SIZE:
                flush_batch()

Files:   0%|          | 0/9 [04:39<?, ?it/s]


KeyboardInterrupt: 

In [None]:
flush_batch()
with open(CHECKPOINT_FILE, 'wb') as f:
    pickle.dump({
        'sum_embeds':   sum_embeds,
        'tweet_counts': tweet_counts,
        'processed':    processed
    }, f)
print(f"✅ Done! Total tweets processed: {processed}")

In [None]:
user_tweet_feats = []
for uid in ordered_uids:
    cnt = tweet_counts.get(uid, 0)
    if cnt > 0:
        avg = sum_embeds[uid] / cnt
    else:
        avg = torch.zeros(768)     # no tweets → zero vector
    user_tweet_feats.append(avg)

# shape [num_users, 768]
tweets_tensor = torch.stack(user_tweet_feats, dim=0)

# sanity check
print("Tweet‐feature tensor size:", tweets_tensor.shape)
# should be (len(ordered_uids), 768)

Tweet‐feature tensor size: torch.Size([1000000, 768])


In [None]:
torch.save(tweets_tensor, 'tweets_tensor.pt')
