# Пробуем простую полносвязную сеть

## Загрузка данных

In [22]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
plt.style.use('seaborn')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from skimage import transform
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.preprocessing import OrdinalEncoder
from read_data import *
from read_data import good_cols
%matplotlib inline

In [None]:
all_df = get_df()
TARGET_LEN = 2000
activities = sep_by_len(all_df, 20000)
cut_df = cut_act(activities, TARGET_LEN, count=4)
n_df = normalize_df(cut_df)
n_df.sample()

 11%|█████████▎                                                                          | 1/9 [00:03<00:30,  3.81s/it]

Получаем вектор признаков

In [None]:
X = n_df.iloc[:, 1:].values

Y = n_df.iloc[:, 0] # целевая переменная
X.shape, Y.shape

In [None]:
X_flat = get_flatten(X)
X_flat.shape

In [None]:
def y_encode(y_data):
    y_targ = np.zeros_like(y_data)
    code = list(np.unique(y_data))
    for i, el in enumerate(y_data):
        y_targ[i] = code.index(el)
        
    return np.array(y_targ, dtype = np.int32)

In [None]:
y_targ = y_encode(Y)
y_targ

In [None]:
X_train_tensor, X_val_tensor, y_train_tensor, y_val_tensor = \
    train_test_split(torch.FloatTensor(X_flat), torch.LongTensor(y_targ), random_state=42, test_size = 0.3)
X_train_tensor.shape, y_train_tensor.shape, np.unique(y_train_tensor), np.unique(y_val_tensor)

In [None]:
class SignalDataset(Dataset):
    def __init__(self, data_X, data_Y):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.X = data_X
        self.Y = data_Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        input_data = self.X[idx]
        label = self.Y[idx]
        
        return input_data, label

In [None]:
BATCH_SIZE = 64

trainset = SignalDataset(X_train_tensor, y_train_tensor)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=0)

testset = SignalDataset(X_val_tensor, y_val_tensor)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=0)
len(trainset)

## Подготовка модели

In [None]:
activation = nn.LeakyReLU

net = torch.nn.Sequential(
    nn.Linear(31*TARGET_LEN, 128),
    activation(),
    nn.Linear(128,128),
    activation(),
    nn.Linear(128,len(np.unique(y_targ)))
)

## Обучение

In [63]:
NUM_EPOCHS = 100

loss_fn = torch.nn.CrossEntropyLoss(size_average=False)
losses = []

learning_rate = 3e-3
optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate)
ebar = tqdm(range(NUM_EPOCHS))

for epoch_num in ebar:
    running_loss = 0.0
    iter_num = 0
#     pbar = tqdm(enumerate(trainloader), total = len(trainloader))
    pbar = enumerate(trainloader)
    for iter_num, (X_batch, y_batch) in pbar:
        iter_num +=1
        # forward (подсчёт ответа с текущими весами)
        y_pred = net(X_batch)
        
        # вычисляем loss'ы
        loss = loss_fn(y_pred, y_batch)
        running_loss += loss.item()
                     
        # зануляем градиенты
        optimizer.zero_grad()

        # backward (подсчёт новых градиентов)
        loss.backward()

        # обновляем веса
        optimizer.step()
        
    line = '[{}/{}] current loss: {}'.format(epoch_num, NUM_EPOCHS, running_loss/len(trainset))
    ebar.set_description(line)
    
    losses.append(running_loss/len(trainset))
        
    with torch.no_grad():
        y_pred = torch.softmax(net(X_val_tensor), 1)
        y = torch.argmax(y_pred, axis=1)
        line = '[{}/{}] current valid score: {}'.format(epoch_num+1, NUM_EPOCHS, f1_score(y_val_tensor, y, average='macro'))
        print(line)
#         ebar.set_description(line)

[0/100] current loss: 25.463682469313706:   1%|▍                                       | 1/100 [00:00<01:02,  1.58it/s]

[1/100] current valid score: 0.02040816326530612


[1/100] current loss: 11.788713718817487:   2%|▊                                       | 2/100 [00:01<00:59,  1.65it/s]

[2/100] current valid score: 0.011494252873563218


[2/100] current loss: 4.663786477189723:   3%|█▏                                       | 3/100 [00:01<00:56,  1.71it/s]

[3/100] current valid score: 0.06330749354005168


[3/100] current loss: 2.508781774257257:   4%|█▋                                       | 4/100 [00:02<00:54,  1.75it/s]

[4/100] current valid score: 0.08922558922558922


[4/100] current loss: 1.9033783703315548:   5%|██                                      | 5/100 [00:02<00:57,  1.65it/s]

[5/100] current valid score: 0.111762789182144


[5/100] current loss: 1.5230194107303774:   6%|██▍                                     | 6/100 [00:03<00:57,  1.64it/s]

[6/100] current valid score: 0.1255299484466151


[6/100] current loss: 1.2038786042996539:   7%|██▊                                     | 7/100 [00:04<00:57,  1.63it/s]

[7/100] current valid score: 0.1737230362230362


[7/100] current loss: 0.9868795813583746:   8%|███▏                                    | 8/100 [00:04<00:54,  1.68it/s]

[8/100] current valid score: 0.15265151515151515


[8/100] current loss: 0.8728048510667754:   9%|███▌                                    | 9/100 [00:05<00:53,  1.70it/s]

[9/100] current valid score: 0.21928349428349428


[9/100] current loss: 0.7322418941714899:  10%|███▉                                   | 10/100 [00:05<00:51,  1.74it/s]

[10/100] current valid score: 0.1506072631072631


[10/100] current loss: 0.528968307060924:  11%|████▎                                  | 11/100 [00:06<00:50,  1.75it/s]

[11/100] current valid score: 0.21119528619528619


[11/100] current loss: 0.4416452326425692:  12%|████▌                                 | 12/100 [00:06<00:49,  1.77it/s]

[12/100] current valid score: 0.2057118807118807


[12/100] current loss: 0.35723919597098497:  13%|████▊                                | 13/100 [00:07<00:48,  1.78it/s]

[13/100] current valid score: 0.16911881977671453


[13/100] current loss: 0.3325507999435673:  14%|█████▎                                | 14/100 [00:08<00:47,  1.79it/s]

[14/100] current valid score: 0.3345504441092677


[14/100] current loss: 0.32311412764758596:  15%|█████▌                               | 15/100 [00:08<00:47,  1.80it/s]

[15/100] current valid score: 0.30304232804232806


[15/100] current loss: 0.3396395066889321:  16%|██████                                | 16/100 [00:09<00:47,  1.77it/s]

[16/100] current valid score: 0.20323979139768614


[16/100] current loss: 0.2377529493192347:  17%|██████▍                               | 17/100 [00:09<00:48,  1.72it/s]

[17/100] current valid score: 0.19365079365079366


[17/100] current loss: 0.2603406266468327:  18%|██████▊                               | 18/100 [00:10<00:49,  1.66it/s]

[18/100] current valid score: 0.12475429975429975


[18/100] current loss: 0.3204937805005205:  19%|███████▏                              | 19/100 [00:11<00:52,  1.55it/s]

[19/100] current valid score: 0.2661356209150327


[19/100] current loss: 0.18936655434166513:  20%|███████▍                             | 20/100 [00:11<00:50,  1.58it/s]

[20/100] current valid score: 0.24802789802789804


[20/100] current loss: 0.1262513970941063:  21%|███████▉                              | 21/100 [00:12<00:49,  1.60it/s]

[21/100] current valid score: 0.23349673202614382


[21/100] current loss: 0.13595409606530415:  22%|████████▏                            | 22/100 [00:13<00:48,  1.62it/s]

[22/100] current valid score: 0.2064407814407814


[22/100] current loss: 0.06599564935133709:  23%|████████▌                            | 23/100 [00:13<00:47,  1.63it/s]

[23/100] current valid score: 0.2715549557654821


[23/100] current loss: 0.025680499711656958:  24%|████████▋                           | 24/100 [00:14<00:45,  1.66it/s]

[24/100] current valid score: 0.24994172494172495


[24/100] current loss: 0.02397162376380548:  25%|█████████▎                           | 25/100 [00:14<00:44,  1.69it/s]

[25/100] current valid score: 0.24049961476432066


[25/100] current loss: 0.020750738922658007:  26%|█████████▎                          | 26/100 [00:15<00:43,  1.71it/s]

[26/100] current valid score: 0.24751944914988397


[26/100] current loss: 0.01300724444350576:  27%|█████████▉                           | 27/100 [00:15<00:42,  1.74it/s]

[27/100] current valid score: 0.3962962962962962


[27/100] current loss: 0.007379097788314509:  28%|██████████                          | 28/100 [00:16<00:42,  1.68it/s]

[28/100] current valid score: 0.28257080610021784


[28/100] current loss: 0.0049435582193659575:  29%|██████████▏                        | 29/100 [00:17<00:42,  1.68it/s]

[29/100] current valid score: 0.35579489954489957


[29/100] current loss: 0.004280445688381427:  30%|██████████▊                         | 30/100 [00:17<00:41,  1.69it/s]

[30/100] current valid score: 0.34960317460317464


[30/100] current loss: 0.003689160010194391:  31%|███████████▏                        | 31/100 [00:18<00:40,  1.72it/s]

[31/100] current valid score: 0.310515873015873


[31/100] current loss: 0.0033127291022035164:  32%|███████████▏                       | 32/100 [00:18<00:39,  1.72it/s]

[32/100] current valid score: 0.310515873015873


[32/100] current loss: 0.00298455148571875:  33%|████████████▏                        | 33/100 [00:19<00:38,  1.73it/s]

[33/100] current valid score: 0.2773809523809524


[33/100] current loss: 0.002851909087077389:  34%|████████████▏                       | 34/100 [00:19<00:37,  1.74it/s]

[34/100] current valid score: 0.3001984126984127


[34/100] current loss: 0.002764698254262529:  35%|████████████▌                       | 35/100 [00:20<00:37,  1.73it/s]

[35/100] current valid score: 0.32162698412698415


[35/100] current loss: 0.002666095847157928:  36%|████████████▉                       | 36/100 [00:21<00:36,  1.73it/s]

[36/100] current valid score: 0.3001984126984127


[36/100] current loss: 0.0025603166834367967:  37%|████████████▉                      | 37/100 [00:21<00:35,  1.77it/s]

[37/100] current valid score: 0.310515873015873


[37/100] current loss: 0.0024483795997088517:  38%|█████████████▎                     | 38/100 [00:22<00:34,  1.77it/s]

[38/100] current valid score: 0.310515873015873


[38/100] current loss: 0.0023481939290839484:  39%|█████████████▋                     | 39/100 [00:22<00:34,  1.78it/s]

[39/100] current valid score: 0.32162698412698415


[39/100] current loss: 0.0022763555218291475:  40%|██████████████                     | 40/100 [00:23<00:33,  1.78it/s]

[40/100] current valid score: 0.3113095238095238


[40/100] current loss: 0.0021925852520436775:  41%|██████████████▎                    | 41/100 [00:23<00:33,  1.78it/s]

[41/100] current valid score: 0.32162698412698415


[41/100] current loss: 0.0021081066773674354:  42%|██████████████▋                    | 42/100 [00:24<00:32,  1.76it/s]

[42/100] current valid score: 0.32162698412698415


[42/100] current loss: 0.002042111478806511:  43%|███████████████▍                    | 43/100 [00:25<00:32,  1.74it/s]

[43/100] current valid score: 0.32162698412698415


[43/100] current loss: 0.0019498102033768242:  44%|███████████████▍                   | 44/100 [00:25<00:31,  1.76it/s]

[44/100] current valid score: 0.32162698412698415


[44/100] current loss: 0.001857941091908672:  45%|████████████████▏                   | 45/100 [00:26<00:31,  1.76it/s]

[45/100] current valid score: 0.32162698412698415


[45/100] current loss: 0.0017956787856613717:  46%|████████████████                   | 46/100 [00:26<00:30,  1.75it/s]

[46/100] current valid score: 0.32162698412698415


[46/100] current loss: 0.0016984856664770988:  47%|████████████████▍                  | 47/100 [00:27<00:30,  1.76it/s]

[47/100] current valid score: 0.3255952380952381


[47/100] current loss: 0.001618803009330257:  48%|█████████████████▎                  | 48/100 [00:27<00:29,  1.77it/s]

[48/100] current valid score: 0.3248677248677249


[48/100] current loss: 0.0015629230566867969:  49%|█████████████████▏                 | 49/100 [00:28<00:29,  1.70it/s]

[49/100] current valid score: 0.3255952380952381


[49/100] current loss: 0.0014723967155063056:  50%|█████████████████▌                 | 50/100 [00:29<00:28,  1.73it/s]

[50/100] current valid score: 0.3208994708994709


[50/100] current loss: 0.0014636425180284958:  51%|█████████████████▊                 | 51/100 [00:29<00:28,  1.74it/s]

[51/100] current valid score: 0.3363095238095238


[51/100] current loss: 0.0013978041526747913:  52%|██████████████████▏                | 52/100 [00:30<00:27,  1.75it/s]

[52/100] current valid score: 0.3972120472120472


[52/100] current loss: 0.0013479503643948857:  53%|██████████████████▌                | 53/100 [00:30<00:26,  1.77it/s]

[53/100] current valid score: 0.34742063492063496


[53/100] current loss: 0.0013422704097337838:  54%|██████████████████▉                | 54/100 [00:31<00:25,  1.77it/s]

[54/100] current valid score: 0.40820105820105823


[54/100] current loss: 0.0012785316316214035:  55%|███████████████████▎               | 55/100 [00:31<00:25,  1.77it/s]

[55/100] current valid score: 0.3335317460317461


[55/100] current loss: 0.0012271057254051774:  56%|███████████████████▌               | 56/100 [00:32<00:24,  1.76it/s]

[56/100] current valid score: 0.3248677248677249


[56/100] current loss: 0.0012016967014689755:  57%|███████████████████▉               | 57/100 [00:33<00:25,  1.71it/s]

[57/100] current valid score: 0.4121693121693122


[57/100] current loss: 0.0011706946917423389:  58%|████████████████████▎              | 58/100 [00:33<00:24,  1.74it/s]

[58/100] current valid score: 0.41289682539682543


[58/100] current loss: 0.0011383400003357632:  59%|████████████████████▋              | 59/100 [00:34<00:23,  1.76it/s]

[59/100] current valid score: 0.406547619047619


[59/100] current loss: 0.0011121172124777382:  60%|█████████████████████              | 60/100 [00:34<00:22,  1.77it/s]

[60/100] current valid score: 0.41051587301587306


[60/100] current loss: 0.0011105092424081593:  61%|█████████████████████▎             | 61/100 [00:35<00:21,  1.78it/s]

[61/100] current valid score: 0.4121693121693122


[61/100] current loss: 0.00105512096751027:  62%|██████████████████████▉              | 62/100 [00:35<00:21,  1.75it/s]

[62/100] current valid score: 0.4121693121693122


[62/100] current loss: 0.0010288665676868058:  63%|██████████████████████             | 63/100 [00:36<00:22,  1.66it/s]

[63/100] current valid score: 0.4121693121693122


[63/100] current loss: 0.0010033136960573312:  64%|██████████████████████▍            | 64/100 [00:37<00:21,  1.66it/s]

[64/100] current valid score: 0.4121693121693122


[64/100] current loss: 0.0009758154884344194:  65%|██████████████████████▊            | 65/100 [00:37<00:20,  1.67it/s]

[65/100] current valid score: 0.4121693121693122


[65/100] current loss: 0.0009588868077087208:  66%|███████████████████████            | 66/100 [00:38<00:20,  1.67it/s]

[66/100] current valid score: 0.4121693121693122


[66/100] current loss: 0.0009238535720037251:  67%|███████████████████████▍           | 67/100 [00:39<00:19,  1.66it/s]

[67/100] current valid score: 0.4121693121693122


[67/100] current loss: 0.0009054137560410229:  68%|███████████████████████▊           | 68/100 [00:39<00:19,  1.61it/s]

[68/100] current valid score: 0.4121693121693122


[68/100] current loss: 0.0008850606930692022:  69%|████████████████████████▏          | 69/100 [00:40<00:18,  1.67it/s]

[69/100] current valid score: 0.4121693121693122


[69/100] current loss: 0.0008573209548868784:  70%|████████████████████████▌          | 70/100 [00:40<00:17,  1.71it/s]

[70/100] current valid score: 0.41031746031746036


[70/100] current loss: 0.0008390304079748751:  71%|████████████████████████▊          | 71/100 [00:41<00:16,  1.74it/s]

[71/100] current valid score: 0.3964285714285715


[71/100] current loss: 0.000829186525226124:  72%|█████████████████████████▉          | 72/100 [00:41<00:16,  1.74it/s]

[72/100] current valid score: 0.41031746031746036


[72/100] current loss: 0.0008018999511936331:  73%|█████████████████████████▌         | 73/100 [00:42<00:15,  1.74it/s]

[73/100] current valid score: 0.40818070818070823


[73/100] current loss: 0.0007935649056623621:  74%|█████████████████████████▉         | 74/100 [00:43<00:15,  1.73it/s]

[74/100] current valid score: 0.40818070818070823


[74/100] current loss: 0.0007757636892601727:  75%|██████████████████████████▎        | 75/100 [00:43<00:14,  1.77it/s]

[75/100] current valid score: 0.41003256003256006


[75/100] current loss: 0.0007548224922966181:  76%|██████████████████████████▌        | 76/100 [00:44<00:13,  1.75it/s]

[76/100] current valid score: 0.40818070818070823


[76/100] current loss: 0.000738044735044241:  77%|███████████████████████████▋        | 77/100 [00:44<00:13,  1.76it/s]

[77/100] current valid score: 0.40818070818070823


[77/100] current loss: 0.0007250743494104079:  78%|███████████████████████████▎       | 78/100 [00:45<00:12,  1.79it/s]

[78/100] current valid score: 0.3964285714285715


[78/100] current loss: 0.0007163527671520303:  79%|███████████████████████████▋       | 79/100 [00:45<00:11,  1.75it/s]

[79/100] current valid score: 0.40818070818070823


[79/100] current loss: 0.0006885158609810883:  80%|████████████████████████████       | 80/100 [00:46<00:11,  1.76it/s]

[80/100] current valid score: 0.3964285714285715


[80/100] current loss: 0.000689541489431044:  81%|█████████████████████████████▏      | 81/100 [00:46<00:10,  1.79it/s]

[81/100] current valid score: 0.40818070818070823


[81/100] current loss: 0.0006705929517624824:  82%|████████████████████████████▋      | 82/100 [00:47<00:10,  1.78it/s]

[82/100] current valid score: 0.40705636955636954


[82/100] current loss: 0.0006573124357113024:  83%|█████████████████████████████      | 83/100 [00:48<00:09,  1.79it/s]

[83/100] current valid score: 0.40705636955636954


[83/100] current loss: 0.0006496917108088974:  84%|█████████████████████████████▍     | 84/100 [00:48<00:08,  1.81it/s]

[84/100] current valid score: 0.40705636955636954


[84/100] current loss: 0.0006336284256199511:  85%|█████████████████████████████▊     | 85/100 [00:49<00:08,  1.74it/s]

[85/100] current valid score: 0.3964285714285715


[85/100] current loss: 0.0006208249549495011:  86%|██████████████████████████████     | 86/100 [00:49<00:07,  1.75it/s]

[86/100] current valid score: 0.3953042328042328


[86/100] current loss: 0.0006166975415213322:  87%|██████████████████████████████▍    | 87/100 [00:50<00:07,  1.77it/s]

[87/100] current valid score: 0.3849002849002849


[87/100] current loss: 0.0006008622259265039:  88%|██████████████████████████████▊    | 88/100 [00:50<00:06,  1.79it/s]

[88/100] current valid score: 0.3953042328042328


[88/100] current loss: 0.0006064532922838277:  89%|███████████████████████████████▏   | 89/100 [00:51<00:06,  1.70it/s]

[89/100] current valid score: 0.3953042328042328


[89/100] current loss: 0.0005777782487978296:  90%|███████████████████████████████▌   | 90/100 [00:52<00:06,  1.60it/s]

[90/100] current valid score: 0.3953042328042328


[90/100] current loss: 0.0005625990537426821:  91%|███████████████████████████████▊   | 91/100 [00:52<00:05,  1.66it/s]

[91/100] current valid score: 0.3953042328042328


[91/100] current loss: 0.0005598171591395285:  92%|████████████████████████████████▏  | 92/100 [00:53<00:04,  1.64it/s]

[92/100] current valid score: 0.3849002849002849


[92/100] current loss: 0.0005550880583260603:  93%|████████████████████████████████▌  | 93/100 [00:54<00:04,  1.64it/s]

[93/100] current valid score: 0.3953042328042328


[93/100] current loss: 0.0005416941480726246:  94%|████████████████████████████████▉  | 94/100 [00:54<00:03,  1.61it/s]

[94/100] current valid score: 0.3953042328042328


[94/100] current loss: 0.0005294445266083973:  95%|█████████████████████████████████▎ | 95/100 [00:55<00:03,  1.57it/s]

[95/100] current valid score: 0.3953042328042328


[95/100] current loss: 0.0005155467897714153:  96%|█████████████████████████████████▌ | 96/100 [00:55<00:02,  1.60it/s]

[96/100] current valid score: 0.3953042328042328


[96/100] current loss: 0.0005048839072507572:  97%|█████████████████████████████████▉ | 97/100 [00:56<00:01,  1.60it/s]

[97/100] current valid score: 0.3953042328042328


[97/100] current loss: 0.0005005555893710958:  98%|██████████████████████████████████▎| 98/100 [00:57<00:01,  1.57it/s]

[98/100] current valid score: 0.3953042328042328


[98/100] current loss: 0.000488453730183646:  99%|███████████████████████████████████▋| 99/100 [00:57<00:00,  1.58it/s]

[99/100] current valid score: 0.3953042328042328


[99/100] current loss: 0.0004893214123823294: 100%|██████████████████████████████████| 100/100 [00:58<00:00,  1.71it/s]

[100/100] current valid score: 0.3953042328042328





In [None]:
fig, axs = plt.subplots(1,1,figsize=(20,8))
plt.plot(np.arange(NUM_EPOCHS), losses)
plt.show()

In [None]:
for i in range(len(np.unique(y_val_tensor))):
    idx = np.argwhere(y_val_tensor==i)[0]
    y_pred = torch.sigmoid(net(X_val_tensor[idx]))
#     print(torch.sum(y_pred,axis=1))
#     print(X_val_tensor.shape, idx.shape, y_val_tensor.shape)
    y = torch.argmax(y_pred, axis=1)
#     print(y_pred)
    print(np.unique(Y)[i], len(y[y==i])/len(idx))

In [266]:
with torch.no_grad():
    y_pred = torch.softmax(net(X_val_tensor), 1)
    y = torch.argmax(y_pred, axis=1)
#     print(net(X_val_tensor).shape)
    print(roc_auc_score(y_val_tensor, y_pred, multi_class='ovo'))
    print(f1_score(y_val_tensor, y, average='macro'))

0.9093150715513585
0.7409394506625
