In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchtext
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset

from tqdm.auto import tqdm

In [2]:
def str_to_list(value):
    list_values = value.strip('[]').split(', ')
    cleaned_list_values = [item[1:-1] for item in list_values]
    return cleaned_list_values

In [3]:
dataset = pd.read_csv("../datasets/tonetags_wsd_1.csv", index_col=0, converters={"text": str_to_list})

In [4]:
labels = dataset.tags.unique().tolist()
dataset.tags = dataset.tags.apply(labels.index)

In [5]:
vocab = torchtext.vocab.GloVe(name='6B', dim=50).stoi
vocab["<unk>"] = len(vocab)
vocab["<pad>"] = len(vocab)

In [6]:
max_length = 4096

In [7]:
class myDataset(Dataset):
    def __init__(self, dataset):
        self.data = []
        for sentence in dataset.text:
            if len(sentence) > max_length:
                continue
            sentence_ids = []
            for token in sentence:
                try:
                    sentence_ids.append(vocab[token])
                except KeyError:
                    sentence_ids.append(vocab["<unk>"])
            self.data.append(sentence_ids)
        self.labels = dataset.tags
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        return self.data[idx], torch.tensor(self.labels.iloc[idx])

In [8]:
def collate_fn(batch):
    data_ids = []
    labels = []
    for dat in batch:
        data_ids.append(dat[0])
        labels.append(dat[1])
    
    for i in range(len(data_ids)):
        while len(data_ids[i]) < max_length:
            data_ids[i].append(vocab["<pad>"])
    
    return data_ids, labels

In [9]:
train, test = train_test_split(dataset, test_size=0.2, shuffle=False)

In [10]:
train_dataset = myDataset(train)
test_dataset = myDataset(test)

In [11]:
batch_size = 64

In [12]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [13]:
for batch in tqdm(train_dataloader):
    embedded_tokens = batch[0]
    labels = batch[1]

  0%|          | 0/1057 [00:00<?, ?it/s]

In [14]:
vec = torchtext.vocab.GloVe('6B', dim=50).vectors.numpy()
vec = np.append(vec, np.zeros(50)).reshape(-1, 50)
vec = np.append(vec, np.ones(50)).reshape(-1, 50)

In [15]:
embed_tensor = torch.tensor(vec, dtype=torch.float)

In [16]:
embed = nn.Embedding.from_pretrained(embed_tensor, freeze=True)