In [None]:
### Импорт библиотек и настройка

In [41]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
import os
import time
import copy
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from PIL import Image
from sklearn.metrics import confusion_matrix, f1_score
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

### Настройки для быстрого обучения

In [None]:
data_dir = '/home/sky/nn-project/images/archive_sp/CUB_200_2011_sp/images'
input_size = 224
batch_size = 32  # batch size
num_workers = 4  # Количество процессов для загрузки данных
num_epochs = 12  #количество эпох

In [24]:
device = torch.device("cpu")

### Трансформации

In [25]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

### Загрузка полного датасета

In [27]:
full_dataset = datasets.ImageFolder(data_dir)

# Разделение на train/val с сохранением стратификации
train_idx, val_idx = train_test_split(
    list(range(len(full_dataset))), 
    test_size=0.2, 
    random_state=42,
    stratify=full_dataset.targets
)

train_dataset = Subset(full_dataset, train_idx)
val_dataset = Subset(full_dataset, val_idx)

# Применяем трансформации
train_dataset.dataset.transform = data_transforms['train']
val_dataset.dataset.transform = data_transforms['val']

In [45]:
full_dataset.classes

['001.Black_footed_Albatross',
 '002.Laysan_Albatross',
 '003.Sooty_Albatross',
 '004.Groove_billed_Ani',
 '005.Crested_Auklet',
 '006.Least_Auklet',
 '007.Parakeet_Auklet',
 '008.Rhinoceros_Auklet',
 '009.Brewer_Blackbird',
 '010.Red_winged_Blackbird',
 '011.Rusty_Blackbird',
 '012.Yellow_headed_Blackbird',
 '013.Bobolink',
 '014.Indigo_Bunting',
 '015.Lazuli_Bunting',
 '016.Painted_Bunting',
 '017.Cardinal',
 '018.Spotted_Catbird',
 '019.Gray_Catbird',
 '020.Yellow_breasted_Chat',
 '021.Eastern_Towhee',
 '022.Chuck_will_Widow',
 '023.Brandt_Cormorant',
 '024.Red_faced_Cormorant',
 '025.Pelagic_Cormorant',
 '026.Bronzed_Cowbird',
 '027.Shiny_Cowbird',
 '028.Brown_Creeper',
 '029.American_Crow',
 '030.Fish_Crow',
 '031.Black_billed_Cuckoo',
 '032.Mangrove_Cuckoo',
 '033.Yellow_billed_Cuckoo',
 '034.Gray_crowned_Rosy_Finch',
 '035.Purple_Finch',
 '036.Northern_Flicker',
 '037.Acadian_Flycatcher',
 '038.Great_Crested_Flycatcher',
 '039.Least_Flycatcher',
 '040.Olive_sided_Flycatcher',
 '

### Применяем трансформации

In [28]:
train_dataset.dataset.transform = data_transforms['train']
val_dataset.dataset.transform = data_transforms['val']

### DataLoader

In [None]:
dataloaders = {
    'train': torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size,
        shuffle=True, num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False
    ),
    'val': torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size,
        shuffle=False, num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False
    )
}

dataset_sizes = {
    'train': len(train_dataset),
    'val': len(val_dataset)
}

class_names = full_dataset.classes
num_classes = len(class_names)
print(f"Number of classes: {num_classes}")
print(f"Training set size: {dataset_sizes['train']}")
print(f"Validaфункция обученияtion set size: {dataset_sizes['val']}")

Number of classes: 200
Training set size: 9430
Validation set size: 2358


In [44]:
import json
with open('class_names.json', 'w') as f: json.dump(class_names, f) 
print(f"Number of classes: {num_classes}") 
print(f"Training set size: {dataset_sizes['train']}") 
print(f"Validation set size: {dataset_sizes['val']}")

Number of classes: 200
Training set size: 9430
Validation set size: 2358


### функция обучения

In [None]:
def train_model_efficient(model, criterion, optimizer, scheduler, num_epochs=12):
    since = time.time()
    
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs-1}')
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            if phase == 'train':
                scheduler.step()
            
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        
        print()
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed//60:.0f}m {time_elapsed%60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')
    
    model.load_state_dict(best_model_wts)
    return model

### Используем EfficientNet-B0 

In [33]:
model_ft = models.efficientnet_b0(pretrained=True)



Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /home/sky/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


100%|██████████| 20.5M/20.5M [00:02<00:00, 9.68MB/s]


### Модифицируем последний слой

In [34]:
num_ftrs = model_ft.classifier[1].in_features
model_ft.classifier[1] = nn.Linear(num_ftrs, num_classes)
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

### Оптимизатор адам

In [35]:
optimizer_ft = optim.AdamW(model_ft.parameters(), lr=0.001, weight_decay=0.01)

### Планировщик с warmup

In [36]:
scheduler = lr_scheduler.OneCycleLR(
    optimizer_ft, 
    max_lr=0.01,
    steps_per_epoch=len(dataloaders['train']),
    epochs=num_epochs,
    anneal_strategy='linear'
)

### Обучение модели

In [37]:
print("Starting training with full dataset...")
model_ft = train_model_efficient(model_ft, criterion, optimizer_ft, scheduler, num_epochs)

Starting training with full dataset...
Epoch 0/9
----------
train Loss: 3.1062 Acc: 0.3409
val Loss: 1.2125 Acc: 0.6713

Epoch 1/9
----------
train Loss: 1.0003 Acc: 0.7420
val Loss: 0.7835 Acc: 0.7841

Epoch 2/9
----------
train Loss: 0.5104 Acc: 0.8668
val Loss: 0.7326 Acc: 0.7880

Epoch 3/9
----------
train Loss: 0.3119 Acc: 0.9165
val Loss: 0.7419 Acc: 0.7939

Epoch 4/9
----------
train Loss: 0.1926 Acc: 0.9496
val Loss: 0.7527 Acc: 0.8003

Epoch 5/9
----------
train Loss: 0.1504 Acc: 0.9611
val Loss: 0.8435 Acc: 0.7765

Epoch 6/9
----------
train Loss: 0.1452 Acc: 0.9616
val Loss: 0.9498 Acc: 0.7663

Epoch 7/9
----------
train Loss: 0.1456 Acc: 0.9601
val Loss: 0.9193 Acc: 0.7668

Epoch 8/9
----------
train Loss: 0.1096 Acc: 0.9699
val Loss: 1.0037 Acc: 0.7663

Epoch 9/9
----------
train Loss: 0.1299 Acc: 0.9648
val Loss: 0.9656 Acc: 0.7680

Training complete in 76m 1s
Best val Acc: 0.8003


### Сохраняем модель

In [38]:
torch.save(model_ft.state_dict(), 'fast_bird_classifier_mobilenet.pth')
print("Model saved to fast_bird_classifier_mobilenet.pth")

Model saved to fast_bird_classifier_mobilenet.pth


### Функция для предсказания

In [39]:
def predict_image(image_path, model, class_names, transform):
    image = Image.open(image_path)
    image = transform(image).float()
    image = image.unsqueeze(0)  # Add batch dimension
    image = image.to(device)
    
    model.eval()
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output.data, 1)
        return class_names[predicted.item()]