In [6]:
import torch
from torch import nn


class RadioTransformerNetwork(nn.Module):
    def __init__(self, in_channels, compressed_dim):
        super(RadioTransformerNetwork, self).__init__()

        self.in_channels = in_channels

        self.encoder = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels, compressed_dim),
        )

        self.decoder = nn.Sequential(
            nn.Linear(compressed_dim, compressed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(compressed_dim, in_channels)
        )

    def decode_signal(self, x):
        return self.decoder(x)

    def forward(self, x):
        x = self.encoder(x)

        # Normalization.
        x = (self.in_channels ** 2) * (x / x.norm(dim=-1)[:, None])

        # 7dBW to SNR.
        training_signal_noise_ratio = 5.01187

        # bit / channel_use
        communication_rate = 1

        # Simulated Gaussian noise.
        noise = Variable(torch.randn(*x.size()) / ((2 * communication_rate * training_signal_noise_ratio) ** 0.5))
        if USE_CUDA: noise = noise.cuda()
        x += noise

        x = self.decoder(x)

        return x
def get_iterator(mode):
        data = train_data if mode else test_data
        labels = train_labels if mode else test_labels
        tensor_dataset = tnt.dataset.TensorDataset([data, labels])

        return tensor_dataset.parallel(batch_size=BATCH_SIZE, num_workers=4, shuffle=mode)


def processor(sample):
        data, labels, training = sample

        data = Variable(data)
        labels = Variable(labels)

        if USE_CUDA:
            data = data.cuda()
            labels = labels.cuda()

        outputs = model(data)

        loss = loss_fn(outputs, labels)

        return loss, outputs


def reset_meters():
        meter_accuracy.reset()
        meter_loss.reset()
        confusion_meter.reset()


def on_sample(state):
        state['sample'].append(state['train'])


def on_forward(state):
        meter_accuracy.add(state['output'].data, torch.LongTensor(state['sample'][1]))
        confusion_meter.add(state['output'].data, torch.LongTensor(state['sample'][1]))
        meter_loss.add(state['loss'].item())


def on_start_epoch(state):
        reset_meters()
        state['iterator'] = tqdm(state['iterator'])


def on_end_epoch(state):
        print('[Epoch %d] Training Loss: %.4f (Accuracy: %.2f%%)' % (
            state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))

        reset_meters()

        engine.test(processor, get_iterator(False))

        print('[Epoch %d] Testing Loss: %.4f (Accuracy: %.2f%%)' % (
            state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))
NUM_EPOCHS = 100
BATCH_SIZE = 256
CHANNEL_SIZE = 4
USE_CUDA = False

In [7]:
from tqdm import tqdm
from torchnet.engine import Engine
from torch.autograd import Variable
from torch.optim import Adam
import torchnet as tnt
import math

model = RadioTransformerNetwork(CHANNEL_SIZE, compressed_dim=int(math.log2(CHANNEL_SIZE)))
if USE_CUDA: model = model.cuda()

train_labels = (torch.rand(10000) * CHANNEL_SIZE).long()
train_data = torch.sparse.torch.eye(CHANNEL_SIZE).index_select(dim=0, index=train_labels)

test_labels = (torch.rand(1500) * CHANNEL_SIZE).long()
test_data = torch.sparse.torch.eye(CHANNEL_SIZE).index_select(dim=0, index=test_labels)

optimizer = Adam(model.parameters())

engine = Engine()
meter_loss = tnt.meter.AverageValueMeter()
meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
confusion_meter = tnt.meter.ConfusionMeter(CHANNEL_SIZE, normalized=True)

loss_fn = nn.CrossEntropyLoss()

In [8]:
engine.hooks['on_sample'] = on_sample
engine.hooks['on_forward'] = on_forward
engine.hooks['on_start_epoch'] = on_start_epoch
engine.hooks['on_end_epoch'] = on_end_epoch

engine.train(processor, get_iterator(True), maxepoch=NUM_EPOCHS, optimizer=optimizer)


100%|██████████| 40/40 [00:00<00:00, 207.45it/s]

[Epoch 1] Training Loss: 1.6789 (Accuracy: 24.98%)





[Epoch 1] Testing Loss: 1.3288 (Accuracy: 27.47%)


100%|██████████| 40/40 [00:00<00:00, 229.05it/s]

[Epoch 2] Training Loss: 1.2210 (Accuracy: 24.98%)





[Epoch 2] Testing Loss: 1.0718 (Accuracy: 27.47%)


100%|██████████| 40/40 [00:00<00:00, 229.44it/s]

[Epoch 3] Training Loss: 1.0347 (Accuracy: 38.27%)





[Epoch 3] Testing Loss: 0.9363 (Accuracy: 50.47%)


100%|██████████| 40/40 [00:00<00:00, 216.42it/s]

[Epoch 4] Training Loss: 0.8984 (Accuracy: 50.72%)





[Epoch 4] Testing Loss: 0.8307 (Accuracy: 50.47%)


100%|██████████| 40/40 [00:00<00:00, 218.82it/s]

[Epoch 5] Training Loss: 0.8136 (Accuracy: 50.72%)





[Epoch 5] Testing Loss: 0.7629 (Accuracy: 50.47%)


100%|██████████| 40/40 [00:00<00:00, 227.38it/s]

[Epoch 6] Training Loss: 0.7502 (Accuracy: 50.72%)





[Epoch 6] Testing Loss: 0.7082 (Accuracy: 50.47%)


100%|██████████| 40/40 [00:00<00:00, 220.97it/s]

[Epoch 7] Training Loss: 0.6974 (Accuracy: 50.72%)





[Epoch 7] Testing Loss: 0.6632 (Accuracy: 50.47%)


100%|██████████| 40/40 [00:00<00:00, 227.37it/s]

[Epoch 8] Training Loss: 0.6463 (Accuracy: 54.13%)





[Epoch 8] Testing Loss: 0.6138 (Accuracy: 75.07%)


100%|██████████| 40/40 [00:00<00:00, 230.70it/s]

[Epoch 9] Training Loss: 0.5976 (Accuracy: 74.82%)





[Epoch 9] Testing Loss: 0.5769 (Accuracy: 75.73%)


100%|██████████| 40/40 [00:00<00:00, 236.61it/s]

[Epoch 10] Training Loss: 0.5684 (Accuracy: 74.88%)





[Epoch 10] Testing Loss: 0.5514 (Accuracy: 75.73%)


100%|██████████| 40/40 [00:00<00:00, 217.21it/s]

[Epoch 11] Training Loss: 0.5451 (Accuracy: 74.89%)





[Epoch 11] Testing Loss: 0.5302 (Accuracy: 75.80%)


100%|██████████| 40/40 [00:00<00:00, 218.23it/s]

[Epoch 12] Training Loss: 0.5249 (Accuracy: 75.00%)





[Epoch 12] Testing Loss: 0.5108 (Accuracy: 75.80%)


100%|██████████| 40/40 [00:00<00:00, 221.60it/s]

[Epoch 13] Training Loss: 0.5021 (Accuracy: 75.19%)





[Epoch 13] Testing Loss: 0.4915 (Accuracy: 76.33%)


100%|██████████| 40/40 [00:00<00:00, 226.52it/s]

[Epoch 14] Training Loss: 0.4839 (Accuracy: 76.60%)





[Epoch 14] Testing Loss: 0.4740 (Accuracy: 78.47%)


100%|██████████| 40/40 [00:00<00:00, 225.32it/s]

[Epoch 15] Training Loss: 0.4692 (Accuracy: 97.55%)





[Epoch 15] Testing Loss: 0.4574 (Accuracy: 99.73%)


100%|██████████| 40/40 [00:00<00:00, 227.53it/s]

[Epoch 16] Training Loss: 0.4519 (Accuracy: 99.72%)





[Epoch 16] Testing Loss: 0.4410 (Accuracy: 99.60%)


100%|██████████| 40/40 [00:00<00:00, 226.83it/s]

[Epoch 17] Training Loss: 0.4380 (Accuracy: 99.77%)





[Epoch 17] Testing Loss: 0.4262 (Accuracy: 99.67%)


100%|██████████| 40/40 [00:00<00:00, 220.36it/s]

[Epoch 18] Training Loss: 0.4234 (Accuracy: 99.92%)





[Epoch 18] Testing Loss: 0.4112 (Accuracy: 99.80%)


100%|██████████| 40/40 [00:00<00:00, 221.96it/s]

[Epoch 19] Training Loss: 0.4083 (Accuracy: 99.93%)





[Epoch 19] Testing Loss: 0.3974 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 230.39it/s]

[Epoch 20] Training Loss: 0.3926 (Accuracy: 99.92%)





[Epoch 20] Testing Loss: 0.3845 (Accuracy: 99.93%)


100%|██████████| 40/40 [00:00<00:00, 221.96it/s]

[Epoch 21] Training Loss: 0.3854 (Accuracy: 99.97%)





[Epoch 21] Testing Loss: 0.3715 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 226.27it/s]

[Epoch 22] Training Loss: 0.3687 (Accuracy: 99.99%)





[Epoch 22] Testing Loss: 0.3595 (Accuracy: 99.93%)


100%|██████████| 40/40 [00:00<00:00, 226.07it/s]

[Epoch 23] Training Loss: 0.3583 (Accuracy: 99.98%)





[Epoch 23] Testing Loss: 0.3476 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 230.94it/s]

[Epoch 24] Training Loss: 0.3425 (Accuracy: 99.99%)





[Epoch 24] Testing Loss: 0.3363 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 224.89it/s]

[Epoch 25] Training Loss: 0.3355 (Accuracy: 100.00%)





[Epoch 25] Testing Loss: 0.3255 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 228.33it/s]

[Epoch 26] Training Loss: 0.3226 (Accuracy: 99.99%)





[Epoch 26] Testing Loss: 0.3150 (Accuracy: 99.93%)


100%|██████████| 40/40 [00:00<00:00, 226.19it/s]

[Epoch 27] Training Loss: 0.3129 (Accuracy: 99.99%)





[Epoch 27] Testing Loss: 0.3056 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 220.96it/s]

[Epoch 28] Training Loss: 0.3014 (Accuracy: 100.00%)





[Epoch 28] Testing Loss: 0.2959 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 217.55it/s]

[Epoch 29] Training Loss: 0.2912 (Accuracy: 100.00%)





[Epoch 29] Testing Loss: 0.2867 (Accuracy: 99.93%)


100%|██████████| 40/40 [00:00<00:00, 228.41it/s]

[Epoch 30] Training Loss: 0.2835 (Accuracy: 99.99%)





[Epoch 30] Testing Loss: 0.2776 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 229.19it/s]

[Epoch 31] Training Loss: 0.2767 (Accuracy: 100.00%)





[Epoch 31] Testing Loss: 0.2696 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 226.79it/s]

[Epoch 32] Training Loss: 0.2673 (Accuracy: 100.00%)





[Epoch 32] Testing Loss: 0.2613 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 225.98it/s]

[Epoch 33] Training Loss: 0.2613 (Accuracy: 100.00%)





[Epoch 33] Testing Loss: 0.2536 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 236.92it/s]

[Epoch 34] Training Loss: 0.2531 (Accuracy: 100.00%)





[Epoch 34] Testing Loss: 0.2456 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 219.56it/s]

[Epoch 35] Training Loss: 0.2419 (Accuracy: 100.00%)





[Epoch 35] Testing Loss: 0.2382 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 220.50it/s]

[Epoch 36] Training Loss: 0.2389 (Accuracy: 100.00%)





[Epoch 36] Testing Loss: 0.2318 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 223.71it/s]

[Epoch 37] Training Loss: 0.2314 (Accuracy: 99.99%)





[Epoch 37] Testing Loss: 0.2247 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 224.43it/s]

[Epoch 38] Training Loss: 0.2240 (Accuracy: 100.00%)





[Epoch 38] Testing Loss: 0.2182 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 231.97it/s]

[Epoch 39] Training Loss: 0.2191 (Accuracy: 100.00%)





[Epoch 39] Testing Loss: 0.2119 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 189.93it/s]

[Epoch 40] Training Loss: 0.2093 (Accuracy: 100.00%)





[Epoch 40] Testing Loss: 0.2057 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 202.48it/s]

[Epoch 41] Training Loss: 0.2046 (Accuracy: 100.00%)





[Epoch 41] Testing Loss: 0.2006 (Accuracy: 99.93%)


100%|██████████| 40/40 [00:00<00:00, 194.44it/s]

[Epoch 42] Training Loss: 0.2004 (Accuracy: 100.00%)





[Epoch 42] Testing Loss: 0.1942 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 176.80it/s]

[Epoch 43] Training Loss: 0.1926 (Accuracy: 100.00%)





[Epoch 43] Testing Loss: 0.1891 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 215.14it/s]

[Epoch 44] Training Loss: 0.1871 (Accuracy: 100.00%)





[Epoch 44] Testing Loss: 0.1838 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 193.19it/s]

[Epoch 45] Training Loss: 0.1823 (Accuracy: 100.00%)





[Epoch 45] Testing Loss: 0.1790 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 177.04it/s]

[Epoch 46] Training Loss: 0.1770 (Accuracy: 100.00%)





[Epoch 46] Testing Loss: 0.1740 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 223.36it/s]

[Epoch 47] Training Loss: 0.1726 (Accuracy: 100.00%)





[Epoch 47] Testing Loss: 0.1691 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 226.92it/s]

[Epoch 48] Training Loss: 0.1677 (Accuracy: 100.00%)





[Epoch 48] Testing Loss: 0.1647 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 217.39it/s]

[Epoch 49] Training Loss: 0.1630 (Accuracy: 100.00%)





[Epoch 49] Testing Loss: 0.1602 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 226.49it/s]

[Epoch 50] Training Loss: 0.1587 (Accuracy: 100.00%)





[Epoch 50] Testing Loss: 0.1558 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 233.37it/s]

[Epoch 51] Training Loss: 0.1548 (Accuracy: 100.00%)





[Epoch 51] Testing Loss: 0.1523 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 229.92it/s]

[Epoch 52] Training Loss: 0.1498 (Accuracy: 100.00%)





[Epoch 52] Testing Loss: 0.1483 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 231.96it/s]

[Epoch 53] Training Loss: 0.1467 (Accuracy: 100.00%)





[Epoch 53] Testing Loss: 0.1444 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 225.93it/s]

[Epoch 54] Training Loss: 0.1423 (Accuracy: 100.00%)





[Epoch 54] Testing Loss: 0.1406 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 220.14it/s]

[Epoch 55] Training Loss: 0.1397 (Accuracy: 100.00%)





[Epoch 55] Testing Loss: 0.1370 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 221.62it/s]

[Epoch 56] Training Loss: 0.1363 (Accuracy: 100.00%)





[Epoch 56] Testing Loss: 0.1335 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 217.45it/s]

[Epoch 57] Training Loss: 0.1320 (Accuracy: 100.00%)





[Epoch 57] Testing Loss: 0.1309 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 223.57it/s]

[Epoch 58] Training Loss: 0.1296 (Accuracy: 100.00%)





[Epoch 58] Testing Loss: 0.1269 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 226.41it/s]

[Epoch 59] Training Loss: 0.1261 (Accuracy: 100.00%)





[Epoch 59] Testing Loss: 0.1242 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 225.90it/s]

[Epoch 60] Training Loss: 0.1227 (Accuracy: 100.00%)





[Epoch 60] Testing Loss: 0.1210 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 189.42it/s]

[Epoch 61] Training Loss: 0.1203 (Accuracy: 100.00%)





[Epoch 61] Testing Loss: 0.1182 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 212.30it/s]

[Epoch 62] Training Loss: 0.1169 (Accuracy: 100.00%)





[Epoch 62] Testing Loss: 0.1153 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 224.79it/s]

[Epoch 63] Training Loss: 0.1142 (Accuracy: 100.00%)





[Epoch 63] Testing Loss: 0.1122 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 232.69it/s]

[Epoch 64] Training Loss: 0.1115 (Accuracy: 100.00%)





[Epoch 64] Testing Loss: 0.1099 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 225.16it/s]

[Epoch 65] Training Loss: 0.1088 (Accuracy: 100.00%)





[Epoch 65] Testing Loss: 0.1070 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 229.28it/s]

[Epoch 66] Training Loss: 0.1059 (Accuracy: 100.00%)





[Epoch 66] Testing Loss: 0.1040 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 230.69it/s]

[Epoch 67] Training Loss: 0.1042 (Accuracy: 100.00%)





[Epoch 67] Testing Loss: 0.1011 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 224.21it/s]

[Epoch 68] Training Loss: 0.1007 (Accuracy: 100.00%)





[Epoch 68] Testing Loss: 0.0982 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 231.77it/s]

[Epoch 69] Training Loss: 0.0978 (Accuracy: 100.00%)





[Epoch 69] Testing Loss: 0.0950 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 216.37it/s]

[Epoch 70] Training Loss: 0.0935 (Accuracy: 100.00%)





[Epoch 70] Testing Loss: 0.0919 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 225.22it/s]

[Epoch 71] Training Loss: 0.0908 (Accuracy: 100.00%)





[Epoch 71] Testing Loss: 0.0882 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 226.08it/s]

[Epoch 72] Training Loss: 0.0869 (Accuracy: 100.00%)





[Epoch 72] Testing Loss: 0.0848 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 224.53it/s]

[Epoch 73] Training Loss: 0.0839 (Accuracy: 100.00%)





[Epoch 73] Testing Loss: 0.0814 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 225.19it/s]

[Epoch 74] Training Loss: 0.0801 (Accuracy: 99.99%)





[Epoch 74] Testing Loss: 0.0782 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 236.45it/s]

[Epoch 75] Training Loss: 0.0765 (Accuracy: 100.00%)





[Epoch 75] Testing Loss: 0.0754 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 223.55it/s]

[Epoch 76] Training Loss: 0.0742 (Accuracy: 100.00%)





[Epoch 76] Testing Loss: 0.0721 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 229.12it/s]

[Epoch 77] Training Loss: 0.0705 (Accuracy: 100.00%)





[Epoch 77] Testing Loss: 0.0698 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 214.04it/s]

[Epoch 78] Training Loss: 0.0684 (Accuracy: 100.00%)





[Epoch 78] Testing Loss: 0.0668 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 227.25it/s]

[Epoch 79] Training Loss: 0.0653 (Accuracy: 100.00%)





[Epoch 79] Testing Loss: 0.0649 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 227.86it/s]

[Epoch 80] Training Loss: 0.0633 (Accuracy: 100.00%)





[Epoch 80] Testing Loss: 0.0619 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 224.95it/s]

[Epoch 81] Training Loss: 0.0607 (Accuracy: 100.00%)





[Epoch 81] Testing Loss: 0.0599 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 218.72it/s]

[Epoch 82] Training Loss: 0.0579 (Accuracy: 100.00%)





[Epoch 82] Testing Loss: 0.0574 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 225.98it/s]

[Epoch 83] Training Loss: 0.0565 (Accuracy: 100.00%)





[Epoch 83] Testing Loss: 0.0550 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 228.32it/s]

[Epoch 84] Training Loss: 0.0549 (Accuracy: 100.00%)





[Epoch 84] Testing Loss: 0.0545 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 222.12it/s]

[Epoch 85] Training Loss: 0.0531 (Accuracy: 100.00%)





[Epoch 85] Testing Loss: 0.0521 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 228.86it/s]

[Epoch 86] Training Loss: 0.0503 (Accuracy: 100.00%)





[Epoch 86] Testing Loss: 0.0500 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 222.69it/s]

[Epoch 87] Training Loss: 0.0490 (Accuracy: 100.00%)





[Epoch 87] Testing Loss: 0.0478 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 218.35it/s]

[Epoch 88] Training Loss: 0.0468 (Accuracy: 100.00%)





[Epoch 88] Testing Loss: 0.0467 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 232.78it/s]

[Epoch 89] Training Loss: 0.0453 (Accuracy: 100.00%)





[Epoch 89] Testing Loss: 0.0448 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 220.50it/s]

[Epoch 90] Training Loss: 0.0441 (Accuracy: 100.00%)





[Epoch 90] Testing Loss: 0.0436 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 222.86it/s]

[Epoch 91] Training Loss: 0.0427 (Accuracy: 100.00%)





[Epoch 91] Testing Loss: 0.0418 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 228.18it/s]

[Epoch 92] Training Loss: 0.0413 (Accuracy: 100.00%)





[Epoch 92] Testing Loss: 0.0409 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 221.58it/s]

[Epoch 93] Training Loss: 0.0401 (Accuracy: 100.00%)





[Epoch 93] Testing Loss: 0.0395 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 231.00it/s]

[Epoch 94] Training Loss: 0.0389 (Accuracy: 100.00%)





[Epoch 94] Testing Loss: 0.0379 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 222.12it/s]

[Epoch 95] Training Loss: 0.0372 (Accuracy: 100.00%)





[Epoch 95] Testing Loss: 0.0371 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 227.66it/s]

[Epoch 96] Training Loss: 0.0365 (Accuracy: 100.00%)





[Epoch 96] Testing Loss: 0.0358 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 229.00it/s]

[Epoch 97] Training Loss: 0.0352 (Accuracy: 100.00%)





[Epoch 97] Testing Loss: 0.0348 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 208.04it/s]

[Epoch 98] Training Loss: 0.0340 (Accuracy: 100.00%)





[Epoch 98] Testing Loss: 0.0337 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 230.46it/s]

[Epoch 99] Training Loss: 0.0334 (Accuracy: 100.00%)





[Epoch 99] Testing Loss: 0.0328 (Accuracy: 100.00%)


100%|██████████| 40/40 [00:00<00:00, 229.89it/s]

[Epoch 100] Training Loss: 0.0326 (Accuracy: 100.00%)





[Epoch 100] Testing Loss: 0.0318 (Accuracy: 100.00%)


{'network': <function __main__.processor(sample)>,
 'iterator': <tqdm.std.tqdm at 0x7f1ab2c62250>,
 'maxepoch': 100,
 'optimizer': Adam (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     differentiable: False
     eps: 1e-08
     foreach: None
     fused: None
     lr: 0.001
     maximize: False
     weight_decay: 0
 ),
 'epoch': 100,
 't': 4000,
 'train': True,
 'sample': [tensor([[0., 0., 1., 0.],
          [1., 0., 0., 0.],
          [0., 0., 1., 0.],
          [0., 0., 1., 0.],
          [1., 0., 0., 0.],
          [0., 0., 1., 0.],
          [0., 0., 1., 0.],
          [0., 1., 0., 0.],
          [0., 1., 0., 0.],
          [0., 1., 0., 0.],
          [1., 0., 0., 0.],
          [0., 0., 1., 0.],
          [0., 0., 0., 1.],
          [0., 0., 0., 1.],
          [0., 1., 0., 0.],
          [0., 0., 1., 0.]]),
  tensor([2, 0, 2, 2, 0, 2, 2, 1, 1, 1, 0, 2, 3, 3, 1, 2]),
  True],
 'output': None,
 'loss': None}

In [4]:
train_data[:4]

tensor([[0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.]])

In [5]:
train_labels[:10]

tensor([3, 1, 0, 1, 1, 2, 0, 0, 3, 2])