In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as datautils

from tqdm import tqdm

from utils.utils import predict, normalize, produce_vocab, proc_set, init_weights, accuracy
from utils.model import LSTMTagger

import argparse
import os


torch.manual_seed(1234)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = 'checkpoint'
rawFile = r"C:\Users\sherw\OneDrive\Desktop\Thesis\scribbling-speech\nico-nocon\raw_storytelling_tl.txt"


In [14]:
with open(checkpoint + '/settings.bin', 'rb') as f:
    word_vocab, word2idx, idx2word, tags_vocab, tag2idx, idx2tag, msl, embedding_dim, hidden_dim, dropout, bidirectional, num_layers, recur_dropout = torch.load(f,map_location =torch.device('cpu'))

    # Produce a blank model
    model = LSTMTagger(word_vocab_sz=len(word_vocab), 
                        tag_vocab_sz=len(tags_vocab), 
                        embedding_dim=embedding_dim, 
                        hidden_dim=hidden_dim, 
                        dropout=dropout,
                        num_layers=num_layers,
                        recur_dropout=recur_dropout,
                        bidirectional=bidirectional)

    # Load checkpoints and put the model in eval mode
    with open(checkpoint + '/model.bin', 'rb') as f:
        model.load_state_dict(torch.load(f,map_location =torch.device('cpu')))
    model = model.cpu()
    model.eval()

In [52]:
with open(rawFile, 'r', encoding='utf-8') as f:
    textList = f.readlines()

wordSet = open('data/tl_storytelling_wordset.txt','w',encoding='utf-8')
tagSet = open('data/tl_storytelling_tagset.txt','w',encoding='utf-8')

for text in textList:
    preds = predict(text, word2idx, idx2tag, word_vocab, msl, model)
    wordSet.write(text)
    tagSet.write(' '.join(preds) + "\n")

print("Done psuedolabeling...")
wordSet.close()
tagSet.close()

Done psuedolabeling...
