In [1]:
import json
import random
import time
import pickle
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Dict

import torch
from torch.utils.data import DataLoader

from slot_dataset import SeqClsDataset
# from slot_model import SeqClassifier
from utils import Vocab

TRAIN = "train"
DEV = "eval"
SPLITS = [TRAIN, DEV]
device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
data_dir = "./data/slot/"
cache_dir = "./cache/slot/"
ckpt_dir = "./ckpt/slot/"
max_len = 35

In [3]:
with open(Path(cache_dir + "vocab.pkl"), "rb") as f:
    vocab: Vocab = pickle.load(f)

tag_idx_path = Path(cache_dir + "tag2idx.json")
tag2idx: Dict[str, int] = json.loads(tag_idx_path.read_text())

data_paths = {split: Path(data_dir + "%s.json" %split) for split in SPLITS}
data = {split: json.loads(path.read_text()) for split, path in data_paths.items()}
datasets: Dict[str, SeqClsDataset] = {
    split: SeqClsDataset(split_data, vocab, tag2idx, max_len)
    for split, split_data in data.items()
}
    
batch_size = 128
train_loader = DataLoader(datasets["train"], batch_size=batch_size, shuffle=True, collate_fn=datasets["train"].collate_fn)
val_loader = DataLoader(datasets["eval"], batch_size=batch_size, shuffle=False, collate_fn=datasets["eval"].collate_fn)

In [17]:
# max_len = 0
# for i in range(len(datasets["train"].data)):
#     sentence = datasets["train"].data[i]["tokens"]
#     if len(sentence) > max_len:
#         max_len = len(sentence)
# print(max_len)

In [18]:
def argmax(vec):
    # return the argmax as a python int
    _, idx = torch.max(vec, 1)
    return idx.item()

# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

In [42]:
from typing import Dict

import torch
import torch.nn as nn
from torch.nn import Embedding

class SeqClassifier(torch.nn.Module):
    def __init__(
        self,
        tag2idx,
        embeddings: torch.tensor,
        hidden_size: int,
        num_layers: int,
        dropout: float,
        bidirectional: bool,
        num_class: int,
    ) -> None:
        super(SeqClassifier, self).__init__()
        self.tag2idx = tag2idx
        self.num_class = num_class
        self.embedding_dim = embeddings.size(1)
        self.embed = Embedding.from_pretrained(embeddings, freeze=False)
        self.rnn = nn.LSTM(input_size=self.embedding_dim, hidden_size=hidden_size, num_layers=num_layers, 
                          dropout=dropout, bidirectional=bidirectional, batch_first=True)
        if bidirectional:
            self.fc = nn.Sequential(
                nn.Linear(hidden_size*2, hidden_size),
                nn.BatchNorm1d(35),
                nn.Linear(hidden_size, num_class),
                nn.Softmax(dim=2)
            )
        else:
            self.fc = nn.Sequential(
                nn.Linear(hidden_size, num_class),
                nn.Softmax(dim=2)
            )
            
        # Matrix of transition parameters.  Entry i,j is the score of
        # transitioning *to* i *from* j.
        self.transitions = nn.Parameter(
            torch.randn(self.num_class, self.num_class))
        
        # These two statements enforce the constraint that we never transfer
        # to the start tag and we never transfer from the stop tag
        self.transitions.data[tag2idx[START_TAG], :] = -10000
        self.transitions.data[:, tag2idx[STOP_TAG]] = -10000
        
    def _forward_alg(self, feats):
        # Do the forward algorithm to compute the partition function
        init_alphas = torch.full((1, self.tagset_size), -10000.)
        # START_TAG has all of the score.
        init_alphas[0][self.tag2idx[START_TAG]] = 0.

        # Wrap in a variable so that we will get automatic backprop
        forward_var = init_alphas

        # Iterate through the sentence
        for fea in feas:
            alphas_t = []  # The forward tensors at this timestep
            for next_tag in range(self.num_class):
                # broadcast the emission score: it is the same regardless of
                # the previous tag
                emit_score = fea[next_tag].view(
                    1, -1).expand(1, self.num_class)
                # the ith entry of trans_score is the score of transitioning to
                # next_tag from i
                trans_score = self.transitions[next_tag].view(1, -1)
                # The ith entry of next_tag_var is the value for the
                # edge (i -> next_tag) before we do log-sum-exp
                next_tag_var = forward_var + trans_score + emit_score
                # The forward variable for this tag is log-sum-exp of all the
                # scores.
                alphas_t.append(log_sum_exp(next_tag_var).view(1))
            forward_var = torch.cat(alphas_t).view(1, -1)
        terminal_var = forward_var + self.transitions[self.tag2idx[STOP_TAG]]
        alpha = log_sum_exp(terminal_var)
        return alpha

    def _get_lstm_features(self, batch):
        out = self.embed(batch)
        out, _ = self.rnn(out)
        out = self.fc(out)
        return out
    
    def _score_sentence(self, feats, tags):
        # Gives the score of a provided tag sequence
        score = torch.zeros(1)
        tags = torch.cat([torch.tensor([self.tag2idx[START_TAG]], dtype=torch.long), tags])
        for i, fea in enumerate(feas):
            score = score + \
                self.transitions[tags[i + 1], tags[i]] + fea[tags[i + 1]]
        score = score + self.transitions[self.tag2idx[STOP_TAG], tags[-1]]
        return score
    
    
    def _viterbi_decode(self, feas):
        backpointers = []

        # Initialize the viterbi variables in log space
        init_vvars = torch.full((1, self.num_class), -10000.)
        init_vvars[0][self.tag2idx[START_TAG]] = 0
        
        # forward_var at step i holds the viterbi variables for step i-1
        forward_var = init_vvars
        for fea in feas:
            bptrs_t = []  # holds the backpointers for this step
            viterbivars_t = []  # holds the viterbi variables for this step
            
            for next_tag in range(self.num_class):
                # next_tag_var[i] holds the viterbi variable for tag i at the
                # previous step, plus the score of transitioning
                # from tag i to next_tag.
                # We don't include the emission scores here because the max
                # does not depend on them (we add them in below)
                next_tag_var = forward_var + self.transitions[next_tag]
                best_tag_id = argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            # Now add in the emission scores, and assign forward_var to the set
            # of viterbi variables we just computed
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)
        
        # Transition to STOP_TAG
        terminal_var = forward_var + self.transitions[self.tag2idx[STOP_TAG]]
        best_tag_id = argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]
        
        # Follow the back pointers to decode the best path.
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        # Pop off the start tag (we dont want to return that to the caller)
        start = best_path.pop()
        assert start == self.tag2idx[START_TAG]  # Sanity check
        best_path.reverse()
        return path_score, best_path
                
    def neg_log_likelihood(self, sentence, tags):
        feats = self._get_lstm_features(sentence)
        forward_score = self._forward_alg(feats)
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score    
                
    def forward(self, batch):
        lstm_feas = self._get_lstm_features(batch)
        score, tag_seq = self._viterbi_decode(lstm_feas)
        return score, tagseq

In [49]:
transitions = nn.Parameter(torch.randn(150, 150))

In [55]:
transitions.data[:, tag2idx[STOP_TAG]] = -10000

In [43]:
embeddings = torch.load(cache_dir + "embeddings.pt")

In [45]:
START_TAG = "<START>"
STOP_TAG = "<STOP>"
tag2idx.update({START_TAG: 9, STOP_TAG: 10})

In [48]:
model = SeqClassifier(tag2idx=tag_idx_path, embeddings=embeddings, hidden_size=256, num_layers=2, dropout=0, 
                      bidirectional=True, num_class=11)
for i, batch in enumerate(train_loader):
    data = batch[0]
    label = batch[1]
    pred = model(data)
#     pclass = pred.argmax(dim=2)
    break

TypeError: 'PosixPath' object is not subscriptable

In [7]:
def cal_joint_acc(pred, label):
    pclass = pred.argmax(dim=2)
    correct = 0
    for i in range(len(label)):
        s_label = label[i][label[i]!=-100]
        length = len(s_label)
        if (s_label==pclass[i][:length]).all():
            correct += 1
    jacc = correct / len(label)
    return jacc

In [8]:
embeddings = torch.load(cache_dir + "embeddings.pt")

model = SeqClassifier(embeddings=embeddings, hidden_size=256, num_layers=2, dropout=0.2, bidirectional=True, num_class=9)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()
epochs = 100
best_jacc = 0.7
# epoch_pbar = trange(epochs, desc="Epoch")
for epoch in range(epochs):
    epoch_start_time = time.time()
    
    train_loss = 0
    train_jacc = 0
    train_len = 0
    val_loss = 0
    val_jacc = 0
    val_len = 0
    
    model.train()
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        data = batch[0].to(device)
        label = batch[1].to(device)
        pred = model(data)
        _pred = pred.permute(0, 2, 1)
        loss = criterion(_pred, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_jacc += cal_joint_acc(pred, label)
    
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            data = batch[0].to(device)
            label = batch[1].to(device)
            pred = model(data)
            _pred = pred.permute(0, 2, 1)
            loss = criterion(_pred, label)
            val_loss += loss.item()
            val_jacc += cal_joint_acc(pred, label)
        
    print('[%03d/%03d] %2.2f sec(s) Train Loss: %.4f Joint_Acc: %.4f| Val loss: %.4f Joint_Acc: %.4f' % \
            (epoch + 1, epochs, time.time()-epoch_start_time, \
             train_loss/train_loader.__len__(), train_jacc/train_loader.__len__(), \
             val_loss/val_loader.__len__(), val_jacc/val_loader.__len__()))
    
    if val_jacc/val_loader.__len__() >= best_jacc:
        best_jacc = val_jacc/val_loader.__len__()
        torch.save(model.state_dict(), "/data/NFS/andy/course/ADL/hw1/slot_weights2.pt")
        print("saving model with acc:%.4f" %(best_jacc))

[001/100] 5.72 sec(s) Train Loss: 1.5706 Joint_Acc: 0.3945| Val loss: 1.5539 Joint_Acc: 0.4056
[002/100] 5.58 sec(s) Train Loss: 1.5571 Joint_Acc: 0.4015| Val loss: 1.5539 Joint_Acc: 0.4056
[003/100] 5.71 sec(s) Train Loss: 1.5575 Joint_Acc: 0.4018| Val loss: 1.5538 Joint_Acc: 0.4056
[004/100] 5.72 sec(s) Train Loss: 1.5555 Joint_Acc: 0.4015| Val loss: 1.5452 Joint_Acc: 0.4056
[005/100] 5.50 sec(s) Train Loss: 1.5236 Joint_Acc: 0.4407| Val loss: 1.4874 Joint_Acc: 0.5192
[006/100] 5.70 sec(s) Train Loss: 1.4893 Joint_Acc: 0.5383| Val loss: 1.4731 Joint_Acc: 0.5507
[007/100] 5.45 sec(s) Train Loss: 1.4586 Joint_Acc: 0.6283| Val loss: 1.4502 Joint_Acc: 0.6503
[008/100] 5.69 sec(s) Train Loss: 1.4466 Joint_Acc: 0.6789| Val loss: 1.4484 Joint_Acc: 0.6409
[009/100] 5.85 sec(s) Train Loss: 1.4407 Joint_Acc: 0.7021| Val loss: 1.4313 Joint_Acc: 0.6711
[010/100] 5.87 sec(s) Train Loss: 1.4154 Joint_Acc: 0.7576| Val loss: 1.4216 Joint_Acc: 0.7283
saving model with acc:0.7283
[011/100] 5.64 sec(s)

[084/100] 5.58 sec(s) Train Loss: 1.3855 Joint_Acc: 0.9167| Val loss: 1.4064 Joint_Acc: 0.8007
saving model with acc:0.8007
[085/100] 5.69 sec(s) Train Loss: 1.3855 Joint_Acc: 0.9170| Val loss: 1.4078 Joint_Acc: 0.7907
[086/100] 5.82 sec(s) Train Loss: 1.3855 Joint_Acc: 0.9170| Val loss: 1.4078 Joint_Acc: 0.7903
[087/100] 5.60 sec(s) Train Loss: 1.3849 Joint_Acc: 0.9209| Val loss: 1.4079 Joint_Acc: 0.7921
[088/100] 5.40 sec(s) Train Loss: 1.3853 Joint_Acc: 0.9179| Val loss: 1.4078 Joint_Acc: 0.7912
[089/100] 5.49 sec(s) Train Loss: 1.3851 Joint_Acc: 0.9198| Val loss: 1.4076 Joint_Acc: 0.7917
[090/100] 5.50 sec(s) Train Loss: 1.3852 Joint_Acc: 0.9193| Val loss: 1.4074 Joint_Acc: 0.8000
[091/100] 5.55 sec(s) Train Loss: 1.3853 Joint_Acc: 0.9182| Val loss: 1.4098 Joint_Acc: 0.7766
[092/100] 5.63 sec(s) Train Loss: 1.3855 Joint_Acc: 0.9176| Val loss: 1.4074 Joint_Acc: 0.7915
[093/100] 5.51 sec(s) Train Loss: 1.3852 Joint_Acc: 0.9185| Val loss: 1.4077 Joint_Acc: 0.7956
[094/100] 5.53 sec(s)