GridSearch (choose parameter to try, lr, dropout)
early stopping

Deactivating callbacks can be especially useful when you do a parameter search (say with sklearn GridSearchCV). If, for instance, you use a callback for learning rate scheduling (e.g. via LRScheduler) and want to test its usefulness, you can compare the performance once with and once without the callback.


implement early stopping

implement gridSearch

implement space_cnn

implement long core for cnn


# Choose Signal and Import Data

In [5]:
import os
import numpy as np
import scipy.io

In [6]:
mods = ['BPSK', 'DQPSK', 'GFSK', 'GMSK', 'OQPSK',
        'PAM4', 'PAM8', 'PSK8', 'QAM16', 'QAM64', 'QPSK']
class_num = len(mods)

In [4]:
data = scipy.io.loadmat(
    "D:/Archive/0006/"
    "batch100000_symbols128_sps8_baud1_snr5.dat",
)

In [7]:
def import_from_mat(data, size):
    features = []
    labels = []
    for mod in mods:
        real = np.array(data[mod].real[:size])
        imag = np.array(data[mod].imag[:size])
        signal = np.concatenate([real, imag], axis=1)
        features.append(signal)
        labels.append(mods.index(mod) * np.ones([size, 1]))

    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)
    
    return features, labels

In [8]:
features, labels = import_from_mat(data,1000)

In [9]:
features = features.astype(np.float32)
labels = labels.astype(np.int64)

In [10]:
X = features
y = labels.reshape(-1)

# Define Model

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [12]:
class Discriminator(nn.Module):
    """Define the model"""

    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(2, 256, 3, padding=1),  # batch, 256, 1024
            nn.BatchNorm1d(256),
            nn.ReLU(),
            # nn.Dropout2d()
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(256, 80, 3, padding=1),  # batch, 80, 1024
            nn.BatchNorm1d(80),
            nn.ReLU(),
            # nn.Dropout2d()
        )
        self.fc1 = nn.Sequential(
            nn.Linear(80 * 1024, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(p=0.6)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(256, class_num),
            nn.ReLU()
        )
        
    def forward(self, x, **kwargs):
        x = x.reshape((x.size(0), 2, -1))
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = F.softmax(x, dim=1)
        return x

# Define Classifier and Callback. 
The Callbacks are used to calculate score and print train process

In [13]:
from skorch import NeuralNetClassifier
from skorch.callbacks import EpochScoring, Callback, Checkpoint
from sklearn.metrics import confusion_matrix
from skorch.utils import data_from_dataset

In [14]:
class Score_ConfusionMatrix(EpochScoring):
    def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
        
        EpochScoring.on_epoch_end(self, net, dataset_train, dataset_valid)
        
        X_test, y_test = data_from_dataset(dataset_valid)
        y_pred = net.predict(X_test)
        cm = confusion_matrix(y_test, y_pred)
        history = net.history
        history.record("confusion_matrix", cm)

In [32]:
class Print_Score_CM(Callback):
    
    def __init__(self, size):
        self.batch_num = 0
        self.sample_num = 0
        self.size = size
        
    def on_batch_end(self, net, **kwargs):
        self.batch_num += 1
        self.sample_num += kwargs['X'].shape[0]
        if self.batch_num % 1 == 0:
            percent = self.sample_num/self.size
            history = net.history
            training = kwargs["training"]
            if training:
                print('processed: {0:.4f}, ' 
                      'train_batch: {1}, '
                      'train_loss:{2:.4f}'.format(
                          percent,
                          history[-1, "batches", -1, "train_batch_size"],
                          history[-1, "batches", -1, "train_loss"]
                      ))
            else:
                print('processed: {0:.4f}, ' 
                      'valid_batch: {1}, '
                      'valid_loss:{2:.4f}'.format(
                          percent,
                          history[-1, "batches", -1, "valid_batch_size"],
                          history[-1, "batches", -1, "valid_loss"]
                      ))
    def on_epoch_end(self, net, **kwargs):
        self.batch_num = 0
        self.sample_num = 0
        history = net.history
        result = history[-1].copy()
        result.pop("confusion_matrix")
        result.pop("batches")
        print("epoch: {0}, "
              "dur: {1:.2f}s, "
              "val_acc: {2:.4f}{3}, "
              "val_loss: {4:.4f}{5}, "
              "saved: {6}".format(
                  history[-1, "epoch"],
                  history[-1, "dur"],
                  history[-1, "accuracy"],
                  "(best)" if history[-1, "accuracy_best"] else "",
                  history[-1, "valid_loss"],
                  "(best)" if history[-1, "valid_loss_best"] else "",
                  history[-1, "event_cp"]
              ))
        print("Confusion matrix:\n {0}".format(
            history[-1, "confusion_matrix"]
        ))

In [33]:
cp = Checkpoint(dirname='best')
net = NeuralNetClassifier(
    Discriminator,
    max_epochs=20,
    lr=0.01,
    device='cuda',
    callbacks=[('best',cp)],
    iterator_train__shuffle=True,
    iterator_valid__shuffle=False
)

score = Score_ConfusionMatrix(scoring="accuracy", lower_is_better=False)
pt = Print_Score_CM(X.shape[0])

net.set_params(callbacks__valid_acc=score)
net.set_params(callbacks__print_log=pt)


<class 'skorch.classifier.NeuralNetClassifier'>[uninitialized](
  module=<class '__main__.Discriminator'>,
)

# Train Classifier

In [34]:
net.fit(X,y)

processed: 0.0116, train_batch: 128, train_loss:2.4105
processed: 0.0233, train_batch: 128, train_loss:2.2767
processed: 0.0349, train_batch: 128, train_loss:2.1414
processed: 0.0465, train_batch: 128, train_loss:2.0217
processed: 0.0582, train_batch: 128, train_loss:1.9109
processed: 0.0698, train_batch: 128, train_loss:1.9174
processed: 0.0815, train_batch: 128, train_loss:1.9534
processed: 0.0931, train_batch: 128, train_loss:1.8866
processed: 0.1047, train_batch: 128, train_loss:1.8917
processed: 0.1164, train_batch: 128, train_loss:1.9158
processed: 0.1280, train_batch: 128, train_loss:1.8214
processed: 0.1396, train_batch: 128, train_loss:1.7970
processed: 0.1513, train_batch: 128, train_loss:1.7986
processed: 0.1629, train_batch: 128, train_loss:1.6825
processed: 0.1745, train_batch: 128, train_loss:1.6705
processed: 0.1862, train_batch: 128, train_loss:1.6144
processed: 0.1978, train_batch: 128, train_loss:1.8149
processed: 0.2095, train_batch: 128, train_loss:1.5611
processed:

processed: 0.5935, train_batch: 128, train_loss:0.9130
processed: 0.6051, train_batch: 128, train_loss:0.8682
processed: 0.6167, train_batch: 128, train_loss:1.0253
processed: 0.6284, train_batch: 128, train_loss:0.8996
processed: 0.6400, train_batch: 128, train_loss:0.9039
processed: 0.6516, train_batch: 128, train_loss:0.9317
processed: 0.6633, train_batch: 128, train_loss:1.0310
processed: 0.6749, train_batch: 128, train_loss:1.0097
processed: 0.6865, train_batch: 128, train_loss:0.9063
processed: 0.6982, train_batch: 128, train_loss:0.9246
processed: 0.7098, train_batch: 128, train_loss:0.9216
processed: 0.7215, train_batch: 128, train_loss:0.8086
processed: 0.7331, train_batch: 128, train_loss:0.8693
processed: 0.7447, train_batch: 128, train_loss:1.0695
processed: 0.7564, train_batch: 128, train_loss:0.9365
processed: 0.7680, train_batch: 128, train_loss:1.0242
processed: 0.7796, train_batch: 128, train_loss:0.8350
processed: 0.7913, train_batch: 128, train_loss:1.0278
processed:

processed: 0.0349, train_batch: 128, train_loss:0.6356
processed: 0.0465, train_batch: 128, train_loss:0.6657
processed: 0.0582, train_batch: 128, train_loss:0.6457
processed: 0.0698, train_batch: 128, train_loss:0.6219
processed: 0.0815, train_batch: 128, train_loss:0.6704
processed: 0.0931, train_batch: 128, train_loss:0.6253
processed: 0.1047, train_batch: 128, train_loss:0.6675
processed: 0.1164, train_batch: 128, train_loss:0.6486
processed: 0.1280, train_batch: 128, train_loss:0.8034
processed: 0.1396, train_batch: 128, train_loss:0.6971
processed: 0.1513, train_batch: 128, train_loss:0.6493
processed: 0.1629, train_batch: 128, train_loss:0.5789
processed: 0.1745, train_batch: 128, train_loss:0.5804
processed: 0.1862, train_batch: 128, train_loss:0.7363
processed: 0.1978, train_batch: 128, train_loss:0.6381
processed: 0.2095, train_batch: 128, train_loss:0.6477
processed: 0.2211, train_batch: 128, train_loss:0.5806
processed: 0.2327, train_batch: 128, train_loss:0.6759
processed:

processed: 0.6167, train_batch: 128, train_loss:0.5390
processed: 0.6284, train_batch: 128, train_loss:0.5357
processed: 0.6400, train_batch: 128, train_loss:0.5173
processed: 0.6516, train_batch: 128, train_loss:0.5049
processed: 0.6633, train_batch: 128, train_loss:0.6019
processed: 0.6749, train_batch: 128, train_loss:0.5122
processed: 0.6865, train_batch: 128, train_loss:0.5314
processed: 0.6982, train_batch: 128, train_loss:0.4668
processed: 0.7098, train_batch: 128, train_loss:0.5055
processed: 0.7215, train_batch: 128, train_loss:0.4443
processed: 0.7331, train_batch: 128, train_loss:0.4294
processed: 0.7447, train_batch: 128, train_loss:0.5334
processed: 0.7564, train_batch: 128, train_loss:0.4252
processed: 0.7680, train_batch: 128, train_loss:0.4284
processed: 0.7796, train_batch: 128, train_loss:0.4646
processed: 0.7913, train_batch: 128, train_loss:0.5556
processed: 0.8000, train_batch: 96, train_loss:0.4524
processed: 0.8116, valid_batch: 128, valid_loss:0.3709
processed: 

processed: 0.0582, train_batch: 128, train_loss:0.4453
processed: 0.0698, train_batch: 128, train_loss:0.3851
processed: 0.0815, train_batch: 128, train_loss:0.3826
processed: 0.0931, train_batch: 128, train_loss:0.4758
processed: 0.1047, train_batch: 128, train_loss:0.4590
processed: 0.1164, train_batch: 128, train_loss:0.3361
processed: 0.1280, train_batch: 128, train_loss:0.4143
processed: 0.1396, train_batch: 128, train_loss:0.3700
processed: 0.1513, train_batch: 128, train_loss:0.4177
processed: 0.1629, train_batch: 128, train_loss:0.3916
processed: 0.1745, train_batch: 128, train_loss:0.4021
processed: 0.1862, train_batch: 128, train_loss:0.3602
processed: 0.1978, train_batch: 128, train_loss:0.3917
processed: 0.2095, train_batch: 128, train_loss:0.4535
processed: 0.2211, train_batch: 128, train_loss:0.3893
processed: 0.2327, train_batch: 128, train_loss:0.4310
processed: 0.2444, train_batch: 128, train_loss:0.3561
processed: 0.2560, train_batch: 128, train_loss:0.3577
processed:

processed: 0.6400, train_batch: 128, train_loss:0.3283
processed: 0.6516, train_batch: 128, train_loss:0.3274
processed: 0.6633, train_batch: 128, train_loss:0.3013
processed: 0.6749, train_batch: 128, train_loss:0.3342
processed: 0.6865, train_batch: 128, train_loss:0.3244
processed: 0.6982, train_batch: 128, train_loss:0.3310
processed: 0.7098, train_batch: 128, train_loss:0.2903
processed: 0.7215, train_batch: 128, train_loss:0.3333
processed: 0.7331, train_batch: 128, train_loss:0.2878
processed: 0.7447, train_batch: 128, train_loss:0.3992
processed: 0.7564, train_batch: 128, train_loss:0.3479
processed: 0.7680, train_batch: 128, train_loss:0.2796
processed: 0.7796, train_batch: 128, train_loss:0.3394
processed: 0.7913, train_batch: 128, train_loss:0.3186
processed: 0.8000, train_batch: 96, train_loss:0.3822
processed: 0.8116, valid_batch: 128, valid_loss:0.1366
processed: 0.8233, valid_batch: 128, valid_loss:0.7778
processed: 0.8349, valid_batch: 128, valid_loss:1.6181
processed: 

processed: 0.0815, train_batch: 128, train_loss:0.3055
processed: 0.0931, train_batch: 128, train_loss:0.2380
processed: 0.1047, train_batch: 128, train_loss:0.2404
processed: 0.1164, train_batch: 128, train_loss:0.2592
processed: 0.1280, train_batch: 128, train_loss:0.3135
processed: 0.1396, train_batch: 128, train_loss:0.2379
processed: 0.1513, train_batch: 128, train_loss:0.2033
processed: 0.1629, train_batch: 128, train_loss:0.2137
processed: 0.1745, train_batch: 128, train_loss:0.2124
processed: 0.1862, train_batch: 128, train_loss:0.2056
processed: 0.1978, train_batch: 128, train_loss:0.2529
processed: 0.2095, train_batch: 128, train_loss:0.2276
processed: 0.2211, train_batch: 128, train_loss:0.3242
processed: 0.2327, train_batch: 128, train_loss:0.2758
processed: 0.2444, train_batch: 128, train_loss:0.2331
processed: 0.2560, train_batch: 128, train_loss:0.3272
processed: 0.2676, train_batch: 128, train_loss:0.2470
processed: 0.2793, train_batch: 128, train_loss:0.1988
processed:

processed: 0.6633, train_batch: 128, train_loss:0.2710
processed: 0.6749, train_batch: 128, train_loss:0.1818
processed: 0.6865, train_batch: 128, train_loss:0.2378
processed: 0.6982, train_batch: 128, train_loss:0.2412
processed: 0.7098, train_batch: 128, train_loss:0.2062
processed: 0.7215, train_batch: 128, train_loss:0.2049
processed: 0.7331, train_batch: 128, train_loss:0.2187
processed: 0.7447, train_batch: 128, train_loss:0.2222
processed: 0.7564, train_batch: 128, train_loss:0.1972
processed: 0.7680, train_batch: 128, train_loss:0.2459
processed: 0.7796, train_batch: 128, train_loss:0.1983
processed: 0.7913, train_batch: 128, train_loss:0.1572
processed: 0.8000, train_batch: 96, train_loss:0.2228
processed: 0.8116, valid_batch: 128, valid_loss:0.7541
processed: 0.8233, valid_batch: 128, valid_loss:1.1303
processed: 0.8349, valid_batch: 128, valid_loss:1.6345
processed: 0.8465, valid_batch: 128, valid_loss:0.2020
processed: 0.8582, valid_batch: 128, valid_loss:0.0250
processed: 

processed: 0.1047, train_batch: 128, train_loss:0.1342
processed: 0.1164, train_batch: 128, train_loss:0.1482
processed: 0.1280, train_batch: 128, train_loss:0.1187
processed: 0.1396, train_batch: 128, train_loss:0.1206
processed: 0.1513, train_batch: 128, train_loss:0.1270
processed: 0.1629, train_batch: 128, train_loss:0.1501
processed: 0.1745, train_batch: 128, train_loss:0.1322
processed: 0.1862, train_batch: 128, train_loss:0.1732
processed: 0.1978, train_batch: 128, train_loss:0.1213
processed: 0.2095, train_batch: 128, train_loss:0.1273
processed: 0.2211, train_batch: 128, train_loss:0.1431
processed: 0.2327, train_batch: 128, train_loss:0.1450
processed: 0.2444, train_batch: 128, train_loss:0.1485
processed: 0.2560, train_batch: 128, train_loss:0.1407
processed: 0.2676, train_batch: 128, train_loss:0.1360
processed: 0.2793, train_batch: 128, train_loss:0.1288
processed: 0.2909, train_batch: 128, train_loss:0.1507
processed: 0.3025, train_batch: 128, train_loss:0.1689
processed:

processed: 0.6865, train_batch: 128, train_loss:0.0949
processed: 0.6982, train_batch: 128, train_loss:0.0917
processed: 0.7098, train_batch: 128, train_loss:0.1240
processed: 0.7215, train_batch: 128, train_loss:0.1291
processed: 0.7331, train_batch: 128, train_loss:0.1073
processed: 0.7447, train_batch: 128, train_loss:0.1190
processed: 0.7564, train_batch: 128, train_loss:0.1297
processed: 0.7680, train_batch: 128, train_loss:0.1264
processed: 0.7796, train_batch: 128, train_loss:0.1545
processed: 0.7913, train_batch: 128, train_loss:0.1061
processed: 0.8000, train_batch: 96, train_loss:0.1004
processed: 0.8116, valid_batch: 128, valid_loss:0.0578
processed: 0.8233, valid_batch: 128, valid_loss:0.6692
processed: 0.8349, valid_batch: 128, valid_loss:1.4019
processed: 0.8465, valid_batch: 128, valid_loss:0.1679
processed: 0.8582, valid_batch: 128, valid_loss:0.0183
processed: 0.8698, valid_batch: 128, valid_loss:0.0201
processed: 0.8815, valid_batch: 128, valid_loss:0.1550
processed: 

processed: 0.1280, train_batch: 128, train_loss:0.0834
processed: 0.1396, train_batch: 128, train_loss:0.1155
processed: 0.1513, train_batch: 128, train_loss:0.0970
processed: 0.1629, train_batch: 128, train_loss:0.0837
processed: 0.1745, train_batch: 128, train_loss:0.1136
processed: 0.1862, train_batch: 128, train_loss:0.0918
processed: 0.1978, train_batch: 128, train_loss:0.1090
processed: 0.2095, train_batch: 128, train_loss:0.0978
processed: 0.2211, train_batch: 128, train_loss:0.0935
processed: 0.2327, train_batch: 128, train_loss:0.0910
processed: 0.2444, train_batch: 128, train_loss:0.0929
processed: 0.2560, train_batch: 128, train_loss:0.1010
processed: 0.2676, train_batch: 128, train_loss:0.0929
processed: 0.2793, train_batch: 128, train_loss:0.1003
processed: 0.2909, train_batch: 128, train_loss:0.1079
processed: 0.3025, train_batch: 128, train_loss:0.1047
processed: 0.3142, train_batch: 128, train_loss:0.0736
processed: 0.3258, train_batch: 128, train_loss:0.0944
processed:

processed: 0.7098, train_batch: 128, train_loss:0.0847
processed: 0.7215, train_batch: 128, train_loss:0.0623
processed: 0.7331, train_batch: 128, train_loss:0.0725
processed: 0.7447, train_batch: 128, train_loss:0.0853
processed: 0.7564, train_batch: 128, train_loss:0.0967
processed: 0.7680, train_batch: 128, train_loss:0.0930
processed: 0.7796, train_batch: 128, train_loss:0.0655
processed: 0.7913, train_batch: 128, train_loss:0.0844
processed: 0.8000, train_batch: 96, train_loss:0.0678
processed: 0.8116, valid_batch: 128, valid_loss:0.0478
processed: 0.8233, valid_batch: 128, valid_loss:0.7678
processed: 0.8349, valid_batch: 128, valid_loss:1.6092
processed: 0.8465, valid_batch: 128, valid_loss:0.1858
processed: 0.8582, valid_batch: 128, valid_loss:0.0120
processed: 0.8698, valid_batch: 128, valid_loss:0.0140
processed: 0.8815, valid_batch: 128, valid_loss:0.1307
processed: 0.8931, valid_batch: 128, valid_loss:0.1178
processed: 0.9047, valid_batch: 128, valid_loss:0.0144
processed: 

processed: 0.1513, train_batch: 128, train_loss:0.0670
processed: 0.1629, train_batch: 128, train_loss:0.0639
processed: 0.1745, train_batch: 128, train_loss:0.0950
processed: 0.1862, train_batch: 128, train_loss:0.0819
processed: 0.1978, train_batch: 128, train_loss:0.0737
processed: 0.2095, train_batch: 128, train_loss:0.0673
processed: 0.2211, train_batch: 128, train_loss:0.0684
processed: 0.2327, train_batch: 128, train_loss:0.0649
processed: 0.2444, train_batch: 128, train_loss:0.0809
processed: 0.2560, train_batch: 128, train_loss:0.1041
processed: 0.2676, train_batch: 128, train_loss:0.0712
processed: 0.2793, train_batch: 128, train_loss:0.0552
processed: 0.2909, train_batch: 128, train_loss:0.0915
processed: 0.3025, train_batch: 128, train_loss:0.0596
processed: 0.3142, train_batch: 128, train_loss:0.0569
processed: 0.3258, train_batch: 128, train_loss:0.0582
processed: 0.3375, train_batch: 128, train_loss:0.0795
processed: 0.3491, train_batch: 128, train_loss:0.0833
processed:

processed: 0.7331, train_batch: 128, train_loss:0.0703
processed: 0.7447, train_batch: 128, train_loss:0.0648
processed: 0.7564, train_batch: 128, train_loss:0.0551
processed: 0.7680, train_batch: 128, train_loss:0.0966
processed: 0.7796, train_batch: 128, train_loss:0.0736
processed: 0.7913, train_batch: 128, train_loss:0.0575
processed: 0.8000, train_batch: 96, train_loss:0.0487
processed: 0.8116, valid_batch: 128, valid_loss:0.0360
processed: 0.8233, valid_batch: 128, valid_loss:0.5970
processed: 0.8349, valid_batch: 128, valid_loss:1.2452
processed: 0.8465, valid_batch: 128, valid_loss:0.1389
processed: 0.8582, valid_batch: 128, valid_loss:0.0090
processed: 0.8698, valid_batch: 128, valid_loss:0.0122
processed: 0.8815, valid_batch: 128, valid_loss:0.1551
processed: 0.8931, valid_batch: 128, valid_loss:0.1526
processed: 0.9047, valid_batch: 128, valid_loss:0.0112
processed: 0.9164, valid_batch: 128, valid_loss:0.0092
processed: 0.9280, valid_batch: 128, valid_loss:0.1730
processed: 

<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=Discriminator(
    (conv1): Sequential(
      (0): Conv1d(2, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (conv2): Sequential(
      (0): Conv1d(256, 80, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (fc1): Sequential(
      (0): Linear(in_features=81920, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Dropout(p=0.6)
    )
    (fc2): Sequential(
      (0): Linear(in_features=256, out_features=11, bias=True)
      (1): ReLU()
    )
  ),
)

In [21]:
net.history

[{'accuracy': 0.635,
  'accuracy_best': True,
  'batches': [{'train_batch_size': 128, 'train_loss': 2.402630567550659},
   {'train_batch_size': 128, 'train_loss': 2.276116132736206},
   {'train_batch_size': 128, 'train_loss': 2.0675370693206787},
   {'train_batch_size': 128, 'train_loss': 1.944294810295105},
   {'train_batch_size': 128, 'train_loss': 1.9311270713806152},
   {'train_batch_size': 128, 'train_loss': 1.9333882331848145},
   {'train_batch_size': 128, 'train_loss': 1.8721981048583984},
   {'train_batch_size': 128, 'train_loss': 1.8092904090881348},
   {'train_batch_size': 128, 'train_loss': 1.820557951927185},
   {'train_batch_size': 128, 'train_loss': 1.8004343509674072},
   {'train_batch_size': 128, 'train_loss': 1.8286545276641846},
   {'train_batch_size': 128, 'train_loss': 1.7345839738845825},
   {'train_batch_size': 128, 'train_loss': 1.7587316036224365},
   {'train_batch_size': 128, 'train_loss': 1.7093651294708252},
   {'train_batch_size': 128, 'train_loss': 1.699425