In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Tutorial dataset
- classification

In [None]:
training_data = [
    ("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
    ("Everybody read that book".split(), ["NN", "V", "DET", "NN"])
]

word_to_ix = {}
for sent, tags in training_data:
    for word in sent:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
            
tag_to_ix = {"DET": 0, "NN": 1, "V": 2}

# Mine: test dataset
- LM

In [35]:
txt = "one of the other reviewers has mentioned that after watching just 1 oz episode you'll be hooked. they are right, as this is exactly what happened with me.the first thing that struck me about oz was its brutality and unflinching scenes of violence, which set in right from the word go. trust me, this is not a show for the faint hearted or timid. this show pulls no punches with regards to drugs, sex or violence. its is hardcore, in the classic use of the word.it is called oz as that is the nickname given to the oswald maximum security state penitentary. it focuses mainly on emerald city, an experimental section of the prison where all the cells have glass fronts and face inwards, so privacy is not high on the agenda. em city is home to many..aryans, muslims, gangstas, latinos, christians, italians, irish and more....so scuffles, death stares, dodgy dealings and shady agreements are never far away.i would say the main appeal of the show is due to the fact that it goes where other shows wouldn't dare. forget pretty pictures painted for mainstream audiences, forget charm, forget romance...oz doesn't mess around. the first episode i ever saw struck me as so nasty it was surreal, i couldn't say i was ready for it, but as i watched more, i developed a taste for oz, and got accustomed to the high levels of graphic violence. not just violence, but injustice (crooked guards who'll be sold out for a nickel, inmates who'll kill on order and get away with it, well mannered, middle class inmates being turned into prison bitches due to their lack of street skills or prison experience) watching oz, you may become comfortable with what is uncomfortable viewing....thats if you can get in touch with your darker side.\n city is home to many..aryans, muslims, gangstas, latinos, christians, italians, irish and more....so scuffles, death stares, dodgy dealings and shady agreements are never far away.i would say the main appeal of the show is due to the fact that it goes where other shows wouldn't dare. forget pretty pictures painted for mainstream audiences, forget charm, forget romance...oz doesn't mess around. the first episode i ever saw struck me as so nasty it was surreal, i couldn't say i was ready for it, but as i watched more, i developed a taste for oz, and got accustomed to the high levels of graphic violence. not just violence, but injustice (crooked guards who'll be sold out for a nickel, inmates who'll kill on order and get away with it, well mannered, middle class inmates being turned into prison bitches due to their lack of street skills or prison experience) watching oz, you may become comfortable with what is uncomfortable viewing....thats if you can get in touch with your darker side."

In [40]:
import spacy
nlp = spacy.load('en_core_web_sm')
text = []
for p in txt.split('\n'):
    p = [t.text for t in nlp(p) if t is not None]
    text.append(p)

In [47]:
vocab = {'<eos>': 0}
for para in text:
    for token in para:
        if token not in vocab:
            vocab[token] = len(vocab)
        else:
            continue

In [49]:
print('# of voc', len(vocab))

# of voc 200


In [164]:
training_data = [(sent_ids[0][0:-2], sent_ids[0][1:-1]), (sent_ids[1][0:-2], sent_ids[1][1:-1])]

---
# Model: Tutorial 

In [None]:
class LSTMTagger(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores

# Model: Mine 
- preserve state information

In [327]:
class LSTM_simple_net(nn.Module):
    def __init__(self, emb_dim, hidn_dim, voc_size):
        super().__init__()
        self.hidn_dim = hidn_dim
        self.emb_dim = emb_dim
        self.word_embd = nn.Embedding(voc_size, self.emb_dim)
        self.lstm = nn.LSTM(input_size=emb_dim, hidden_size=hidn_dim, num_layers=1)
        self.final = nn.Linear(hidn_dim, voc_size)
    
    def init_state(self):
        return (torch.zeros(1, 1, self.hidn_dim), torch.zeros(1, 1, self.hidn_dim))  # h_0: (num_layers * num_directions, batch, hidden_size) 
    
    def forward(self, sentence_id, prev_state):
        embeds = self.word_embd(sentence_id)
        lstm_out, state = self.lstm(embeds.view(len(sentence_id), 1, -1), prev_state) #out: (seq_len, batch, num_directions * hidden_size); hidn_state:(num_layers * num_directions, batch, hidden_size)
        
        #final_in: (seq_len, num_directions * hidden_size)
        finalout = self.final(lstm_out)
        score = F.log_softmax(finalout, dim=2)
        return score, state

def init_weight(m):
    for name, param in m.named_parameters():
        if isinstance(m, nn.Embedding):
            continue
        if 'weight' in name:
            nn.init.kaiming_normal_(param.data)
        if 'bias' in name:
            nn.init.constant_(param.data, 0)

---
# Train:Tutorial

In [144]:
EMBEDDING_DIM = 6
HIDDEN_DIM = 6

# model = LSTM_simple_net(emb_dim=EMBEDDING_DIM, hidn_dim=HIDDEN_DIM, voc_size=len(vocab), tagset_size=2)
# model.apply(init_weight)

model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))
print(model)
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

for epoch in range(300):
    for sentence, tags in training_data:
        model.zero_grad()
        sentence_in = prepare_sequence(sentence, word_to_ix)
        targets = prepare_sequence(tags, tag_to_ix)

        tag_scores = model(sentence_in) # (seq_len, num_tag)

        loss = loss_function(tag_scores, targets)
        loss.backward()
        optimizer.step()


LSTMTagger(
  (word_embeddings): Embedding(9, 6)
  (lstm): LSTM(6, 6)
  (hidden2tag): Linear(in_features=6, out_features=3, bias=True)
)


- tutorial result

In [154]:
# See what the scores are after training
with torch.no_grad():
    inputs = prepare_sequence(training_data[1][0], word_to_ix)
    tag_scores = model(inputs)
    print('num of seq:{}; num of tag:{};'.format(tag_scores.shape[0], tag_scores.shape[1]))
    print(tag_scores)
    print(torch.max((tag_scores), dim=1))
    # The sentence is "the dog ate the apple".  i,j corresponds to score for tag j
    # for word i. The predicted tag is the maximum scoring tag.
    # Here, we can see the predicted sequence below is 0 1 2 0 1
    # since 0 is index of the maximum value of row 1,
    # 1 is the index of maximum value of row 2, etc.
    # Which is DET NOUN VERB DET NOUN, the correct sequence!

num of seq:4; num of tag:3;
tensor([[-6.0556, -0.0260, -3.7594],
        [-4.0851, -3.7192, -0.0419],
        [-0.0201, -4.5532, -4.6755],
        [-4.2887, -0.0317, -4.0475]])
torch.return_types.max(
values=tensor([-0.0260, -0.0419, -0.0201, -0.0317]),
indices=tensor([1, 2, 0, 1]))


# Train: Mine
- When predicting the next word, we just need **the last word** and **the last state** for input

In [329]:
EMBEDDING_DIM = 50
HIDDEN_DIM = 50

model = LSTM_simple_net(emb_dim=EMBEDDING_DIM, hidn_dim=HIDDEN_DIM, voc_size=len(vocab))
model.apply(init_weight)

LSTM_simple_net(
  (word_embd): Embedding(200, 50)
  (lstm): LSTM(50, 50)
  (final): Linear(in_features=50, out_features=200, bias=True)
)

In [352]:
# model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))
print(model)
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

for epoch in range(10000):
    state_h, state_c = model.init_state()
    for sentence, trg in training_data:
        sentence, trg = torch.tensor(sentence), torch.tensor(trg)
        model.zero_grad()
        out, (state_h, state_c) = model(sentence, (state_h, state_c)) # (seq_len, num_tag)
        loss = loss_function(out.view(-1, len(vocab)), trg)
        
        state_h.detach_()
        state_c.detach_()
        
        loss.backward()
        optimizer.step()
print(loss)

LSTM_simple_net(
  (word_embd): Embedding(200, 50)
  (lstm): LSTM(50, 50)
  (final): Linear(in_features=50, out_features=200, bias=True)
)
tensor(0.0098, grad_fn=<NllLossBackward>)


In [359]:
force_t_text = []
nforce_t_text = []

with torch.no_grad():
    ############# force teaching #############

    inputs = torch.tensor(training_data[0][0])
    state_h, state_c = model.init_state()
    out, (state_h, state_c) = model(inputs, (state_h, state_c))
    print('Query\n', ' '.join([id2voc[i.item()] for i in inputs]))
    print()
    print('Answer\n', ' '.join([id2voc[i.item()] for i in torch.tensor(training_data[0][0][40:])]))
    print()
    
    print('w/i force teaching')
    for w_idx in training_data[0][0][40:]: 
        out, (state_h, state_c) = model(torch.tensor(w_idx).view(1,), (state_h, state_c))
        out = torch.argmax(out).item()
        force_t_text.append(out)
    print('Gen\n', ' '.join([id2voc[i] for i in force_t_text]))
    
    ############# w/o force teaching #############    
    inputs = torch.tensor(training_data[0][0])
    state_h, state_c = model.init_state()
    out, (state_h, state_c) = model(inputs, (state_h, state_c))
    out = torch.argmax(out[-1]).item()

    print('\nw/o force teaching')
    for _ in range(1000): 
        out, (state_h, state_c) = model(torch.tensor(out).view(1,), (state_h, state_c))
        out = torch.argmax(out).item()
        nforce_t_text.append(out)
    print('Gen\n', ' '.join([id2voc[i] for i in nforce_t_text]))

Query
 one of the other reviewers has mentioned that after watching just 1 oz episode you 'll be hooked . they are right , as this is exactly what happened with me.the first thing that struck me about oz was its brutality and unflinching scenes of violence , which set in right from the word go . trust me , this is not a show for the faint hearted or timid . this show pulls no punches with regards to drugs , sex or violence . its is hardcore , in the classic use of the word.it is called oz as that is the nickname given to the oswald maximum security state penitentary . it focuses mainly on emerald city , an experimental section of the prison where all the cells have glass fronts and face inwards , so privacy is not high on the agenda . em city is home to many .. aryans , muslims , gangstas , latinos , christians , italians , irish and more .... so scuffles , death stares , dodgy dealings and shady agreements are never far away.i would say the main appeal of the show is due to the fact t