In [None]:
# pip install wandb

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import wandb

from torch.autograd import Variable
from torchvision.models import resnet18, swin_v2_t

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

import PIL
import torch
import torchvision

import matplotlib.pyplot as plt
from sklearn import metrics
import tqdm
import json

import os

In [2]:
from PIL import Image
from sklearn import (manifold, datasets, decomposition, ensemble,
                     discriminant_analysis, random_projection)
import torchvision.transforms.functional as Function
from IPython.display import display
from time import time
from matplotlib import offsetbox
from sklearn.neighbors import DistanceMetric
%matplotlib inline

from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# !rm -rf -r /content/dataset/__MACOSX

In [4]:
def soft_max(array):
    return nn.Softmax(dim=1)(array)

class EmbedNet(nn.Module):
    def __init__(self, base_model, out_dim):
        super(EmbedNet, self).__init__()
        self.base_model = base_model
        self.base_model.fc = torch.nn.Linear(base_model.fc.in_features, 512)
        self.fc1 = torch.nn.Linear(512, 512)
        self.fc2 = torch.nn.Linear(512, 256)
        self.fc3 = torch.nn.Linear(256, out_dim)

    def forward(self, x):
        x = self.base_model(x)
        x = self.fc1(F.normalize(x))
        x = self.fc2(F.normalize(x))
        x = self.fc3(F.normalize(x))
        return soft_max(F.normalize(x))

In [5]:
data_dir = './data'

# Преобразования для изображений
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Случайное отражение по горизонтали
    transforms.RandomRotation(10),  # Случайное вращение на угол до 10 градусов
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Изменение яркости, контраста, насыщенности и оттенка
    transforms.RandomResizedCrop(224),  # Случайное изменение размера и обрезка изображения до 224x224
    transforms.ToTensor()
])

# Создание датасета из папки с изображениями
dataset = ImageFolder(root=data_dir, transform=transform)

# Разделение на тренировочный и тестовый наборы данных
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

batch_size = 128
# Создание DataLoader для тренировочного и тестового наборов данных
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

class_to_idx = dataset.class_to_idx
idx_to_class = dict(zip(class_to_idx.values(), class_to_idx.keys()))

In [6]:
# Discard layers at the end of base network
encoder = resnet18(pretrained=True)

out_dim = 7

lr = 0.001

model = EmbedNet(encoder, out_dim).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)



In [7]:
def save_model(model, path):
    torch.save(model.state_dict(), path)

In [9]:
from tqdm import tqdm_notebook

In [10]:
num_epochs = 80

# wandb.init(
#     project="Gagarin-Hack",

#     config={
#     "learning_rate": lr,
#     "batch_size": batch_size,
#     "epochs": num_epochs,
#     }
# )

metrics_dict = {
    'Precision':[],
    'Recall':[],
    'F1-score':[],
    'Accuracy':[],
}

for epoch in range(num_epochs):
    model.train()
    for images, labels in tqdm_notebook(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

    # Оценка модели на тестовом наборе данных
    model.eval()
    predictions = []
    true_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    precision = precision_score(true_labels, predictions, average='weighted')
    recall = recall_score(true_labels, predictions, average='weighted')
    f1 = f1_score(true_labels, predictions, average='weighted')
    accuracy = accuracy_score(true_labels, predictions)


    print(f'Precision: {precision:.2f}')
    print(f'Recall: {recall:.2f}')
    print(f'F1-score: {f1:.2f}')
    print(f'Accuracy: {accuracy:.2f}')
    metrics_dict['Precision'].append(precision)
    metrics_dict['Recall'].append(recall)
    metrics_dict['F1-score'].append(f1)
    metrics_dict['Accuracy'].append(accuracy)

    save_model(model, f'./weights/tmp_weights_{epoch}_{precision:.2f}_{recall:.2f}_{f1:.2f}_{accuracy:.2f}.pt')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 1/80, Loss: 1.8630971908569336
Precision: 0.12
Recall: 0.17
F1-score: 0.06
Accuracy: 0.17


  _warn_prf(average, modifier, msg_start, len(result))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 2/80, Loss: 1.8237557411193848
Precision: 0.25
Recall: 0.17
F1-score: 0.07
Accuracy: 0.17


  _warn_prf(average, modifier, msg_start, len(result))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 3/80, Loss: 1.8028260469436646
Precision: 0.75
Recall: 0.67
F1-score: 0.67
Accuracy: 0.67


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 4/80, Loss: 1.8025482892990112
Precision: 0.83
Recall: 0.43
F1-score: 0.43
Accuracy: 0.43


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 5/80, Loss: 1.7891885042190552
Precision: 0.82
Recall: 0.73
F1-score: 0.73
Accuracy: 0.73


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 6/80, Loss: 1.7822545766830444
Precision: 0.75
Recall: 0.54
F1-score: 0.50
Accuracy: 0.54


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 7/80, Loss: 1.7819074392318726
Precision: 0.80
Recall: 0.69
F1-score: 0.66
Accuracy: 0.69


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 8/80, Loss: 1.7892736196517944
Precision: 0.75
Recall: 0.69
F1-score: 0.68
Accuracy: 0.69


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 9/80, Loss: 1.7955609560012817
Precision: 0.81
Recall: 0.78
F1-score: 0.78
Accuracy: 0.78


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 10/80, Loss: 1.7769966125488281
Precision: 0.67
Recall: 0.43
F1-score: 0.42
Accuracy: 0.43


  _warn_prf(average, modifier, msg_start, len(result))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 11/80, Loss: 1.7751200199127197
Precision: 0.83
Recall: 0.77
F1-score: 0.78
Accuracy: 0.77


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 12/80, Loss: 1.778002381324768
Precision: 0.77
Recall: 0.73
F1-score: 0.73
Accuracy: 0.73


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 13/80, Loss: 1.786521315574646
Precision: 0.75
Recall: 0.68
F1-score: 0.69
Accuracy: 0.68


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 14/80, Loss: 1.7805105447769165
Precision: 0.91
Recall: 0.88
F1-score: 0.89
Accuracy: 0.88


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 15/80, Loss: 1.7782549858093262
Precision: 0.76
Recall: 0.59
F1-score: 0.54
Accuracy: 0.59


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 16/80, Loss: 1.7878955602645874
Precision: 0.68
Recall: 0.69
F1-score: 0.65
Accuracy: 0.69


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 17/80, Loss: 1.783113718032837
Precision: 0.83
Recall: 0.66
F1-score: 0.66
Accuracy: 0.66


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 18/80, Loss: 1.7895444631576538
Precision: 0.83
Recall: 0.79
F1-score: 0.80
Accuracy: 0.79


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 19/80, Loss: 1.7748849391937256
Precision: 0.74
Recall: 0.62
F1-score: 0.60
Accuracy: 0.62


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 20/80, Loss: 1.7874232530593872
Precision: 0.57
Recall: 0.53
F1-score: 0.47
Accuracy: 0.53


  _warn_prf(average, modifier, msg_start, len(result))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 21/80, Loss: 1.7785285711288452
Precision: 0.80
Recall: 0.71
F1-score: 0.67
Accuracy: 0.71


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 22/80, Loss: 1.8003644943237305
Precision: 0.90
Recall: 0.80
F1-score: 0.82
Accuracy: 0.80


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 23/80, Loss: 1.7822134494781494
Precision: 0.85
Recall: 0.80
F1-score: 0.81
Accuracy: 0.80


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 24/80, Loss: 1.7780746221542358
Precision: 0.74
Recall: 0.60
F1-score: 0.61
Accuracy: 0.60


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 25/80, Loss: 1.782075047492981
Precision: 0.67
Recall: 0.55
F1-score: 0.54
Accuracy: 0.55


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 26/80, Loss: 1.7821346521377563
Precision: 0.46
Recall: 0.39
F1-score: 0.33
Accuracy: 0.39


  _warn_prf(average, modifier, msg_start, len(result))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 27/80, Loss: 1.787927508354187
Precision: 0.69
Recall: 0.66
F1-score: 0.64
Accuracy: 0.66


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 28/80, Loss: 1.7781058549880981
Precision: 0.85
Recall: 0.80
F1-score: 0.80
Accuracy: 0.80


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 29/80, Loss: 1.77950119972229
Precision: 0.87
Recall: 0.83
F1-score: 0.83
Accuracy: 0.83


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 30/80, Loss: 1.794848918914795
Precision: 0.77
Recall: 0.73
F1-score: 0.70
Accuracy: 0.73


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 31/80, Loss: 1.7790459394454956
Precision: 0.80
Recall: 0.77
F1-score: 0.77
Accuracy: 0.77


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 32/80, Loss: 1.7740669250488281
Precision: 0.81
Recall: 0.74
F1-score: 0.73
Accuracy: 0.74


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 33/80, Loss: 1.7897039651870728
Precision: 0.84
Recall: 0.76
F1-score: 0.76
Accuracy: 0.76


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 34/80, Loss: 1.7756924629211426
Precision: 0.90
Recall: 0.86
F1-score: 0.86
Accuracy: 0.86


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 35/80, Loss: 1.7939090728759766
Precision: 0.67
Recall: 0.52
F1-score: 0.53
Accuracy: 0.52


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 36/80, Loss: 1.781173825263977


In [None]:
model.eval()

In [None]:
img = Image.open('/content/driver_example.jpeg')
transformed_img = transform(img)

In [None]:
idx_to_class

In [None]:
model(transformed_img.to(device).unsqueeze(0))

In [None]:
torch.save(model.state_dict(), 'v1_weights.pt')