In [1]:
import os
import copy
import cv2
import random
import shutil
import torch
import torchvision
import IPython
import json

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import albumentations as A

from albumentations.pytorch import ToTensorV2
from PIL import Image
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm
from torchvision.models import resnet50, ResNet50_Weights
from torchmetrics.classification import BinaryAccuracy, BinaryConfusionMatrix, BinaryPrecision, BinaryRecall, BinaryF1Score
from IPython.display import clear_output
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
from tqdm.notebook import trange
from math import ceil
from glob import glob

%matplotlib inline

In [2]:
device = torch.device('cuda:5') if torch.cuda.is_available() else torch.device('cpu')

----

### Model

In [4]:
class TimeDistributed(nn.Module):
    def __init__(self, module):
        super(TimeDistributed, self).__init__()
        self.module = module

    def forward(self, x):
        if len(x.size()) <= 2:
            return self.module(x)

        x_reshape = x.contiguous().view(-1, x.size(-1))

        y = self.module(x_reshape)
        y = y.contiguous().view(x.size(0), -1, y.size(-1))
        return y

In [5]:
class ImageEmbender(nn.Module):
    def __init__(self, emb_dim, need_freeze=False):
        super(ImageEmbender, self).__init__()

        resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        resnet.fc = nn.Linear(in_features=2048, out_features=emb_dim, bias=True)

        if need_freeze:
            for param in resnet.parameters():
                param.requires_grad = False

        self.resnet = resnet

    def forward(self, x):
        x = self.resnet(x)
        return x


class TimeSeriesImageEncoder(nn.Module):
    def __init__(self, emb_dim, hidden_dim, n_layers, bidirectional, dropout):
        super(TimeSeriesImageEncoder, self).__init__()

        self.lstm = nn.LSTM(emb_dim,
                            hidden_dim,
                            num_layers=n_layers,
                            bidirectional=bidirectional,
                            dropout=dropout,
                            batch_first=True
                           )

    def forward(self, x):
        output, (hn, cn) = self.lstm(x)
        return output[:, -1, :]


class ClussifictionHead(nn.Module):
    def __init__(self, n_classes, input_dim, hidden_dim):
        super(ClussifictionHead, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, n_classes)
        )

    def forward(self, x):
        x = self.model(x)
        return x


In [6]:
class TimeSeriesImagesClassificationModel(nn.Module):
    def __init__(self, emb_size, need_freeze_resnet,
                enc_hid_dim, enc_n_layers,
                enc_bidirectional, enc_dropout,
                dec_hid_dim, n_classes):
        super().__init__()

        self.embedding = ImageEmbender(
                                    emb_size,
                                    need_freeze_resnet
                                    )


        self.encoder = TimeSeriesImageEncoder(
                                    emb_size,
                                    enc_hid_dim,
                                    enc_n_layers,
                                    enc_bidirectional,
                                    enc_dropout
                                    )

        self.decoder = ClussifictionHead(
                                    n_classes,
                                    enc_hid_dim,
                                    dec_hid_dim
                                    )

    def forward(self, x):

        embs = []
        for i in range(x.shape[1]):
            embs.append(self.embedding(x[:, i, :, :].squeeze(1)))
        emb_out = torch.stack(embs, axis=1)

        enc_out = self.encoder(emb_out)
        dec_out = self.decoder(enc_out)
        return dec_out


In [7]:
test_model = TimeSeriesImagesClassificationModel(
    emb_size=128,
    need_freeze_resnet=False,
    enc_hid_dim=128,
    enc_n_layers=1,
    enc_bidirectional=False,
    enc_dropout=0.1,
    dec_hid_dim=256,
    n_classes=2
)



In [8]:
test_batch = torch.randn(16, 5, 3, 608, 208)

with torch.no_grad():
    logits = test_model(test_batch)

print('output shape:', logits.shape)
assert len(logits.shape) == 2
assert logits.shape[0] == 16
assert logits.shape[1] == 2

output shape: torch.Size([16, 2])


----

### Data

In [20]:
dataframe = pd.read_csv('./statistic_2022_12_19.csv', index_col=0)
dataframe = dataframe.reset_index()
dataframe.columns = ['add_date', 'is_touched', 'location', 's3_link', 'scan_result',
       'plt_dir']

In [21]:
dataframe.head()

Unnamed: 0,add_date,is_touched,location,s3_link,scan_result,plt_dir
0,2022-12-14 11:53:13.25442,0,K24-51C5,https://s3.mds.yandex.net/rms-cloud/440d6e51-e...,PLT11500719,data/PLT11500719/440d6e51-e9d0-4ea2-a7ab-ee684...
1,2022-11-09 13:55:14.50472,0,K23-32B2,https://s3.mds.yandex.net/rms-cloud/f14dec0d-c...,PLT11298653,data/PLT11298653/f14dec0d-c919-4591-a605-64cb8...
2,2022-12-15 16:24:56.845572,0,K19-46B4,https://s3.mds.yandex.net/rms-cloud/c23671e5-2...,PLT11488403,data/PLT11488403/c23671e5-28b1-4610-8aef-02e9a...
3,2022-11-23 15:41:36.947232,0,K19-49A5,https://s3.mds.yandex.net/rms-cloud/9ebbe98a-d...,PLT11461224,data/PLT11461224/9ebbe98a-dcbf-4a81-87d3-3a23e...
4,2022-12-14 15:24:00.09199,0,K32-08C4,https://s3.mds.yandex.net/rms-cloud/219efc5a-1...,PLT11487213,data/PLT11487213/219efc5a-18ad-4fc7-871d-a0421...


In [22]:
class PackageDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform, n_ts=5):
        self.dataset = dataset
        self.transform = transform
        self.n_ts = n_ts

        self.plts = list(dataset['scan_result'].unique())

    def __len__(self):
        return len(self.plts)

    def __getitem__(self, idx):
        plt_info = self.dataset[self.dataset['scan_result'] == self.plts[idx]]
        plt_info = plt_info.sort_values(['add_date'])


        n_ts = min(self.n_ts, len(plt_info))
        idxs = np.random.choice(range(len(plt_info)), self.n_ts)
        idxs = sorted(idxs)

        imgs_jpeg_path = list(plt_info.iloc[idxs]['plt_dir'])
        is_touched = list(plt_info.iloc[idxs]['is_touched'])

        images = []
        for img_jpeg_path in imgs_jpeg_path:
            image = np.array(Image.open(img_jpeg_path))
            if self.transform is not None:
                transformed = self.transform(image=image)
                image = transformed['image']

            images.append(image)

        images = torch.stack(images, axis=0)

        return {'images': images, 'targets': is_touched}

In [23]:
test_transform = A.Compose(
    [
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        A.Resize(width=208, height=608),
        ToTensorV2()
    ]
)

train_transform = A.Compose(
    [
        A.Resize(width=208, height=608),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

In [24]:
test_dataset = PackageDataset(dataframe, train_transform)

print(test_dataset[1]['images'].shape)
assert len(test_dataset[1]['images'].shape) == 4

torch.Size([5, 3, 608, 208])


---

#### Анализ данных

In [25]:
dataframe[dataframe['scan_result'] == 'PLT11466151'].sort_values(['add_date'])

Unnamed: 0,add_date,is_touched,location,s3_link,scan_result,plt_dir
39203,2022-11-22 11:49:41.409188,0,K22-44B3,https://s3.mds.yandex.net/rms-cloud/24df7595-7...,PLT11466151,data/PLT11466151/24df7595-732e-4eff-b421-52b4f...
27315,2022-11-23 14:48:17.539217,0,K22-44B3,https://s3.mds.yandex.net/rms-cloud/63599d50-5...,PLT11466151,data/PLT11466151/63599d50-599d-42c1-988f-93d8a...
31509,2022-11-23 15:09:02.732171,0,K22-44B3,https://s3.mds.yandex.net/rms-cloud/1d2245ff-b...,PLT11466151,data/PLT11466151/1d2245ff-bc37-45e0-8f31-87d34...
1680,2022-11-24 12:00:19.117362,0,K22-44B3,https://s3.mds.yandex.net/rms-cloud/529de1b2-2...,PLT11466151,data/PLT11466151/529de1b2-2769-4c46-8917-2977b...
15223,2022-11-30 15:31:40.828355,1,K22-44B3,https://s3.mds.yandex.net/rms-cloud/fb78d58d-a...,PLT11466151,data/PLT11466151/fb78d58d-a13a-436d-9085-ad040...
58468,2022-11-30 15:45:19.524823,1,K22-44B3,https://s3.mds.yandex.net/rms-cloud/3bf9215f-e...,PLT11466151,data/PLT11466151/3bf9215f-edb1-426d-b6c4-8c83b...
55853,2022-12-01 09:28:26.021746,1,K22-44B3,https://s3.mds.yandex.net/rms-cloud/9468f872-8...,PLT11466151,data/PLT11466151/9468f872-8d66-4e37-853a-a6e3d...
46282,2022-12-02 12:29:30.655756,1,K22-44B3,https://s3.mds.yandex.net/rms-cloud/0393dd19-8...,PLT11466151,data/PLT11466151/0393dd19-8d4a-4456-bb02-1b0a4...
13257,2022-12-05 10:50:07.043269,0,K21-43B3,https://s3.mds.yandex.net/rms-cloud/cf76012d-1...,PLT11466151,data/PLT11466151/cf76012d-1b7d-42e1-b200-7218c...
58176,2022-12-05 13:16:01.56889,0,K21-43B3,https://s3.mds.yandex.net/rms-cloud/48d51af7-d...,PLT11466151,data/PLT11466151/48d51af7-d0ad-4cbd-aaf6-094be...


In [26]:
s3_links = list(dataframe[dataframe['scan_result'] == 'PLT11466151'].sort_values(['add_date'])['s3_link'])

##### Вывод: мы определяем не вскрытость паллеты, а её тронутость, для этого весьма важно смотреть на предыдущие значения и вот пример:

In [27]:
s3_links

['https://s3.mds.yandex.net/rms-cloud/24df7595-732e-4eff-b421-52b4fa3e1883.jpeg',
 'https://s3.mds.yandex.net/rms-cloud/63599d50-599d-42c1-988f-93d8a33b175e.jpeg',
 'https://s3.mds.yandex.net/rms-cloud/1d2245ff-bc37-45e0-8f31-87d346e80b65.jpeg',
 'https://s3.mds.yandex.net/rms-cloud/529de1b2-2769-4c46-8917-2977b23d5cbd.jpeg',
 'https://s3.mds.yandex.net/rms-cloud/fb78d58d-a13a-436d-9085-ad0407f029f0.jpeg',
 'https://s3.mds.yandex.net/rms-cloud/3bf9215f-edb1-426d-b6c4-8c83b0c4f7fc.jpeg',
 'https://s3.mds.yandex.net/rms-cloud/9468f872-8d66-4e37-853a-a6e3d27066b2.jpeg',
 'https://s3.mds.yandex.net/rms-cloud/0393dd19-8d4a-4456-bb02-1b0a42a49416.jpeg',
 'https://s3.mds.yandex.net/rms-cloud/cf76012d-1b7d-42e1-b200-7218c2a93226.jpeg',
 'https://s3.mds.yandex.net/rms-cloud/48d51af7-d0ad-4cbd-aaf6-094be678ac9c.jpeg',
 'https://s3.mds.yandex.net/rms-cloud/2cea570e-b5a1-4655-9385-cda517f0564a.jpeg',
 'https://s3.mds.yandex.net/rms-cloud/4660664d-ee2b-49f3-a15d-640992cf1c72.jpeg',
 'https://s3.mds

In [28]:
plts = test_dataset.plts
new_dataframe = pd.DataFrame(columns=dataframe.columns)
for plt in tqdm(plts):

    plt_info = dataframe[dataframe['scan_result'] == plt]
    plt_info = plt_info.sort_values(['add_date'])

    prev_location = None
    true_plt = []
    cnt = 0
    for idx, row in plt_info.iterrows():
        if not prev_location or prev_location != row['location']:
            prev_location = row['location']
            cnt += 1
        true_plt.append(row['scan_result'] + str(cnt))

    plt_info['true_plt'] = true_plt
    new_dataframe = pd.concat([new_dataframe, plt_info], ignore_index=True)


  0%|          | 0/10783 [00:00<?, ?it/s]

In [29]:
new_dataframe.head()

Unnamed: 0,add_date,is_touched,location,s3_link,scan_result,plt_dir,true_plt
0,2022-12-12 12:27:23.891609,0,K24-51B5,https://s3.mds.yandex.net/rms-cloud/d04d4a91-3...,PLT11500719,data/PLT11500719/d04d4a91-3a7d-487d-a40f-a248f...,PLT115007191
1,2022-12-12 13:56:21.41017,0,K24-51B5,https://s3.mds.yandex.net/rms-cloud/ea847f4b-8...,PLT11500719,data/PLT11500719/ea847f4b-8623-4acc-bd1d-354fb...,PLT115007191
2,2022-12-14 11:53:13.25442,0,K24-51C5,https://s3.mds.yandex.net/rms-cloud/440d6e51-e...,PLT11500719,data/PLT11500719/440d6e51-e9d0-4ea2-a7ab-ee684...,PLT115007192
3,2022-12-15 12:43:47.89228,0,K24-51B5,https://s3.mds.yandex.net/rms-cloud/2252a4c7-9...,PLT11500719,data/PLT11500719/2252a4c7-9f5f-4f80-9791-f52dd...,PLT115007193
4,2022-12-19 12:21:54.420098,0,K24-51A5,https://s3.mds.yandex.net/rms-cloud/b4fab748-6...,PLT11500719,data/PLT11500719/b4fab748-6600-49fd-b12d-51a0e...,PLT115007194


Изменим датасет и посмотрим на результат:

In [30]:
class PackageDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform, n_ts=5):
        self.dataset = dataset
        self.transform = transform
        self.n_ts = n_ts

        self.plts = list(dataset['true_plt'].unique())

    def __len__(self):
        return len(self.plts)

    def __getitem__(self, idx):
        plt_info = self.dataset[self.dataset['true_plt'] == self.plts[idx]]
        plt_info = plt_info.sort_values(['add_date'])


        n_ts = min(self.n_ts, len(plt_info))
        idxs = np.random.choice(range(len(plt_info)), self.n_ts)
        idxs = sorted(idxs)

        imgs_jpeg_path = list(plt_info.iloc[idxs]['plt_dir'])
        is_touched = 1 in list(plt_info.iloc[idxs]['is_touched'])

        images = []
        for img_jpeg_path in imgs_jpeg_path:
            image = np.array(Image.open(img_jpeg_path))
            if self.transform is not None:
                transformed = self.transform(image=image)
                image = transformed['image']

            images.append(image)

        images = torch.stack(images, axis=0)

        return {'images': images, 'targets': int(is_touched)}

In [31]:
test_dataset = PackageDataset(new_dataframe, train_transform)

In [32]:
def visualize_augmentations(dataset, samples=5):
    """The function visualizes an augmented image"""
    dataset = copy.deepcopy(dataset)
    dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize))])

    glob_cnt = 0
    for idx in range(samples):
        print('>>>>' * 10)
        images, targets = dataset[glob_cnt]['images'], dataset[glob_cnt]['targets']
        while (1 not in targets) or (0 not in targets):
            glob_cnt += 1
            images, targets = dataset[glob_cnt]['images'], dataset[glob_cnt]['targets']
        print(f'PLT idx = {glob_cnt}')

        for image, target in zip(images, targets):
            plt.imshow(image.permute(1, 2, 0))
            plt.title(f"Image, is_touched={target}")
            plt.tight_layout()
            plt.show()
        glob_cnt += 1


In [33]:
visualize_augmentations(test_dataset, samples=7)

>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>


TypeError: argument of type 'int' is not iterable

In [34]:
new_dataframe.head()

Unnamed: 0,add_date,is_touched,location,s3_link,scan_result,plt_dir,true_plt
0,2022-12-12 12:27:23.891609,0,K24-51B5,https://s3.mds.yandex.net/rms-cloud/d04d4a91-3...,PLT11500719,data/PLT11500719/d04d4a91-3a7d-487d-a40f-a248f...,PLT115007191
1,2022-12-12 13:56:21.41017,0,K24-51B5,https://s3.mds.yandex.net/rms-cloud/ea847f4b-8...,PLT11500719,data/PLT11500719/ea847f4b-8623-4acc-bd1d-354fb...,PLT115007191
2,2022-12-14 11:53:13.25442,0,K24-51C5,https://s3.mds.yandex.net/rms-cloud/440d6e51-e...,PLT11500719,data/PLT11500719/440d6e51-e9d0-4ea2-a7ab-ee684...,PLT115007192
3,2022-12-15 12:43:47.89228,0,K24-51B5,https://s3.mds.yandex.net/rms-cloud/2252a4c7-9...,PLT11500719,data/PLT11500719/2252a4c7-9f5f-4f80-9791-f52dd...,PLT115007193
4,2022-12-19 12:21:54.420098,0,K24-51A5,https://s3.mds.yandex.net/rms-cloud/b4fab748-6...,PLT11500719,data/PLT11500719/b4fab748-6600-49fd-b12d-51a0e...,PLT115007194


In [35]:
len(new_dataframe)

109808

In [36]:
len(new_dataframe.groupby(['scan_result'])['is_touched'].count())

10783

In [37]:
len(new_dataframe.groupby(['true_plt'])['is_touched'].count())

28089

In [38]:
np.sum(new_dataframe.groupby(['true_plt'])['is_touched'].sum() != new_dataframe.groupby(['true_plt'])['is_touched'].count()) - np.sum(new_dataframe.groupby(['true_plt'])['is_touched'].sum() == 0)

594

----

### Разбивка данных, тестирование

In [39]:
plts = np.unique(new_dataframe['scan_result'])
train_pallets, test_pallets = train_test_split(plts, test_size=0.2, random_state=42)

data_train, data_test = new_dataframe.loc[new_dataframe['scan_result'].isin(train_pallets)], new_dataframe.loc[new_dataframe['scan_result'].isin(test_pallets)]
len(data_train) / len(data_test)

3.5914032446897473

In [40]:
data_train.head()

Unnamed: 0,add_date,is_touched,location,s3_link,scan_result,plt_dir,true_plt
0,2022-12-12 12:27:23.891609,0,K24-51B5,https://s3.mds.yandex.net/rms-cloud/d04d4a91-3...,PLT11500719,data/PLT11500719/d04d4a91-3a7d-487d-a40f-a248f...,PLT115007191
1,2022-12-12 13:56:21.41017,0,K24-51B5,https://s3.mds.yandex.net/rms-cloud/ea847f4b-8...,PLT11500719,data/PLT11500719/ea847f4b-8623-4acc-bd1d-354fb...,PLT115007191
2,2022-12-14 11:53:13.25442,0,K24-51C5,https://s3.mds.yandex.net/rms-cloud/440d6e51-e...,PLT11500719,data/PLT11500719/440d6e51-e9d0-4ea2-a7ab-ee684...,PLT115007192
3,2022-12-15 12:43:47.89228,0,K24-51B5,https://s3.mds.yandex.net/rms-cloud/2252a4c7-9...,PLT11500719,data/PLT11500719/2252a4c7-9f5f-4f80-9791-f52dd...,PLT115007193
4,2022-12-19 12:21:54.420098,0,K24-51A5,https://s3.mds.yandex.net/rms-cloud/b4fab748-6...,PLT11500719,data/PLT11500719/b4fab748-6600-49fd-b12d-51a0e...,PLT115007194


In [41]:
train_dataset = PackageDataset(data_train, train_transform)

test_dataset = PackageDataset(data_test, test_transform)

In [42]:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, shuffle=True,
    batch_size=16, num_workers=1
)

val_dataloader = torch.utils.data.DataLoader(
    test_dataset, shuffle=False,
    batch_size=16, num_workers=1
)

In [43]:
data_test

Unnamed: 0,add_date,is_touched,location,s3_link,scan_result,plt_dir,true_plt
61,2022-12-07 14:19:20.234102,0,K19-46B4,https://s3.mds.yandex.net/rms-cloud/eb6d9f52-c...,PLT11488403,data/PLT11488403/eb6d9f52-ccb2-4294-a67e-e33ee...,PLT114884031
62,2022-12-08 14:17:51.891244,0,K19-46C4,https://s3.mds.yandex.net/rms-cloud/2f19d9d0-7...,PLT11488403,data/PLT11488403/2f19d9d0-74c2-47df-bd91-178e1...,PLT114884032
63,2022-12-09 17:54:46.370113,0,K19-46C4,https://s3.mds.yandex.net/rms-cloud/93ff1f4f-9...,PLT11488403,data/PLT11488403/93ff1f4f-96d4-45b9-8174-690de...,PLT114884032
64,2022-12-12 12:25:40.367969,0,K19-46B4,https://s3.mds.yandex.net/rms-cloud/82efec1e-5...,PLT11488403,data/PLT11488403/82efec1e-55d5-471b-8126-40580...,PLT114884033
65,2022-12-12 13:39:50.43196,0,K19-46C4,https://s3.mds.yandex.net/rms-cloud/a0a950dd-d...,PLT11488403,data/PLT11488403/a0a950dd-dae3-4679-92f4-85edb...,PLT114884034
...,...,...,...,...,...,...,...
109767,2022-12-19 12:26:30.157283,0,K28-05C4,https://s3.mds.yandex.net/rms-cloud/e15fedcb-b...,PLT11498964,data/PLT11498964/e15fedcb-b958-43ad-9554-b78b7...,PLT114989641
109768,2022-12-19 12:26:30.904178,0,K28-09B3,https://s3.mds.yandex.net/rms-cloud/05530e4d-6...,PLT11497791,data/PLT11497791/05530e4d-6063-4c61-a903-17f61...,PLT114977911
109773,2022-12-14 17:25:21.31825,0,K46-44C4,https://s3.mds.yandex.net/rms-cloud/175795eb-a...,PLT11461779,data/PLT11461779/175795eb-a0be-4228-b4c0-16be3...,PLT114617791
109780,2022-12-15 17:08:13.812858,0,K35-32C6,https://s3.mds.yandex.net/rms-cloud/cb852a0f-f...,PLT11508392,data/PLT11508392/cb852a0f-fde4-458b-b60c-f5f19...,PLT115083921


In [None]:
for batch in val_dataloader:
    batch
    with torch.no_grad():
        logits = test_model(batch['images'])
    break

----

### Обучение

In [44]:
def train_loop(model, train_loader, criterion, optimizer):
    train_loss = num_samples = 0.0
    sigmoid = nn.Sigmoid()
    for batch in tqdm(train_loader):
        batch_pred = model(batch['images'].to(device))
        batch_pred = batch_pred.squeeze(-1)
        loss = criterion(sigmoid(batch_pred), batch['targets'].to(device).float())

        num_samples += len(batch_pred)
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

    train_loss = train_loss / num_samples
    return model, optimizer, train_loss

In [45]:
def test_loop(model, val_loader, criterion, metrics):
    sigmoid = nn.Sigmoid()
    with torch.no_grad():
        logs_num = 0
        val_losses = 0.0
        accuracy = []
        precision = []
        recall = []
        f1_score = []
        confusion_matrix = torch.zeros((2, 2))

        for batch in tqdm(val_loader):
            preds = model(batch['images'].to(device))
            preds = sigmoid(preds.squeeze(-1))
            target = batch['targets'].to(device)

            # compute loss
            #
            loss = criterion(preds, batch['targets'].to(device).float())
            val_losses += loss.item()

            # compute metrics
            #
            accuracy.append(metrics['accuracy'](preds, target).cpu())
            precision.append(metrics['precision'](preds, target).cpu())
            recall.append(metrics['recall'](preds, target).cpu())
            f1_score.append(metrics['f1_score'](preds, target).cpu())
            confusion_matrix += metrics['confusion_matrix'](preds, target).cpu()

            logs_num += len(target)

    return val_losses / logs_num, np.mean(accuracy), np.mean(precision), np.mean(recall), np.mean(f1_score), confusion_matrix

In [46]:
def learning_loop(model, optimizer, train_loader, val_loader,
                loss_fn, metrics, epochs=10):

    plot_result = {
        'train_loss': [],
        'test_loss': [],
        'accuracy': [],
        'precision': [],
        'recall': [],
        'f1_score': [],
        'confusion_matrix': []
    }

    best_recall = float('inf')
    best_model = None

    for epoch in range(1, epochs+1):
        # train
        #
        print(f'#{epoch}/{epochs}:')
        model, optimizer, train_loss = train_loop(model, train_loader, loss_fn, optimizer)
        plot_result['train_loss'].append(train_loss)

        # test
        #
        test_losses, accuracy, precision, recall, f1_score, confusion_matrix = test_loop(model, val_loader, loss_fn, metrics)
        plot_result['test_loss'].append(test_losses)
        plot_result['accuracy'] += [accuracy]
        plot_result['precision'] += [precision]
        plot_result['recall'] += [recall]
        plot_result['f1_score'] += [f1_score]
        plot_result['confusion_matrix'] += [confusion_matrix]


        if recall < best_recall:
            best_recall = recall
            best_model = copy.deepcopy(model)

        # plot
        #
        clear_output(True)

        fig = make_subplots(rows=3, cols=2)

        for idx, (key, value) in enumerate(plot_result.items()):

            if key == 'confusion_matrix':
                print('Confusion Matrix:')
                for matrix in value:
                    cnfs_mtrx = '''
                    TP = {0}; FN = {1}
                    FP = {2}; TN = {3}
                    '''
                    cnfs_mtrx = cnfs_mtrx.format(matrix[1, 1], matrix[1, 0], matrix[0, 1], matrix[0, 0])
                    print(cnfs_mtrx)
                continue

            fig.add_trace(
                go.Scatter(y=value, name=key),
                row=idx // 2 + 1, col= idx % 2 + 1
            )

        fig.update_layout(height=600, width=800, title_text=f'#{epoch}/{epochs}:')
        fig.show()

    return model, optimizer, plot_result

In [47]:
def create_model_and_optimizer(net, lr=1e-4, beta1=0.9, beta2=0.999, device=device):
    model = net.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr, [beta1, beta2])
    return model, optimizer

In [48]:
metrics = {
    'accuracy': BinaryAccuracy().to(device),
    'precision': BinaryPrecision().to(device),
    'recall': BinaryRecall().to(device),
    'f1_score': BinaryF1Score().to(device),
    'confusion_matrix': BinaryConfusionMatrix().to(device)
}
loss_fn = nn.BCELoss()

model = TimeSeriesImagesClassificationModel(
    emb_size=128, 
    need_freeze_resnet=False, 
    enc_hid_dim=128, 
    enc_n_layers=1, 
    enc_bidirectional=False, 
    enc_dropout=0.1,
    dec_hid_dim=256, 
    n_classes=1
)

model, optimizer_model = create_model_and_optimizer(model)



In [50]:
model, optimizer_model, plot_result = learning_loop(
    model=model,
    optimizer=optimizer_model,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    loss_fn=loss_fn,
    metrics=metrics,
    epochs=1
)

#1/1:


  0%|          | 0/1384 [00:00<?, ?it/s]

In [None]:
predictions = []
labels = []

In [None]:
conf = np.array(predictions)[:, 1]
labels = np.array(labels)

In [None]:
lr_precision, lr_recall, _ = precision_recall_curve(labels, conf)
lr_auc = auc(lr_recall, lr_precision)
# summarize scores
print('auc=%.3f' % (lr_auc))
# plot the precision-recall curves
no_skill = len(labels[labels==1]) / len(labels)
pyplot.plot([0, 1], [no_skill, no_skill], linestyle='--', label='No Skill')
pyplot.plot(lr_recall, lr_precision, marker='.', label='Trained')
# axis labels
pyplot.xlabel('Recall')
pyplot.ylabel('Precision')
# show the legend
pyplot.legend()
# show the plot
pyplot.show()