In [1]:
import sys
sys.path.append('../../datasets/')
from prepare_sequences import prepare, germanBats
import matplotlib.pyplot as plt

classes = germanBats

In [2]:
num_bands = 257
patch_len = 44                               # = 250ms ~ 25ms
patch_skip = patch_len / 2                   # = 150ms ~ 15ms

resize = None

mode = 'slide'
options = {
    'seq_len': 60,                            # = 500ms with ~ 5 calls
    'seq_skip': 15,
}

X_test, Y_test = prepare("../../datasets/prepared.h5", classes, patch_len, patch_skip,
                                                         options, mode, resize, only_test=True)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:37<00:00,  2.07s/it]


In [3]:
print("Total sequences:", len(X_test))
print(X_test.shape, Y_test.shape)

Total sequences: 4979
(4979, 60, 44, 257) (4979,)


In [4]:
import time
import datetime
import numpy as np
import tqdm
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

In [5]:
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
    x = x.long().view(-1, 1)
    return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)

def mixup(x, y, num_classes):
    x_flipped = x.flip(0)
    x.add_(x_flipped)
    y1 = one_hot(y, num_classes, device=x.device)
    y2 = one_hot(y.flip(0), num_classes, device=x.device)
    return x, y1 + y2

In [6]:
batch_size = 1
num_classes = len(list(classes))

test_len = len(X_test) - len(X_test) % 2
test_data = TensorDataset(torch.Tensor(X_test[:test_len]), torch.from_numpy(Y_test[:test_len]))
test_loader = DataLoader(test_data, batch_size=batch_size)

In [7]:
'''def stitch(a, r):
    return a[::r]
  
def plot_sequence(seq, y):
    plt.figure(figsize = (20, 2.5))
    stitched = stitch(seq, int(patch_len / patch_skip))
    spec = np.rot90(np.concatenate(stitched))
    plt.imshow(spec, interpolation='nearest', aspect='auto', cmap='inferno')
    plt.colorbar()
    label_list = []
    if(len(y.shape) > 0):
        for i in np.argsort(-y)[:2]:
            label_list.append(list(classes)[i])
        plt.title(", ".join(label_list))
    else:
        plt.title(list(classes)[y])

k = 3
X1, Y1 = next(iter(test_loader))
print(X1.shape, Y1.shape)
plot_sequence(X1[k].detach().numpy(), Y1[k].detach().numpy())
plot_sequence(X1[-k-1].detach().numpy(), Y1[-k-1].detach().numpy())

X1, Y1 = mixup(X1, Y1, num_classes=num_classes)
plot_sequence(X1[k].detach().numpy(), Y1[k].detach().numpy())'''

'def stitch(a, r):\n    return a[::r]\n  \ndef plot_sequence(seq, y):\n    plt.figure(figsize = (20, 2.5))\n    stitched = stitch(seq, int(patch_len / patch_skip))\n    spec = np.rot90(np.concatenate(stitched))\n    plt.imshow(spec, interpolation=\'nearest\', aspect=\'auto\', cmap=\'inferno\')\n    plt.colorbar()\n    label_list = []\n    if(len(y.shape) > 0):\n        for i in np.argsort(-y)[:2]:\n            label_list.append(list(classes)[i])\n        plt.title(", ".join(label_list))\n    else:\n        plt.title(list(classes)[y])\n\nk = 3\nX1, Y1 = next(iter(test_loader))\nprint(X1.shape, Y1.shape)\nplot_sequence(X1[k].detach().numpy(), Y1[k].detach().numpy())\nplot_sequence(X1[-k-1].detach().numpy(), Y1[-k-1].detach().numpy())\n\nX1, Y1 = mixup(X1, Y1, num_classes=num_classes)\nplot_sequence(X1[k].detach().numpy(), Y1[k].detach().numpy())'

In [8]:
from bat_2 import Net

max_len = 60
d_model = 64 

nhead = 2
dim_feedforward = 32
num_layers = 2
dropout = 0.3
classifier_dropout = 0.3

model = Net(
    max_len=max_len,
    patch_dim=resize[0]*resize[1] if resize is not None else patch_len * num_bands, # patch_len * num_bands, # 44 * 257 = 11,308
    d_model=d_model,
    num_classes=len(list(classes)),
    nhead=nhead,
    dim_feedforward=dim_feedforward,
    num_layers=num_layers,
    dropout=dropout,
    classifier_dropout=classifier_dropout,
)
model.load_state_dict(torch.load('bat_2_convnet.pth'))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model, device_ids=[0, 1])
    
model.to(device)
print(device)

cuda:0


In [9]:
call_nocall_model = torch.jit.load('../call_nocall/call_nocall.pt')
call_nocall_model.to(device)

RecursiveScriptModule(
  original_name=ResNet
  (conv1): RecursiveScriptModule(original_name=Conv2d)
  (bn1): RecursiveScriptModule(original_name=BatchNorm2d)
  (relu): RecursiveScriptModule(original_name=ReLU)
  (maxpool): RecursiveScriptModule(original_name=MaxPool2d)
  (layer1): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(
      original_name=Block
      (conv1): RecursiveScriptModule(original_name=Conv2d)
      (bn1): RecursiveScriptModule(original_name=BatchNorm2d)
      (conv2): RecursiveScriptModule(original_name=Conv2d)
      (bn2): RecursiveScriptModule(original_name=BatchNorm2d)
      (conv3): RecursiveScriptModule(original_name=Conv2d)
      (bn3): RecursiveScriptModule(original_name=BatchNorm2d)
      (relu): RecursiveScriptModule(original_name=ReLU)
      (identity_downsample): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(original_name=Conv2d)
        (1): RecursiveScriptModule(original_na

In [13]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

corrects = 0.0
mixed_corrects = 0.0

model.eval()
call_nocall_model.eval()

# iterate over test data
for inputs, labels in tqdm.tqdm(test_loader):
    inputs, labels = inputs.to(device), labels.to(device)
    
    cnc_outputs = call_nocall_model(inputs[0].unsqueeze(1))
    cnc_pred = torch.argmax(cnc_outputs, 1) # call indices
    
    if cnc_pred.nonzero().shape[0] > 1:
        output = model(inputs) # Feed Network
        prediction = torch.argmax(output, 1)
        corrects += (prediction == labels).sum().item()
    
print("Test acc:", corrects / (len(test_data)))

# iterate over test data
for inputs, labels in tqdm.tqdm(test_loader):
    inputs, labels = inputs.to(device), labels.to(device)
    inputs, labels = mixup(inputs, labels, num_classes=num_classes)
    
    output = model(inputs.to(device)) # Feed Network
    prediction = torch.argsort(output, 1)[:,-2:]
    target = torch.argsort(labels, 1)[:,-2:]
    mixed_corrects += (torch.count_nonzero(prediction == target, dim=1) + \
                       torch.count_nonzero(prediction.flip(1) == target, dim=1)).sum().item()
    
print("Mixed test acc:", mixed_corrects / (len(test_data) * 2))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4978/4978 [01:48<00:00, 45.74it/s]
  0%|▎                                                                                                                                                     | 10/4978 [00:00<00:50, 97.75it/s]

Test acc: 0.8049417436721575


 66%|█████████████████████████████████████████████████████████████████████████████████████████████████                                                   | 3264/4978 [00:38<00:20, 85.34it/s]


KeyboardInterrupt: 