In [1]:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from poutyne.framework import Model
from poutyne.framework.callbacks import ModelCheckpoint
from torchmetrics.classification.accuracy import MulticlassAccuracy
from models import *
from data import *

In [2]:
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
net = VT_CNN2(
    n_classes=11,
    dropout=0.5
)

In [4]:
dataset = RadioML2016()

Loading dataset from file...
Processing dataset


100%|██████████| 200/200 [00:00<00:00, 15386.58it/s]


In [5]:
total = len(dataset)
lengths = [int(len(dataset)*0.5)]
lengths.append(total - lengths[0])
print("Splitting into {} train and {} val".format(lengths[0], lengths[1]))
train_set, val_set = random_split(dataset, lengths)

Splitting into 100000 train and 100000 val


In [6]:
train_dataloader = DataLoader(train_set, batch_size=512)
val_dataloader = DataLoader(val_set, batch_size=512)

In [7]:
os.makedirs("models", exist_ok=True)
checkpoint = ModelCheckpoint(
    filename=os.path.join("models", "vtcnn2.pt"),
    monitor="val_loss",
    save_best_only=True
)
callbacks = [checkpoint]

In [8]:
top3 = MulticlassAccuracy(num_classes=11, top_k=3)
top5 = MulticlassAccuracy(num_classes=11, top_k=5)
metrics = ["acc", top3, top5]

In [9]:
model = Model(
    network=net,
    optimizer="AdamW",
    loss_function=nn.CrossEntropyLoss(),
    batch_metrics=metrics
)

In [10]:
model.cuda()
model.fit_generator(
    train_dataloader,
    val_dataloader,
    epochs=100,
    callbacks=callbacks
)

[35mEpoch: [36m  1/100 [35mTrain steps: [36m196 [35mVal steps: [36m196 [32m8.53s [35mloss:[94m 2.315781[35m acc:[94m 9.991000[35m multiclass_accuracy1:[94m 0.301557[35m multiclass_accuracy2:[94m 0.503567[35m val_loss:[94m 2.305203[35m val_acc:[94m 10.021000[35m val_multiclass_accuracy1:[94m 0.300000[35m val_multiclass_accuracy2:[94m 0.499960[0m
[35mEpoch: [36m  2/100 [35mTrain steps: [36m196 [35mVal steps: [36m196 [32m8.27s [35mloss:[94m 2.306737[35m acc:[94m 9.906000[35m multiclass_accuracy1:[94m 0.299037[35m multiclass_accuracy2:[94m 0.498688[35m val_loss:[94m 2.303653[35m val_acc:[94m 10.030000[35m val_multiclass_accuracy1:[94m 0.300000[35m val_multiclass_accuracy2:[94m 0.500000[0m
[35mEpoch: [36m  3/100 [35mTrain steps: [36m196 [35mVal steps: [36m196 [32m8.20s [35mloss:[94m 2.304571[35m acc:[94m 9.989000[35m multiclass_accuracy1:[94m 0.302671[35m multiclass_accuracy2:[94m 0.502142[35m val_loss:[94m 2.301173[35m val_a

[{'epoch': 1,
  'time': 8.526078299968503,
  'loss': 2.3157811140441895,
  'acc': 9.991,
  'multiclass_accuracy1': 0.30155712366104126,
  'multiclass_accuracy2': 0.5035674571990967,
  'val_loss': 2.305203116378784,
  'val_acc': 10.021,
  'val_multiclass_accuracy1': 0.30000001192092896,
  'val_multiclass_accuracy2': 0.49995970726013184},
 {'epoch': 2,
  'time': 8.26578909996897,
  'loss': 2.3067371719360352,
  'acc': 9.906,
  'multiclass_accuracy1': 0.2990373373031616,
  'multiclass_accuracy2': 0.49868762493133545,
  'val_loss': 2.3036532899475097,
  'val_acc': 10.03,
  'val_multiclass_accuracy1': 0.30000001192092896,
  'val_multiclass_accuracy2': 0.5},
 {'epoch': 3,
  'time': 8.198971000034362,
  'loss': 2.3045709114837645,
  'acc': 9.989,
  'multiclass_accuracy1': 0.3026713728904724,
  'multiclass_accuracy2': 0.5021424293518066,
  'val_loss': 2.301172536239624,
  'val_acc': 10.041,
  'val_multiclass_accuracy1': 0.30000001192092896,
  'val_multiclass_accuracy2': 0.5002883076667786},
 {