In [1]:
import os
import gc
import sys
import json
import time
import torch
import joblib
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
import plotly.express as px
import matplotlib.pyplot as plt

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
pd.set_option('display.float_format', lambda x: '%.4f' % x)

# Params

In [2]:
data_path = Path(r"/database/kaggle/PII/data")
os.listdir(data_path)

['train.json', 'test.json', 'sample_submission.csv']

In [3]:
df = pd.read_json(data_path/'train.json')
df.shape

(6807, 5)

In [4]:
df.head(2)

Unnamed: 0,document,full_text,tokens,trailing_whitespace,labels
0,7,Design Thinking for innovation reflexion-Avril...,"[Design, Thinking, for, innovation, reflexion,...","[True, True, True, True, False, False, True, F...","[O, O, O, O, O, O, O, O, O, B-NAME_STUDENT, I-..."
1,10,Diego Estrada\n\nDesign Thinking Assignment\n\...,"[Diego, Estrada, \n\n, Design, Thinking, Assig...","[True, False, False, True, True, False, False,...","[B-NAME_STUDENT, I-NAME_STUDENT, O, O, O, O, O..."


# Data

In [11]:
import spacy
from spacy import displacy
from pylab import cm, matplotlib
import os

colors = {
            'NAME_STUDENT': '#8000ff',
            'EMAIL': '#2b7ff6',
            'USERNAME': '#2adddd',
            'ID_NUM': '#80ffb4',
            'PHONE_NUM': 'd4dd80',
            'URL_PERSONAL': '#ff8042',
            'STREET_ADDRESS': '#ff0000'
         }


def visualize(full_text,offset_mapping,labels):
    
    ents = []
    for offset,lab in zip(offset_mapping,labels):
        ents.append({
                        'start': int(offset[0]), 
                         'end': int(offset[1]), 
                         'label': str(lab.split('-')[1]) #+ ' - ' + str(row['discourse_effectiveness'])
                    })

    doc2 = {
        "text": full_text,
        "ents": ents,
#         "title": "idx"
    }

    options = {"ents": list(colors.keys()), "colors": colors}
    displacy.render(doc2, style="ent", options=options, manual=True, jupyter=True)

In [5]:
import re
from difflib import SequenceMatcher

import codecs
import os
from collections import Counter
from typing import Dict, List, Tuple

from tqdm.auto import tqdm
import numpy as np
import pandas as pd
from text_unidecode import unidecode
import joblib
import torch

def replace_encoding_with_utf8(error: UnicodeError) -> Tuple[bytes, int]:
    return error.object[error.start : error.end].encode("utf-8"), error.end


def replace_decoding_with_cp1252(error: UnicodeError) -> Tuple[str, int]:
    return error.object[error.start : error.end].decode("cp1252"), error.end


# Register the encoding and decoding error handlers for `utf-8` and `cp1252`.
codecs.register_error("replace_encoding_with_utf8", replace_encoding_with_utf8)
codecs.register_error("replace_decoding_with_cp1252", replace_decoding_with_cp1252)



def resolve_encodings_and_normalize(text: str) -> str:
    """Resolve the encoding problems and normalize the abnormal characters."""
    text = (
        text.encode("raw_unicode_escape")
        .decode("utf-8", errors="replace_decoding_with_cp1252")
        .encode("cp1252", errors="replace_encoding_with_utf8")
        .decode("utf-8", errors="replace_decoding_with_cp1252")
    )
    text = unidecode(text)
    return text


def clean_text(text):
    text = text.replace(u'\x9d', u' ')
    text = resolve_encodings_and_normalize(text)
    # text = text.replace(u'\xa0', u' ')
    # text = text.replace(u'\x85', u'\n')
    text = text.strip()
    return text

def add_text_to_df(test_df,data_folder):
    mapper = {}
    for idx in tqdm(test_df.essay_id.unique()):
        with open(data_folder/f'{idx}.txt','r') as f:
            texte = clean_text(f.read())
            # texte = resolve_encodings_and_normalize(f.read())
            # texte = texte.strip() 
        mapper[idx] = texte

    test_df['discourse_ids'] = np.arange(len(test_df))
    test_df['essay_text'] = test_df['essay_id'].map(mapper)
    test_df['discourse_text'] = test_df['discourse_text'].transform(clean_text)
    test_df['discourse_text'] = test_df['discourse_text'].str.strip()

    test_df['previous_discourse_end'] = 0
    test_df['st_ed'] = test_df.apply(get_start_end('discourse_text'),axis=1)
    test_df['discourse_start'] = test_df['st_ed'].transform(lambda x:x[0])
    test_df['discourse_end'] = test_df['st_ed'].transform(lambda x:x[1])
    test_df['previous_discourse_end'] = test_df.groupby("essay_id")['discourse_end'].transform(lambda x:x.shift(1).fillna(0)).astype(int)
    test_df['st_ed'] = test_df.apply(get_start_end('discourse_text'),axis=1)
    test_df['discourse_start'] = test_df['st_ed'].transform(lambda x:x[0]) #+ test_df['previous_discourse_end']
    test_df['discourse_end'] = test_df['st_ed'].transform(lambda x:x[1]) #+ test_df['previous_discourse_end']

    if 'target' in test_df.columns:
        classe_mapper = {'Ineffective':0,"Adequate":1,"Effective":2}
        test_df['target'] = test_df['discourse_effectiveness'].map(classe_mapper)
        
    else:
        test_df['target'] = 1 

    return test_df

def get_essays(df,n_cpu=4):
    
    pool = joblib.Parallel(n_cpu)
    mapper = joblib.delayed(_get_essay)
    tasks = [mapper(df) for idx,df in df.groupby('id')]
    ids = [idx for idx,_ in df.groupby('id')]
    
    return pd.DataFrame({"id":ids,"essay":pool(tqdm(tasks))})

def _get_essay(df):
    text_recons = ''
    for i, (id_,row) in enumerate(df.iterrows()):
    #     print(row)
        activity = row["activity"]
        curs_pos = row["cursor_position"] # cursor position AFTER activity!
        text_change = row["text_change"]


        if activity == 'Input' or activity == 'Paste':
            text_recons = text_recons[:curs_pos - len(text_change)] + text_change + text_recons[curs_pos - len(text_change):]   
        if activity == 'Remove/Cut':
            text_recons = text_recons[:curs_pos] + text_recons[curs_pos + len(text_change):]
        if activity == 'Replace': # Combined remove and input operation
            cut, add = text_change.split(' => ')
            text_recons = text_recons[:curs_pos - len(add)] + add + text_recons[curs_pos - len(add) + len(cut):]

        if "Move" in activity:
            a, b, c, d = map(
                        int,
                        re.match(
                            r"Move From \[(\d+), (\d+)\] To \[(\d+), (\d+)\]",
                            activity,
                        ).groups(),
                    )

            if a != c:
                if a < c:
                    text_recons = text_recons[:a] + text_recons[b:d] + text_recons[a:b] + text_recons[d:]
                else:
                    text_recons = text_recons[:c] + text_recons[a:b] + text_recons[c:a] + text_recons[b:]
                    
    return text_recons


def get_text_start_end(txt, s, search_from=0):
    txt = txt[int(search_from):]
    try:
        idx = txt.find(s)
        if idx >= 0:
            st = idx
            ed = st + len(s)
        else:
            raise ValueError('Error')
    except:
        res = [(m.start(0), m.end(0)) for m in re.finditer(s, txt)]
        if len(res):
            st, ed = res[0][0], res[0][1]
        else:
            m = SequenceMatcher(None, s, txt).get_opcodes()
            for tag, i1, i2, j1, j2 in m:
                if tag == 'replace':
                    s = s[:i1] + txt[j1:j2] + s[i2:]
                if tag == "delete":
                    s = s[:i1] + s[i2:]

            res = [(m.start(0), m.end(0)) for m in re.finditer(s, txt)]
            if len(res):
                st, ed = res[0][0], res[0][1]
            else:
                idx = txt.find(s)
                if idx >= 0:
                    st = idx
                    ed = st + len(s)
                else:
                    st, ed = 0, 0
    return st + search_from, ed + search_from


def get_offset_mapping(full_text, tokens):
    offset_mapping = []

    current_offset = 0
    for token in tokens:
        start, end = get_text_start_end(full_text, token, search_from=current_offset)
        offset_mapping.append((start, end))
        current_offset = end

    return offset_mapping

def get_start_end(col):
    def search_start_end(row):
        txt = row.essay_text
        search_from = row.previous_discourse_end
        s = row[col]
        # print(search_from)
        return get_text_start_end(txt,s,search_from)
    return search_start_end

def batch_to_device(batch, device):
    batch_dict = {key: batch[key].to(device) for key in batch}
    return batch_dict


def text_to_words(text):
    word = text.split()
    word_offset = []

    start = 0
    for w in word:
        r = text[start:].find(w)

        if r==-1:
            raise NotImplementedError
        else:
            start = start+r
            end   = start+len(w)
            word_offset.append((start,end))
        start = end

    return word, word_offset

def text_to_sentence(text):
    sentences = re.split(r' *[\.\?!\n][\'"\)\]]* *', text)
    sentences = [x for x in sentences if x!=""]
    
    sentence_offset = []
    start = 0
    for w in sentences:
        r = text[start:].find(w)

        if r==-1:
            raise NotImplementedError
        else:
            start = start+r
            end   = start+len(w)
            sentence_offset.append((start,end))
        start = end

    return sentences,sentence_offset

def text_to_paragraph(text):
    sentences = re.split(r' *[\n][\'"\)\]]* *', text)
    sentences = [x for x in sentences if x!=""]
    
    sentence_offset = []
    start = 0
    for w in sentences:
        r = text[start:].find(w)

        if r==-1:
            raise NotImplementedError
        else:
            start = start+r
            end   = start+len(w)
            sentence_offset.append((start,end))
        start = end

    return sentences,sentence_offset


def get_span_from_text(text,span_type="words"):
    
    if span_type=="words":
        spans,spans_offset = text_to_words(text)
    elif span_type=="sentences":
        spans,spans_offset = text_to_sentence(text)
    else:
        spans,spans_offset = text_to_paragraph(text)
    
    return spans,spans_offset


def get_span_len_from_text(text,span_type="words"):
    
    if span_type=="words":
        spans_len = len(text.split())
    elif span_type=="sentences":
        sentences = re.split(r' *[\.\?!\n][\'"\)\]]* *', text)
        spans_len = len([x for x in sentences if x!=""])
    else:
        sentences = re.split(r' *[\n][\'"\)\]]* *', text)
        spans_len = len([x for x in sentences if x!=""])
    
    return spans_len


def to_gpu(data,device):
    if isinstance(data, dict):
        return {k: to_gpu(v,device) for k, v in data.items()}
    elif isinstance(data, list):
        return [to_gpu(v,device) for v in data]
    elif isinstance(data, torch.Tensor):
        return data.to(device)
    else:
        return data
    
def to_np(data):
    if isinstance(data, dict):
        return {k: to_np(v) for k, v in data.items()}
    elif isinstance(data, list):
        return [to_np(v) for v in data]
    elif isinstance(data, torch.Tensor):
        return data.cpu().numpy()
    else:
        return data

def get_start_end_offset(col):
    def search_start_end(row):
        txt = row.full_text
        toks = row[col]
        # print(search_from)
        return get_offset_mapping(txt,toks)
    return search_start_end


In [95]:
import re
import torch
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
# from data.data_utils import clean_text,get_start_end,get_offset_mapping,get_start_end_offset

from tqdm.auto import tqdm


LABEL2TYPE = ('NAME_STUDENT','EMAIL','USERNAME','ID_NUM', 'PHONE_NUM','URL_PERSONAL','STREET_ADDRESS','O')
TYPE2LABEL = {t: l for l, t in enumerate(LABEL2TYPE)}


## =============================================================================== ##
class FeedbackDataset(Dataset):
    def __init__(self,
                 df,
                 tokenizer,
                 mask_prob=0.0,
                 mask_ratio=0.0,
                 train = True
                 ):
        
        self.train = True
        self.tokenizer = tokenizer
        if len(self.tokenizer.encode("\n\n"))==2:
            df["full_text"] = df['full_text'].transform(lambda x:x.str.replace("\n\n"," | "))
            df["tokens"] = df['tokens'].transform(lambda x:[i.str.replace("\n\n"," | ") for i in x])

        self.df = self.prepare_df(df)

        print(f'Loaded {len(self)} samples.')

        assert 0 <= mask_prob <= 1
        assert 0 <= mask_ratio <= 1
        self.mask_prob = mask_prob
        self.mask_ratio = mask_ratio

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        df = self.df.iloc[index]
        text = df['full_text']
        text_id = df['document']
        labels = df['labels'] if self.train else [] 
        
        tokens = self.tokenizer(text, return_offsets_mapping=True)
        input_ids = torch.LongTensor(tokens['input_ids'])
        attention_mask = torch.LongTensor(tokens['attention_mask'])
        offset_mapping = np.array(tokens['offset_mapping'])
#         offset_mapping = self.strip_offset_mapping(text, offset_mapping)
        num_tokens = len(input_ids)

        # token slices of words
        woff = np.array(df['offset'])
        toff = offset_mapping
        wx1, wx2 = woff.T
        tx1, tx2 = toff.T
        ix1 = np.maximum(wx1[..., None], tx1[None, ...])
        ix2 = np.minimum(wx2[..., None], tx2[None, ...])
        ux1 = np.minimum(wx1[..., None], tx1[None, ...])
        ux2 = np.maximum(wx2[..., None], tx2[None, ...])
        ious = (ix2 - ix1).clip(min=0) / (ux2 - ux1)
#         assert (ious > 0).any(-1).all()

        word_boxes = []
#         err = []
        for i,row in enumerate(ious):
            inds = row.nonzero()[0]
            try:
                word_boxes.append([inds[0], 0, inds[-1] + 1, 1])
            except:
                word_boxes.append([-100, 0, -99, 1])
#                 err.append(i)
                
        word_boxes = torch.FloatTensor(word_boxes)

        # word slices of ground truth spans
        gt_spans = []        
        for i,label in enumerate(labels) :
#             if i not in err:
            gt_spans.append([i,TYPE2LABEL[label.split('-')[1] if label!="O" else "O"]])
            
        gt_spans = torch.LongTensor(gt_spans)

        # random mask augmentation
        if np.random.random() < self.mask_prob:
            all_inds = np.arange(1, len(input_ids) - 1)
            n_mask = max(int(len(all_inds) * self.mask_ratio), 1)
            np.random.shuffle(all_inds)
            mask_inds = all_inds[:n_mask]
            input_ids[mask_inds] = self.tokenizer.mask_token_id

        return dict(text=text,
                    text_id=text_id,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    word_boxes=word_boxes,
                    gt_spans=gt_spans)
    
    def prepare_df(self,test_df):
        test_df['full_text'] = test_df['full_text'].transform(clean_text)        
        test_df['tokens'] = test_df['tokens'].transform(lambda x:[clean_text(i) for i in x])
        test_df['offset'] = test_df.apply(get_start_end_offset('tokens'),axis=1)
#         test_df['nb_labels'] = test_df['labels'].transform(lambda x:len([i for i in x if i!="O" ]))
        return test_df
    
    def strip_offset_mapping(self, text, offset_mapping):
        ret = []
        for start, end in offset_mapping:
            match = list(re.finditer('\S+', text[start:end]))
            if len(match) == 0:
                ret.append((start, end))
            else:
                span_start, span_end = match[0].span()
                ret.append((start + span_start, start + span_end))
        return np.array(ret)

    def get_word_offsets(self, text):
        matches = re.finditer("\S+", text)
        spans = []
        words = []
        for match in matches:
            span = match.span()
            word = match.group()
            spans.append(span)
            words.append(word)
        assert tuple(words) == tuple(text.split())
        return np.array(spans)
    
## =============================================================================== ##
class CustomCollator(object):
    def __init__(self, tokenizer, model):
        self.pad_token_id = tokenizer.pad_token_id
        if hasattr(model.config, 'attention_window'):
            # For longformer
            # https://github.com/huggingface/transformers/blob/v4.17.0/src/transformers/models/longformer/modeling_longformer.py#L1548
            self.attention_window = (model.config.attention_window
                                     if isinstance(
                                         model.config.attention_window, int)
                                     else max(model.config.attention_window))
        else:
            self.attention_window = None

    def __call__(self, samples):
        batch_size = len(samples)
        assert batch_size == 1, f'Only batch_size=1 supported, got batch_size={batch_size}.'

        sample = samples[0]

        max_seq_length = len(sample['input_ids'])
        if self.attention_window is not None:
            attention_window = self.attention_window
            padded_length = (attention_window -
                             max_seq_length % attention_window
                             ) % attention_window + max_seq_length
        else:
            padded_length = max_seq_length

        input_shape = (1, padded_length)
        input_ids = torch.full(input_shape,
                               self.pad_token_id,
                               dtype=torch.long)
        attention_mask = torch.zeros(input_shape, dtype=torch.long)

        seq_length = len(sample['input_ids'])
        input_ids[0, :seq_length] = sample['input_ids']
        attention_mask[0, :seq_length] = sample['attention_mask']

        text_id = sample['text_id']
        text = sample['text']
        word_boxes = sample['word_boxes']
        gt_spans = sample['gt_spans']

        return dict(text_id=text_id,
                    text=text,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    word_boxes=word_boxes,
                    gt_spans=gt_spans)

In [96]:
from transformers import AutoTokenizer, AutoModel, AutoConfig

In [97]:
max_length = 512
model_name = 'microsoft/deberta-large'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [98]:
ds = FeedbackDataset(df,tokenizer)

Loaded 6807 samples.


In [99]:
dx = ds[0]

Token indices sequence length is longer than the specified maximum sequence length for this model (867 > 512). Running this sequence through the model will result in indexing errors


In [120]:
(dx["word_boxes"]!=-100)[:,0]*1

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0,
        1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
        1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [101]:
dx["gt_spans"].shape

torch.Size([753, 2])

# Models

In [102]:
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
import torch.utils.checkpoint
import torch.nn.functional as F
import gc

# from mmcv.cnn import bias_init_with_prob
from torchvision.ops import roi_align, nms

def aggregate_tokens_to_words(feat, word_boxes):
    feat = feat.permute(0, 2, 1).unsqueeze(2)
    output = roi_align(feat, [word_boxes], 1, aligned=True)
    return output.squeeze(-1).squeeze(-1)


def span_nms(start, end, score, nms_thr=0.5):
    boxes = torch.stack(
        [
            start,
            torch.zeros_like(start),
            end,
            torch.ones_like(start),
        ],
        dim=1,
    ).float()
    keep = nms(boxes, score, nms_thr)
    return keep

class FeedbackModel(nn.Module):
    def __init__(self,
                 model_name,
                 num_labels = 8,
                 config_path=None,
                 pretrained_path = None,
                 use_dropout=False,
                 use_gradient_checkpointing = False
                 ):
        super().__init__()
        self.pretrained_path = pretrained_path
        self.config = AutoConfig.from_pretrained(model_name, output_hidden_states=True) if not config_path else torch.load(config_path)

        self.use_dropout = use_dropout
        if not self.use_dropout:
            self.config.update(
                                {
                                    "hidden_dropout_prob": 0.0,
                                    "attention_probs_dropout_prob": 0.0,
                                }
                                    )

        self.backbone = AutoModel.from_pretrained(model_name,config=self.config) if not config_path else AutoModel.from_config(self.config)        
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.fc = nn.Linear(self.config.hidden_size, num_labels)
        
        if self.pretrained_path:
            try:
                self.load_from_cp()
            except:
                pass
        if use_gradient_checkpointing:
            self.backbone.gradient_checkpointing_enable()
        # self.fc.bias.data[0].fill_(bias_init_with_prob(0.02))
        # self.fc.bias.data[3:-3].fill_(bias_init_with_prob(1 / num_label_discourse_type))
        # self.fc.bias.data[-3:].fill_(bias_init_with_prob(1 / num_label_effectiveness))

    def load_from_cp(self):
        print("Using Pretrained Weights")
        print(self.pretrained_path)
        state_dict = torch.load(self.pretrained_path, map_location=lambda storage, loc: storage)
        del state_dict['fc.bias']
        del state_dict['fc.weight']

        if 'fc_seg.bias' in state_dict.keys():
            del state_dict['fc_seg.bias']
            del state_dict['fc_seg.weight']
            for key in list(state_dict.keys()):
                state_dict[key.replace('model.deberta.', '')] = state_dict.pop(key)
        else:
            for key in list(state_dict.keys()):
                state_dict[key.replace('backbone.', '')] = state_dict.pop(key)
         
        self.backbone.load_state_dict(state_dict, strict=True)
        print('Loading successed !')

    def forward(self,b):
        x = self.backbone(b["input_ids"],b["attention_mask"]).last_hidden_state
        x = self.dropout(x)
        x = self.fc(x)
        x = aggregate_tokens_to_words(x, b['word_boxes'])
        # obj_pred = x[..., 0]
        # reg_pred = x[..., 1:3]
        # type_pred = x[..., 3:-3]
        # eff_pred = x[..., -3:]
        return x

In [103]:
model = FeedbackModel(model_name)

Some weights of the model checkpoint at microsoft/deberta-large were not used when initializing DebertaModel: ['lm_predictions.lm_head.dense.weight', 'lm_predictions.lm_head.LayerNorm.weight', 'lm_predictions.lm_head.bias', 'lm_predictions.lm_head.dense.bias', 'lm_predictions.lm_head.LayerNorm.bias']
- This IS expected if you are initializing DebertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [106]:
from torch.utils.data import DataLoader

In [107]:
collator = CustomCollator(tokenizer,model)
train_loader = DataLoader(ds,batch_size=1,collate_fn=collator)

In [108]:
for data in train_loader:
    break

In [128]:
y = model(data)

In [142]:
yy = y.softmax(-1)

In [146]:
s,i = yy.max(-1)

In [150]:
i

tensor([5, 2, 3, 2, 1, 6, 1, 2, 1, 1, 1, 0, 5, 6, 3, 0, 2, 3, 3, 3, 2, 3, 0, 1,
        3, 1, 3, 3, 2, 0, 2, 2, 5, 4, 5, 0, 3, 0, 2, 0, 2, 3, 1, 1, 3, 0, 2, 3,
        3, 6, 1, 2, 6, 1, 6, 6, 1, 5, 5, 2, 1, 1, 6, 0, 5, 1, 5, 1, 5, 5, 5, 5,
        5, 5, 5, 3, 0, 2, 6, 6, 5, 6, 1, 1, 1, 0, 1, 5, 6, 3, 6, 5, 0, 2, 2, 3,
        6, 6, 2, 4, 3, 0, 6, 2, 6, 3, 2, 5, 6, 0, 6, 1, 6, 2, 6, 6, 0, 6, 0, 6,
        6, 5, 6, 4, 6, 6, 6, 6, 5, 5, 5, 6, 6, 6, 6, 0, 6, 0, 6, 0, 6, 6, 5, 0,
        6, 0, 6, 6, 1, 6, 2, 2, 6, 0, 6, 0, 6, 6, 6, 1, 4, 6, 6, 6, 0, 6, 2, 2,
        2, 2, 2, 1, 2, 6, 2, 0, 1, 2, 0, 6, 0, 6, 6, 6, 3, 4, 5, 6, 6, 6, 6, 2,
        0, 6, 0, 6, 6, 5, 6, 6, 5, 0, 6, 0, 6, 6, 3, 2, 2, 5, 5, 2, 3, 6, 6, 0,
        6, 0, 6, 1, 0, 6, 0, 6, 6, 2, 5, 5, 0, 6, 0, 6, 6, 5, 5, 5, 2, 0, 1, 1,
        5, 2, 3, 1, 1, 1, 1, 2, 6, 2, 5, 1, 1, 1, 2, 2, 2, 1, 0, 2, 2, 5, 5, 0,
        2, 2, 2, 2, 2, 1, 2, 2, 2, 3, 2, 0, 2, 2, 0, 6, 0, 7, 2, 0, 2, 2, 2, 2,
        0, 1, 4, 5, 2, 1, 2, 2, 3, 5, 6,

In [151]:
yy.shape

torch.Size([753, 8])

In [152]:
TYPE2LABEL

{'NAME_STUDENT': 0,
 'EMAIL': 1,
 'USERNAME': 2,
 'ID_NUM': 3,
 'PHONE_NUM': 4,
 'URL_PERSONAL': 5,
 'STREET_ADDRESS': 6,
 'O': 7}