Бинарная классификация:
1. Терминал с дефектом
2. Терминал без дефекта

В качестве модели используется трансформер.

# Загрузка библиотек

In [1]:
!pip -q install vit_pytorch linformer

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.0/92.0 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.pyplot as plt
import torch
from torch import nn
from torchvision import datasets, transforms, models
from collections import OrderedDict
from PIL import Image
import PIL
from torch.optim import lr_scheduler
import os
import random
from sklearn.model_selection import train_test_split
import shutil
import gc
from tqdm import tqdm
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader, Dataset
from IPython.display import clear_output
from vit_pytorch.efficient import ViT
from linformer import Linformer
from torch.optim.lr_scheduler import StepLR

In [3]:
RANDOM_STATE = 42

In [4]:
seed = RANDOM_STATE
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [6]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Загрузка данных

In [12]:
!sudo apt install unar
!unar 'drive/MyDrive/sorted_data_merged.rar'

[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
  sorted_data_merged/receipts/ЗНО0301571566_TS903168069-05e1da94-c5f2-4898-8217-a14426b35713.jpg  (132891 B)... OK.
  sorted_data_merged/receipts/ЗНО0301571586_TS903168077-86f0f279-1da6-425e-9000-d9d0a674fdb3.jpg  (123380 B)... OK.
  sorted_data_merged/receipts/ЗНО0301571593_TS903168182-762ca820-ff6f-4389-9208-fe5da8c8e659.jpg  (191494 B)... OK.
  sorted_data_merged/receipts/ЗНО0301571737_TS903168191-aff6ca2b-c35d-42a3-babe-14e00daf4f2b.jpg  (240333 B)... OK.
  sorted_data_merged/receipts/ЗНО0301571779_TS903168194-ef758884-591e-464d-9532-2ceab6c1e7e2.jpg  (117505 B)... OK.
  sorted_data_merged/receipts/ЗНО0301571802_TS903168193-82157d01-f6e6-49f3-8e14-14e2de89bcfc.jpg  (130805 B)... OK.
  sorted_data_merged/receipts/ЗНО0301571971_TS903168176-f67d0422-8a7c-489e-89d9-9586e50e6165.jpg  (96245 B)... OK.
  sorted_data_merged/receipts/ЗНО0301571977_TS903168276-71f39c41-e268-4031-9721-2ae845054b9c.jpg  (99370 B)

In [13]:
df = pd.read_pickle('drive/MyDrive/df_markup.pkl')
df.columns

Index(['file_name', 'quality_photo', 'terminal', 'receipt', 'terminal_damaged',
       'terminal_undamaged', 'terminal_unrecognized_defect', 'other',
       'anticlass', 'hash'],
      dtype='object')

In [14]:
df = pd.read_pickle('drive/MyDrive/df_markup.pkl')
df = df.loc[
    (df['terminal']==1) & (df['quality_photo']==1) &
     (df['terminal_unrecognized_defect']==0),
    ['file_name', 'terminal', 'terminal_damaged']
]
df['file_name'] = df['file_name'].apply(lambda x: x.split('..\\data\\')[1].replace('\\', '/'))

In [15]:
df.shape

(4084, 3)

In [16]:
from PIL import UnidentifiedImageError
for file in tqdm(df['file_name']):
    try:
        img = Image.open(file)
    except UnidentifiedImageError:
        df = df[df['file_name']!=file]

100%|██████████| 4084/4084 [00:00<00:00, 5634.90it/s]


In [17]:
df.shape

(3921, 3)

In [18]:
train_data, test_data = train_test_split(
    df, stratify=df['terminal_damaged'],
    test_size=0.25
  )

In [19]:
# Augment train data
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

# Don't augment test data, only reshape
test_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

In [20]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, target_column, transform=None):
        self.data = dataframe
        self.transform = transform
        self.target_column = target_column

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        label = self.data.iloc[idx][self.target_column]
        img = Image.open(self.data.iloc[idx]['file_name']).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label

In [21]:
NUM_WORKERS = os.cpu_count()
BATCH_SIZE = 64

train_dataset = CustomDataset(train_data, target_column='terminal_damaged', transform=train_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

test_dataset = CustomDataset(test_data, target_column='terminal_damaged', transform=test_transforms)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

# Efficient Attention

## Linformer

In [36]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=65,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

## Visual Transformer

In [37]:
model = ViT(
    dim=128,
    image_size=256,
    patch_size=32,
    num_classes=2,
    transformer=efficient_transformer,
    channels=3,
).to(device)

# Training

In [38]:
# Training settings
batch_size = 32
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42

In [39]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [41]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    epoch_f1 = 0

    for data, label in tqdm(train_dataloader, desc='Train: '):
        label = label.type(torch.LongTensor)
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        f1 = f1_score(label.cpu().numpy(), output.argmax(dim=1).cpu().numpy(), average='macro')

        epoch_f1 += f1 / len(train_dataloader)
        epoch_accuracy += acc / len(train_dataloader)
        epoch_loss += loss / len(train_dataloader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        epoch_val_f1 = 0
        for data, label in tqdm(test_dataloader, desc='Test: '):
            label = label.type(torch.LongTensor)
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            f1 = f1_score(label.cpu().numpy(), val_output.argmax(dim=1).cpu().numpy(), average='macro')
            epoch_val_accuracy += acc / len(test_dataloader)
            epoch_val_f1 += f1 / len(test_dataloader)
            epoch_val_loss += val_loss / len(test_dataloader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )
    print(f"f1_train:\t{epoch_f1}\tf1_val:{epoch_val_f1}\n")

Train: 100%|██████████| 46/46 [01:28<00:00,  1.92s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.73s/it]


Epoch : 1 - loss : 0.5855 - acc: 0.7231 - val_loss : 0.5867 - val_acc: 0.7184

f1_train:	0.41931967605326903	f1_val:0.4176752674552263



Train: 100%|██████████| 46/46 [01:30<00:00,  1.97s/it]
Test: 100%|██████████| 16/16 [00:26<00:00,  1.65s/it]


Epoch : 2 - loss : 0.5744 - acc: 0.7231 - val_loss : 0.5872 - val_acc: 0.7184

f1_train:	0.4188408986226147	f1_val:0.4171827884570492



Train: 100%|██████████| 46/46 [01:29<00:00,  1.95s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.71s/it]


Epoch : 3 - loss : 0.5638 - acc: 0.7239 - val_loss : 0.5751 - val_acc: 0.7184

f1_train:	0.4235721083149791	f1_val:0.41739030799652055



Train: 100%|██████████| 46/46 [01:26<00:00,  1.87s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.74s/it]


Epoch : 4 - loss : 0.5443 - acc: 0.7299 - val_loss : 0.5934 - val_acc: 0.7243

f1_train:	0.47565539167657955	f1_val:0.43958898243388095



Train: 100%|██████████| 46/46 [01:24<00:00,  1.83s/it]
Test: 100%|██████████| 16/16 [00:28<00:00,  1.76s/it]


Epoch : 5 - loss : 0.5248 - acc: 0.7523 - val_loss : 0.5645 - val_acc: 0.7234

f1_train:	0.570881235506579	f1_val:0.5833982096113087



Train: 100%|██████████| 46/46 [01:28<00:00,  1.93s/it]
Test: 100%|██████████| 16/16 [00:26<00:00,  1.65s/it]


Epoch : 6 - loss : 0.5228 - acc: 0.7555 - val_loss : 0.5656 - val_acc: 0.7261

f1_train:	0.6009456331440886	f1_val:0.5907547438618169



Train: 100%|██████████| 46/46 [01:26<00:00,  1.88s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.72s/it]


Epoch : 7 - loss : 0.5010 - acc: 0.7701 - val_loss : 0.5619 - val_acc: 0.7390

f1_train:	0.6277964563045239	f1_val:0.5594162193227101



Train: 100%|██████████| 46/46 [01:26<00:00,  1.88s/it]
Test: 100%|██████████| 16/16 [00:26<00:00,  1.68s/it]


Epoch : 8 - loss : 0.4818 - acc: 0.7785 - val_loss : 0.5498 - val_acc: 0.7430

f1_train:	0.6615374738262929	f1_val:0.5870773231553928



Train: 100%|██████████| 46/46 [01:28<00:00,  1.92s/it]
Test: 100%|██████████| 16/16 [00:26<00:00,  1.68s/it]


Epoch : 9 - loss : 0.4695 - acc: 0.7867 - val_loss : 0.5782 - val_acc: 0.7360

f1_train:	0.6748222449703455	f1_val:0.6099007531712181



Train: 100%|██████████| 46/46 [01:24<00:00,  1.84s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.72s/it]


Epoch : 10 - loss : 0.4351 - acc: 0.8102 - val_loss : 0.5837 - val_acc: 0.7157

f1_train:	0.7286785387755579	f1_val:0.6281133101348765



Train: 100%|██████████| 46/46 [01:24<00:00,  1.85s/it]
Test: 100%|██████████| 16/16 [00:28<00:00,  1.76s/it]


Epoch : 11 - loss : 0.4234 - acc: 0.8142 - val_loss : 0.5972 - val_acc: 0.7399

f1_train:	0.7373095777486589	f1_val:0.5578564127241538



Train: 100%|██████████| 46/46 [01:26<00:00,  1.88s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.72s/it]


Epoch : 12 - loss : 0.3893 - acc: 0.8330 - val_loss : 0.6219 - val_acc: 0.7321

f1_train:	0.7647779927278063	f1_val:0.5839244769207402



Train: 100%|██████████| 46/46 [01:25<00:00,  1.87s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.73s/it]


Epoch : 13 - loss : 0.3651 - acc: 0.8486 - val_loss : 0.6562 - val_acc: 0.6882

f1_train:	0.7857525006332617	f1_val:0.6101212317235195



Train: 100%|██████████| 46/46 [01:25<00:00,  1.85s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.74s/it]


Epoch : 14 - loss : 0.3598 - acc: 0.8504 - val_loss : 0.6319 - val_acc: 0.7370

f1_train:	0.7896997498900751	f1_val:0.6004313711536168



Train: 100%|██████████| 46/46 [01:25<00:00,  1.87s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.74s/it]


Epoch : 15 - loss : 0.3413 - acc: 0.8503 - val_loss : 0.7647 - val_acc: 0.6216

f1_train:	0.7968801517589402	f1_val:0.5840528180676533



Train: 100%|██████████| 46/46 [01:25<00:00,  1.85s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.73s/it]


Epoch : 16 - loss : 0.2961 - acc: 0.8755 - val_loss : 0.7578 - val_acc: 0.7322

f1_train:	0.832265020230101	f1_val:0.610724136215598



Train: 100%|██████████| 46/46 [01:25<00:00,  1.87s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.73s/it]


Epoch : 17 - loss : 0.3080 - acc: 0.8727 - val_loss : 0.7169 - val_acc: 0.6980

f1_train:	0.8283063958679356	f1_val:0.623316940981075



Train: 100%|██████████| 46/46 [01:25<00:00,  1.85s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.74s/it]


Epoch : 18 - loss : 0.2640 - acc: 0.8933 - val_loss : 0.8056 - val_acc: 0.6997

f1_train:	0.8568057792319761	f1_val:0.6263857250408307



Train: 100%|██████████| 46/46 [01:24<00:00,  1.84s/it]
Test: 100%|██████████| 16/16 [00:27<00:00,  1.71s/it]


Epoch : 19 - loss : 0.2435 - acc: 0.9000 - val_loss : 0.7717 - val_acc: 0.6891

f1_train:	0.8682479941781156	f1_val:0.619214112113869



Train: 100%|██████████| 46/46 [01:25<00:00,  1.86s/it]
Test: 100%|██████████| 16/16 [00:28<00:00,  1.77s/it]

Epoch : 20 - loss : 0.2049 - acc: 0.9255 - val_loss : 0.9761 - val_acc: 0.6921

f1_train:	0.9020570092468378	f1_val:0.615394605084994




