In [3]:
import os, sys
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
sys.path.append("src")

# Change data

In [None]:
import re
import pandas as pd
import torch
from torch.utils.data import Dataset


def process_data(tweet, selected_text, sentiment, tokenizer, max_len):
    """
    Processes the tweet and outputs the features necessary for model training and inference.
    
    Note: there are some differences between this and the BERT version (bert-case-uncased), mostly due to differences in token codes and special tokens
    """
    tweet = " " + " ".join(str(tweet).split())
    selected_text = " " + " ".join(str(selected_text).split())

    len_st = len(selected_text) - 1
    idx0 = None
    idx1 = None

    for ind in (i for i, e in enumerate(tweet) if e == selected_text[1]):
        if " " + tweet[ind: ind+len_st] == selected_text:
            idx0 = ind
            idx1 = ind + len_st - 1
            break

    char_targets = [0] * len(tweet)
    if idx0 != None and idx1 != None:
        for ct in range(idx0, idx1 + 1):
            char_targets[ct] = 1
    
    tok_tweet = tokenizer.encode(tweet)
    input_ids_orig = tok_tweet.ids
    tweet_offsets = tok_tweet.offsets
    new_words = [1 if (offset_from==0) or (tweet[offset_from-1]==" ") else 0 for offset_from, _ in tweet_offsets]   
    
    target_idx = []
    for j, (offset1, offset2) in enumerate(tweet_offsets):
        if sum(char_targets[offset1: offset2]) > 0:
            target_idx.append(j)
    
    targets_start = target_idx[0]
    targets_end = target_idx[-1]

    sentiment_id = {
        'positive': 1313,
        'negative': 2430,
        'neutral': 7974
    }
    
    input_ids = [0] + [sentiment_id[sentiment]] + [2] + [2] + input_ids_orig + [2]
    token_type_ids = [0, 0, 0, 0] + [0] * (len(input_ids_orig) + 1)
    mask = [1] * len(token_type_ids)
    tweet_offsets = [(0, 0)] * 4 + tweet_offsets + [(0, 0)]
    targets_start += 4
    targets_end += 4
    new_words = [0] * 4 + new_words + [0]

    padding_length = max_len - len(input_ids)
    if padding_length > 0:
        input_ids = input_ids + ([1] * padding_length)
        mask = mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)
        tweet_offsets = tweet_offsets + ([(0, 0)] * padding_length)
        new_words = new_words + ([0] * padding_length)
    
    return {
        'ids': input_ids,
        'mask': mask,
        'token_type_ids': token_type_ids,
        'targets_start': targets_start,
        'targets_end': targets_end,
        'orig_tweet': tweet,
        'orig_selected': selected_text,
        'sentiment': sentiment,
        'offsets': tweet_offsets,
        'new_words': new_words
    }


class TweetDataset(Dataset):

    def __init__(self, df_path, folds, tokenizer,
                 max_len=192, max_num_samples=None):
        if isinstance(folds, int):
            folds = [folds]
        df = pd.read_csv(df_path)
        df = df[df['fold'].isin(folds)]
        if max_num_samples is not None:
            df = df.iloc[:max_num_samples]
        self.tweet = df['text'].values
        self.sentiment = df['sentiment'].values
        self.selected_text = df['selected_text'].values
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.tweet)

    def __getitem__(self, item):
        data = process_data(
            self.tweet[item], 
            self.selected_text[item], 
            self.sentiment[item],
            self.tokenizer,
            self.max_len
        )

        return {
            'ids': torch.tensor(data['ids'], dtype=torch.long),
            'mask': torch.tensor(data['mask'], dtype=torch.long),
            'token_type_ids': torch.tensor(data['token_type_ids'], dtype=torch.long),
            'start_positions': torch.tensor(data['targets_start'], dtype=torch.long),
            'end_positions': torch.tensor(data['targets_end'], dtype=torch.long),
            'orig_tweet': data['orig_tweet'],
            'orig_selected': data['orig_selected'],
            'sentiment': data['sentiment'],
            'offsets': torch.tensor(data['offsets'], dtype=torch.long)
            #'offsets': data['offsets']
        }