In [1]:
# !pip install torch torchvision

# !pip install -U transformers
# !pip install -U albumentations
# !pip install -U opencv-python
# !pip install -U scikit-learn
# !pip install -U Pillow
# !pip install -U tqdm
# !pip install -U pandas
# !pip install -U torchsummary
# !pip install timm
# !pip install ipywidgets

In [2]:
import os
import random
import pandas as pd
import numpy as np
from PIL import Image, ImageOps
from tqdm.notebook import tqdm
from itertools import product

from sklearn.metrics import roc_auc_score, f1_score
from sklearn.model_selection import train_test_split

import albumentations as A

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from transformers import DeiTFeatureExtractor, DeiTForImageClassification

from torchsummary import summary

In [4]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

In [18]:
384*384/(16*16)

576.0

# Подготовка данных для обучения

In [5]:
root_dir = 'dataset'
batch_size = 128
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [6]:
data = pd.DataFrame([
    {'image_path': os.path.join(directory, filename), 'label': os.path.basename(directory)}
    for directory, _, filenames in os.walk(root_dir)
    for filename in filenames
    if os.path.basename(directory).isdigit()
])

In [7]:
data

Unnamed: 0,image_path,label
0,dataset/8443/4e5f5bdd-5d3d-45b3-9408-03f8d1f33...,8443
1,dataset/8443/7b530781-7900-4ae2-b387-a1efdb521...,8443
2,dataset/8443/5af9cd6e-518b-43f8-8f8d-aebf0cdec...,8443
3,dataset/8443/65b140e2-a0b5-4ee6-8c45-23bcfcf20...,8443
4,dataset/8443/4d00724c-a300-48ab-91e6-7f2ba698b...,8443
...,...,...
8995,dataset/13866/78fcc893-7d03-4aff-bf27-5c07ad4c...,13866
8996,dataset/13866/a716dcbd-91b7-4f75-adf3-30404bbc...,13866
8997,dataset/13866/97418098-5215-4863-8dcd-7ba48935...,13866
8998,dataset/13866/e7f08476-fa39-47ff-ba38-ff59d6d5...,13866


In [8]:
label2id = {value: i for i, value in enumerate(data['label'].unique())}
data['label_id'] = data['label'].map(label2id)

In [9]:
train, val, _, _ = train_test_split(data, data['label_id'], test_size=0.1)

In [10]:
class ProductDataset(Dataset):

    def __init__(self, meta, transform=None):
        self.meta = meta
        self.transform = transform

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = self.meta['image_path'].iloc[idx]
        image = Image.open(img_name).convert('RGB')

        if self.transform:
            image = self.transform(image = np.asarray(image))['image']
            image = Image.fromarray(image)

        image = feature_extractor(images=image, return_tensors="pt")['pixel_values'][0]
            
        cls = self.meta['label_id'].iloc[idx]

        return image, cls

In [11]:
train_transform = A.Compose([
        # A.GaussNoise(p=0.2),
        # A.OneOf([
        #     A.MotionBlur(p=.2),
        #     A.MedianBlur(blur_limit=3, p=0.1),
        #     A.Blur(blur_limit=3, p=0.1),
        # ], p=0.2),
        A.ShiftScaleRotate(shift_limit=0.15, scale_limit=0.2, rotate_limit=30, p=0.75),
        # A.OneOf([
        #     A.OpticalDistortion(p=0.3),
        #     A.GridDistortion(p=.1),
        #     A.PiecewiseAffine(p=0.3),
        # ], p=0.2),
        # A.OneOf([
        #     A.CLAHE(clip_limit=2),
        #     A.Sharpen(),
        #     A.Emboss(),
        #     A.RandomBrightnessContrast(),            
        # ], p=0.3),
        # A.HueSaturationValue(p=0.3),
])

transform = None

  original_init(self, **validated_kwargs)


In [12]:
feature_extractor = DeiTFeatureExtractor.from_pretrained('facebook/deit-small-distilled-patch16-224')



In [13]:
train_dataset = ProductDataset(meta=train, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=12, pin_memory=True)

val_dataset = ProductDataset(meta=val, transform=transform)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=12, pin_memory=True)

# Создание модели и обучение

![clip](https://assets-global.website-files.com/5d7b77b063a9066d83e1209c/639b1df59b5ec8f6e5fdb8cf_transformer%20gif.gif)

In [14]:
model = DeiTForImageClassification.from_pretrained('facebook/deit-small-distilled-patch16-224')
model.classifier = nn.Linear(384, len(label2id))
model.to(device)

Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-small-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DeiTForImageClassification(
  (deit): DeiTModel(
    (embeddings): DeiTEmbeddings(
      (patch_embeddings): DeiTPatchEmbeddings(
        (projection): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DeiTEncoder(
      (layer): ModuleList(
        (0-11): 12 x DeiTLayer(
          (attention): DeiTSdpaAttention(
            (attention): DeiTSdpaSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): DeiTSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): DeiTIntermediate(
            (dense): L

In [15]:
EPOCHS = 5
lr = 1e-5

optimizer = optim.AdamW(model.parameters(), lr)

criterion = nn.CrossEntropyLoss()

In [16]:
for epoch in range(EPOCHS):
    
    model.train()
            
        
    train_loss = []
    train_predictions = []
    train_targets = []
    for batch, targets in tqdm(train_dataloader, desc=f"Epoch: {epoch}"):
        optimizer.zero_grad()
        
        batch = batch.to(device)
        targets = targets.to(device)
        
        predictions = model(batch).logits
        
        loss = criterion(predictions, targets) 
        loss.backward()
        optimizer.step()

        train_loss.append(loss.item())
        
        train_predictions.extend(predictions.cpu().detach().numpy().argmax(axis=1))
        train_targets.extend(targets.cpu().detach().numpy())
        
    print('Training loss:', np.mean(train_loss))
    print('Train f1:', f1_score(train_targets, train_predictions, average='weighted'))
    
    model.eval()
        
    val_predictions = []
    val_targets = []
    for batch, targets in tqdm(val_dataloader, desc=f"Epoch: {epoch}"):
        
        with torch.no_grad():
        
            batch = batch.to(device)
            targets = targets.to(device)
            predictions = model(batch).logits
            

            val_predictions.extend(predictions.cpu().numpy().argmax(axis=1))
            val_targets.extend(targets.cpu().numpy())
        
    print('Val f1:', f1_score(val_targets, val_predictions, average='weighted'))
    # model.save_pretrained('model')

Epoch: 0:   0%|          | 0/64 [00:00<?, ?it/s]

Training loss: 1.66400555241853
Train f1: 0.5148115335566549


Epoch: 0:   0%|          | 0/8 [00:00<?, ?it/s]

Val f1: 0.8441375839893437


Epoch: 1:   0%|          | 0/64 [00:00<?, ?it/s]

KeyboardInterrupt: 