In [1]:
import time
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from sklearn import metrics 
import wandb

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
def calc_metrics(model_predictions, true_values, threshold=0.5):
    # print(model_predictions)
    model_lables = [1 if i > threshold else 0 for i in model_predictions]
    acc = metrics.accuracy_score(model_lables, true_values)
    precision = metrics.precision_score(model_lables, true_values, average='weighted')
    recall = metrics.recall_score(model_lables, true_values, average='weighted')

    return {"accuracy": acc, "recall": recall, "precision": precision}

In [3]:
# Задаем путь к папке с изображениями
data_dir = './dataset'

# Задаем размер изображений
image_size = 224

# Определяем трансформации, которые будут применяться к изображениям
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)), 
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Создаем датасеты для каждой папки
logos_dataset = ImageFolder(root=data_dir, transform=transform)

# Задаем размеры для каждого набора данных
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 1 - train_ratio - val_ratio

# Разделяем данные на train, val и test наборы
logos_train_data, logos_val_data, logos_test_data = random_split(logos_dataset, [train_ratio, val_ratio, test_ratio])

# Создаем загрузчики данных
# train_loader = DataLoader(logos_train_data, batch_size=32, shuffle=True)
# val_loader = DataLoader(logos_val_data, batch_size=32, shuffle=False)
# test_loader = DataLoader(logos_test_data, batch_size=32, shuffle=False)

sets = ['train', 'val', 'test']

datasets = {
    'train': logos_train_data,
    'val': logos_val_data,
    'test': logos_test_data
}

dataloaders = {i: DataLoader(datasets[i], batch_size=32, shuffle=True) for i in datasets.keys()}

In [4]:
class LinearNet(nn.Module):
    def __init__(self, backbone):
        super(LinearNet, self).__init__()
        
        self.backbone = backbone
        self.backbone.fc = nn.Linear(512, 512)
        self.fc = nn.Linear(512, 1)
        self.sig = nn.Sigmoid()
        
    def forward(self, x):
        x = self.backbone(F.normalize(x))
        x = self.fc(F.normalize(x))
        return self.sig(x)

In [5]:
resnet = torchvision.models.resnet18(pretrained=True)
model = LinearNet(resnet)



In [8]:
criterion = nn.BCEWithLogitsLoss()
lr = 0.01
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.95)
epochs = 10

# pbar = tqdm.tqdm(total=wandb_config['epochs'])
best_val_loss = 1e9

for epoch in range(epochs):
  # обучение
  print(f"Epoch {epoch}/{epochs}...")
  epoch_loss = 0
  model.train()
  for train_image, train_label in tqdm(dataloaders['train']):
    train_image = train_image.to(device)
    train_label = train_label.float().to(device).view(-1, 1)
    output_train = model(train_image)
    loss = criterion(output_train, train_label)
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()
    epoch_loss += train_image.size(0) * loss.item()
  print(f"Train loss: {epoch_loss/len(dataloaders['train'])}")
  
  model.eval()
  # валидация
  val_loss = 0
  predictions = []
  lables = []
  for val_image, val_label in tqdm(dataloaders['val']):
    val_image = val_image.to(device)
    val_label = val_label.float().to(device).view(-1, 1)
    output_val = model(val_image)
  
    predictions += output_val.squeeze().tolist()
    lables += val_label.squeeze().tolist()
    
    loss_ = criterion(output_val, val_label)
    val_loss += val_image.size(0) * loss_
  val_loss /= len(dataloaders['val'])

  metrics_dict = calc_metrics(predictions, lables, threshold=0.5)
  metrics_dict['loss'] = epoch_loss/len(dataloaders['train'])
  metrics_dict['val_loss'] = val_loss
  # wandb.log(metrics_dict)

  print(f"Val loss: {val_loss}")
  print(f"Val metrics: {metrics}")
  if val_loss < best_val_loss:
    best_val_loss = val_loss
    torch.save(model.state_dict(), f'./weights/best_weihts_{val_loss}.pt')
  if (epoch+1)%5==0:
    torch.save(model.state_dict(), f'./weights/{epoch}_weihts_{val_loss}.pt')

  
    

Epoch 0/10...


100%|██████████| 26/26 [01:01<00:00,  2.37s/it]


Train loss: 22.225483542451492


100%|██████████| 4/4 [00:04<00:00,  1.02s/it]
  _warn_prf(average, modifier, msg_start, len(result))


Val loss: 17.649717330932617
Val metrics: <module 'sklearn.metrics' from '/Users/kuprik01/Projects/LogoGenaPipline/.venv/lib/python3.11/site-packages/sklearn/metrics/__init__.py'>
Epoch 1/10...


100%|██████████| 26/26 [01:00<00:00,  2.34s/it]


Train loss: 21.832736278955753


100%|██████████| 4/4 [00:04<00:00,  1.02s/it]
  _warn_prf(average, modifier, msg_start, len(result))


Val loss: 17.32621192932129
Val metrics: <module 'sklearn.metrics' from '/Users/kuprik01/Projects/LogoGenaPipline/.venv/lib/python3.11/site-packages/sklearn/metrics/__init__.py'>
Epoch 2/10...


100%|██████████| 26/26 [01:01<00:00,  2.35s/it]


Train loss: 21.33001951987927


100%|██████████| 4/4 [00:04<00:00,  1.02s/it]
  _warn_prf(average, modifier, msg_start, len(result))


Val loss: 16.868223190307617
Val metrics: <module 'sklearn.metrics' from '/Users/kuprik01/Projects/LogoGenaPipline/.venv/lib/python3.11/site-packages/sklearn/metrics/__init__.py'>
Epoch 3/10...


100%|██████████| 26/26 [01:00<00:00,  2.33s/it]


Train loss: 20.663371425408585


100%|██████████| 4/4 [00:04<00:00,  1.08s/it]


Val loss: 16.185617446899414
Val metrics: <module 'sklearn.metrics' from '/Users/kuprik01/Projects/LogoGenaPipline/.venv/lib/python3.11/site-packages/sklearn/metrics/__init__.py'>
Epoch 4/10...


100%|██████████| 26/26 [00:59<00:00,  2.30s/it]


Train loss: 19.75419828066459


100%|██████████| 4/4 [00:03<00:00,  1.03it/s]


Val loss: 15.277521133422852
Val metrics: <module 'sklearn.metrics' from '/Users/kuprik01/Projects/LogoGenaPipline/.venv/lib/python3.11/site-packages/sklearn/metrics/__init__.py'>
Epoch 5/10...


100%|██████████| 26/26 [01:01<00:00,  2.35s/it]


Train loss: 18.63598100726421


100%|██████████| 4/4 [00:03<00:00,  1.02it/s]


Val loss: 14.52774715423584
Val metrics: <module 'sklearn.metrics' from '/Users/kuprik01/Projects/LogoGenaPipline/.venv/lib/python3.11/site-packages/sklearn/metrics/__init__.py'>
Epoch 6/10...


100%|██████████| 26/26 [01:01<00:00,  2.35s/it]


Train loss: 17.827528155767002


100%|██████████| 4/4 [00:04<00:00,  1.03s/it]


Val loss: 14.04560661315918
Val metrics: <module 'sklearn.metrics' from '/Users/kuprik01/Projects/LogoGenaPipline/.venv/lib/python3.11/site-packages/sklearn/metrics/__init__.py'>
Epoch 7/10...


100%|██████████| 26/26 [01:01<00:00,  2.36s/it]


Train loss: 17.29706501502257


100%|██████████| 4/4 [00:03<00:00,  1.01it/s]


Val loss: 13.688220977783203
Val metrics: <module 'sklearn.metrics' from '/Users/kuprik01/Projects/LogoGenaPipline/.venv/lib/python3.11/site-packages/sklearn/metrics/__init__.py'>
Epoch 8/10...


100%|██████████| 26/26 [01:06<00:00,  2.56s/it]


Train loss: 16.903294989695915


100%|██████████| 4/4 [00:06<00:00,  1.50s/it]


Val loss: 13.429288864135742
Val metrics: <module 'sklearn.metrics' from '/Users/kuprik01/Projects/LogoGenaPipline/.venv/lib/python3.11/site-packages/sklearn/metrics/__init__.py'>
Epoch 9/10...


100%|██████████| 26/26 [01:21<00:00,  3.14s/it]


Train loss: 16.654782658586136


100%|██████████| 4/4 [00:05<00:00,  1.49s/it]


Val loss: 13.251705169677734
Val metrics: <module 'sklearn.metrics' from '/Users/kuprik01/Projects/LogoGenaPipline/.venv/lib/python3.11/site-packages/sklearn/metrics/__init__.py'>


wandb: Network error (ConnectionError), entering retry loop.


In [6]:
model.load_state_dict(torch.load('weights/best_weihts_13.251705169677734.pt'))

<All keys matched successfully>

In [7]:
model.eval()

predictions = []
lables = []

for val_image, val_label in tqdm(dataloaders['test']):
    val_image = val_image.to(device)
    val_label = val_label.float().to(device).view(-1, 1)
    output_val = model(val_image)
    predictions += output_val.squeeze().tolist()
    lables += val_label.squeeze().tolist()

calc_metrics(predictions, lables, threshold=0.5)

100%|██████████| 4/4 [00:04<00:00,  1.01s/it]


{'accuracy': 1.0, 'recall': 1.0, 'precision': 1.0}

In [8]:
from PIL import Image
import os

def calc_score(ex):
    return model(transform(ex).unsqueeze(0).to(device))

scores = []
for filename in os.listdir('./0'):
    ex = Image.open(os.path.join('./0', filename)).convert('RGB')
    scores.append((model(transform(ex).unsqueeze(0).to(device))[0][0].item(), filename))

sorted(scores)

[(0.06779826432466507, 'example8.png'),
 (0.09797636419534683, 'example3.png'),
 (0.10543155670166016, 'example2.png'),
 (0.1177564412355423, 'example7.png'),
 (0.11834253370761871, 'example1.png'),
 (0.3665633201599121, 'example9.png'),
 (0.6575329303741455, 'example4.png'),
 (0.7490559816360474, 'example5.png'),
 (0.7507413625717163, 'example6.png')]