In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset

# vectorize sequences as inputs 
def seq_to_tensor(raw_sample,dim_model=728):
    seq_embed = torch.tensor([np.concatenate([word_embedding[word2idx[w]],pos_embedding[pos2idx[raw_sample["pos"][i]]]]) 
                       for i,w in enumerate(raw_sample["word"])])
    return seq_embed
    
# dynamic padding: seqeuences are padded to the maximum length of mini-batch sequences
def collate_fn(batch):
    sorted_batch = sorted(batch, key=lambda x: x[0].size(0), reverse=True)
    sequences = [x[0] for x in sorted_batch]
    sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)
    lengths = torch.LongTensor([len(x) for x in sequences])
    labels = torch.LongTensor(list(map(lambda x: x[1], sorted_batch)))
    feats = torch.FloatTensor(list(map(lambda x: x[2], sorted_batch)))
    return sequences_padded, labels, lengths, feats

class IronyDataset(Dataset):
    def __init__(self, raw_data, transform=None, addition_feats=None, with_label=True):
        self.data = raw_data
        self.transform = transform
        self.addition_feats = addition_feats
        self.with_label = with_label
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sample = self.data[index]
        label = []
        feats = []
        if self.with_label:
            label = self.data[index]["label"]
        
        if self.transform is not None:
            sample = self.transform(sample)
            
        if self.addition_feats is not None:
            feats = self.addition_feats[index]
            
        return sample, label, feats