In [None]:
# pip install wandb

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import wandb

from torch.autograd import Variable
from torchvision.models import resnet18, swin_v2_t

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

import PIL
import torch
import torchvision

import matplotlib.pyplot as plt
from sklearn import metrics
import tqdm
import json

import os

In [2]:
from PIL import Image
from sklearn import (manifold, datasets, decomposition, ensemble,
                     discriminant_analysis, random_projection)
import torchvision.transforms.functional as Function
from IPython.display import display
from time import time
from matplotlib import offsetbox
from sklearn.neighbors import DistanceMetric
%matplotlib inline

from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# !rm -rf -r /content/dataset/__MACOSX

In [4]:
def soft_max(array):
    return nn.Softmax(dim=1)(array)

class EmbedNet(nn.Module):
    def __init__(self, base_model, out_dim):
        super(EmbedNet, self).__init__()
        self.base_model = base_model
        self.base_model.fc = torch.nn.Linear(base_model.fc.in_features, 512)
        self.fc1 = torch.nn.Linear(512, 512)
        self.fc2 = torch.nn.Linear(512, 256)
        self.fc3 = torch.nn.Linear(256, out_dim)

    def forward(self, x):
        x = self.base_model(x)
        x = self.fc1(F.normalize(x))
        x = self.fc2(F.normalize(x))
        x = self.fc3(F.normalize(x))
        return soft_max(F.normalize(x))

In [5]:
data_dir = './data'

# Преобразования для изображений
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Случайное отражение по горизонтали
    transforms.RandomRotation(10),  # Случайное вращение на угол до 10 градусов
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Изменение яркости, контраста, насыщенности и оттенка
    transforms.RandomResizedCrop(224),  # Случайное изменение размера и обрезка изображения до 224x224
    transforms.ToTensor()
])

# Создание датасета из папки с изображениями
dataset = ImageFolder(root=data_dir, transform=transform)

# Разделение на тренировочный и тестовый наборы данных
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

batch_size = 128
# Создание DataLoader для тренировочного и тестового наборов данных
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

class_to_idx = dataset.class_to_idx
idx_to_class = dict(zip(class_to_idx.values(), class_to_idx.keys()))

In [6]:
# Discard layers at the end of base network
encoder = resnet18(pretrained=True)

out_dim = 7

lr = 0.001

model = EmbedNet(encoder, out_dim).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)



In [7]:
def save_model(model, path):
    torch.save(model.state_dict(), path)

In [8]:
from tqdm import tqdm_notebook

In [9]:
num_epochs = 80

# wandb.init(
#     project="Gagarin-Hack",

#     config={
#     "learning_rate": lr,
#     "batch_size": batch_size,
#     "epochs": num_epochs,
#     }
# )

metrics_dict = {
    'Precision':[],
    'Recall':[],
    'F1-score':[],
    'Accuracy':[],
}

for epoch in range(num_epochs):
    model.train()
    for images, labels in tqdm_notebook(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

    # Оценка модели на тестовом наборе данных
    model.eval()
    predictions = []
    true_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    precision = precision_score(true_labels, predictions, average='weighted')
    recall = recall_score(true_labels, predictions, average='weighted')
    f1 = f1_score(true_labels, predictions, average='weighted')
    accuracy = accuracy_score(true_labels, predictions)


    print(f'Precision: {precision:.2f}')
    print(f'Recall: {recall:.2f}')
    print(f'F1-score: {f1:.2f}')
    print(f'Accuracy: {accuracy:.2f}')
    metrics_dict['Precision'].append(precision)
    metrics_dict['Recall'].append(recall)
    metrics_dict['F1-score'].append(f1)
    metrics_dict['Accuracy'].append(accuracy)

    save_model(model, f'./weights/tmp_weights_{epoch}_{precision:.2f}_{recall:.2f}_{f1:.2f}_{accuracy:.2f}.pt')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 1/80, Loss: 1.84670090675354
Precision: 0.43
Recall: 0.37
F1-score: 0.28
Accuracy: 0.37


  _warn_prf(average, modifier, msg_start, len(result))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 2/80, Loss: 1.8169664144515991
Precision: 0.63
Recall: 0.64
F1-score: 0.59
Accuracy: 0.64


  _warn_prf(average, modifier, msg_start, len(result))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 3/80, Loss: 1.7936105728149414
Precision: 0.87
Recall: 0.85
F1-score: 0.85
Accuracy: 0.85


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 4/80, Loss: 1.8036980628967285
Precision: 0.80
Recall: 0.69
F1-score: 0.69
Accuracy: 0.69


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 5/80, Loss: 1.7825534343719482
Precision: 0.46
Recall: 0.40
F1-score: 0.26
Accuracy: 0.40


  _warn_prf(average, modifier, msg_start, len(result))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 6/80, Loss: 1.788064956665039
Precision: 0.76
Recall: 0.70
F1-score: 0.66
Accuracy: 0.70


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 7/80, Loss: 1.7889610528945923
Precision: 0.77
Recall: 0.68
F1-score: 0.67
Accuracy: 0.68


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 8/80, Loss: 1.783137321472168
Precision: 0.81
Recall: 0.70
F1-score: 0.69
Accuracy: 0.70


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 9/80, Loss: 1.7871310710906982
Precision: 0.77
Recall: 0.62
F1-score: 0.59
Accuracy: 0.62


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 10/80, Loss: 1.7817275524139404
Precision: 0.48
Recall: 0.36
F1-score: 0.31
Accuracy: 0.36


  _warn_prf(average, modifier, msg_start, len(result))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 11/80, Loss: 1.7779988050460815
Precision: 0.74
Recall: 0.62
F1-score: 0.59
Accuracy: 0.62


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 12/80, Loss: 1.7854958772659302
Precision: 0.75
Recall: 0.64
F1-score: 0.64
Accuracy: 0.64


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 13/80, Loss: 1.7827168703079224
Precision: 0.81
Recall: 0.79
F1-score: 0.79
Accuracy: 0.79


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 14/80, Loss: 1.7840162515640259
Precision: 0.88
Recall: 0.86
F1-score: 0.86
Accuracy: 0.86


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 15/80, Loss: 1.7838174104690552
Precision: 0.86
Recall: 0.83
F1-score: 0.83
Accuracy: 0.83


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 16/80, Loss: 1.7825819253921509
Precision: 0.86
Recall: 0.81
F1-score: 0.82
Accuracy: 0.81


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 17/80, Loss: 1.773749589920044
Precision: 0.85
Recall: 0.80
F1-score: 0.81
Accuracy: 0.80


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 18/80, Loss: 1.7926571369171143
Precision: 0.81
Recall: 0.76
F1-score: 0.75
Accuracy: 0.76


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 19/80, Loss: 1.7797003984451294
Precision: 0.81
Recall: 0.70
F1-score: 0.68
Accuracy: 0.70


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 20/80, Loss: 1.7867021560668945
Precision: 0.83
Recall: 0.74
F1-score: 0.72
Accuracy: 0.74


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 21/80, Loss: 1.7809836864471436
Precision: 0.82
Recall: 0.69
F1-score: 0.70
Accuracy: 0.69


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 22/80, Loss: 1.7867751121520996
Precision: 0.86
Recall: 0.83
F1-score: 0.83
Accuracy: 0.83


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 23/80, Loss: 1.7829197645187378
Precision: 0.87
Recall: 0.85
F1-score: 0.85
Accuracy: 0.85


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 24/80, Loss: 1.7829855680465698
Precision: 0.82
Recall: 0.79
F1-score: 0.78
Accuracy: 0.79


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 25/80, Loss: 1.7835559844970703
Precision: 0.88
Recall: 0.87
F1-score: 0.87
Accuracy: 0.87


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 26/80, Loss: 1.7801543474197388
Precision: 0.86
Recall: 0.84
F1-score: 0.84
Accuracy: 0.84


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 27/80, Loss: 1.7859997749328613
Precision: 0.83
Recall: 0.81
F1-score: 0.80
Accuracy: 0.81


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 28/80, Loss: 1.785724401473999
Precision: 0.81
Recall: 0.76
F1-score: 0.76
Accuracy: 0.76


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 29/80, Loss: 1.7738405466079712
Precision: 0.83
Recall: 0.74
F1-score: 0.73
Accuracy: 0.74


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 30/80, Loss: 1.7789286375045776
Precision: 0.76
Recall: 0.69
F1-score: 0.67
Accuracy: 0.69


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 31/80, Loss: 1.7771087884902954
Precision: 0.78
Recall: 0.73
F1-score: 0.71
Accuracy: 0.73


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 32/80, Loss: 1.8015270233154297
Precision: 0.82
Recall: 0.74
F1-score: 0.72
Accuracy: 0.74


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 33/80, Loss: 1.7742921113967896
Precision: 0.68
Recall: 0.69
F1-score: 0.66
Accuracy: 0.69


  _warn_prf(average, modifier, msg_start, len(result))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 34/80, Loss: 1.769040584564209
Precision: 0.85
Recall: 0.73
F1-score: 0.72
Accuracy: 0.73


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 35/80, Loss: 1.7787376642227173
Precision: 0.81
Recall: 0.61
F1-score: 0.61
Accuracy: 0.61


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 36/80, Loss: 1.790182113647461
Precision: 0.61
Recall: 0.47
F1-score: 0.42
Accuracy: 0.47


  _warn_prf(average, modifier, msg_start, len(result))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 37/80, Loss: 1.7848687171936035
Precision: 0.72
Recall: 0.60
F1-score: 0.56
Accuracy: 0.60


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 38/80, Loss: 1.7902636528015137
Precision: 0.71
Recall: 0.58
F1-score: 0.53
Accuracy: 0.58


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 39/80, Loss: 1.7714089155197144
Precision: 0.79
Recall: 0.73
F1-score: 0.73
Accuracy: 0.73


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 40/80, Loss: 1.7816380262374878
Precision: 0.88
Recall: 0.84
F1-score: 0.84
Accuracy: 0.84


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 41/80, Loss: 1.7750623226165771
Precision: 0.67
Recall: 0.60
F1-score: 0.57
Accuracy: 0.60


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 42/80, Loss: 1.775952696800232
Precision: 0.82
Recall: 0.68
F1-score: 0.65
Accuracy: 0.68


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 43/80, Loss: 1.779378056526184
Precision: 0.81
Recall: 0.66
F1-score: 0.65
Accuracy: 0.66


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 44/80, Loss: 1.7909910678863525
Precision: 0.92
Recall: 0.87
F1-score: 0.88
Accuracy: 0.87


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 45/80, Loss: 1.7849984169006348
Precision: 0.90
Recall: 0.88
F1-score: 0.87
Accuracy: 0.88


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 46/80, Loss: 1.773758888244629
Precision: 0.83
Recall: 0.79
F1-score: 0.80
Accuracy: 0.79


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 47/80, Loss: 1.7830780744552612
Precision: 0.91
Recall: 0.91
F1-score: 0.91
Accuracy: 0.91


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 48/80, Loss: 1.7781997919082642
Precision: 0.91
Recall: 0.91
F1-score: 0.91
Accuracy: 0.91


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 49/80, Loss: 1.784157156944275
Precision: 0.88
Recall: 0.88
F1-score: 0.88
Accuracy: 0.88


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 50/80, Loss: 1.7679858207702637
Precision: 0.82
Recall: 0.80
F1-score: 0.80
Accuracy: 0.80


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 51/80, Loss: 1.7725638151168823
Precision: 0.88
Recall: 0.87
F1-score: 0.87
Accuracy: 0.87


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 52/80, Loss: 1.7758312225341797
Precision: 0.90
Recall: 0.89
F1-score: 0.89
Accuracy: 0.89


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 53/80, Loss: 1.785468339920044
Precision: 0.89
Recall: 0.86
F1-score: 0.86
Accuracy: 0.86


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 54/80, Loss: 1.7716472148895264
Precision: 0.81
Recall: 0.60
F1-score: 0.59
Accuracy: 0.60


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 55/80, Loss: 1.7792022228240967
Precision: 0.74
Recall: 0.64
F1-score: 0.61
Accuracy: 0.64


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 56/80, Loss: 1.773242473602295
Precision: 0.83
Recall: 0.82
F1-score: 0.82
Accuracy: 0.82


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 57/80, Loss: 1.7759723663330078
Precision: 0.83
Recall: 0.75
F1-score: 0.74
Accuracy: 0.75


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 58/80, Loss: 1.7788724899291992
Precision: 0.86
Recall: 0.82
F1-score: 0.82
Accuracy: 0.82


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 59/80, Loss: 1.7822034358978271
Precision: 0.86
Recall: 0.83
F1-score: 0.84
Accuracy: 0.83


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 60/80, Loss: 1.7819104194641113
Precision: 0.82
Recall: 0.80
F1-score: 0.80
Accuracy: 0.80


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 61/80, Loss: 1.7790107727050781
Precision: 0.78
Recall: 0.70
F1-score: 0.69
Accuracy: 0.70


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 62/80, Loss: 1.785627841949463
Precision: 0.78
Recall: 0.69
F1-score: 0.67
Accuracy: 0.69


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 63/80, Loss: 1.7806072235107422
Precision: 0.67
Recall: 0.60
F1-score: 0.56
Accuracy: 0.60


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 64/80, Loss: 1.7793469429016113
Precision: 0.79
Recall: 0.64
F1-score: 0.66
Accuracy: 0.64


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 65/80, Loss: 1.7758866548538208
Precision: 0.85
Recall: 0.80
F1-score: 0.81
Accuracy: 0.80


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 66/80, Loss: 1.7752060890197754
Precision: 0.83
Recall: 0.76
F1-score: 0.74
Accuracy: 0.76


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 67/80, Loss: 1.7754265069961548
Precision: 0.79
Recall: 0.69
F1-score: 0.68
Accuracy: 0.69


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 68/80, Loss: 1.779064416885376
Precision: 0.86
Recall: 0.83
F1-score: 0.82
Accuracy: 0.83


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 69/80, Loss: 1.7799791097640991
Precision: 0.89
Recall: 0.87
F1-score: 0.87
Accuracy: 0.87


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 70/80, Loss: 1.7788808345794678
Precision: 0.91
Recall: 0.88
F1-score: 0.88
Accuracy: 0.88


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 71/80, Loss: 1.778250813484192
Precision: 0.86
Recall: 0.83
F1-score: 0.82
Accuracy: 0.83


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 72/80, Loss: 1.774131178855896
Precision: 0.94
Recall: 0.93
F1-score: 0.93
Accuracy: 0.93


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 73/80, Loss: 1.7817431688308716
Precision: 0.90
Recall: 0.88
F1-score: 0.88
Accuracy: 0.88


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 74/80, Loss: 1.774276852607727
Precision: 0.90
Recall: 0.88
F1-score: 0.88
Accuracy: 0.88


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 75/80, Loss: 1.7828173637390137
Precision: 0.89
Recall: 0.88
F1-score: 0.89
Accuracy: 0.88


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 76/80, Loss: 1.7767152786254883
Precision: 0.90
Recall: 0.85
F1-score: 0.86
Accuracy: 0.85


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 77/80, Loss: 1.7797940969467163
Precision: 0.84
Recall: 0.74
F1-score: 0.74
Accuracy: 0.74


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 78/80, Loss: 1.7728550434112549
Precision: 0.85
Recall: 0.64
F1-score: 0.66
Accuracy: 0.64


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 79/80, Loss: 1.7787647247314453
Precision: 0.90
Recall: 0.88
F1-score: 0.87
Accuracy: 0.88


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, labels in tqdm_notebook(train_loader):


  0%|          | 0/4 [00:00<?, ?it/s]

Epoch 80/80, Loss: 1.7836090326309204
Precision: 0.90
Recall: 0.88
F1-score: 0.88
Accuracy: 0.88


In [None]:
model.eval()

In [None]:
img = Image.open('/content/driver_example.jpeg')
transformed_img = transform(img)

In [None]:
idx_to_class

In [None]:
model(transformed_img.to(device).unsqueeze(0))

In [None]:
torch.save(model.state_dict(), 'v1_weights.pt')