In [1]:
import torch
import pickle

import torch.nn as nn

from torch.utils.data import DataLoader
from ds.drivers.datasets import DriverDataset, DriversDataset, all_data

In [2]:
device = "cuda"
ds_step_size = 5
test_batch_size = 1

In [3]:
class ConvX(nn.Module):
    def __init__(
        self, in_planes, out_planes, kernel=3, stride=1, padding=None
    ):
        super(ConvX, self).__init__()
        padding = kernel // 2 if padding is None else padding
        self.conv = nn.Conv1d(
            in_planes,
            out_planes,
            kernel_size=kernel,
            stride=stride,
            padding=padding,
            bias=False,
        )
        self.bn = nn.BatchNorm1d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

In [4]:
class Net6Conv2MaxPool(nn.Module):
    def __init__(self):
        super(Net6Conv2MaxPool, self).__init__()
        self.seq = torch.nn.Sequential(
            ConvX(1, 1, kernel=2, padding=0),
            ConvX(1, 1, kernel=2, padding=0),
            nn.MaxPool1d(2),
            ConvX(1, 1, kernel=2, padding=0),
            ConvX(1, 1, kernel=2, padding=0),
            nn.MaxPool1d(2),
            ConvX(1, 1, kernel=2, padding=0),
            ConvX(1, 1, kernel=2, stride=2),
            nn.Conv1d(1, 1, kernel_size=2),
        )
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        return self.softmax(torch.flatten(self.seq(x), 1))

In [5]:
def get_net(data):
    net = Net6Conv2MaxPool()
    net.load_state_dict(data.get("net_state_dict"))
    net.to(device=device)
    return net

In [6]:
def val(model, dl_test, return_valid_count_only=True):
    labels_ = []
    outputs_ = []
    with torch.no_grad():
        sum_ = 0
        for i, data in enumerate(dl_test):
            inputs, labels = data
            labels_.append(labels)
            inputs = (
                torch.unsqueeze(inputs, 1).to(torch.float32).to(device=device)
            )
            outputs = model(inputs).cpu()
            outputs = outputs.max(1).indices
            outputs_.append(outputs)
            eq = outputs == labels
            sum_ += eq.sum()

    return sum_ if return_valid_count_only else (sum_, labels_, outputs_)

In [7]:
ds_all_test = DriversDataset(
    all_data, ds_type="test", data_field="working_data_filtered", step=5
)
dl_all_test = DataLoader(
    ds_all_test,
    batch_size=1,
    shuffle=True,
    num_workers=1,
    pin_memory=False,
    drop_last=True,
)

In [8]:
drivers_test_datasets = {
    driver_id: DriverDataset(
        driver_data,
        ds_type="test",
        data_field="working_data_filtered",
        step=ds_step_size,
    )
    for driver_id, driver_data in all_data.items()
}

In [9]:
nets = [
    "Net6Conv2MaxPool_Adam_lr_0.0001_betas_default_seed_1338",
    "Net6Conv2MaxPool_Adam_lr_0.001_betas_default_seed_1337",
    "Net6Conv2MaxPool_Adam_lr_0.0001_betas_0.85_0.995",
    "Net6Conv2MaxPool_Adam_lr_0.0001_betas_default_seed_1336",
    "Net6Conv2MaxPool_Adam_lr_0.0005_betas_default",
    "Net6Conv2MaxPool_Adam_lr_5e-05_betas_default",
    "Net6Conv2MaxPool_ASGD_lr_0.0001_betas_default",
    "Net6Conv2MaxPool_RMSprop_lr_0.0001_betas_default",
    "Net6Conv2MaxPool_SGD_lr_0.0001",
]

In [10]:
drivers = [5, 6, 7, 8, 9, 10, 11, 12, 15, 16]
for net in nets:
    print(net)
    with open(f"models_dumps/{net}.pkl", "rb") as f:
        data = pickle.load(f)
    accs = []
    net = get_net(data)
    net.eval()
    acc = val(net, dl_all_test)
    print(
        "Computed common rate:",
        acc / len(ds_all_test),
        "; file rate:",
        data.get("rate"),
    )
    for driver_id, ds_test in drivers_test_datasets.items():
        dl_test = DataLoader(
            ds_test,
            batch_size=test_batch_size,
            shuffle=True,
            num_workers=1,
            pin_memory=False,
            drop_last=True,
        )
        net = get_net(data)
        net.eval()
        acc = val(net, dl_test)
        accs.append(acc / len(ds_test))
        # print("Driver", driver_id, "rate: ", accs[-1])
    print("Mean: ", sum(accs) / len(accs), "; list:", list(zip(drivers, accs)))
    print("-------------------------")

Net6Conv2MaxPool_Adam_lr_0.0001_betas_default_seed_1338
Computed common rate: tensor(0.8106) ; file rate: tensor(0.8106)
Mean:  tensor(0.7822) ; list: [(5, tensor(0.8835)), (6, tensor(0.0217)), (7, tensor(0.3175)), (8, tensor(0.9924)), (9, tensor(0.6341)), (10, tensor(1.)), (11, tensor(1.)), (12, tensor(1.)), (15, tensor(0.9730)), (16, tensor(1.))]
-------------------------
Net6Conv2MaxPool_Adam_lr_0.001_betas_default_seed_1337
Computed common rate: tensor(0.8053) ; file rate: None
Mean:  tensor(0.7788) ; list: [(5, tensor(0.8544)), (6, tensor(0.0217)), (7, tensor(0.3175)), (8, tensor(0.9773)), (9, tensor(0.7195)), (10, tensor(1.)), (11, tensor(0.9691)), (12, tensor(1.)), (15, tensor(0.9459)), (16, tensor(0.9825))]
-------------------------
Net6Conv2MaxPool_Adam_lr_0.0001_betas_0.85_0.995
Computed common rate: tensor(0.8074) ; file rate: tensor(0.8074)
Mean:  tensor(0.7781) ; list: [(5, tensor(0.9029)), (6, tensor(0.0217)), (7, tensor(0.3175)), (8, tensor(0.9924)), (9, tensor(0.5732)),