In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Download

In [2]:
!pip install pyvi

Collecting pyvi
  Downloading pyvi-0.1.1-py2.py3-none-any.whl (8.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.5/8.5 MB[0m [31m64.1 MB/s[0m eta [36m0:00:00[0m
Collecting sklearn-crfsuite (from pyvi)
  Downloading sklearn_crfsuite-0.3.6-py2.py3-none-any.whl (12 kB)
Collecting python-crfsuite>=0.8.3 (from sklearn-crfsuite->pyvi)
  Downloading python_crfsuite-0.9.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m78.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: python-crfsuite, sklearn-crfsuite, pyvi
Successfully installed python-crfsuite-0.9.10 pyvi-0.1.1 sklearn-crfsuite-0.3.6


## Import libraries

In [3]:
import re
import ast
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
from torch.optim import Adam
from torchtext.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader
from pyvi import ViTokenizer

# pre-processing

In [4]:
def normalize(text):
    t = text.replace('\n', ' ')
    t = t.lower()
    return t

def delete_hashtag(text):
    return re.sub(r'#\w+', '', text)

def delete_link(text):
    return re.sub(r'http\S+', '', text)

def remove_emojis(text):
    emoj = re.compile(r"""[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\U00002702-\U000027B0\U000024C2-\U0001F251\U0001f926-\U0001f937\U00010000-\U0010ffff\u200d\u23cf\u23e9\u231a\ufe0f\u3030-]+(?<!\n)""", re.UNICODE)
    return re.sub(emoj, '', text)

def encode_number(text):
    t = text.split(' ')
    t = map(lambda x: '<number>' if bool(re.match(r'^[0-9]+(\.[0-9]+)?$', x)) else x, t)
    return ' '.join(t)

def delete_onelen_token(text):
    t = text.split(' ')
    t = filter(lambda x: len(x)>1, t)
    return ' '.join(t)

def preprocessing(text):
    t = normalize(text)
    t = delete_hashtag(t)
    t = delete_link(t)
    t = remove_emojis(t)
    t = ViTokenizer.tokenize(t)
    t = encode_number(t)
    t = delete_onelen_token(t)
    return t

#Model and Training

In [5]:
%cd /content/drive/MyDrive/CS114

/content/drive/.shortcut-targets-by-id/1Y8ECOyKvn31ywCsUN8dtq1YwMiMpEqep/CS114


## Define class dataset

In [6]:
class HashTag_Dataset(Dataset):
  def __init__(self, root, max_length=250):
    super(HashTag_Dataset, self).__init__()
    self.classes = ['#Q&A', '#cv', '#data', '#deep_learning', '#machine_learning', '#math', '#nlp', '#python', '#sharing']
    text, labels = [], []

    df = pd.read_csv(root, encoding='utf-8-sig')
    texts = df['text']
    labels = df["label"]
    self.texts = texts
    self.labels = labels
    self.vocab = self.make_vocab(texts)
    self.max_length = max_length
  def make_vocab(self, texts):
    vocab = dict()
    for text in texts:
      words = text.split()
      for word in words:
        if word not in vocab:
          vocab[word] = 1
        else:
          vocab[word] += 1
    vocab = list(dict(filter(lambda x: x[1]>3, vocab.items())).keys())
    vocab.append('<UNK>')
    vocab.append('<PAD>')
    return vocab
  def encode_text(self, text):
    words = text.split()
    if len(words) > self.max_length:
      words = words[:self.max_length]
    else:
      words += ['<PAD>']*(self.max_length-len(words))
    enc = [self.vocab.index(w) if w in self.vocab else self.vocab.index('<UNK>') for w in words]
    return enc
  def encode_label(self, label):
    enc = ast.literal_eval(label)
    enc = [1 if l in enc else 0 for l in self.classes]
    return enc
  def __len__(self):
    return len(self.labels)

  def len_vocab(self):
    return len(self.vocab)

  def num_classes(self):
    return len(self.classes)

  def __getitem__(self, idx):
    text = self.texts[idx]
    label = self.labels[idx]
    encode = self.encode_text(text)
    label = self.encode_label(label)
    encode = torch.tensor(encode, dtype=torch.long)
    label = torch.tensor(label, dtype=torch.float32)
    return encode, label

## Create the model

In [7]:
class HashtagRecommendation(nn.Module):
  def __init__(self, embedding_dim, hidden_dim, vocab_size, num_labels):
    super(HashtagRecommendation, self).__init__()
    self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
    self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
    self.fc1 = nn.Linear(hidden_dim, 128)
    self.fc2 = nn.Linear(128, 64)
    self.fc3 = nn.Linear(64, num_labels)
    self.relu = nn.ReLU()
    self.sigmoid = nn.Sigmoid()
    self.dropout = nn.Dropout(p=0.2)

  def forward(self, sentence):
    embeds = self.word_embeddings(sentence)
    gru_out, _ = self.gru(embeds)
    # Select the last output of GRU layer
    gru_out = gru_out[:, -1, :]
    x = self.relu(self.fc1(gru_out))
    x = self.dropout(x)
    x = self.relu(self.fc2(x))
    x = self.dropout(x)
    logits = self.fc3(x)
    probs = self.sigmoid(logits)
    return probs

# Training

In [8]:
def accuracy(y_true, y_pred):
    temp = 0
    for i in range(y_true.shape[0]):
        temp += sum(np.logical_and(y_true[i], y_pred[i])) / sum(np.logical_or(y_true[i], y_pred[i]))
    return temp / y_true.shape[0]

In [9]:
train_set = HashTag_Dataset(root = "train_data.csv")
test_set = HashTag_Dataset(root = "test_data.csv")
test_loader = DataLoader(test_set, batch_size=2)
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2, drop_last=True)
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
num_epochs = 250

In [10]:
model = HashtagRecommendation(embedding_dim=100,hidden_dim=256,num_labels=train_set.num_classes(), vocab_size=train_set.len_vocab())
model = model.to(device)
criterion = nn.BCELoss()
optimizer = Adam(model.parameters(), lr=0.0001)
num_iters = len(train_loader)

In [12]:
best_acc = 0

for epoch in range(num_epochs):
    model.train()
    progress_bar = tqdm(train_loader, colour='green')

    for iter, (texts, labels) in enumerate(progress_bar):
        texts = texts.to(device)
        labels = labels.to(dtype=torch.float).to(device)

        # forward
        outputs = model(texts)
        loss_value = criterion(outputs, labels)  # Assuming your model has a single output node
        progress_bar.set_description("Epoch {}/{}. Iteration {}/{}. Loss {:.5f}".format(epoch+1, num_epochs, iter+1, num_iters, loss_value.item()))

        # backward
        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()

    # evaluate
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for iter, (texts, labels) in enumerate(progress_bar):
            texts = texts.to(device)
            labels = labels.to(dtype=torch.float).to(device)
            outputs = np.array(model(texts).tolist())
            y_pred.extend((outputs >= 0.5).astype(np.float32))
            y_true.extend(labels.cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    acc = accuracy(y_true, y_pred)

    print(f"Accuracy: {acc}")

    if acc > best_acc:
        best_acc = acc
        torch.save(model, '/content/drive/MyDrive/CS114/best_model.pth')

Epoch 1/250. Iteration 262/262. Loss 0.30729: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.89it/s]


Accuracy: 0.3652989821882948


Epoch 2/250. Iteration 262/262. Loss 0.45523: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.44it/s]


Accuracy: 0.4973759541984724


Epoch 3/250. Iteration 262/262. Loss 0.45006: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.30it/s]


Accuracy: 0.5054866412213733


Epoch 4/250. Iteration 262/262. Loss 0.53755: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.56it/s]


Accuracy: 0.4770992366412206


Epoch 5/250. Iteration 262/262. Loss 0.32531: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.61it/s]


Accuracy: 0.5467557251908391


Epoch 6/250. Iteration 262/262. Loss 0.30727: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.70it/s]


Accuracy: 0.5648854961832058


Epoch 7/250. Iteration 262/262. Loss 0.30100: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.73it/s]


Accuracy: 0.566157760814249


Epoch 8/250. Iteration 262/262. Loss 0.25564: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.76it/s]


Accuracy: 0.5709287531806615


Epoch 9/250. Iteration 262/262. Loss 0.32242: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.59it/s]


Accuracy: 0.5760178117048343


Epoch 10/250. Iteration 262/262. Loss 0.43426: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.98it/s]


Accuracy: 0.5815839694656486


Epoch 11/250. Iteration 262/262. Loss 0.34762: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.82it/s]


Accuracy: 0.5855597964376589


Epoch 12/250. Iteration 262/262. Loss 0.24668: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.36it/s]


Accuracy: 0.5830152671755723


Epoch 13/250. Iteration 262/262. Loss 0.20848: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.47it/s]


Accuracy: 0.5760178117048341


Epoch 14/250. Iteration 262/262. Loss 0.41686: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.72it/s]


Accuracy: 0.5879452926208656


Epoch 15/250. Iteration 262/262. Loss 0.31108: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.70it/s]


Accuracy: 0.5854007633587786


Epoch 16/250. Iteration 262/262. Loss 0.26594: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.68it/s]


Accuracy: 0.5831743002544528


Epoch 17/250. Iteration 262/262. Loss 0.51067: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.60it/s]


Accuracy: 0.586673027989822


Epoch 18/250. Iteration 262/262. Loss 0.33523: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.38it/s]


Accuracy: 0.5888994910941472


Epoch 19/250. Iteration 262/262. Loss 0.30995: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.60it/s]


Accuracy: 0.5925572519083968


Epoch 20/250. Iteration 262/262. Loss 0.29108: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.21it/s]


Accuracy: 0.5944656488549619


Epoch 21/250. Iteration 262/262. Loss 0.36773: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.77it/s]


Accuracy: 0.5930343511450383


Epoch 22/250. Iteration 262/262. Loss 0.29299: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.86it/s]


Accuracy: 0.5954198473282446


Epoch 23/250. Iteration 262/262. Loss 0.41159: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.81it/s]


Accuracy: 0.575858778625954


Epoch 24/250. Iteration 262/262. Loss 0.25556: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.73it/s]


Accuracy: 0.5930343511450377


Epoch 25/250. Iteration 262/262. Loss 0.20133: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.68it/s]


Accuracy: 0.5949427480916031


Epoch 26/250. Iteration 262/262. Loss 0.29541: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.33it/s]


Accuracy: 0.5963740458015266


Epoch 27/250. Iteration 262/262. Loss 0.29249: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.46it/s]


Accuracy: 0.5951017811704836


Epoch 28/250. Iteration 262/262. Loss 0.42033: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.49it/s]


Accuracy: 0.5954198473282444


Epoch 29/250. Iteration 262/262. Loss 0.27097: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.49it/s]


Accuracy: 0.5978053435114504


Epoch 30/250. Iteration 262/262. Loss 0.37332: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.74it/s]


Accuracy: 0.5982824427480916


Epoch 31/250. Iteration 262/262. Loss 0.26466: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.73it/s]


Accuracy: 0.5973282442748092


Epoch 32/250. Iteration 262/262. Loss 0.29544: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.93it/s]


Accuracy: 0.5972487277353693


Epoch 33/250. Iteration 262/262. Loss 0.24729: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.77it/s]


Accuracy: 0.6016221374045805


Epoch 34/250. Iteration 262/262. Loss 0.44004: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.57it/s]


Accuracy: 0.6025763358778629


Epoch 35/250. Iteration 262/262. Loss 0.21743: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.70it/s]


Accuracy: 0.6020992366412213


Epoch 36/250. Iteration 262/262. Loss 0.14947: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.75it/s]


Accuracy: 0.6020992366412217


Epoch 37/250. Iteration 262/262. Loss 0.29458: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.54it/s]


Accuracy: 0.6016221374045799


Epoch 38/250. Iteration 262/262. Loss 0.37161: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.69it/s]


Accuracy: 0.60543893129771


Epoch 39/250. Iteration 262/262. Loss 0.18652: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.59it/s]


Accuracy: 0.6059160305343518


Epoch 40/250. Iteration 262/262. Loss 0.25152: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.42it/s]


Accuracy: 0.6025763358778626


Epoch 41/250. Iteration 262/262. Loss 0.41349: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.77it/s]


Accuracy: 0.6063931297709926


Epoch 42/250. Iteration 262/262. Loss 0.31307: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.11it/s]


Accuracy: 0.6059160305343514


Epoch 43/250. Iteration 262/262. Loss 0.29948: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.72it/s]


Accuracy: 0.6059160305343518


Epoch 44/250. Iteration 262/262. Loss 0.22681: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.51it/s]


Accuracy: 0.6035305343511452


Epoch 45/250. Iteration 262/262. Loss 0.68876: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.67it/s]


Accuracy: 0.60559796437659


Epoch 46/250. Iteration 262/262. Loss 0.20284: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.61it/s]


Accuracy: 0.6054389312977099


Epoch 47/250. Iteration 262/262. Loss 0.48941: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.57it/s]


Accuracy: 0.60543893129771


Epoch 48/250. Iteration 262/262. Loss 0.32796: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.55it/s]


Accuracy: 0.6054389312977101


Epoch 49/250. Iteration 262/262. Loss 0.25753: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.65it/s]


Accuracy: 0.6011450381679393


Epoch 50/250. Iteration 262/262. Loss 0.33112: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.63it/s]


Accuracy: 0.6059160305343516


Epoch 51/250. Iteration 262/262. Loss 0.20917: 100%|[32m██████████[0m| 262/262 [00:19<00:00, 13.17it/s]


Accuracy: 0.6059160305343508


Epoch 52/250. Iteration 262/262. Loss 0.42895: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.69it/s]


Accuracy: 0.6059160305343516


Epoch 53/250. Iteration 262/262. Loss 0.24655: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.70it/s]


Accuracy: 0.6059160305343515


Epoch 54/250. Iteration 262/262. Loss 0.28759: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.54it/s]


Accuracy: 0.6059160305343514


Epoch 55/250. Iteration 262/262. Loss 0.44729: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.71it/s]


Accuracy: 0.6059160305343507


Epoch 56/250. Iteration 262/262. Loss 0.25976: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.61it/s]


Accuracy: 0.6059160305343515


Epoch 57/250. Iteration 262/262. Loss 0.20709: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.82it/s]


Accuracy: 0.6059160305343512


Epoch 58/250. Iteration 262/262. Loss 0.31150: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.68it/s]


Accuracy: 0.606870229007634


Epoch 59/250. Iteration 262/262. Loss 0.32740: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.59it/s]


Accuracy: 0.6065521628498733


Epoch 60/250. Iteration 262/262. Loss 0.23226: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.53it/s]


Accuracy: 0.6068702290076334


Epoch 61/250. Iteration 262/262. Loss 0.37990: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.63it/s]


Accuracy: 0.6068702290076339


Epoch 62/250. Iteration 262/262. Loss 0.15699: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.53it/s]


Accuracy: 0.6068702290076338


Epoch 63/250. Iteration 262/262. Loss 0.35557: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.69it/s]


Accuracy: 0.603291984732825


Epoch 64/250. Iteration 262/262. Loss 0.24477: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.35it/s]


Accuracy: 0.6059160305343516


Epoch 65/250. Iteration 262/262. Loss 0.64306: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.62it/s]


Accuracy: 0.606870229007634


Epoch 66/250. Iteration 262/262. Loss 0.21283: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.67it/s]


Accuracy: 0.6068702290076337


Epoch 67/250. Iteration 262/262. Loss 0.14580: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.56it/s]


Accuracy: 0.6063931297709922


Epoch 68/250. Iteration 262/262. Loss 0.44874: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.68it/s]


Accuracy: 0.6083015267175569


Epoch 69/250. Iteration 262/262. Loss 0.23791: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.67it/s]


Accuracy: 0.6075063613231556


Epoch 70/250. Iteration 262/262. Loss 0.29594: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.71it/s]


Accuracy: 0.6070292620865143


Epoch 71/250. Iteration 262/262. Loss 0.26217: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.60it/s]


Accuracy: 0.6055184478371505


Epoch 72/250. Iteration 262/262. Loss 0.35729: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.64it/s]


Accuracy: 0.6083015267175573


Epoch 73/250. Iteration 262/262. Loss 0.31404: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.49it/s]


Accuracy: 0.608301526717557


Epoch 74/250. Iteration 262/262. Loss 0.37804: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.38it/s]


Accuracy: 0.6083015267175571


Epoch 75/250. Iteration 262/262. Loss 0.16844: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.68it/s]


Accuracy: 0.6073473282442746


Epoch 76/250. Iteration 262/262. Loss 0.37533: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.55it/s]


Accuracy: 0.607824427480916


Epoch 77/250. Iteration 262/262. Loss 0.35913: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.64it/s]


Accuracy: 0.607824427480916


Epoch 78/250. Iteration 262/262. Loss 0.32438: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.65it/s]


Accuracy: 0.6048027989821887


Epoch 79/250. Iteration 262/262. Loss 0.14107: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.59it/s]


Accuracy: 0.6073473282442752


Epoch 80/250. Iteration 262/262. Loss 0.22017: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.68it/s]


Accuracy: 0.6083015267175576


Epoch 81/250. Iteration 262/262. Loss 0.22122: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.61it/s]


Accuracy: 0.6083015267175574


Epoch 82/250. Iteration 262/262. Loss 0.39451: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.83it/s]


Accuracy: 0.6078244274809165


Epoch 83/250. Iteration 262/262. Loss 0.16084: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.63it/s]


Accuracy: 0.6092557251908397


Epoch 84/250. Iteration 262/262. Loss 0.30328: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.55it/s]


Accuracy: 0.6087786259541988


Epoch 85/250. Iteration 262/262. Loss 0.27503: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.58it/s]


Accuracy: 0.6092557251908397


Epoch 86/250. Iteration 262/262. Loss 0.43419: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.57it/s]


Accuracy: 0.6084605597964376


Epoch 87/250. Iteration 262/262. Loss 0.27293: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.28it/s]


Accuracy: 0.608778625954199


Epoch 88/250. Iteration 262/262. Loss 0.04775: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.59it/s]


Accuracy: 0.6097328244274808


Epoch 89/250. Iteration 262/262. Loss 0.39286: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.49it/s]


Accuracy: 0.6095737913486008


Epoch 90/250. Iteration 262/262. Loss 0.27201: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.66it/s]


Accuracy: 0.6067111959287528


Epoch 91/250. Iteration 262/262. Loss 0.30313: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.68it/s]


Accuracy: 0.6087786259541985


Epoch 92/250. Iteration 262/262. Loss 0.13162: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.54it/s]


Accuracy: 0.6087786259541984


Epoch 93/250. Iteration 262/262. Loss 0.29901: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.63it/s]


Accuracy: 0.6106870229007635


Epoch 94/250. Iteration 262/262. Loss 0.36005: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.70it/s]


Accuracy: 0.6106870229007635


Epoch 95/250. Iteration 262/262. Loss 0.35466: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.59it/s]


Accuracy: 0.6106870229007636


Epoch 96/250. Iteration 262/262. Loss 0.22736: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.49it/s]


Accuracy: 0.6116412213740459


Epoch 97/250. Iteration 262/262. Loss 0.29093: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.59it/s]


Accuracy: 0.6116412213740458


Epoch 98/250. Iteration 262/262. Loss 0.31597: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.22it/s]


Accuracy: 0.6121183206106874


Epoch 99/250. Iteration 262/262. Loss 0.25392: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.58it/s]


Accuracy: 0.611641221374046


Epoch 100/250. Iteration 262/262. Loss 0.36900: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.48it/s]


Accuracy: 0.6120388040712471


Epoch 101/250. Iteration 262/262. Loss 0.27211: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.68it/s]


Accuracy: 0.6108460559796436


Epoch 102/250. Iteration 262/262. Loss 0.24229: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.64it/s]


Accuracy: 0.6118002544529264


Epoch 103/250. Iteration 262/262. Loss 0.16949: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.67it/s]


Accuracy: 0.6125954198473286


Epoch 104/250. Iteration 262/262. Loss 0.31372: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.43it/s]


Accuracy: 0.6125954198473285


Epoch 105/250. Iteration 262/262. Loss 0.23141: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.69it/s]


Accuracy: 0.6122773536895674


Epoch 106/250. Iteration 262/262. Loss 0.22553: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.30it/s]


Accuracy: 0.612118320610687


Epoch 107/250. Iteration 262/262. Loss 0.23167: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.68it/s]


Accuracy: 0.5892970737913489


Epoch 108/250. Iteration 262/262. Loss 0.22928: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.44it/s]


Accuracy: 0.6125954198473285


Epoch 109/250. Iteration 262/262. Loss 0.30458: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.57it/s]


Accuracy: 0.6125954198473283


Epoch 110/250. Iteration 262/262. Loss 0.34392: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.34it/s]


Accuracy: 0.6125954198473281


Epoch 111/250. Iteration 262/262. Loss 0.18780: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.63it/s]


Accuracy: 0.6125954198473282


Epoch 112/250. Iteration 262/262. Loss 0.36585: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.52it/s]


Accuracy: 0.6125954198473285


Epoch 113/250. Iteration 262/262. Loss 0.28982: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.50it/s]


Accuracy: 0.6122773536895677


Epoch 114/250. Iteration 262/262. Loss 0.13593: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.43it/s]


Accuracy: 0.6121183206106876


Epoch 115/250. Iteration 262/262. Loss 0.16381: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.36it/s]


Accuracy: 0.6125954198473285


Epoch 116/250. Iteration 262/262. Loss 0.35940: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.18it/s]


Accuracy: 0.612118320610687


Epoch 117/250. Iteration 262/262. Loss 0.34369: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.54it/s]


Accuracy: 0.6121183206106874


Epoch 118/250. Iteration 262/262. Loss 0.27024: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.61it/s]


Accuracy: 0.6119592875318068


Epoch 119/250. Iteration 262/262. Loss 0.12793: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.47it/s]


Accuracy: 0.611641221374046


Epoch 120/250. Iteration 262/262. Loss 0.26364: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.55it/s]


Accuracy: 0.6124363867684484


Epoch 121/250. Iteration 262/262. Loss 0.22379: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.53it/s]


Accuracy: 0.6149013994910942


Epoch 122/250. Iteration 262/262. Loss 0.20212: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.59it/s]


Accuracy: 0.6086195928753182


Epoch 123/250. Iteration 262/262. Loss 0.19933: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.65it/s]


Accuracy: 0.6204675572519085


Epoch 124/250. Iteration 262/262. Loss 0.06802: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.36it/s]


Accuracy: 0.6241253180661578


Epoch 125/250. Iteration 262/262. Loss 0.19819: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.58it/s]


Accuracy: 0.6293734096692116


Epoch 126/250. Iteration 262/262. Loss 0.23552: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.31it/s]


Accuracy: 0.6324268447837149


Epoch 127/250. Iteration 262/262. Loss 0.25140: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.15it/s]


Accuracy: 0.6356870229007636


Epoch 128/250. Iteration 262/262. Loss 0.22777: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.43it/s]


Accuracy: 0.6394243002544528


Epoch 129/250. Iteration 262/262. Loss 0.22560: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.15it/s]


Accuracy: 0.630407124681934


Epoch 130/250. Iteration 262/262. Loss 0.24655: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.49it/s]


Accuracy: 0.6344147582697198


Epoch 131/250. Iteration 262/262. Loss 0.27301: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.45it/s]


Accuracy: 0.629993638676845


Epoch 132/250. Iteration 262/262. Loss 0.34697: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.56it/s]


Accuracy: 0.6273377862595423


Epoch 133/250. Iteration 262/262. Loss 0.10965: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.46it/s]


Accuracy: 0.6391857506361328


Epoch 134/250. Iteration 262/262. Loss 0.22781: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.57it/s]


Accuracy: 0.6471692111959294


Epoch 135/250. Iteration 262/262. Loss 0.18476: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.49it/s]


Accuracy: 0.6393447837150131


Epoch 136/250. Iteration 262/262. Loss 0.26444: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.57it/s]


Accuracy: 0.6513040712468197


Epoch 137/250. Iteration 262/262. Loss 0.16248: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.44it/s]


Accuracy: 0.6581424936386774


Epoch 138/250. Iteration 262/262. Loss 0.08862: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.44it/s]


Accuracy: 0.6435114503816792


Epoch 139/250. Iteration 262/262. Loss 0.12113: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.50it/s]


Accuracy: 0.6662054707379139


Epoch 140/250. Iteration 262/262. Loss 0.22161: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.46it/s]


Accuracy: 0.6588581424936394


Epoch 141/250. Iteration 262/262. Loss 0.10708: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.51it/s]


Accuracy: 0.662388676844784


Epoch 142/250. Iteration 262/262. Loss 0.14895: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.33it/s]


Accuracy: 0.6620388040712472


Epoch 143/250. Iteration 262/262. Loss 0.27312: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.11it/s]


Accuracy: 0.6788167938931301


Epoch 144/250. Iteration 262/262. Loss 0.16188: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.28it/s]


Accuracy: 0.6894720101781174


Epoch 145/250. Iteration 262/262. Loss 0.18310: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.39it/s]


Accuracy: 0.6717398218829523


Epoch 146/250. Iteration 262/262. Loss 0.07735: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.55it/s]


Accuracy: 0.6786100508905853


Epoch 147/250. Iteration 262/262. Loss 0.12994: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.51it/s]


Accuracy: 0.6899809160305341


Epoch 148/250. Iteration 262/262. Loss 0.46035: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.45it/s]


Accuracy: 0.7134701017811705


Epoch 149/250. Iteration 262/262. Loss 0.32716: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.62it/s]


Accuracy: 0.7077290076335877


Epoch 150/250. Iteration 262/262. Loss 0.14688: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.38it/s]


Accuracy: 0.7277989821882951


Epoch 151/250. Iteration 262/262. Loss 0.28204: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.50it/s]


Accuracy: 0.7313295165394406


Epoch 152/250. Iteration 262/262. Loss 0.20470: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.30it/s]


Accuracy: 0.729564249363868


Epoch 153/250. Iteration 262/262. Loss 0.17211: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.38it/s]


Accuracy: 0.7298505089058523


Epoch 154/250. Iteration 262/262. Loss 0.20945: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.46it/s]


Accuracy: 0.7453562340966926


Epoch 155/250. Iteration 262/262. Loss 0.18082: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.50it/s]


Accuracy: 0.742000636132316


Epoch 156/250. Iteration 262/262. Loss 0.06771: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.33it/s]


Accuracy: 0.725747455470738


Epoch 157/250. Iteration 262/262. Loss 0.20663: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.31it/s]


Accuracy: 0.7393447837150122


Epoch 158/250. Iteration 262/262. Loss 0.22958: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.28it/s]


Accuracy: 0.7575222646310432


Epoch 159/250. Iteration 262/262. Loss 0.19320: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.34it/s]


Accuracy: 0.7733301526717558


Epoch 160/250. Iteration 262/262. Loss 0.29196: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.25it/s]


Accuracy: 0.7775922391857507


Epoch 161/250. Iteration 262/262. Loss 0.23225: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.50it/s]


Accuracy: 0.776017811704835


Epoch 162/250. Iteration 262/262. Loss 0.12398: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.31it/s]


Accuracy: 0.7908237913486004


Epoch 163/250. Iteration 262/262. Loss 0.22658: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.10it/s]


Accuracy: 0.7796755725190839


Epoch 164/250. Iteration 262/262. Loss 0.28338: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.43it/s]


Accuracy: 0.7725031806615774


Epoch 165/250. Iteration 262/262. Loss 0.13142: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.44it/s]


Accuracy: 0.8078880407124678


Epoch 166/250. Iteration 262/262. Loss 0.07561: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.45it/s]


Accuracy: 0.7896946564885494


Epoch 167/250. Iteration 262/262. Loss 0.09492: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.33it/s]


Accuracy: 0.7983460559796438


Epoch 168/250. Iteration 262/262. Loss 0.15232: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.55it/s]


Accuracy: 0.8189567430025443


Epoch 169/250. Iteration 262/262. Loss 0.08516: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.46it/s]


Accuracy: 0.8239503816793885


Epoch 170/250. Iteration 262/262. Loss 0.06007: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.41it/s]


Accuracy: 0.8239185750636125


Epoch 171/250. Iteration 262/262. Loss 0.05999: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.36it/s]


Accuracy: 0.8127385496183201


Epoch 172/250. Iteration 262/262. Loss 0.11717: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.27it/s]


Accuracy: 0.8089853689567426


Epoch 173/250. Iteration 262/262. Loss 0.11678: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.16it/s]


Accuracy: 0.7995069974554703


Epoch 174/250. Iteration 262/262. Loss 0.07218: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.47it/s]


Accuracy: 0.8483619592875317


Epoch 175/250. Iteration 262/262. Loss 0.06116: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.41it/s]


Accuracy: 0.821358142493638


Epoch 176/250. Iteration 262/262. Loss 0.07203: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.59it/s]


Accuracy: 0.8407124681933836


Epoch 177/250. Iteration 262/262. Loss 0.20891: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.32it/s]


Accuracy: 0.8182729007633583


Epoch 178/250. Iteration 262/262. Loss 0.15032: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.38it/s]


Accuracy: 0.8134223918575056


Epoch 179/250. Iteration 262/262. Loss 0.15428: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.15it/s]


Accuracy: 0.8395038167938931


Epoch 180/250. Iteration 262/262. Loss 0.11330: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.26it/s]


Accuracy: 0.8469783715012716


Epoch 181/250. Iteration 262/262. Loss 0.11820: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.14it/s]


Accuracy: 0.8651240458015259


Epoch 182/250. Iteration 262/262. Loss 0.18450: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.82it/s]


Accuracy: 0.8589376590330781


Epoch 183/250. Iteration 262/262. Loss 0.16570: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.96it/s]


Accuracy: 0.863199745547073


Epoch 184/250. Iteration 262/262. Loss 0.26590: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.34it/s]


Accuracy: 0.8620069974554704


Epoch 185/250. Iteration 262/262. Loss 0.06396: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.15it/s]


Accuracy: 0.8931456743002534


Epoch 186/250. Iteration 262/262. Loss 0.10096: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.15it/s]


Accuracy: 0.8712309160305335


Epoch 187/250. Iteration 262/262. Loss 0.24120: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.01it/s]


Accuracy: 0.8316316793893124


Epoch 188/250. Iteration 262/262. Loss 0.12007: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.88it/s]


Accuracy: 0.8996660305343502


Epoch 189/250. Iteration 262/262. Loss 0.02710: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.08it/s]


Accuracy: 0.8738390585241725


Epoch 190/250. Iteration 262/262. Loss 0.02052: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.83it/s]


Accuracy: 0.9023695928753175


Epoch 191/250. Iteration 262/262. Loss 0.09462: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.19it/s]


Accuracy: 0.9033555979643757


Epoch 192/250. Iteration 262/262. Loss 0.24468: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.32it/s]


Accuracy: 0.8903307888040705


Epoch 193/250. Iteration 262/262. Loss 0.11165: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.33it/s]


Accuracy: 0.8719147582697196


Epoch 194/250. Iteration 262/262. Loss 0.28575: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.14it/s]


Accuracy: 0.8973123409669206


Epoch 195/250. Iteration 262/262. Loss 0.06168: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.35it/s]


Accuracy: 0.9186704834605591


Epoch 196/250. Iteration 262/262. Loss 0.10383: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.30it/s]


Accuracy: 0.8948632315521617


Epoch 197/250. Iteration 262/262. Loss 0.23033: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.34it/s]


Accuracy: 0.9149650127226459


Epoch 198/250. Iteration 262/262. Loss 0.06975: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.25it/s]


Accuracy: 0.9143447837150115


Epoch 199/250. Iteration 262/262. Loss 0.01802: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.34it/s]


Accuracy: 0.9176049618320602


Epoch 200/250. Iteration 262/262. Loss 0.19747: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.16it/s]


Accuracy: 0.8694497455470735


Epoch 201/250. Iteration 262/262. Loss 0.06329: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.17it/s]


Accuracy: 0.9238708651399482


Epoch 202/250. Iteration 262/262. Loss 0.18409: 100%|[32m██████████[0m| 262/262 [00:20<00:00, 12.61it/s]


Accuracy: 0.9189408396946552


Epoch 203/250. Iteration 262/262. Loss 0.08361: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.22it/s]


Accuracy: 0.906456743002544


Epoch 204/250. Iteration 262/262. Loss 0.13617: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.20it/s]


Accuracy: 0.8672073791348596


Epoch 205/250. Iteration 262/262. Loss 0.19413: 100%|[32m██████████[0m| 262/262 [00:18<00:00, 14.26it/s]


Accuracy: 0.9059955470737904


Epoch 206/250. Iteration 262/262. Loss 0.06755: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.26it/s]


Accuracy: 0.9156806615776073


Epoch 207/250. Iteration 262/262. Loss 0.04666: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.81it/s]


Accuracy: 0.9244274809160298


Epoch 208/250. Iteration 262/262. Loss 0.07779: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.39it/s]


Accuracy: 0.9268129770992354


Epoch 209/250. Iteration 262/262. Loss 0.17073: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.12it/s]


Accuracy: 0.9369910941475814


Epoch 210/250. Iteration 262/262. Loss 0.03241: 100%|[32m██████████[0m| 262/262 [00:16<00:00, 15.42it/s]


Accuracy: 0.9238708651399484


Epoch 211/250. Iteration 262/262. Loss 0.01549: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.04it/s]


Accuracy: 0.9328244274809149


Epoch 212/250. Iteration 262/262. Loss 0.08905: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.29it/s]


Accuracy: 0.9380407124681923


Epoch 213/250. Iteration 262/262. Loss 0.04913: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.14it/s]


Accuracy: 0.941046437659032


Epoch 214/250. Iteration 262/262. Loss 0.03648: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.72it/s]


Accuracy: 0.9385337150127218


Epoch 215/250. Iteration 262/262. Loss 0.21264: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.86it/s]


Accuracy: 0.8746501272264632


Epoch 216/250. Iteration 262/262. Loss 0.03831: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.05it/s]


Accuracy: 0.9344942748091594


Epoch 217/250. Iteration 262/262. Loss 0.10261: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.16it/s]


Accuracy: 0.9383746819338412


Epoch 218/250. Iteration 262/262. Loss 0.11923: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.18it/s]


Accuracy: 0.9431456743002538


Epoch 219/250. Iteration 262/262. Loss 0.02234: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.11it/s]


Accuracy: 0.9492366412213725


Epoch 220/250. Iteration 262/262. Loss 0.05695: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.06it/s]


Accuracy: 0.8999522900763353


Epoch 221/250. Iteration 262/262. Loss 0.02304: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.12it/s]


Accuracy: 0.9468511450381669


Epoch 222/250. Iteration 262/262. Loss 0.05793: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.27it/s]


Accuracy: 0.945419847328243


Epoch 223/250. Iteration 262/262. Loss 0.19319: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.23it/s]


Accuracy: 0.954627862595419


Epoch 224/250. Iteration 262/262. Loss 0.02028: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.25it/s]


Accuracy: 0.9111641221374037


Epoch 225/250. Iteration 262/262. Loss 0.03182: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.03it/s]


Accuracy: 0.94731234096692


Epoch 226/250. Iteration 262/262. Loss 0.02497: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.93it/s]


Accuracy: 0.9181933842239178


Epoch 227/250. Iteration 262/262. Loss 0.02680: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.35it/s]


Accuracy: 0.9493479643765893


Epoch 228/250. Iteration 262/262. Loss 0.05012: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.20it/s]


Accuracy: 0.9494751908396938


Epoch 229/250. Iteration 262/262. Loss 0.11271: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.27it/s]


Accuracy: 0.9562659033078871


Epoch 230/250. Iteration 262/262. Loss 0.01571: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.90it/s]


Accuracy: 0.9524491094147575


Epoch 231/250. Iteration 262/262. Loss 0.03687: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.80it/s]


Accuracy: 0.9526717557251899


Epoch 232/250. Iteration 262/262. Loss 0.03324: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.77it/s]


Accuracy: 0.9563136132315512


Epoch 233/250. Iteration 262/262. Loss 0.03476: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.90it/s]


Accuracy: 0.9475826972010163


Epoch 234/250. Iteration 262/262. Loss 0.01711: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.18it/s]


Accuracy: 0.9469624681933835


Epoch 235/250. Iteration 262/262. Loss 0.03491: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.12it/s]


Accuracy: 0.9489503816793885


Epoch 236/250. Iteration 262/262. Loss 0.00792: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.25it/s]


Accuracy: 0.9460559796437648


Epoch 237/250. Iteration 262/262. Loss 0.01583: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.07it/s]


Accuracy: 0.9562499999999985


Epoch 238/250. Iteration 262/262. Loss 0.00092: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.27it/s]


Accuracy: 0.9558365139949099


Epoch 239/250. Iteration 262/262. Loss 0.01373: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.13it/s]


Accuracy: 0.9623091603053427


Epoch 240/250. Iteration 262/262. Loss 0.05079: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.85it/s]


Accuracy: 0.9561227735368949


Epoch 241/250. Iteration 262/262. Loss 0.11872: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.86it/s]


Accuracy: 0.9489662849872766


Epoch 242/250. Iteration 262/262. Loss 0.04530: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.07it/s]


Accuracy: 0.9556933842239178


Epoch 243/250. Iteration 262/262. Loss 0.08032: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.98it/s]


Accuracy: 0.9611959287531803


Epoch 244/250. Iteration 262/262. Loss 0.05071: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.86it/s]


Accuracy: 0.9545642493638671


Epoch 245/250. Iteration 262/262. Loss 0.02897: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.78it/s]


Accuracy: 0.9505407124681926


Epoch 246/250. Iteration 262/262. Loss 0.02431: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.07it/s]


Accuracy: 0.9641380407124674


Epoch 247/250. Iteration 262/262. Loss 0.16571: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 14.85it/s]


Accuracy: 0.9578403307888033


Epoch 248/250. Iteration 262/262. Loss 0.00944: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.09it/s]


Accuracy: 0.9562977099236639


Epoch 249/250. Iteration 262/262. Loss 0.12386: 100%|[32m██████████[0m| 262/262 [00:17<00:00, 15.02it/s]


Accuracy: 0.9520038167938923


Epoch 250/250. Iteration 262/262. Loss 0.01448: 100%|[32m██████████[0m| 262/262 [00:18<00:00, 14.50it/s]


Accuracy: 0.9598759541984726


## Evaluate on test set

In [16]:
model = torch.load('/content/drive/MyDrive/CS114/best_model.pth')
model.eval()
y_true = []
y_pred = []
progress_bar = tqdm(test_loader, colour='yellow')
with torch.no_grad():
    for iter, (texts, labels) in enumerate(progress_bar):
        texts = texts.to(device)
        labels = labels.to(dtype=torch.float).to(device)
        outputs = np.array(model(texts).tolist())
        y_pred.extend((outputs >= 0.5).astype(np.float32))
        y_true.extend(labels.cpu().numpy())

y_true = np.array(y_true)
y_pred = np.array(y_pred)
acc = accuracy(y_true, y_pred)

print(f"Accuracy: {acc}")

100%|[33m██████████[0m| 117/117 [00:01<00:00, 79.47it/s]

Accuracy: 0.22542735042735054





#Testing

In [17]:
#load model
model = torch.load('/content/drive/MyDrive/CS114/best_model.pth')
print(model)

HashtagRecommendation(
  (word_embeddings): Embedding(2180, 100)
  (lstm): GRU(100, 256, batch_first=True)
  (fc1): Linear(in_features=256, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=9, bias=True)
  (relu): ReLU()
  (sigmoid): Sigmoid()
  (dropout): Dropout(p=0.2, inplace=False)
)


In [22]:
sentence = input()

giáo_trình và hướng_dẫn_giải bài_tập xác_suất thống_kê của bách_khoa hà_nội mình xin chia_sẻ với các bạn <number> quyển sách mình từng học học_kì thứ <number> tại đại_học bách_khoa hà_nội nhờ <number> quyển sách này mà sau_này khi học về data science và machine learning mình gần như_không phải ôn lại kiến_thức toán link to pdf


In [25]:
# test = preprocessing(sentence)
test = train_set.encode_text(sentence)
test = ToTensor()(test)[None, :]
test = test.to(device)
pred = model(test)[0].tolist()
pred = [train_set.classes[i] for i in range(len(pred)) if pred[i]>0.5]
print("Predict label: ", *pred)

Predict label:  #machine_learning #sharing
