In [1]:
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
sns.set()

In [2]:
import torch
from torch.autograd import Variable
from torch import nn, optim
from torch.utils.data import random_split, Dataset, DataLoader

In [3]:
datasets = np.load("final set.npz")
datasets.allow_pickle = True
my_mapping = datasets["use_word2idx"]
dataset = datasets["dataset"]

yongwu = torch.from_numpy(dataset.item()["dufu"])
shuqing = torch.from_numpy(dataset.item()["sushi"])

label_yongwu = torch.full((len(yongwu), ), fill_value=0)
label_shuqing = torch.full((len(shuqing), ), fill_value=1)

final_label = torch.cat((label_yongwu, label_shuqing), dim=0)
final_dataset = torch.cat((yongwu, shuqing), dim=0)

final_label = final_label.type(torch.LongTensor)

In [4]:
final_label.shape

torch.Size([2300])

In [5]:
final_dataset.shape

torch.Size([2300, 1210])

In [6]:
dataset_train, dataset_test, label_train, label_test = train_test_split(final_dataset,
                                                            final_label, test_size=0.1,
                                                                       random_state=1)

In [7]:
dataset_test.shape

torch.Size([230, 1210])

In [8]:
dataset_train.shape

torch.Size([2070, 1210])

### Dataset

In [9]:
class MyDataset(Dataset):
    
    def __init__(self, dataset, label):
        
        self.datasets = dataset
        self.labels = label

    def __getitem__(self, idx):
        return self.datasets[idx], self.labels[idx]

    def __len__(self):
        return len(self.labels)


### Model

In [10]:
class PoemClassifier(nn.Module):
    
    def __init__(self, words_num, embedding_size, hidden_size, classes, num_layers,
                    batch_size, sequence_length):
        super(PoemClassifier, self).__init__()
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.words_num = words_num
        self.sequence_length = sequence_length
        self.emb = nn.Embedding(words_num, embedding_size)
        self.LSTM = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True)
        self.fc1 = nn.Linear(hidden_size, classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, hidden=None):
        batch_size, sequence_length = x.shape # x batch_size, sequence_length
        if hidden is None:
            h, c = self.init_hidden(x, batch_size)
        else:
            h. c = hidden
        out = self.emb(x) # batch_size, sequence_length, embedding_size
#         out = out.transpose(1, 0) # sequence_length, batch_size, embedding_size
        out, hidden = self.LSTM(out, (h, c)) # sequence_length, batch_size, hidden_size
#         out = out.view(self.batch_size, -1)
        out = out[:, -1, :]
        out = self.fc1(out)
        return out, hidden

    def init_hidden(self, ipt, batch_size):
        h = ipt.data.new(self.num_layers, batch_size, self.hidden_size).fill_(0).float()
        c = ipt.data.new(self.num_layers, batch_size, self.hidden_size).fill_(0).float()
        h = Variable(h)
        c = Variable(c)
        return (h, c)


In [11]:
len(dataset_train[1])

1210

### trainer

In [19]:
batch_size = 32
epoch = 40
model = PoemClassifier(len(my_mapping), 128, 128, 2, 2, batch_size, 300)
optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=0.003)
criterion = nn.CrossEntropyLoss()
model = model.cuda()
datasets = MyDataset(dataset_train, label_train)
data_loader = DataLoader(dataset=datasets, batch_size=batch_size, shuffle=True, drop_last=True, 
                                    num_workers=4)

In [20]:
model = model.cuda()
dataset_test = dataset_test.cuda()
label_test = label_test.cuda()
for e in range(1, epoch + 1):
    for idx, item in enumerate(data_loader):
        data, labels = item
        data = data.cuda()
        labels = labels.cuda()
        h = None
        if idx == 0:
            out, h = model(data)
        else:
            out, h = model(data)
        loss = criterion(out, labels)
        p = (torch.sum(torch.max(out, dim=1)[1] == labels).item()) / batch_size
        acc = sum(torch.max(model(dataset_test.detach())[0], dim=1)[1] == label_test).item() / len(dataset_test)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"Epoch [{e}/{epoch}] step [{idx + 1}/{len(data_loader)}] loss = {loss.item()} accuracy = {p} val_acc = {round(acc, 4)}")

Epoch [1/40] step [1/64] loss = 0.6950701475143433 accuracy = 0.5 val_acc = 0.4652
Epoch [1/40] step [2/64] loss = 0.700562596321106 accuracy = 0.375 val_acc = 0.5348
Epoch [1/40] step [3/64] loss = 0.6771875023841858 accuracy = 0.625 val_acc = 0.4652
Epoch [1/40] step [4/64] loss = 0.7552219033241272 accuracy = 0.40625 val_acc = 0.4652
Epoch [1/40] step [5/64] loss = 0.7341263294219971 accuracy = 0.40625 val_acc = 0.4652
Epoch [1/40] step [6/64] loss = 0.689287543296814 accuracy = 0.53125 val_acc = 0.4652
Epoch [1/40] step [7/64] loss = 0.6860019564628601 accuracy = 0.5625 val_acc = 0.4652
Epoch [1/40] step [8/64] loss = 0.6886114478111267 accuracy = 0.5625 val_acc = 0.5174
Epoch [1/40] step [9/64] loss = 0.6978358030319214 accuracy = 0.40625 val_acc = 0.4783
Epoch [1/40] step [10/64] loss = 0.6975655555725098 accuracy = 0.46875 val_acc = 0.5174
Epoch [1/40] step [11/64] loss = 0.6952385306358337 accuracy = 0.34375 val_acc = 0.5739
Epoch [1/40] step [12/64] loss = 0.6896302700042725 a

Epoch [2/40] step [32/64] loss = 0.40214401483535767 accuracy = 0.8125 val_acc = 0.8304
Epoch [2/40] step [33/64] loss = 0.3747028708457947 accuracy = 0.84375 val_acc = 0.8304
Epoch [2/40] step [34/64] loss = 0.14821483194828033 accuracy = 0.96875 val_acc = 0.8435
Epoch [2/40] step [35/64] loss = 0.3574621081352234 accuracy = 0.875 val_acc = 0.8435
Epoch [2/40] step [36/64] loss = 0.3312839865684509 accuracy = 0.84375 val_acc = 0.8391
Epoch [2/40] step [37/64] loss = 0.6273157000541687 accuracy = 0.8125 val_acc = 0.8348
Epoch [2/40] step [38/64] loss = 0.2531633973121643 accuracy = 0.9375 val_acc = 0.8435
Epoch [2/40] step [39/64] loss = 0.4517996311187744 accuracy = 0.78125 val_acc = 0.8391
Epoch [2/40] step [40/64] loss = 0.4405096769332886 accuracy = 0.8125 val_acc = 0.8391
Epoch [2/40] step [41/64] loss = 0.22872130572795868 accuracy = 1.0 val_acc = 0.8348
Epoch [2/40] step [42/64] loss = 0.3070167303085327 accuracy = 0.875 val_acc = 0.8348
Epoch [2/40] step [43/64] loss = 0.318165

Epoch [3/40] step [62/64] loss = 0.19796335697174072 accuracy = 0.90625 val_acc = 0.8652
Epoch [3/40] step [63/64] loss = 0.3225436210632324 accuracy = 0.84375 val_acc = 0.8478
Epoch [3/40] step [64/64] loss = 0.42845314741134644 accuracy = 0.8125 val_acc = 0.8522
Epoch [4/40] step [1/64] loss = 0.24546653032302856 accuracy = 0.90625 val_acc = 0.8565
Epoch [4/40] step [2/64] loss = 0.19940194487571716 accuracy = 0.96875 val_acc = 0.8522
Epoch [4/40] step [3/64] loss = 0.37474945187568665 accuracy = 0.875 val_acc = 0.8478
Epoch [4/40] step [4/64] loss = 0.12676256895065308 accuracy = 1.0 val_acc = 0.8652
Epoch [4/40] step [5/64] loss = 0.09775039553642273 accuracy = 1.0 val_acc = 0.8522
Epoch [4/40] step [6/64] loss = 0.2993592619895935 accuracy = 0.84375 val_acc = 0.8609
Epoch [4/40] step [7/64] loss = 0.2623694837093353 accuracy = 0.875 val_acc = 0.8739
Epoch [4/40] step [8/64] loss = 0.13594619929790497 accuracy = 0.96875 val_acc = 0.8826
Epoch [4/40] step [9/64] loss = 0.30001705884

Epoch [5/40] step [28/64] loss = 0.1368108093738556 accuracy = 0.96875 val_acc = 0.8565
Epoch [5/40] step [29/64] loss = 0.3461135923862457 accuracy = 0.90625 val_acc = 0.8565
Epoch [5/40] step [30/64] loss = 0.23326006531715393 accuracy = 0.9375 val_acc = 0.8913
Epoch [5/40] step [31/64] loss = 0.34448525309562683 accuracy = 0.875 val_acc = 0.8783
Epoch [5/40] step [32/64] loss = 0.7559595108032227 accuracy = 0.75 val_acc = 0.8348
Epoch [5/40] step [33/64] loss = 0.24548941850662231 accuracy = 0.90625 val_acc = 0.8652
Epoch [5/40] step [34/64] loss = 0.23674117028713226 accuracy = 0.9375 val_acc = 0.8826
Epoch [5/40] step [35/64] loss = 0.15773971378803253 accuracy = 0.96875 val_acc = 0.8913
Epoch [5/40] step [36/64] loss = 0.21524935960769653 accuracy = 0.9375 val_acc = 0.8609
Epoch [5/40] step [37/64] loss = 0.3105583190917969 accuracy = 0.875 val_acc = 0.8522
Epoch [5/40] step [38/64] loss = 0.13755634427070618 accuracy = 0.96875 val_acc = 0.8435
Epoch [5/40] step [39/64] loss = 0.

Epoch [6/40] step [58/64] loss = 0.35061463713645935 accuracy = 0.84375 val_acc = 0.8348
Epoch [6/40] step [59/64] loss = 0.32967931032180786 accuracy = 0.875 val_acc = 0.8478
Epoch [6/40] step [60/64] loss = 0.2337658405303955 accuracy = 0.96875 val_acc = 0.8478
Epoch [6/40] step [61/64] loss = 0.46704643964767456 accuracy = 0.8125 val_acc = 0.8522
Epoch [6/40] step [62/64] loss = 0.4419253468513489 accuracy = 0.78125 val_acc = 0.8739
Epoch [6/40] step [63/64] loss = 0.296466201543808 accuracy = 0.84375 val_acc = 0.8522
Epoch [6/40] step [64/64] loss = 0.3935905396938324 accuracy = 0.78125 val_acc = 0.8304
Epoch [7/40] step [1/64] loss = 0.23068606853485107 accuracy = 0.9375 val_acc = 0.8391
Epoch [7/40] step [2/64] loss = 0.27700066566467285 accuracy = 0.84375 val_acc = 0.8522
Epoch [7/40] step [3/64] loss = 0.19526931643486023 accuracy = 0.9375 val_acc = 0.8739
Epoch [7/40] step [4/64] loss = 0.21278579533100128 accuracy = 0.9375 val_acc = 0.8652
Epoch [7/40] step [5/64] loss = 0.28

Epoch [8/40] step [24/64] loss = 0.0510939322412014 accuracy = 0.96875 val_acc = 0.913
Epoch [8/40] step [25/64] loss = 0.18950310349464417 accuracy = 0.9375 val_acc = 0.9043
Epoch [8/40] step [26/64] loss = 0.13671019673347473 accuracy = 0.9375 val_acc = 0.887
Epoch [8/40] step [27/64] loss = 0.13582459092140198 accuracy = 0.96875 val_acc = 0.887
Epoch [8/40] step [28/64] loss = 0.12431630492210388 accuracy = 0.96875 val_acc = 0.887
Epoch [8/40] step [29/64] loss = 0.06951388716697693 accuracy = 0.96875 val_acc = 0.8826
Epoch [8/40] step [30/64] loss = 0.16894039511680603 accuracy = 0.9375 val_acc = 0.887
Epoch [8/40] step [31/64] loss = 0.02231770008802414 accuracy = 1.0 val_acc = 0.9043
Epoch [8/40] step [32/64] loss = 0.10298610478639603 accuracy = 0.96875 val_acc = 0.9043
Epoch [8/40] step [33/64] loss = 0.16002941131591797 accuracy = 0.90625 val_acc = 0.9043
Epoch [8/40] step [34/64] loss = 0.4529799818992615 accuracy = 0.875 val_acc = 0.8913
Epoch [8/40] step [35/64] loss = 0.08

Epoch [9/40] step [54/64] loss = 0.35011351108551025 accuracy = 0.84375 val_acc = 0.8609
Epoch [9/40] step [55/64] loss = 0.27516061067581177 accuracy = 0.875 val_acc = 0.8826
Epoch [9/40] step [56/64] loss = 0.19669121503829956 accuracy = 0.90625 val_acc = 0.9
Epoch [9/40] step [57/64] loss = 0.14388513565063477 accuracy = 1.0 val_acc = 0.8913
Epoch [9/40] step [58/64] loss = 0.28283074498176575 accuracy = 0.84375 val_acc = 0.887
Epoch [9/40] step [59/64] loss = 0.22882971167564392 accuracy = 0.875 val_acc = 0.8826
Epoch [9/40] step [60/64] loss = 0.1165349930524826 accuracy = 0.96875 val_acc = 0.8913
Epoch [9/40] step [61/64] loss = 0.19816091656684875 accuracy = 0.90625 val_acc = 0.9
Epoch [9/40] step [62/64] loss = 0.1166234090924263 accuracy = 0.9375 val_acc = 0.887
Epoch [9/40] step [63/64] loss = 0.05569807067513466 accuracy = 1.0 val_acc = 0.887
Epoch [9/40] step [64/64] loss = 0.44462209939956665 accuracy = 0.90625 val_acc = 0.8783
Epoch [10/40] step [1/64] loss = 0.1777465045

Epoch [11/40] step [20/64] loss = 0.14088965952396393 accuracy = 0.96875 val_acc = 0.8696
Epoch [11/40] step [21/64] loss = 0.1970645636320114 accuracy = 0.9375 val_acc = 0.8435
Epoch [11/40] step [22/64] loss = 0.2390706092119217 accuracy = 0.90625 val_acc = 0.8261
Epoch [11/40] step [23/64] loss = 0.16790804266929626 accuracy = 0.9375 val_acc = 0.8435
Epoch [11/40] step [24/64] loss = 0.3420006334781647 accuracy = 0.875 val_acc = 0.8739
Epoch [11/40] step [25/64] loss = 0.14164036512374878 accuracy = 0.96875 val_acc = 0.9087
Epoch [11/40] step [26/64] loss = 0.1211383044719696 accuracy = 1.0 val_acc = 0.9
Epoch [11/40] step [27/64] loss = 0.12873294949531555 accuracy = 0.96875 val_acc = 0.8913
Epoch [11/40] step [28/64] loss = 0.34736281633377075 accuracy = 0.8125 val_acc = 0.8478
Epoch [11/40] step [29/64] loss = 0.22219456732273102 accuracy = 0.875 val_acc = 0.8913
Epoch [11/40] step [30/64] loss = 0.10999097675085068 accuracy = 0.96875 val_acc = 0.8957
Epoch [11/40] step [31/64] l

Epoch [12/40] step [50/64] loss = 0.384107381105423 accuracy = 0.875 val_acc = 0.9
Epoch [12/40] step [51/64] loss = 0.33939695358276367 accuracy = 0.875 val_acc = 0.8957
Epoch [12/40] step [52/64] loss = 0.5048911571502686 accuracy = 0.8125 val_acc = 0.9
Epoch [12/40] step [53/64] loss = 0.0658034160733223 accuracy = 0.96875 val_acc = 0.9043
Epoch [12/40] step [54/64] loss = 0.2095564603805542 accuracy = 0.9375 val_acc = 0.8957
Epoch [12/40] step [55/64] loss = 0.1867656409740448 accuracy = 0.96875 val_acc = 0.887
Epoch [12/40] step [56/64] loss = 0.5829001069068909 accuracy = 0.75 val_acc = 0.8957
Epoch [12/40] step [57/64] loss = 0.23210817575454712 accuracy = 0.90625 val_acc = 0.8913
Epoch [12/40] step [58/64] loss = 0.1779087632894516 accuracy = 0.9375 val_acc = 0.9
Epoch [12/40] step [59/64] loss = 0.23354752361774445 accuracy = 0.9375 val_acc = 0.9043
Epoch [12/40] step [60/64] loss = 0.15207645297050476 accuracy = 0.96875 val_acc = 0.9
Epoch [12/40] step [61/64] loss = 0.225261

Epoch [14/40] step [16/64] loss = 0.012629315257072449 accuracy = 1.0 val_acc = 0.9
Epoch [14/40] step [17/64] loss = 0.02744976058602333 accuracy = 1.0 val_acc = 0.9
Epoch [14/40] step [18/64] loss = 0.11494327336549759 accuracy = 0.9375 val_acc = 0.9043
Epoch [14/40] step [19/64] loss = 0.04408998787403107 accuracy = 0.96875 val_acc = 0.9
Epoch [14/40] step [20/64] loss = 0.33178308606147766 accuracy = 0.9375 val_acc = 0.9
Epoch [14/40] step [21/64] loss = 0.02703152969479561 accuracy = 1.0 val_acc = 0.9
Epoch [14/40] step [22/64] loss = 0.04474257677793503 accuracy = 0.96875 val_acc = 0.9
Epoch [14/40] step [23/64] loss = 0.19533053040504456 accuracy = 0.9375 val_acc = 0.8957
Epoch [14/40] step [24/64] loss = 0.1684800684452057 accuracy = 0.96875 val_acc = 0.8957
Epoch [14/40] step [25/64] loss = 0.18441739678382874 accuracy = 0.9375 val_acc = 0.8913
Epoch [14/40] step [26/64] loss = 0.07684195041656494 accuracy = 0.96875 val_acc = 0.9087
Epoch [14/40] step [27/64] loss = 0.11320413

Epoch [15/40] step [46/64] loss = 0.08688084036111832 accuracy = 1.0 val_acc = 0.9087
Epoch [15/40] step [47/64] loss = 0.16979917883872986 accuracy = 0.9375 val_acc = 0.9
Epoch [15/40] step [48/64] loss = 0.15288807451725006 accuracy = 0.90625 val_acc = 0.9
Epoch [15/40] step [49/64] loss = 0.09315794706344604 accuracy = 1.0 val_acc = 0.9
Epoch [15/40] step [50/64] loss = 0.14138169586658478 accuracy = 0.9375 val_acc = 0.9
Epoch [15/40] step [51/64] loss = 0.1785963773727417 accuracy = 0.9375 val_acc = 0.8957
Epoch [15/40] step [52/64] loss = 0.16287414729595184 accuracy = 0.9375 val_acc = 0.8913
Epoch [15/40] step [53/64] loss = 0.11166852712631226 accuracy = 0.96875 val_acc = 0.9087
Epoch [15/40] step [54/64] loss = 0.18944285809993744 accuracy = 0.9375 val_acc = 0.9087
Epoch [15/40] step [55/64] loss = 0.1314748227596283 accuracy = 0.96875 val_acc = 0.9174
Epoch [15/40] step [56/64] loss = 0.21570159494876862 accuracy = 0.9375 val_acc = 0.9043
Epoch [15/40] step [57/64] loss = 0.02

Epoch [17/40] step [12/64] loss = 0.23949553072452545 accuracy = 0.90625 val_acc = 0.9087
Epoch [17/40] step [13/64] loss = 0.2052285373210907 accuracy = 0.96875 val_acc = 0.9043
Epoch [17/40] step [14/64] loss = 0.2658844590187073 accuracy = 0.96875 val_acc = 0.9087
Epoch [17/40] step [15/64] loss = 0.19530639052391052 accuracy = 0.90625 val_acc = 0.9174
Epoch [17/40] step [16/64] loss = 0.06290113925933838 accuracy = 1.0 val_acc = 0.9217
Epoch [17/40] step [17/64] loss = 0.04922329634428024 accuracy = 1.0 val_acc = 0.913
Epoch [17/40] step [18/64] loss = 0.050358183681964874 accuracy = 1.0 val_acc = 0.8957
Epoch [17/40] step [19/64] loss = 0.09923303127288818 accuracy = 0.9375 val_acc = 0.8913
Epoch [17/40] step [20/64] loss = 0.10426738858222961 accuracy = 0.9375 val_acc = 0.9043
Epoch [17/40] step [21/64] loss = 0.1455317884683609 accuracy = 0.9375 val_acc = 0.9043
Epoch [17/40] step [22/64] loss = 0.057925060391426086 accuracy = 1.0 val_acc = 0.8957
Epoch [17/40] step [23/64] loss

Epoch [18/40] step [42/64] loss = 0.24242016673088074 accuracy = 0.9375 val_acc = 0.9087
Epoch [18/40] step [43/64] loss = 0.09656500816345215 accuracy = 0.9375 val_acc = 0.913
Epoch [18/40] step [44/64] loss = 0.12173771858215332 accuracy = 0.96875 val_acc = 0.9174
Epoch [18/40] step [45/64] loss = 0.1503465473651886 accuracy = 0.96875 val_acc = 0.9174
Epoch [18/40] step [46/64] loss = 0.06111714988946915 accuracy = 0.96875 val_acc = 0.9174
Epoch [18/40] step [47/64] loss = 0.03281095623970032 accuracy = 1.0 val_acc = 0.913
Epoch [18/40] step [48/64] loss = 0.19781722128391266 accuracy = 0.90625 val_acc = 0.913
Epoch [18/40] step [49/64] loss = 0.06824499368667603 accuracy = 0.96875 val_acc = 0.8913
Epoch [18/40] step [50/64] loss = 0.11697547137737274 accuracy = 0.96875 val_acc = 0.9
Epoch [18/40] step [51/64] loss = 0.04428780823945999 accuracy = 1.0 val_acc = 0.8957
Epoch [18/40] step [52/64] loss = 0.12489831447601318 accuracy = 0.96875 val_acc = 0.8957
Epoch [18/40] step [53/64] 

Epoch [20/40] step [8/64] loss = 0.01062341034412384 accuracy = 1.0 val_acc = 0.8957
Epoch [20/40] step [9/64] loss = 0.014225110411643982 accuracy = 1.0 val_acc = 0.8913
Epoch [20/40] step [10/64] loss = 0.04162248969078064 accuracy = 0.96875 val_acc = 0.9043
Epoch [20/40] step [11/64] loss = 0.004600532352924347 accuracy = 1.0 val_acc = 0.9087
Epoch [20/40] step [12/64] loss = 0.009483382105827332 accuracy = 1.0 val_acc = 0.913
Epoch [20/40] step [13/64] loss = 0.017559845000505447 accuracy = 1.0 val_acc = 0.913
Epoch [20/40] step [14/64] loss = 0.3997851610183716 accuracy = 0.9375 val_acc = 0.8957
Epoch [20/40] step [15/64] loss = 0.017971130087971687 accuracy = 1.0 val_acc = 0.8913
Epoch [20/40] step [16/64] loss = 0.23718303442001343 accuracy = 0.9375 val_acc = 0.8913
Epoch [20/40] step [17/64] loss = 0.027943890541791916 accuracy = 1.0 val_acc = 0.9
Epoch [20/40] step [18/64] loss = 0.1509312093257904 accuracy = 0.9375 val_acc = 0.8696
Epoch [20/40] step [19/64] loss = 0.59300887

Epoch [21/40] step [38/64] loss = 0.14233864843845367 accuracy = 0.96875 val_acc = 0.8609
Epoch [21/40] step [39/64] loss = 0.13331347703933716 accuracy = 0.96875 val_acc = 0.8696
Epoch [21/40] step [40/64] loss = 0.06543200463056564 accuracy = 0.96875 val_acc = 0.8783
Epoch [21/40] step [41/64] loss = 0.04039393737912178 accuracy = 1.0 val_acc = 0.887
Epoch [21/40] step [42/64] loss = 0.3745347261428833 accuracy = 0.875 val_acc = 0.9043
Epoch [21/40] step [43/64] loss = 0.2612435817718506 accuracy = 0.875 val_acc = 0.9087
Epoch [21/40] step [44/64] loss = 0.07781859487295151 accuracy = 0.96875 val_acc = 0.913
Epoch [21/40] step [45/64] loss = 0.25456780195236206 accuracy = 0.875 val_acc = 0.9
Epoch [21/40] step [46/64] loss = 0.1908901184797287 accuracy = 0.9375 val_acc = 0.8957
Epoch [21/40] step [47/64] loss = 0.031239069998264313 accuracy = 1.0 val_acc = 0.9
Epoch [21/40] step [48/64] loss = 0.0860246792435646 accuracy = 0.96875 val_acc = 0.9217
Epoch [21/40] step [49/64] loss = 0.

Epoch [23/40] step [4/64] loss = 0.18596871197223663 accuracy = 0.90625 val_acc = 0.9
Epoch [23/40] step [5/64] loss = 0.059489693492650986 accuracy = 1.0 val_acc = 0.9043
Epoch [23/40] step [6/64] loss = 0.07662860304117203 accuracy = 1.0 val_acc = 0.9087
Epoch [23/40] step [7/64] loss = 0.1178198754787445 accuracy = 0.96875 val_acc = 0.9043
Epoch [23/40] step [8/64] loss = 0.034066133201122284 accuracy = 1.0 val_acc = 0.9087
Epoch [23/40] step [9/64] loss = 0.06404181569814682 accuracy = 1.0 val_acc = 0.8783
Epoch [23/40] step [10/64] loss = 0.10919502377510071 accuracy = 0.9375 val_acc = 0.8609
Epoch [23/40] step [11/64] loss = 0.08771362900733948 accuracy = 0.96875 val_acc = 0.8652
Epoch [23/40] step [12/64] loss = 0.028086379170417786 accuracy = 1.0 val_acc = 0.8696
Epoch [23/40] step [13/64] loss = 0.17589306831359863 accuracy = 0.9375 val_acc = 0.8783
Epoch [23/40] step [14/64] loss = 0.02668713964521885 accuracy = 1.0 val_acc = 0.9087
Epoch [23/40] step [15/64] loss = 0.0273091

Epoch [24/40] step [34/64] loss = 0.019442154094576836 accuracy = 1.0 val_acc = 0.9087
Epoch [24/40] step [35/64] loss = 0.1328851580619812 accuracy = 0.96875 val_acc = 0.9087
Epoch [24/40] step [36/64] loss = 0.08969520032405853 accuracy = 0.96875 val_acc = 0.913
Epoch [24/40] step [37/64] loss = 0.09412319958209991 accuracy = 0.96875 val_acc = 0.9043
Epoch [24/40] step [38/64] loss = 0.012698091566562653 accuracy = 1.0 val_acc = 0.9043
Epoch [24/40] step [39/64] loss = 0.0948052704334259 accuracy = 0.96875 val_acc = 0.9
Epoch [24/40] step [40/64] loss = 0.059222105890512466 accuracy = 0.96875 val_acc = 0.9043
Epoch [24/40] step [41/64] loss = 0.12039745599031448 accuracy = 0.9375 val_acc = 0.9
Epoch [24/40] step [42/64] loss = 0.1335776001214981 accuracy = 0.96875 val_acc = 0.9174
Epoch [24/40] step [43/64] loss = 0.10158379375934601 accuracy = 0.96875 val_acc = 0.9087
Epoch [24/40] step [44/64] loss = 0.04225229099392891 accuracy = 1.0 val_acc = 0.9043
Epoch [24/40] step [45/64] los

Epoch [25/40] step [64/64] loss = 0.07175139337778091 accuracy = 0.96875 val_acc = 0.8913
Epoch [26/40] step [1/64] loss = 0.017256416380405426 accuracy = 1.0 val_acc = 0.887
Epoch [26/40] step [2/64] loss = 0.010415777564048767 accuracy = 1.0 val_acc = 0.887
Epoch [26/40] step [3/64] loss = 0.13039538264274597 accuracy = 0.96875 val_acc = 0.8913
Epoch [26/40] step [4/64] loss = 0.020654739812016487 accuracy = 1.0 val_acc = 0.9
Epoch [26/40] step [5/64] loss = 0.04039305821061134 accuracy = 1.0 val_acc = 0.9
Epoch [26/40] step [6/64] loss = 0.009744402021169662 accuracy = 1.0 val_acc = 0.9
Epoch [26/40] step [7/64] loss = 0.019193993881344795 accuracy = 1.0 val_acc = 0.9043
Epoch [26/40] step [8/64] loss = 0.09920210391283035 accuracy = 0.96875 val_acc = 0.8957
Epoch [26/40] step [9/64] loss = 0.028803769499063492 accuracy = 1.0 val_acc = 0.8957
Epoch [26/40] step [10/64] loss = 0.09110578894615173 accuracy = 0.9375 val_acc = 0.8957
Epoch [26/40] step [11/64] loss = 0.01447171717882156

Epoch [27/40] step [31/64] loss = 0.01057112030684948 accuracy = 1.0 val_acc = 0.887
Epoch [27/40] step [32/64] loss = 0.014980686828494072 accuracy = 1.0 val_acc = 0.887
Epoch [27/40] step [33/64] loss = 0.009890180081129074 accuracy = 1.0 val_acc = 0.8826
Epoch [27/40] step [34/64] loss = 0.004746720194816589 accuracy = 1.0 val_acc = 0.887
Epoch [27/40] step [35/64] loss = 0.05623125657439232 accuracy = 0.96875 val_acc = 0.887
Epoch [27/40] step [36/64] loss = 0.17739799618721008 accuracy = 0.96875 val_acc = 0.8826
Epoch [27/40] step [37/64] loss = 0.02073528617620468 accuracy = 1.0 val_acc = 0.887
Epoch [27/40] step [38/64] loss = 0.009983539581298828 accuracy = 1.0 val_acc = 0.8826
Epoch [27/40] step [39/64] loss = 0.012585103511810303 accuracy = 1.0 val_acc = 0.8783
Epoch [27/40] step [40/64] loss = 0.009287849068641663 accuracy = 1.0 val_acc = 0.8783
Epoch [27/40] step [41/64] loss = 0.030978946015238762 accuracy = 1.0 val_acc = 0.8783
Epoch [27/40] step [42/64] loss = 0.01975654

Epoch [28/40] step [61/64] loss = 0.23628456890583038 accuracy = 0.9375 val_acc = 0.9087
Epoch [28/40] step [62/64] loss = 0.02894553542137146 accuracy = 1.0 val_acc = 0.9
Epoch [28/40] step [63/64] loss = 0.133626326918602 accuracy = 0.96875 val_acc = 0.8826
Epoch [28/40] step [64/64] loss = 0.09811890870332718 accuracy = 0.96875 val_acc = 0.8783
Epoch [29/40] step [1/64] loss = 0.17459964752197266 accuracy = 0.90625 val_acc = 0.8826
Epoch [29/40] step [2/64] loss = 0.040965162217617035 accuracy = 1.0 val_acc = 0.9043
Epoch [29/40] step [3/64] loss = 0.023819275200366974 accuracy = 1.0 val_acc = 0.9174
Epoch [29/40] step [4/64] loss = 0.0652523934841156 accuracy = 0.96875 val_acc = 0.8913
Epoch [29/40] step [5/64] loss = 0.1095077395439148 accuracy = 0.9375 val_acc = 0.887
Epoch [29/40] step [6/64] loss = 0.1736486256122589 accuracy = 0.96875 val_acc = 0.887
Epoch [29/40] step [7/64] loss = 0.18480977416038513 accuracy = 0.9375 val_acc = 0.8826
Epoch [29/40] step [8/64] loss = 0.02895

Epoch [30/40] step [27/64] loss = 0.02083795703947544 accuracy = 1.0 val_acc = 0.8913
Epoch [30/40] step [28/64] loss = 0.011325124651193619 accuracy = 1.0 val_acc = 0.8957
Epoch [30/40] step [29/64] loss = 0.015858445316553116 accuracy = 1.0 val_acc = 0.9043
Epoch [30/40] step [30/64] loss = 0.025796452537178993 accuracy = 1.0 val_acc = 0.9
Epoch [30/40] step [31/64] loss = 0.024345552548766136 accuracy = 1.0 val_acc = 0.8957
Epoch [30/40] step [32/64] loss = 0.18742844462394714 accuracy = 0.96875 val_acc = 0.8913
Epoch [30/40] step [33/64] loss = 0.03963567689061165 accuracy = 0.96875 val_acc = 0.9087
Epoch [30/40] step [34/64] loss = 0.0520319864153862 accuracy = 0.96875 val_acc = 0.9043
Epoch [30/40] step [35/64] loss = 0.050742119550704956 accuracy = 0.96875 val_acc = 0.8957
Epoch [30/40] step [36/64] loss = 0.01873440109193325 accuracy = 1.0 val_acc = 0.887
Epoch [30/40] step [37/64] loss = 0.04270181804895401 accuracy = 0.96875 val_acc = 0.8783
Epoch [30/40] step [38/64] loss = 

Epoch [31/40] step [57/64] loss = 0.013382047414779663 accuracy = 1.0 val_acc = 0.8783
Epoch [31/40] step [58/64] loss = 0.06550107896327972 accuracy = 0.9375 val_acc = 0.887
Epoch [31/40] step [59/64] loss = 0.005160488188266754 accuracy = 1.0 val_acc = 0.8783
Epoch [31/40] step [60/64] loss = 0.00650058314204216 accuracy = 1.0 val_acc = 0.8783
Epoch [31/40] step [61/64] loss = 0.009198013693094254 accuracy = 1.0 val_acc = 0.8826
Epoch [31/40] step [62/64] loss = 0.1557750552892685 accuracy = 0.96875 val_acc = 0.8783
Epoch [31/40] step [63/64] loss = 0.024545324966311455 accuracy = 1.0 val_acc = 0.8739
Epoch [31/40] step [64/64] loss = 0.008945073932409286 accuracy = 1.0 val_acc = 0.8739
Epoch [32/40] step [1/64] loss = 0.007158942520618439 accuracy = 1.0 val_acc = 0.8739
Epoch [32/40] step [2/64] loss = 0.18676692247390747 accuracy = 0.96875 val_acc = 0.8739
Epoch [32/40] step [3/64] loss = 0.0040550678968429565 accuracy = 1.0 val_acc = 0.8739
Epoch [32/40] step [4/64] loss = 0.00438

Epoch [33/40] step [23/64] loss = 0.014635240659117699 accuracy = 1.0 val_acc = 0.887
Epoch [33/40] step [24/64] loss = 0.011682864278554916 accuracy = 1.0 val_acc = 0.887
Epoch [33/40] step [25/64] loss = 0.03793981298804283 accuracy = 0.96875 val_acc = 0.887
Epoch [33/40] step [26/64] loss = 0.027724022045731544 accuracy = 1.0 val_acc = 0.8739
Epoch [33/40] step [27/64] loss = 0.003862708806991577 accuracy = 1.0 val_acc = 0.887
Epoch [33/40] step [28/64] loss = 0.00676611065864563 accuracy = 1.0 val_acc = 0.8913
Epoch [33/40] step [29/64] loss = 0.0157974474132061 accuracy = 1.0 val_acc = 0.8739
Epoch [33/40] step [30/64] loss = 0.0346919521689415 accuracy = 1.0 val_acc = 0.8739
Epoch [33/40] step [31/64] loss = 0.015912923961877823 accuracy = 1.0 val_acc = 0.8783
Epoch [33/40] step [32/64] loss = 0.0026829056441783905 accuracy = 1.0 val_acc = 0.8739
Epoch [33/40] step [33/64] loss = 0.024551106616854668 accuracy = 1.0 val_acc = 0.8826
Epoch [33/40] step [34/64] loss = 0.182188883423

Epoch [34/40] step [53/64] loss = 0.3070136606693268 accuracy = 0.875 val_acc = 0.8826
Epoch [34/40] step [54/64] loss = 0.018434926867485046 accuracy = 1.0 val_acc = 0.8913
Epoch [34/40] step [55/64] loss = 0.014402784407138824 accuracy = 1.0 val_acc = 0.8913
Epoch [34/40] step [56/64] loss = 0.03642605245113373 accuracy = 0.96875 val_acc = 0.8783
Epoch [34/40] step [57/64] loss = 0.039119902998209 accuracy = 0.96875 val_acc = 0.8826
Epoch [34/40] step [58/64] loss = 0.01322004571557045 accuracy = 1.0 val_acc = 0.8913
Epoch [34/40] step [59/64] loss = 0.02902037277817726 accuracy = 1.0 val_acc = 0.8957
Epoch [34/40] step [60/64] loss = 0.02945265732705593 accuracy = 1.0 val_acc = 0.8957
Epoch [34/40] step [61/64] loss = 0.014081835746765137 accuracy = 1.0 val_acc = 0.8913
Epoch [34/40] step [62/64] loss = 0.011831443756818771 accuracy = 1.0 val_acc = 0.8957
Epoch [34/40] step [63/64] loss = 0.04765511304140091 accuracy = 0.96875 val_acc = 0.9
Epoch [34/40] step [64/64] loss = 0.044831

Epoch [36/40] step [19/64] loss = 0.009354652836918831 accuracy = 1.0 val_acc = 0.8826
Epoch [36/40] step [20/64] loss = 0.0402546301484108 accuracy = 0.96875 val_acc = 0.8739
Epoch [36/40] step [21/64] loss = 0.09978456795215607 accuracy = 0.96875 val_acc = 0.8783
Epoch [36/40] step [22/64] loss = 0.004175417125225067 accuracy = 1.0 val_acc = 0.8913
Epoch [36/40] step [23/64] loss = 0.004564113914966583 accuracy = 1.0 val_acc = 0.8913
Epoch [36/40] step [24/64] loss = 0.009792938828468323 accuracy = 1.0 val_acc = 0.8957
Epoch [36/40] step [25/64] loss = 0.012795854359865189 accuracy = 1.0 val_acc = 0.887
Epoch [36/40] step [26/64] loss = 0.014572573825716972 accuracy = 1.0 val_acc = 0.8826
Epoch [36/40] step [27/64] loss = 0.02730073779821396 accuracy = 1.0 val_acc = 0.8826
Epoch [36/40] step [28/64] loss = 0.27960214018821716 accuracy = 0.9375 val_acc = 0.8826
Epoch [36/40] step [29/64] loss = 0.09400539845228195 accuracy = 0.96875 val_acc = 0.8826
Epoch [36/40] step [30/64] loss = 0

Epoch [37/40] step [50/64] loss = 0.25365254282951355 accuracy = 0.9375 val_acc = 0.887
Epoch [37/40] step [51/64] loss = 0.007806640118360519 accuracy = 1.0 val_acc = 0.8913
Epoch [37/40] step [52/64] loss = 0.16939249634742737 accuracy = 0.9375 val_acc = 0.8957
Epoch [37/40] step [53/64] loss = 0.04092438519001007 accuracy = 0.96875 val_acc = 0.8957
Epoch [37/40] step [54/64] loss = 0.02177496813237667 accuracy = 1.0 val_acc = 0.8957
Epoch [37/40] step [55/64] loss = 0.011348966509103775 accuracy = 1.0 val_acc = 0.8913
Epoch [37/40] step [56/64] loss = 0.007246591150760651 accuracy = 1.0 val_acc = 0.887
Epoch [37/40] step [57/64] loss = 0.013019215315580368 accuracy = 1.0 val_acc = 0.887
Epoch [37/40] step [58/64] loss = 0.047011855989694595 accuracy = 1.0 val_acc = 0.887
Epoch [37/40] step [59/64] loss = 0.006664972752332687 accuracy = 1.0 val_acc = 0.8783
Epoch [37/40] step [60/64] loss = 0.09820472449064255 accuracy = 0.96875 val_acc = 0.8826
Epoch [37/40] step [61/64] loss = 0.04

Epoch [39/40] step [16/64] loss = 0.004915241152048111 accuracy = 1.0 val_acc = 0.8826
Epoch [39/40] step [17/64] loss = 0.006025038659572601 accuracy = 1.0 val_acc = 0.8783
Epoch [39/40] step [18/64] loss = 0.02449754625558853 accuracy = 1.0 val_acc = 0.8696
Epoch [39/40] step [19/64] loss = 0.009921632707118988 accuracy = 1.0 val_acc = 0.8696
Epoch [39/40] step [20/64] loss = 0.11701343953609467 accuracy = 0.96875 val_acc = 0.8739
Epoch [39/40] step [21/64] loss = 0.016533780843019485 accuracy = 1.0 val_acc = 0.8783
Epoch [39/40] step [22/64] loss = 0.009200721979141235 accuracy = 1.0 val_acc = 0.8826
Epoch [39/40] step [23/64] loss = 0.013823190703988075 accuracy = 1.0 val_acc = 0.8913
Epoch [39/40] step [24/64] loss = 0.1771210879087448 accuracy = 0.96875 val_acc = 0.8957
Epoch [39/40] step [25/64] loss = 0.07442275434732437 accuracy = 0.96875 val_acc = 0.9
Epoch [39/40] step [26/64] loss = 0.08327968418598175 accuracy = 0.96875 val_acc = 0.8826
Epoch [39/40] step [27/64] loss = 0.

Epoch [40/40] step [46/64] loss = 0.46507540345191956 accuracy = 0.84375 val_acc = 0.8391
Epoch [40/40] step [47/64] loss = 0.01914723590016365 accuracy = 1.0 val_acc = 0.9043
Epoch [40/40] step [48/64] loss = 0.03192981332540512 accuracy = 1.0 val_acc = 0.8783
Epoch [40/40] step [49/64] loss = 0.04738268628716469 accuracy = 1.0 val_acc = 0.8739
Epoch [40/40] step [50/64] loss = 0.10124287009239197 accuracy = 0.9375 val_acc = 0.8609
Epoch [40/40] step [51/64] loss = 0.10047164559364319 accuracy = 1.0 val_acc = 0.8565
Epoch [40/40] step [52/64] loss = 0.09173228591680527 accuracy = 0.96875 val_acc = 0.8739
Epoch [40/40] step [53/64] loss = 0.042511604726314545 accuracy = 1.0 val_acc = 0.887
Epoch [40/40] step [54/64] loss = 0.038880567997694016 accuracy = 1.0 val_acc = 0.887
Epoch [40/40] step [55/64] loss = 0.06810162961483002 accuracy = 0.96875 val_acc = 0.8957
Epoch [40/40] step [56/64] loss = 0.036480486392974854 accuracy = 1.0 val_acc = 0.8783
Epoch [40/40] step [57/64] loss = 0.02

In [None]:
final_dataset = final_dataset.cuda()
final_label = final_label.cuda()
import sys

In [35]:
s = 198
sum(torch.max(model(dataset_test[s:s + batch_size])[0], dim=1)[1] == label_test[s:s + batch_size]).item() / len(dataset_test[s:s + batch_size])

0.9375

In [23]:
torch.save(model.state_dict(), "PoemClassify.pth")

In [30]:
len(dataset_test)

230