GridSearch (choose parameter to try, lr, dropout, max_epoch, patenice warm_start_param)
early stopping

Deactivating callbacks can be especially useful when you do a parameter search (say with sklearn GridSearchCV). If, for instance, you use a callback for learning rate scheduling (e.g. via LRScheduler) and want to test its usefulness, you can compare the performance once with and once without the callback.


implement early stopping

implement gridSearch

implement space_cnn

implement long core for cnn


# Choose Signal and Import Data

In [1]:
import os
import numpy as np
import scipy.io

In [2]:
mods = ['BPSK', 'DQPSK', 'GFSK', 'GMSK', 'OQPSK',
        'PAM4', 'PAM8', 'PSK8', 'QAM16', 'QAM64', 'QPSK']
class_num = len(mods)

In [3]:
data = scipy.io.loadmat(
    "D:/Archive/0006/"
    "batch100000_symbols128_sps8_baud1_snr5.dat",
)

In [4]:
def import_from_mat(data, size):
    features = []
    labels = []
    for mod in mods:
        real = np.array(data[mod].real[:size])
        imag = np.array(data[mod].imag[:size])
        signal = np.concatenate([real, imag], axis=1)
        features.append(signal)
        labels.append(mods.index(mod) * np.ones([size, 1]))

    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)
    
    return features, labels

In [5]:
features, labels = import_from_mat(data,1000)

In [6]:
features = features.astype(np.float32)
labels = labels.astype(np.int64)

In [7]:
X = features
y = labels.reshape(-1)

# Define Model

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [9]:
class Discriminator(nn.Module):
    """Define the model"""

    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(2, 256, 3, padding=1),  # batch, 256, 1024
            nn.BatchNorm1d(256),
            nn.ReLU(),
            # nn.Dropout2d()
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(256, 80, 3, padding=1),  # batch, 80, 1024
            nn.BatchNorm1d(80),
            nn.ReLU(),
            # nn.Dropout2d()
        )
        self.fc1 = nn.Sequential(
            nn.Linear(80 * 1024, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(p=0.6)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(256, class_num),
            nn.ReLU()
        )
        
    def forward(self, x, **kwargs):
        x = x.reshape((x.size(0), 2, -1))
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = F.softmax(x, dim=1)
        return x

# Define Classifier and Callback. 
The Callbacks are used to calculate score and print train process

In [10]:
from skorch import NeuralNetClassifier
from skorch.callbacks import Callback, EpochScoring, Checkpoint, EarlyStopping
from sklearn.metrics import confusion_matrix
from skorch.utils import data_from_dataset

In [11]:
class Score_ConfusionMatrix(EpochScoring):
    def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
        
        EpochScoring.on_epoch_end(self, net, dataset_train, dataset_valid)
        
        X_test, y_test = data_from_dataset(dataset_valid)
        y_pred = net.predict(X_test)
        cm = confusion_matrix(y_test, y_pred)
        history = net.history
        history.record("confusion_matrix", cm)

In [12]:
class Print_Score_CM(Callback):
    
    def __init__(self, size):
        self.batch_num = 0
        self.sample_num = 0
        self.size = size
        
    def on_batch_end(self, net, **kwargs):
        self.batch_num += 1
        self.sample_num += kwargs['X'].shape[0]
        if self.batch_num % 1 == 0:
            percent = self.sample_num/self.size
            history = net.history
            training = kwargs["training"]
            if training:
                print('processed: {0:.4f}, ' 
                      'train_batch: {1}, '
                      'train_loss:{2:.4f}'.format(
                          percent,
                          history[-1, "batches", -1, "train_batch_size"],
                          history[-1, "batches", -1, "train_loss"]
                      ))
            else:
                print('processed: {0:.4f}, ' 
                      'valid_batch: {1}, '
                      'valid_loss:{2:.4f}'.format(
                          percent,
                          history[-1, "batches", -1, "valid_batch_size"],
                          history[-1, "batches", -1, "valid_loss"]
                      ))
    def on_epoch_end(self, net, **kwargs):
        self.batch_num = 0
        self.sample_num = 0
        history = net.history
        result = history[-1].copy()
        result.pop("confusion_matrix")
        result.pop("batches")
        print("epoch: {0}, "
              "dur: {1:.2f}s, "
              "val_acc: {2:.4f}{3}, "
              "val_loss: {4:.4f}{5}, "
              "saved: {6}".format(
                  history[-1, "epoch"],
                  history[-1, "dur"],
                  history[-1, "accuracy"],
                  "(best)" if history[-1, "accuracy_best"] else "",
                  history[-1, "valid_loss"],
                  "(best)" if history[-1, "valid_loss_best"] else "",
                  history[-1, "event_cp"]
              ))
        print("Confusion matrix:\n {0}".format(
            history[-1, "confusion_matrix"]
        ))

In [13]:
cp = Checkpoint(dirname='best')
early_stop = EarlyStopping(patience=10)
net = NeuralNetClassifier(
    Discriminator,
    max_epochs=20,
    lr=0.01,
    device='cuda',
    callbacks=[('best',cp),
               ('early', early_stop)
              ],
    iterator_train__shuffle=True,
    iterator_valid__shuffle=False
)

score = Score_ConfusionMatrix(scoring="accuracy", lower_is_better=False)
pt = Print_Score_CM(X.shape[0])

net.set_params(callbacks__valid_acc=score)
net.set_params(callbacks__print_log=pt)


<class 'skorch.classifier.NeuralNetClassifier'>[uninitialized](
  module=<class '__main__.Discriminator'>,
)

# Train Classifier

In [14]:
net.fit(X,y)

processed: 0.0116, train_batch: 128, train_loss:2.4394
processed: 0.0233, train_batch: 128, train_loss:2.2913
processed: 0.0349, train_batch: 128, train_loss:2.1163
processed: 0.0465, train_batch: 128, train_loss:2.0574
processed: 0.0582, train_batch: 128, train_loss:1.9575
processed: 0.0698, train_batch: 128, train_loss:1.9119
processed: 0.0815, train_batch: 128, train_loss:1.8276
processed: 0.0931, train_batch: 128, train_loss:1.8506
processed: 0.1047, train_batch: 128, train_loss:1.8279
processed: 0.1164, train_batch: 128, train_loss:1.7150
processed: 0.1280, train_batch: 128, train_loss:1.7522
processed: 0.1396, train_batch: 128, train_loss:1.8290
processed: 0.1513, train_batch: 128, train_loss:1.7282
processed: 0.1629, train_batch: 128, train_loss:1.6961
processed: 0.1745, train_batch: 128, train_loss:1.6559
processed: 0.1862, train_batch: 128, train_loss:1.7497
processed: 0.1978, train_batch: 128, train_loss:1.7387
processed: 0.2095, train_batch: 128, train_loss:1.7077
processed:

processed: 0.5935, train_batch: 128, train_loss:1.0085
processed: 0.6051, train_batch: 128, train_loss:1.0766
processed: 0.6167, train_batch: 128, train_loss:1.0604
processed: 0.6284, train_batch: 128, train_loss:1.1555
processed: 0.6400, train_batch: 128, train_loss:0.9532
processed: 0.6516, train_batch: 128, train_loss:1.0461
processed: 0.6633, train_batch: 128, train_loss:0.9966
processed: 0.6749, train_batch: 128, train_loss:0.9487
processed: 0.6865, train_batch: 128, train_loss:0.9812
processed: 0.6982, train_batch: 128, train_loss:0.9808
processed: 0.7098, train_batch: 128, train_loss:1.0799
processed: 0.7215, train_batch: 128, train_loss:0.9886
processed: 0.7331, train_batch: 128, train_loss:1.0997
processed: 0.7447, train_batch: 128, train_loss:0.9527
processed: 0.7564, train_batch: 128, train_loss:0.9885
processed: 0.7680, train_batch: 128, train_loss:1.0152
processed: 0.7796, train_batch: 128, train_loss:0.9690
processed: 0.7913, train_batch: 128, train_loss:1.1078
processed:

processed: 0.0349, train_batch: 128, train_loss:0.7724
processed: 0.0465, train_batch: 128, train_loss:0.6793
processed: 0.0582, train_batch: 128, train_loss:0.6883
processed: 0.0698, train_batch: 128, train_loss:0.8232
processed: 0.0815, train_batch: 128, train_loss:0.7435
processed: 0.0931, train_batch: 128, train_loss:0.8040
processed: 0.1047, train_batch: 128, train_loss:0.8419
processed: 0.1164, train_batch: 128, train_loss:0.6994
processed: 0.1280, train_batch: 128, train_loss:0.8126
processed: 0.1396, train_batch: 128, train_loss:0.6827
processed: 0.1513, train_batch: 128, train_loss:0.7330
processed: 0.1629, train_batch: 128, train_loss:0.7242
processed: 0.1745, train_batch: 128, train_loss:0.8800
processed: 0.1862, train_batch: 128, train_loss:0.7780
processed: 0.1978, train_batch: 128, train_loss:0.7915
processed: 0.2095, train_batch: 128, train_loss:0.7397
processed: 0.2211, train_batch: 128, train_loss:0.8576
processed: 0.2327, train_batch: 128, train_loss:0.7223
processed:

processed: 0.6167, train_batch: 128, train_loss:0.5457
processed: 0.6284, train_batch: 128, train_loss:0.5144
processed: 0.6400, train_batch: 128, train_loss:0.5821
processed: 0.6516, train_batch: 128, train_loss:0.5692
processed: 0.6633, train_batch: 128, train_loss:0.4974
processed: 0.6749, train_batch: 128, train_loss:0.5234
processed: 0.6865, train_batch: 128, train_loss:0.6128
processed: 0.6982, train_batch: 128, train_loss:0.5748
processed: 0.7098, train_batch: 128, train_loss:0.6172
processed: 0.7215, train_batch: 128, train_loss:0.5470
processed: 0.7331, train_batch: 128, train_loss:0.6500
processed: 0.7447, train_batch: 128, train_loss:0.5854
processed: 0.7564, train_batch: 128, train_loss:0.4937
processed: 0.7680, train_batch: 128, train_loss:0.5399
processed: 0.7796, train_batch: 128, train_loss:0.4763
processed: 0.7913, train_batch: 128, train_loss:0.5281
processed: 0.8000, train_batch: 96, train_loss:0.5810
processed: 0.8116, valid_batch: 128, valid_loss:0.2885
processed: 

processed: 0.0582, train_batch: 128, train_loss:0.3716
processed: 0.0698, train_batch: 128, train_loss:0.4608
processed: 0.0815, train_batch: 128, train_loss:0.4067
processed: 0.0931, train_batch: 128, train_loss:0.3852
processed: 0.1047, train_batch: 128, train_loss:0.4767
processed: 0.1164, train_batch: 128, train_loss:0.3571
processed: 0.1280, train_batch: 128, train_loss:0.4156
processed: 0.1396, train_batch: 128, train_loss:0.3635
processed: 0.1513, train_batch: 128, train_loss:0.4139
processed: 0.1629, train_batch: 128, train_loss:0.4830
processed: 0.1745, train_batch: 128, train_loss:0.4697
processed: 0.1862, train_batch: 128, train_loss:0.3730
processed: 0.1978, train_batch: 128, train_loss:0.4117
processed: 0.2095, train_batch: 128, train_loss:0.4353
processed: 0.2211, train_batch: 128, train_loss:0.3221
processed: 0.2327, train_batch: 128, train_loss:0.4205
processed: 0.2444, train_batch: 128, train_loss:0.3776
processed: 0.2560, train_batch: 128, train_loss:0.3581
processed:

processed: 0.6400, train_batch: 128, train_loss:0.2818
processed: 0.6516, train_batch: 128, train_loss:0.3165
processed: 0.6633, train_batch: 128, train_loss:0.2790
processed: 0.6749, train_batch: 128, train_loss:0.3616
processed: 0.6865, train_batch: 128, train_loss:0.2805
processed: 0.6982, train_batch: 128, train_loss:0.3875
processed: 0.7098, train_batch: 128, train_loss:0.3143
processed: 0.7215, train_batch: 128, train_loss:0.3047
processed: 0.7331, train_batch: 128, train_loss:0.3031
processed: 0.7447, train_batch: 128, train_loss:0.3760
processed: 0.7564, train_batch: 128, train_loss:0.2665
processed: 0.7680, train_batch: 128, train_loss:0.3043
processed: 0.7796, train_batch: 128, train_loss:0.2884
processed: 0.7913, train_batch: 128, train_loss:0.3576
processed: 0.8000, train_batch: 96, train_loss:0.2775
processed: 0.8116, valid_batch: 128, valid_loss:0.1118
processed: 0.8233, valid_batch: 128, valid_loss:0.6962
processed: 0.8349, valid_batch: 128, valid_loss:1.4312
processed: 

processed: 0.0815, train_batch: 128, train_loss:0.2523
processed: 0.0931, train_batch: 128, train_loss:0.2395
processed: 0.1047, train_batch: 128, train_loss:0.2204
processed: 0.1164, train_batch: 128, train_loss:0.2325
processed: 0.1280, train_batch: 128, train_loss:0.2302
processed: 0.1396, train_batch: 128, train_loss:0.2292
processed: 0.1513, train_batch: 128, train_loss:0.2607
processed: 0.1629, train_batch: 128, train_loss:0.2057
processed: 0.1745, train_batch: 128, train_loss:0.2145
processed: 0.1862, train_batch: 128, train_loss:0.2068
processed: 0.1978, train_batch: 128, train_loss:0.2323
processed: 0.2095, train_batch: 128, train_loss:0.2138
processed: 0.2211, train_batch: 128, train_loss:0.2108
processed: 0.2327, train_batch: 128, train_loss:0.2105
processed: 0.2444, train_batch: 128, train_loss:0.2923
processed: 0.2560, train_batch: 128, train_loss:0.2368
processed: 0.2676, train_batch: 128, train_loss:0.2515
processed: 0.2793, train_batch: 128, train_loss:0.2498
processed:

processed: 0.6633, train_batch: 128, train_loss:0.1866
processed: 0.6749, train_batch: 128, train_loss:0.1771
processed: 0.6865, train_batch: 128, train_loss:0.1976
processed: 0.6982, train_batch: 128, train_loss:0.1600
processed: 0.7098, train_batch: 128, train_loss:0.2062
processed: 0.7215, train_batch: 128, train_loss:0.1853
processed: 0.7331, train_batch: 128, train_loss:0.1458
processed: 0.7447, train_batch: 128, train_loss:0.2223
processed: 0.7564, train_batch: 128, train_loss:0.2157
processed: 0.7680, train_batch: 128, train_loss:0.1921
processed: 0.7796, train_batch: 128, train_loss:0.2083
processed: 0.7913, train_batch: 128, train_loss:0.2508
processed: 0.8000, train_batch: 96, train_loss:0.1628
processed: 0.8116, valid_batch: 128, valid_loss:0.0336
processed: 0.8233, valid_batch: 128, valid_loss:0.8085
processed: 0.8349, valid_batch: 128, valid_loss:1.7814
processed: 0.8465, valid_batch: 128, valid_loss:0.2260
processed: 0.8582, valid_batch: 128, valid_loss:0.0236
processed: 

processed: 0.1047, train_batch: 128, train_loss:0.1457
processed: 0.1164, train_batch: 128, train_loss:0.1401
processed: 0.1280, train_batch: 128, train_loss:0.1126
processed: 0.1396, train_batch: 128, train_loss:0.1303
processed: 0.1513, train_batch: 128, train_loss:0.1656
processed: 0.1629, train_batch: 128, train_loss:0.1482
processed: 0.1745, train_batch: 128, train_loss:0.1181
processed: 0.1862, train_batch: 128, train_loss:0.1421
processed: 0.1978, train_batch: 128, train_loss:0.1254
processed: 0.2095, train_batch: 128, train_loss:0.1385
processed: 0.2211, train_batch: 128, train_loss:0.1268
processed: 0.2327, train_batch: 128, train_loss:0.1690
processed: 0.2444, train_batch: 128, train_loss:0.1404
processed: 0.2560, train_batch: 128, train_loss:0.1567
processed: 0.2676, train_batch: 128, train_loss:0.1423
processed: 0.2793, train_batch: 128, train_loss:0.1356
processed: 0.2909, train_batch: 128, train_loss:0.1444
processed: 0.3025, train_batch: 128, train_loss:0.1479
processed:

processed: 0.6865, train_batch: 128, train_loss:0.1136
processed: 0.6982, train_batch: 128, train_loss:0.1141
processed: 0.7098, train_batch: 128, train_loss:0.1128
processed: 0.7215, train_batch: 128, train_loss:0.1270
processed: 0.7331, train_batch: 128, train_loss:0.1207
processed: 0.7447, train_batch: 128, train_loss:0.1094
processed: 0.7564, train_batch: 128, train_loss:0.1385
processed: 0.7680, train_batch: 128, train_loss:0.1185
processed: 0.7796, train_batch: 128, train_loss:0.1367
processed: 0.7913, train_batch: 128, train_loss:0.1483
processed: 0.8000, train_batch: 96, train_loss:0.1399
processed: 0.8116, valid_batch: 128, valid_loss:0.0179
processed: 0.8233, valid_batch: 128, valid_loss:0.4837
processed: 0.8349, valid_batch: 128, valid_loss:1.0407
processed: 0.8465, valid_batch: 128, valid_loss:0.1482
processed: 0.8582, valid_batch: 128, valid_loss:0.0262
processed: 0.8698, valid_batch: 128, valid_loss:0.0273
processed: 0.8815, valid_batch: 128, valid_loss:0.3200
processed: 

processed: 0.1280, train_batch: 128, train_loss:0.0996
processed: 0.1396, train_batch: 128, train_loss:0.1329
processed: 0.1513, train_batch: 128, train_loss:0.1365
processed: 0.1629, train_batch: 128, train_loss:0.1187
processed: 0.1745, train_batch: 128, train_loss:0.0965
processed: 0.1862, train_batch: 128, train_loss:0.0826
processed: 0.1978, train_batch: 128, train_loss:0.0909
processed: 0.2095, train_batch: 128, train_loss:0.0892
processed: 0.2211, train_batch: 128, train_loss:0.0694
processed: 0.2327, train_batch: 128, train_loss:0.0924
processed: 0.2444, train_batch: 128, train_loss:0.0715
processed: 0.2560, train_batch: 128, train_loss:0.1056
processed: 0.2676, train_batch: 128, train_loss:0.0901
processed: 0.2793, train_batch: 128, train_loss:0.0767
processed: 0.2909, train_batch: 128, train_loss:0.0908
processed: 0.3025, train_batch: 128, train_loss:0.0906
processed: 0.3142, train_batch: 128, train_loss:0.1133
processed: 0.3258, train_batch: 128, train_loss:0.0977
processed:

processed: 0.7098, train_batch: 128, train_loss:0.0884
processed: 0.7215, train_batch: 128, train_loss:0.0962
processed: 0.7331, train_batch: 128, train_loss:0.0732
processed: 0.7447, train_batch: 128, train_loss:0.1022
processed: 0.7564, train_batch: 128, train_loss:0.0743
processed: 0.7680, train_batch: 128, train_loss:0.0589
processed: 0.7796, train_batch: 128, train_loss:0.0776
processed: 0.7913, train_batch: 128, train_loss:0.0729
processed: 0.8000, train_batch: 96, train_loss:0.1139
processed: 0.8116, valid_batch: 128, valid_loss:0.0855
processed: 0.8233, valid_batch: 128, valid_loss:1.0291
processed: 0.8349, valid_batch: 128, valid_loss:2.1641
processed: 0.8465, valid_batch: 128, valid_loss:0.2573
processed: 0.8582, valid_batch: 128, valid_loss:0.0119
processed: 0.8698, valid_batch: 128, valid_loss:0.0123
processed: 0.8815, valid_batch: 128, valid_loss:0.1746
processed: 0.8931, valid_batch: 128, valid_loss:0.1666
processed: 0.9047, valid_batch: 128, valid_loss:0.0170
processed: 

<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=Discriminator(
    (conv1): Sequential(
      (0): Conv1d(2, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (conv2): Sequential(
      (0): Conv1d(256, 80, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (fc1): Sequential(
      (0): Linear(in_features=81920, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Dropout(p=0.6)
    )
    (fc2): Sequential(
      (0): Linear(in_features=256, out_features=11, bias=True)
      (1): ReLU()
    )
  ),
)

In [None]:
net.history

In [28]:
net.get_default_callbacks()

[('epoch_timer', <skorch.callbacks.logging.EpochTimer at 0x1a152026710>),
 ('train_loss', <skorch.callbacks.scoring.BatchScoring at 0x1a1520261d0>),
 ('valid_loss', <skorch.callbacks.scoring.BatchScoring at 0x1a152026470>),
 ('valid_acc', <skorch.callbacks.scoring.EpochScoring at 0x1a152026eb8>),
 ('print_log', <skorch.callbacks.logging.PrintLog at 0x1a152026f98>)]

In [30]:
net.callbacks_

[('epoch_timer', <skorch.callbacks.logging.EpochTimer at 0x1a151f57710>),
 ('train_loss', <skorch.callbacks.scoring.BatchScoring at 0x1a151f57748>),
 ('valid_loss', <skorch.callbacks.scoring.BatchScoring at 0x1a151f57fd0>),
 ('valid_acc', <__main__.Score_ConfusionMatrix at 0x1a1520f1cf8>),
 ('best', <skorch.callbacks.training.Checkpoint at 0x1a1520f1240>),
 ('early', <skorch.callbacks.training.EarlyStopping at 0x1a1520f1080>),
 ('print_log', <__main__.Print_Score_CM at 0x1a1520f1048>)]