In [1]:
import os
import torch
import pandas as pd
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.optim import lr_scheduler

import transformers
import tokenizers
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from tqdm.autonotebook import tqdm
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(42)
import re
from torchcontrib.optim import SWA

# Utils

In [4]:
MAX_LEN = 192
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 16
EPOCHS = 5
ROBERTA_PATH = "./roberta-base-squad2/"
TOKENIZER = tokenizers.ByteLevelBPETokenizer(
    vocab_file=f"{ROBERTA_PATH}/vocab.json", 
    merges_file=f"{ROBERTA_PATH}/merges.txt", 
    lowercase=True,
    add_prefix_space=True
    )

In [2]:
class AverageMeter:
    """
    Computes and stores the average and current value
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

# Preprocessing & Postprocessing

1. replace and retrieve unknown characters with relevant characters
2. seperate tokens such as '!!!' or '???' into single token when encoding

In [3]:
def preprocess_im(df):
    im_dict = {
        'iï¿½m':"i'm",
        'Iï¿½m':"I'm",
        'Iï¿½M':"I'M",
        'Iï¿½d':"I'd",
        'Iï¿½D':"I'D",
    }
    
    for key, item in im_dict.items():
        
        df.loc[df['text'].str.contains(key),'selected_text'] = df.loc[
                    df['text'].str.contains(key),'selected_text'
        ].apply(lambda x: re.sub(key, item, x))
        
        
        df.loc[df['text'].str.contains(key),'text'] = df.loc[
                    df['text'].str.contains(key),'text'
        ].apply(lambda x: re.sub(key, item, x))

    return df

def preprocess_all(df):
    
    df.loc[:, 'text'] = df.loc[:, 'text'].apply(lambda x: x.lower())
    df.loc[:, 'selected_text'] = df.loc[:, 'selected_text'].apply(lambda x: x.lower())
    
    proc_dict = {
        'ï¿½s':"'s",
        'nï¿½t':"n't",
        'ï¿½ve':"'ve",
        'ï¿½ll':"'ll",
        'ï¿½re':"'re",
        "inï¿½": "ing",
        "n`\*\*\*\*": "n't"
    }
    for key, item in proc_dict.items():
        if key == '`s':
            df.loc[df['text'].str.contains("(\w`s)"),'selected_text'] = df.loc[
                df['text'].str.contains("(\w`s)"),'selected_text'
            ].apply(lambda x: re.sub(key, item, x))
            
            df.loc[(df['text'].str.contains("(\w`s)")),'text'] = df.loc[
                df['text'].str.contains("(\w`s)"),'text'
            ].apply(lambda x: re.sub(key, item, x))
            
        else:
            df.loc[df['text'].str.contains(key),'selected_text'] = df.loc[
                df['text'].str.contains(key),'selected_text'
            ].apply(lambda x: re.sub(key, item, x))
            
            df.loc[df['text'].str.contains(key),'text'] = df.loc[
                df['text'].str.contains(key),'text'
            ].apply(lambda x: re.sub(key, item, x))
    
    df.loc[df['selected_text'].str.contains("((ï|¿|½))"),'selected_text'] = df.loc[
        df['selected_text'].str.contains("((ï|¿|½))"),'selected_text'
    ].apply(lambda x: re.sub("((ï|¿|½))", "", x))
    
    
    df.loc[df['text'].str.contains("((ï|¿|½))"),'text'] = df.loc[
        df['text'].str.contains("((ï|¿|½))"),'text'
    ].apply(lambda x: re.sub("((ï|¿|½))", "", x))

            
    return df

def preprocess_repeat(df):
    
    df.loc[df.text.str.contains("(?<=\.)(\.)(?<!\w)"), 'selected_text'] = df.loc[
        df.text.str.contains("(?<=\.)(\.)(?<!\w)")
    ].selected_text.apply(lambda x:re.sub(r'(?<=\.)(\.)(?<!\w)', r' \1', x))
    
    df.loc[df.text.str.contains("(?<=\.)(\.)(?<!\w)"), 'text'] = df.loc[
        df.text.str.contains("(?<=\.)(\.)(?<!\w)")
    ].text.apply(lambda x:re.sub(r'(?<=\.)(\.)(?<!\w)', r' \1', x))
    
    df.loc[df.text.str.contains("(?<=\!)(\!)(?<!\w)"), 'selected_text'] = df.loc[
        df.text.str.contains("(?<=\!)(\!)(?<!\w)")
    ].selected_text.apply(lambda x:re.sub(r'(?<=\!)(\!)(?<!\w)', r' \1', x))
    
    df.loc[df.text.str.contains("(?<=\!)(\!)(?<!\w)"), 'text'] = df.loc[
        df.text.str.contains("(?<=\!)(\!)(?<!\w)")
    ].text.apply(lambda x:re.sub(r'(?<=\!)(\!)(?<!\w)', r' \1', x))
    
    df.loc[df.text.str.contains("(?<=\?)(\?)(?<!\w)"), 'selected_text'] = df.loc[
        df.text.str.contains("(?<=\?)(\?)(?<!\w)")
    ].selected_text.apply(lambda x:re.sub(r'(?<=\?)(\?)(?<!\w)', r' \1', x))
    
    df.loc[df.text.str.contains("(?<=\?)(\?)(?<!\w)"), 'text'] = df.loc[
        df.text.str.contains("(?<=\?)(\?)(?<!\w)")
    ].text.apply(lambda x:re.sub(r'(?<=\?)(\?)(?<!\w)', r' \1', x))
    
    
    return df

def postprocess_repeat(df):
    
    df.loc[df.text_x.str.contains("(?<=\.)(\.)(?<!\w)"), 'pred'] = df.loc[
        df.text_x.str.contains("(?<=\.)(\.)(?<!\w)")
    ].pred.apply(lambda x:re.sub(r'(?<=\.)(\s)(?=(\.|\s))', "", x))
    
    df.loc[df.text_x.str.contains("(?<=\!)(\!)(?<!\w)"), 'pred'] = df.loc[
        df.text_x.str.contains("(?<=\!)(\!)(?<!\w)")
    ].pred.apply(lambda x:re.sub(r'(?<=\!)(\s)(?=(\!|\s))', "", x))
    
    df.loc[df.text_x.str.contains("(?<=\?)(\?)(?<!\w)"), 'pred'] = df.loc[
        df.text_x.str.contains("(?<=\?)(\?)(?<!\w)")
    ].pred.apply(lambda x:re.sub(r'(?<=\?)(\s)(?=(\?|\s))', "", x))
    
    return df


def postprocess_im(df):
    im_dict = {
        'iï¿½m':"i'm",
        'Iï¿½m':"I'm",
        'Iï¿½M':"I'M",
        'Iï¿½d':"I'd",
        'Iï¿½D':"I'D",
    }
    
    
    for key, item in im_dict.items():
        df.loc[df['text_x'].str.contains(key),'pred'] = df.loc[
                    df['text_x'].str.contains(key),'pred'
        ].apply(lambda x: re.sub(item.lower(), key, x))
                


    return df

def postprocess_all(df):
    
    df.loc[:, 'text_x'] = df.loc[:, 'text_x'].apply(lambda x: x.lower())
    df.loc[:, 'pred'] = df.loc[:, 'pred'].apply(lambda x: x.lower())
    
    proc_dict = {
        'ï¿½s':"'s",
        'nï¿½t':"n't",
        'ï¿½ve':"'ve",
        'ï¿½ll':"'ll",
        'ï¿½re':"'re",
        "inï¿½": "ing",
        "n`\*\*\*\*": "n't"
        
    }
    for key, item in proc_dict.items():
        if key == '`s':
            df.loc[df['text_x'].str.contains("(\w`s)"),'pred'] = df.loc[
                df['text_x'].str.contains("(\w`s)"),'pred'
            ].apply(lambda x: re.sub(item, key, x))
        else:
            df.loc[df['text_x'].str.contains(key),'pred'] = df.loc[
                df['text_x'].str.contains(key),'pred'
            ].apply(lambda x: re.sub(item, key, x))
            
    return df

# Data Conversion

In [5]:
def process_data(tweet, selected_text, sentiment, tokenizer, max_len):
    tweet = " " + " ".join(str(tweet).split())
    selected_text = " " + " ".join(str(selected_text).split())
    
    len_st = len(selected_text) - 1
    idx0 = None
    idx1 = None
    
    if len_st == 0:
        print(selected_text)

    for ind in (i for i, e in enumerate(tweet) if e == selected_text[1]):
        if " " + tweet[ind: ind+len_st] == selected_text:
            idx0 = ind
            idx1 = ind + len_st - 1
            break

    char_targets = [0] * len(tweet)
    if idx0 != None and idx1 != None:
        for ct in range(idx0, idx1 + 1):
            char_targets[ct] = 1
            
#     processed_tweet = re.sub('、、、', '、', tweet)

    tok_tweet = tokenizer.encode(tweet)
    tweet_offsets = [offset for i, offset in enumerate(tok_tweet.offsets) if tok_tweet.ids[i]!=47341]
    input_ids_orig = [id for id in tok_tweet.ids if id!=47341]

    target_idx = []
    for j, (offset1, offset2) in enumerate(tweet_offsets):
        if sum(char_targets[offset1: offset2]) > 0:
            target_idx.append(j)

    if len(target_idx) == 0:
        print(tweet, selected_text)
    
    targets_start = target_idx[0]
    targets_end = target_idx[-1]

    sentiment_id = {
        'positive': 1313,
        'negative': 2430,
        'neutral': 7974
    }
    
    input_ids = [0] + [sentiment_id[sentiment]] + [2] + [2] + input_ids_orig + [2]
    token_type_ids = [0, 0, 0, 0] + [0] * (len(input_ids_orig) + 1)
    mask = [1] * len(token_type_ids)
    tweet_offsets = [(0, 0)] * 4 + tweet_offsets + [(0, 0)]
    targets_start += 4
    targets_end += 4

    padding_length = max_len - len(input_ids)
    if padding_length > 0:
        input_ids = input_ids + ([1] * padding_length)
        mask = mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)
        tweet_offsets = tweet_offsets + ([(0, 0)] * padding_length)
    
    return {
        'ids': input_ids,
        'mask': mask,
        'token_type_ids': token_type_ids,
        'targets_start': targets_start,
        'targets_end': targets_end,
        'orig_tweet': tweet,
        'orig_selected': selected_text,
        'sentiment': sentiment,
        'offsets': tweet_offsets
    }


# Dataset

In [6]:
class TweetDataset:
    def __init__(self, tweet, sentiment, selected_text):
        self.tweet = tweet
        self.sentiment = sentiment
        self.selected_text = selected_text
        self.tokenizer = TOKENIZER
        self.max_len = MAX_LEN
    
    def __len__(self):
        return len(self.tweet)

    def __getitem__(self, item):
        data = process_data(
            self.tweet[item], 
            self.selected_text[item], 
            self.sentiment[item],
            self.tokenizer,
            self.max_len
        )

        return {
            'ids': torch.tensor(data["ids"], dtype=torch.long),
            'mask': torch.tensor(data["mask"], dtype=torch.long),
            'token_type_ids': torch.tensor(data["token_type_ids"], dtype=torch.long),
            'targets_start': torch.tensor(data["targets_start"], dtype=torch.long),
            'targets_end': torch.tensor(data["targets_end"], dtype=torch.long),
            'orig_tweet': data["orig_tweet"],
            'orig_selected': data["orig_selected"],
            'sentiment': data["sentiment"],
            'offsets': torch.tensor(data["offsets"], dtype=torch.long)
        }

# Model

In [7]:
class TweetModel(transformers.BertPreTrainedModel):
    def __init__(self, conf):
        super(TweetModel, self).__init__(conf)
        self.roberta = transformers.RobertaModel.from_pretrained(ROBERTA_PATH, config=conf)
        self.dropouts = nn.ModuleList([
            nn.Dropout(0.5) for _ in range(5)
        ])
        self.drop_out = nn.Dropout(0.1)
        self.avgpool = nn.AvgPool1d(4)
        self.l0 = nn.Linear(768, 2)
        torch.nn.init.normal_(self.l0.weight, std=0.02)

    
    def forward(self, ids, mask, token_type_ids):
        _, _, out  = self.roberta(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids
        )

        out = torch.cat((out[-1], out[-2], out[-3], out[-4]), dim=-1)
        out = self.avgpool(out)

        for i, dropout in enumerate(self.dropouts):
            if i == 0:
                h = self.l0(dropout(out))
            else:
                h += self.l0(dropout(out))
        logits = h / len(self.dropouts)

        start_logits, end_logits = logits.split(1, dim=-1)

        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        return start_logits, end_logits

# Distance Loss Function

In [8]:
def dist_between(start_logits, end_logits, device='cuda', max_seq_len=192):
    """get dist btw. pred & ground_truth"""

    linear_func = torch.tensor(np.linspace(0, 1, max_seq_len, endpoint=False), requires_grad=False)
    linear_func = linear_func.to(device)

    start_pos = (start_logits*linear_func).sum(axis=1)
    end_pos = (end_logits*linear_func).sum(axis=1)

    diff = end_pos-start_pos

    return diff.sum(axis=0)/diff.size(0)


def dist_loss(start_logits, end_logits, start_positions, end_positions, device='cuda', max_seq_len=192, scale=2):
    """calculate distance loss between prediction's length & GT's length
    
    Input
    - start_logits ; shape (batch, max_seq_len{128})
        - logits for start index
    - end_logits
        - logits for end index
    - start_positions ; shape (batch, 1)
        - start index for GT
    - end_positions
        - end index for GT
    """
    start_logits = torch.nn.Softmax(1)(start_logits) # shape ; (batch, max_seq_len)
    end_logits = torch.nn.Softmax(1)(end_logits)
    
    start_one_hot = torch.nn.functional.one_hot(start_positions, num_classes=max_seq_len).to(device)
    end_one_hot = torch.nn.functional.one_hot(end_positions, num_classes=max_seq_len).to(device)
    
    pred_dist = dist_between(start_logits, end_logits, device, max_seq_len)
    gt_dist = dist_between(start_one_hot, end_one_hot, device, max_seq_len) # always positive
    diff = (gt_dist-pred_dist)

    rev_diff_squared = 1-torch.sqrt(diff*diff) # as diff is smaller, make it get closer to the one
    loss = -torch.log(rev_diff_squared) # by using negative log function, if argument is near zero -> inifinite, near one -> zero

    return loss*scale

In [10]:
def loss_fn(start_logits, end_logits, start_positions, end_positions):
    loss_fct = nn.CrossEntropyLoss()
    start_loss = loss_fct(start_logits, start_positions)
    end_loss = loss_fct(end_logits, end_positions)
    distance_loss = dist_loss(start_logits, end_logits, start_positions, end_positions)
    total_loss = (start_loss + end_loss) + distance_loss
    return total_loss

In [11]:
class EarlyStopping(object):
    def __init__(self, mode='max', min_delta=0, patience=10, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)
        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False
    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False
        if np.isnan(metrics):
            return True
        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1
        if self.num_bad_epochs >= self.patience:
            return True
        return False
    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

# Competition Metric Calculation

In [12]:
def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

def calculate_jaccard_score(
    original_tweet, 
    target_string, 
    sentiment_val, 
    idx_start, 
    idx_end, 
    offsets,
    verbose=False):
    
    if idx_end < idx_start:
        idx_end = idx_start
    
    filtered_output  = ""
    for ix in range(idx_start, idx_end + 1):
        filtered_output += original_tweet[offsets[ix][0]: offsets[ix][1]]
        if (ix+1) < len(offsets) and offsets[ix][1] < offsets[ix+1][0]:
            filtered_output += " "

    # if sentiment_val == "neutral":
    #     filtered_output = original_tweet

    jac = jaccard(target_string.strip(), filtered_output.strip())
    return jac, filtered_output

In [13]:
def train_fn(train_data_loader, model, optimizer, device, scheduler=None):
    model.train()
    running_loss = 0
    for batch, d in enumerate(train_data_loader):
        ids = d["ids"]
        token_type_ids = d["token_type_ids"]
        mask = d["mask"]
        targets_start = d["targets_start"]
        targets_end = d["targets_end"]

        optimizer.zero_grad()
        ids = ids.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets_start = targets_start.to(device, dtype=torch.long)
        targets_end = targets_end.to(device, dtype=torch.long)
        outputs_start, outputs_end = model(
                ids=ids,
                mask=mask,
                token_type_ids=token_type_ids
        )

        loss = loss_fn(outputs_start, outputs_end, targets_start, targets_end)
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        running_loss += loss.item()
        if batch % 300 == 0 and batch > 0:
            print('Epoch: %d | %5d\%5d batches | loss: %.5f' %
                  (epoch + 1, batch, len(train_data_loader), running_loss / 300))
            running_loss = 0.0

In [15]:
def eval_fn(valid_data_loader, model, device):
    model.eval()
    losses = AverageMeter()
    jaccards = AverageMeter()

    with torch.no_grad():
        tk0 = tqdm(valid_data_loader, total=len(valid_data_loader))
        for bi, d in enumerate(tk0):
            ids = d["ids"]
            token_type_ids = d["token_type_ids"]
            mask = d["mask"]
            sentiment = d["sentiment"]
            orig_selected = d["orig_selected"]
            orig_tweet = d["orig_tweet"]
            targets_start = d["targets_start"]
            targets_end = d["targets_end"]
            offsets = d["offsets"].numpy()

            ids = ids.to(device, dtype=torch.long)
            token_type_ids = token_type_ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            targets_start = targets_start.to(device, dtype=torch.long)
            targets_end = targets_end.to(device, dtype=torch.long)

            outputs_start, outputs_end = model(
                ids=ids,
                mask=mask,
                token_type_ids=token_type_ids
            )
            loss = loss_fn(outputs_start, outputs_end, targets_start, targets_end)
            outputs_start = torch.softmax(outputs_start, dim=1).cpu().detach().numpy()
            outputs_end = torch.softmax(outputs_end, dim=1).cpu().detach().numpy()
            jaccard_scores = []
            for px, tweet in enumerate(orig_tweet):
                selected_tweet = orig_selected[px]
                tweet_sentiment = sentiment[px]
                jaccard_score, _  = calculate_jaccard_score(
                    original_tweet=tweet,
                    target_string=selected_tweet,
                    sentiment_val=tweet_sentiment,
                    idx_start=np.argmax(outputs_start[px, :]),
                    idx_end=np.argmax(outputs_end[px, :]),
                    offsets=offsets[px]
                )
                jaccard_scores.append(jaccard_score)

            jaccards.update(np.mean(jaccard_scores), ids.size(0))
            losses.update(loss.item(), ids.size(0))
            tk0.set_postfix(loss=losses.avg, jaccard=jaccards.avg)
    
    print(f"Jaccard = {jaccards.avg}, Loss = {losses.avg}")
    return jaccards.avg

In [None]:
total_jaccard = []

for fold in range(5):
    
    device = torch.device("cuda")
    model_config = transformers.RobertaConfig.from_pretrained(ROBERTA_PATH)
    model_config.output_hidden_states = True
    model = TweetModel(conf=model_config)
    model.to(device)
    best_jaccard = 0
    
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    ]

    es = EarlyStopping(patience=2)
    es.step(0)

    train = pd.read_pickle('train_folds_42_clean_steph_v1.pkl')
    train = train.loc[(train.kfold != fold)].reset_index(drop=True)

    train = preprocess_im(train)
    train = preprocess_all(train)
    train = preprocess_repeat(train)
    

    valid = pd.read_csv('train_folds_42.csv')
    valid = valid.loc[(valid.kfold == fold)].reset_index(drop=True)

    valid = preprocess_im(valid)
    valid = preprocess_all(valid)
    valid = preprocess_repeat(valid)
    
    
    train_data = TweetDataset(tweet=train.text.values,
                              sentiment=train.sentiment.values,
                              selected_text=train.selected_text.values)

    valid_data = TweetDataset(tweet=valid.text.values,
                              sentiment=valid.sentiment.values,
                              selected_text=valid.selected_text.values)

    train_data_loader = torch.utils.data.DataLoader(
                            train_data,
                            shuffle=False,
                            batch_size=TRAIN_BATCH_SIZE,
                            num_workers=1
                        )

    valid_data_loader = torch.utils.data.DataLoader(
                            valid_data,
                            shuffle=False,
                            batch_size=VALID_BATCH_SIZE,
                            num_workers=1
                        )

    num_train_steps = int(len(train) / TRAIN_BATCH_SIZE * EPOCHS)
    optimizer = transformers.AdamW(optimizer_parameters, lr=3e-5)
    scheduler = transformers.get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )

    for epoch in range(5):

        train_fn(train_data_loader, model, optimizer, device, scheduler)
        jaccard_score = eval_fn(valid_data_loader, model, device)

        if es.step(jaccard_score):
            break
        if jaccard_score > best_jaccard:
            torch.save(model.state_dict(), f'Roberta_Local_{fold+1}.pth')
            best_jaccard = jaccard_score

    total_jaccard.append(best_jaccard)

# Inference

In [None]:
def get_best_start_end_idxs(_start_logits, _end_logits):
    best_logit = -1000
    best_idxs = None
    for start_idx, start_logit in enumerate(_start_logits):
        addition = np.repeat(start_logit, len(_end_logits[start_idx:])) + _end_logits[start_idx:]
        argmax = addition.argmax()
        maximum = addition[addition.argmax()].item()
        if maximum > best_logit:
            best_logit = maximum
            best_idxs = (start_idx, start_idx+argmax)
    
    return best_idxs

In [16]:
def inf_fn(data_loader, model, device):
    final_output = []
    model.eval()
    losses = AverageMeter()
    jaccards = AverageMeter()
    start_losses = []
    end_losses = []
    with torch.no_grad():
        tk0 = tqdm(data_loader, total=len(data_loader))
        for bi, d in enumerate(tk0):
            ids = d["ids"]
            token_type_ids = d["token_type_ids"]
            mask = d["mask"]
            sentiment = d["sentiment"]
            orig_selected = d["orig_selected"]
            orig_tweet = d["orig_tweet"]
            targets_start = d["targets_start"]
            targets_end = d["targets_end"]
            offsets = d["offsets"].numpy()
            ids = ids.to(device, dtype=torch.long)
            token_type_ids = token_type_ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            targets_start = targets_start.to(device, dtype=torch.long)
            targets_end = targets_end.to(device, dtype=torch.long)
            outputs_start, outputs_end = model(
                ids=ids,
                mask=mask,
                token_type_ids=token_type_ids
            )
            ce_loss = nn.CrossEntropyLoss(reduction='none')
            start_loss = ce_loss(outputs_start, targets_start)
            start_losses.append(start_loss.cpu().numpy())
            end_loss = ce_loss(outputs_end, targets_end)
            end_losses.append(end_loss.cpu().numpy())
            outputs_start = torch.softmax(outputs_start, dim=1).cpu().detach().numpy()
            outputs_end = torch.softmax(outputs_end, dim=1).cpu().detach().numpy()
            jaccard_scores = []
            for px, tweet in enumerate(orig_tweet):
                selected_tweet = orig_selected[px]
                tweet_sentiment = sentiment[px]
                best_start, best_end = get_best_start_end_idxs(outputs_start[px, :], outputs_end[px, :])
                _, output_sentence = calculate_jaccard_score(
                    original_tweet=tweet,
                    target_string=selected_tweet,
                    sentiment_val=tweet_sentiment,
                    idx_start=best_start,
                    idx_end=best_end,
#                     idx_start=np.argmax(outputs_start[px, :]),
#                     idx_end=np.argmax(outputs_end[px, :]),
                    offsets=offsets[px]
                )
                final_output.append(output_sentence)
        start_losses = np.hstack(start_losses)
        end_losses = np.hstack(end_losses)
    return final_output, start_losses, end_losses

In [None]:
oof = pd.DataFrame()
for fold in range(1,6):
    valid_all = pd.read_csv('train_folds_42.csv')
    valid_all = valid_all.loc[(valid_all.kfold==(fold-1))].reset_index(drop=True)
    valid_all['original_selected_text'] = valid_all['selected_text'].copy()
    
    valid_all = preprocess_im(valid_all)
    valid_all = preprocess_all(valid_all)
    valid_all = preprocess_repeat(valid_all)
    valid = valid_all.copy()
    
    valid_data = TweetDataset(tweet=valid.text.values,
                              sentiment=valid.sentiment.values,
                              selected_text=valid.selected_text.values)
    valid_data_loader = torch.utils.data.DataLoader(
        valid_data,
        shuffle=False,
        batch_size=VALID_BATCH_SIZE,
        num_workers=1
    )
    device = torch.device("cuda")
    model_config = transformers.RobertaConfig.from_pretrained(ROBERTA_PATH)
    model_config.output_hidden_states = True
    model = TweetModel(conf=model_config)
    model.to(device)
    model.load_state_dict(torch.load(f'Roberta_Local_{fold}.pth'))
    model.to(device)
    temp_valid = valid.copy()
    best_jaccard = 0
    # print('Epoch:', epoch+1, '|', end=' ')
    temp_valid['pred'], temp_valid['start_loss'], temp_valid['end_loss'] = inf_fn(valid_data_loader, model, device)
    # temp_valid.loc[temp_valid['pred'] == '', 'pred'] = temp_valid.loc[temp_valid['pred'] == '', 'text']
    temp_valid['j_score'] = temp_valid.apply(lambda x: jaccard(x['selected_text'], x['pred']), axis=1)
    jaccard_score = temp_valid['j_score'].mean()
    valid_final = temp_valid.copy()
    valid_final['j_score'] = valid_final.apply(lambda x: jaccard(x['selected_text'], x['pred']), axis=1)
    print('fold', fold, ':', valid_final['j_score'].mean(), valid_final['start_loss'].mean(), valid_final['end_loss'].mean())
    oof = oof.append(valid_final, ignore_index=True)

# Postprocess Predictions 

In [None]:
valid_real = pd.read_csv('train_folds_42.csv')
all_valid = valid_real.merge(oof, how='left', left_on='textID', right_on='textID')

all_valid = postprocess_im(all_valid)
all_valid = postprocess_all(all_valid)
all_valid = postprocess_repeat(all_valid)
all_valid['pred'] = all_valid.pred.apply(lambda x: x.replace("\*", "*"))

all_valid['pred'] = all_valid['pred'].apply(lambda x: x.replace('!!!!', '!!') if len(x.split())==1 else x)
all_valid['pred'] = all_valid['pred'].apply(lambda x: x.replace('!!!', '!!') if len(x.split())==1 else x)
all_valid['pred'] = all_valid['pred'].apply(lambda x: x.replace('....', '..') if len(x.split())==1 else x)
all_valid['pred'] = all_valid['pred'].apply(lambda x: x.replace('...', '..') if len(x.split())==1 else x)


all_valid.loc[all_valid['sentiment_x']=='neutral', 'pred'] = all_valid.loc[
    all_valid['sentiment_x']=='neutral', 'text_x']

all_valid['j_score_real'] = all_valid.apply(lambda x: jaccard(x['selected_text_x'], x['pred']), axis=1)
all_valid.j_score_real.mean()