# Import Libraries

In [None]:
import re

import numpy as np
import pandas as pd
import spacy
import torch
import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

print(torchtext.__version__)

In [None]:
def seed_everything(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(42)

# Load Data

In [None]:
with open("data/semeval-tweets/twitter-training-data.txt", encoding="utf8") as f:
    data = f.readlines()
    data = [line.strip().split("\t") for line in data]
    data = pd.DataFrame(data, columns=["id", "label", "tweet"])

In [None]:
data.head()

# Preprocess Data

In [None]:
def remove_user_mentions(tweet: str):
    user_handle_pattern = re.compile("(@[a-zA-Z0-9_]+)")

    return user_handle_pattern.sub("", tweet)

In [None]:
def remove_tweet_hashtag(tweet: str):
    hashtag_pattern = re.compile("#(\w+)")

    return hashtag_pattern.sub("", tweet)

In [None]:
def remove_url(tweet: str):
    url_pattern = re.compile(
        "http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+"
    )
    tweet = url_pattern.sub("", tweet)
    return tweet

In [None]:
def remove_special_characters(tweet: str):
    special_characters_pattern = re.compile("[^a-zA-Z0-9\s]")

    return special_characters_pattern.sub("", tweet)

In [None]:
def remove_digits(tweet: str):
    digits_pattern = re.compile(r"\b\d+\b")
    # single character word: \b\w{1}\b

    return digits_pattern.sub("", tweet)

In [None]:
tokenizer = get_tokenizer(tokenizer="spacy", language="en_core_web_sm")
# tokenizer = get_tokenizer(tokenizer='basic_english')

In [None]:
nlp = spacy.load("en_core_web_sm")

In [None]:
type(nlp)

In [None]:
def preprocess_tweet(tweet: str, nlp) -> list[str]:
    tweet = remove_url(tweet)  # what about emails?
    tweet = remove_user_mentions(tweet)
    tweet = remove_tweet_hashtag(tweet)
    # tweet = remove_special_characters(tweet)
    tweet = remove_digits(tweet)
    # remove multiple spaces
    tweet = re.sub(r"\s+", " ", tweet)
    # remove leading and trailing spaces
    tweet = tweet.strip()
    # lowercase
    tweet = tweet.lower()

    # tokenize
    doc = nlp(tweet)
    tweet_tokens = [
        token.text
        for token in doc
        if not token.is_stop and not token.is_punct and not token.is_space
    ]

    return tweet_tokens

In [None]:
# get 10 random tweets and apply the preprocessing
for tweet in data["tweet"].sample(10):
    print(tweet)
    print(preprocess_tweet(tweet, nlp))
    print("===" * 20)

In [None]:
preprocessed_tweets = [preprocess_tweet(tweet, nlp) for tweet in data["tweet"]]

In [None]:
preprocessed_tweets[:2]

# Build vocabulary

In [None]:
special_tokens = ["<unk>", "<pad>"]
vocab = build_vocab_from_iterator(preprocessed_tweets, specials=special_tokens)

In [None]:
unk_index = vocab["<unk>"]
pad_index = vocab["<pad>"]

In [None]:
vocab.set_default_index(unk_index)

In [None]:
vocab.lookup_indices(["hello", "world", "trump", "this", "is", "good"])

In [None]:
def convert_tweet_to_ids(tweet: str):
    tweet_ids = vocab.lookup_indices(tweet)
    return tweet_ids

In [None]:
tweet_ids = list(map(convert_tweet_to_ids, preprocessed_tweets))

In [None]:
tweet_ids[0]

In [None]:
tweet_ids[1]

In [None]:
vocab.lookup_tokens(tweet_ids[0])

# Dataloader

In [None]:
from torch.utils.data import DataLoader, Dataset

In [None]:
class TweetDataSet(Dataset):
    def __init__(self, tweet_ids_list, labels_list):
        self.tweet_ids_list = tweet_ids_list
        self.labels_list = labels_list

    def __len__(self):
        return len(self.tweet_ids_list)

    def __getitem__(self, idx):
        return self.tweet_ids_list[idx], self.labels_list[idx]

In [None]:
ds = TweetDataSet(tweet_ids, data['label'])

In [None]:
ds[:3]

In [None]:
def collate_fn(batch):
    pass
    