In [228]:
import pandas as pd

from utlis import *
from model import *
from dataset import *
from metrics import *

from collections import Counter
from collections import defaultdict
from torch.utils.data import DataLoader
from poprogress import simple_progress as simp

In [236]:
# load data
all_data = pd.read_csv("../data_preprocess/all-data.csv")
all_len = len(all_data)
print("all_len: ",all_len)

# split data
train_data, valid_data, test_data = split_dataset(all_data, 0.8, 0.1)
print("train_data_size: ",len(train_data))
print("valid_data_size: ",len(valid_data))
print("test_data_size: ",len(test_data))
print("Spliting data done")
print("-"*30)

# get unique labels
label_unique = sorted(get_label_unique(train_data))

# get dicts
label_to_id = {k: v for v,k in enumerate(label_unique)}
id_to_label = {k: v for k,v in enumerate(label_unique)}
print(label_to_id)
print(id_to_label)

# get seq
train_token_seq, train_label_seq = get_data_seq(train_data)
valid_token_seq, valid_label_seq = get_data_seq(valid_data)
test_token_seq, test_label_seq = get_data_seq(test_data)
print("Get sequences done")
print("-"*30)

# get token -> id and label -> id
token2cnt = Counter([token for sentence in train_token_seq for token in sentence])
label_set = sorted(set(label for sentence in train_label_seq for label in sentence))
token_to_id = get_token2id(token2cnt)
print("Encoding data done")
print("size: ",len(token_to_id))
print("-"*30)

# dataset
train_set = nerDataset(train_token_seq, train_label_seq, token_to_id, label_to_id, preprocess=True)
valid_set = nerDataset(valid_token_seq, valid_label_seq, token_to_id, label_to_id, preprocess=True)
test_set = nerDataset(test_token_seq, test_label_seq, token_to_id, label_to_id, preprocess=True)
print("Making datasets done")
print("-"*30)

# dataloader
train_coll_fn = nerCollator(token_to_id["<UNK>"], label_to_id["O"], 100)
valid_coll_fn = nerCollator(token_to_id["<UNK>"], label_to_id["O"], 100)
test_coll_fn = nerCollator(token_to_id["<UNK>"], label_to_id["O"], 100)
bz = 64
train_loader = DataLoader(dataset=train_set, batch_size=bz, shuffle=False, collate_fn=train_coll_fn)
valid_loader = DataLoader(dataset=valid_set, batch_size=bz, shuffle=False, collate_fn=valid_coll_fn)
test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False, collate_fn=test_coll_fn)
print("Making Dataloaders done")
print("-"*30)


all_len:  21363
train_data_size:  17099
valid_data_size:  2056
test_data_size:  2208
Spliting data done
------------------------------


100%|██████████| 17099/17099 [00:00<00:00, 78957.63it/s]


{'B-LOC': 0, 'B-MISC': 1, 'B-ORG': 2, 'B-PER': 3, 'I-LOC': 4, 'I-MISC': 5, 'I-ORG': 6, 'I-PER': 7, 'O': 8}
{0: 'B-LOC', 1: 'B-MISC', 2: 'B-ORG', 3: 'B-PER', 4: 'I-LOC', 5: 'I-MISC', 6: 'I-ORG', 7: 'I-PER', 8: 'O'}
Get sequences done
------------------------------
Encoding data done
size:  25324
------------------------------
Making datasets done
------------------------------
Making Dataloaders done
------------------------------


In [237]:
embedding_layer = Embedding(num_embeddings=len(token_to_id), embedding_dim=128)

rnn_layer = dynamicRNN(rnn_unit=torch.nn.LSTM, input_size=128, hidden_size=256, 
                num_layers=1, dropout=0, bidirectional=True)

linear_head = LinearHead(linear_head=torch.nn.Linear(in_features=(2*256), 
                                                     out_features=len(label_to_id)))

model = BiLSTM(embedding_layer=embedding_layer, rnn_layer=rnn_layer, linear_head=linear_head)#.to(device)
print("Setting models done")
print("-"*30)

Setting models done
------------------------------


In [240]:
criterion = torch.nn.CrossEntropyLoss(reduction="none")
optimizer_type = torch.optim.Adam
optimizer = optimizer_type(params=model.parameters(), lr=0.001, amsgrad=False)
print("Setting metrics done")
print("-"*30)

verbose = True
n_epoch = 2
clip_grad_norm = 0.9
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


Setting metrics done
------------------------------


In [239]:
def data_epoch(model, dataloader, criterion, mode, device, verbose=True):
    metrics = defaultdict(list)
    if mode == "train":
        model.train()
    else:
        model.eval()
        
    for tokens, labels, lengths in simp(dataloader):
        tokens, labels, lengths = (tokens.to(device), labels.to(device), lengths.to(device))

        mask = masking(lengths)

        # forward pass
        if mode == "train":
            logits = model(tokens, lengths) # bz,xxx,9
            loss_without_reduction = criterion(logits.transpose(-1, -2), labels)
            loss = torch.sum(loss_without_reduction * mask) / torch.sum(mask)
            # backward pass
            loss.backward()
            # gradient clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad_norm, norm_type=2)
            # update 
            optimizer.step()
            optimizer.zero_grad()
            
        else:
            with torch.no_grad():
                logits = model(tokens, lengths)
                loss_without_reduction = criterion(logits.transpose(-1, -2), labels)
                loss = torch.sum(loss_without_reduction * mask) / torch.sum(mask)

        # make predictions
        y_true = to_numpy(labels[mask])
        y_pred = to_numpy(logits.argmax(dim=-1)[mask])

        # calculate metrics
        metrics = calculate_metrics(
            metrics=metrics,
            loss=loss.item(),
            y_true=y_true,
            y_pred=y_pred,
            idx2label=id_to_label,
        )

    return metrics

In [241]:
for epoch in range(n_epoch):
    train_metrics = data_epoch(model, train_loader, criterion, "train", device, False)
    valid_metrics = data_epoch(model, valid_loader, criterion, "valid", device, False)
test_metrics = data_epoch(model, test_loader, criterion, "test", device, False)

  0%|          | 0/268 [00:00<?, ?it/s]

100%|██████████| 268/268 [02:19<00:00,  1.91it/s]
100%|██████████| 33/33 [00:02<00:00, 12.87it/s]
100%|██████████| 268/268 [02:47<00:00,  1.60it/s]
100%|██████████| 33/33 [00:02<00:00, 12.18it/s]
100%|██████████| 2208/2208 [00:24<00:00, 88.39it/s] 


In [242]:
print(train_metrics["loss"])
print(valid_metrics["loss"])
print(test_metrics["loss"])

[0.4717346131801605, 0.49730631709098816, 0.488316148519516, 0.47532713413238525, 0.4522187411785126, 0.4301983714103699, 0.5155901908874512, 0.6095348000526428, 0.29971954226493835, 0.45472252368927, 0.4529896378517151, 0.44182056188583374, 0.4766692817211151, 0.43818122148513794, 0.35954028367996216, 0.4266773760318756, 0.3407035768032074, 0.3767316937446594, 0.4268537759780884, 0.28083232045173645, 0.3857327103614807, 0.37716054916381836, 0.36280879378318787, 0.4677313566207886, 0.40527617931365967, 0.35508573055267334, 0.42261579632759094, 0.4845113456249237, 0.3577958643436432, 0.1186693087220192, 0.10933682322502136, 0.467134565114975, 0.6114150285720825, 0.5468419790267944, 0.376510351896286, 0.3260049521923065, 0.36694684624671936, 0.36829501390457153, 0.4361661374568939, 0.34739765524864197, 0.32046541571617126, 0.46215346455574036, 0.3823970556259155, 0.5380191206932068, 0.4424641728401184, 0.59922194480896, 0.39472144842147827, 0.503883421421051, 0.4375780522823334, 0.229334

In [251]:
sent = "Smicer pushed the ball home in injury time to lead his team to a 3-2 victory over Montpellier , who were leading 2-1 until Cameroon 's Marc-Vivien Foe equalised on a header in the 85th minute ."
sent = sent.lower().split()
sent_tokens = []
for x in sent:
    sent_tokens.append(token_to_id.get(x, 1))
model.eval()
in_tokens = torch.tensor(sent_tokens).unsqueeze(0).to(device)
in_length = torch.tensor([len(sent_tokens)]).to(device)

cc = np.argmax(to_numpy(model(in_tokens, in_length)[0]),1)
ou_labels = []
ou_dict = {}
for i,x in enumerate(cc):    
    # print({sent[i]: id_to_label[x]})
    ou_dict[sent[i]] = id_to_label[x]
    ou_labels.append(id_to_label[x])
ou_dict

{'smicer': 'B-PER',
 'pushed': 'O',
 'the': 'O',
 'ball': 'O',
 'home': 'O',
 'in': 'O',
 'injury': 'O',
 'time': 'O',
 'to': 'O',
 'lead': 'O',
 'his': 'O',
 'team': 'O',
 'a': 'O',
 '3-2': 'O',
 'victory': 'O',
 'over': 'O',
 'montpellier': 'B-PER',
 ',': 'O',
 'who': 'O',
 'were': 'O',
 'leading': 'O',
 '2-1': 'O',
 'until': 'O',
 'cameroon': 'B-LOC',
 "'s": 'O',
 'marc-vivien': 'O',
 'foe': 'I-PER',
 'equalised': 'I-PER',
 'on': 'O',
 'header': 'O',
 '85th': 'O',
 'minute': 'O',
 '.': 'O'}