# training_part

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torch.utils.data import DataLoader, random_split
from PIL import ImageEnhance, ImageOps, Image
from sklearn.metrics import f1_score
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import PIL
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

# Аугментации данных
class NightTransform:
    def __call__(self, img):
        enhancer = ImageEnhance.Brightness(img)
        img = enhancer.enhance(0.3)  # Темнее
        return img

class BlackWhiteTransform:
    def __call__(self, img):
        img = ImageOps.grayscale(img)  # Черно-белое
        img = img.convert("RGB")  # Преобразуем в 3 канала
        return img

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.RandomResizedCrop(224),
    transforms.Lambda(lambda img: BlackWhiteTransform()(img) if np.random.rand() < 0.2 else img),  # Применяем с вероятностью 20%
    transforms.Lambda(lambda img: NightTransform()(img) if np.random.rand() < 0.2 else img),  # Применяем с вероятностью 20%
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Функция для фильтрации файлов
def is_image_file(filename):
    return any(filename.lower().endswith(extension) for extension in ['.jpeg', '.jpg', '.png'])

# Кастомный класс Dataset для обработки ошибок чтения изображений
class CustomImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        try:
            return super(CustomImageFolder, self).__getitem__(index)
        except (PIL.UnidentifiedImageError, OSError):
            print(f"Corrupted image detected at index {index}. Skipping.")
            return self.__getitem__((index + 1) % len(self))

# Загрузка данных
data_dir = 'train_dataset'
batch_size = 32

# Использование is_valid_file для фильтрации изображений
full_dataset = CustomImageFolder(root=data_dir, transform=transform_train, is_valid_file=is_image_file)

# Деление данных на тренировочные и валидационные
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Применение различных трансформаций к валидационному набору
train_dataset.dataset.transform = transform_train
val_dataset.dataset.transform = transform_val

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Определение модели
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weights = EfficientNet_B0_Weights.DEFAULT
model = efficientnet_b0(weights=weights)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 3)  # 3 класса
model = model.to(device)

# Определение функции потерь и оптимизатора
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Функции обучения и валидации
def train_model(model, criterion, optimizer, train_loader, val_loader, device, num_epochs=10):
    best_f1 = 0.0
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Тренировка
        model.train()
        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        train_accs.append(epoch_acc.item())

        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Валидация
        model.eval()
        running_loss = 0.0
        running_corrects = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        epoch_loss = running_loss / len(val_loader.dataset)
        epoch_acc = running_corrects.double() / len(val_loader.dataset)
        epoch_f1 = f1_score(all_labels, all_preds, average='weighted')
        val_losses.append(epoch_loss)
        val_accs.append(epoch_acc.item())

        print(f'Val Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} F1: {epoch_f1:.4f}')

        if epoch_f1 > best_f1:
            best_f1 = epoch_f1
            torch.save(model.state_dict(), 'model.pth')

    print(f'Best val F1: {best_f1:.4f}')

    return train_losses, val_losses, train_accs, val_accs

def evaluate_model(model, dataloader, device):
    model.eval()
    running_corrects = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = running_corrects.double() / len(dataloader.dataset)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    print(f'Accuracy: {accuracy * 100:.2f}% F1: {f1:.4f}')

# Функция для отрисовки графиков
def plot_metrics(train_losses, val_losses, train_accs, val_accs):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b', label='Training loss')
    plt.plot(epochs, val_losses, 'r', label='Validation loss')
    plt.title('Training and validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accs, 'b', label='Training accuracy')
    plt.plot(epochs, val_accs, 'r', label='Validation accuracy')
    plt.title('Training and validation accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()

# Обучение модели
train_losses, val_losses, train_accs, val_accs = train_model(model, criterion, optimizer, train_loader, val_loader, device, num_epochs=10)

# Отрисовка метрик
plot_metrics(train_losses, val_losses, train_accs, val_accs)

# Загрузка и оценка лучшей модели

In [None]:
model.load_state_dict(torch.load('model.pth'))
evaluate_model(model, val_loader, device)


# Визуализация предиктов

In [None]:
import torch
import torch.nn as nn
from torchvision.models import efficientnet_b0
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from PIL import Image

# Функция для отображения изображения
def imshow(image, title=None):
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # Пауза, чтобы графики обновлялись

# Функция для преобразования индекса класса в строку с именем класса
def index_to_class_str(index, dataset):
    return dataset.classes[index]

# Инициализация модели EfficientNet-B0
model = efficientnet_b0(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 3)  # 3 класса

# Загрузка весов лучшей модели
model.load_state_dict(torch.load('model.pth'))
model.eval()

# Определение устройства (GPU или CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Преобразования данных (без преобразований)
transform = transforms.Compose([
    transforms.Resize(224),  # Увеличить размер на 224 пикселей
    transforms.CenterCrop(224)
])

# Путь к изображению
image_path = "data_to_check_streamlit"

# Загрузка изображения и его отображение без преобразований
image = Image.open(image_path).convert("RGB")
imshow(image, title="Original Image")

# Преобразование изображения для подачи на вход модели
image = transform(image)
image = image.unsqueeze(0).to(device)

# Получение предсказаний модели
outputs = model(image)
_, preds = torch.max(outputs, 1)
predicted_class = preds.item()

# Вывод предсказанного класса
print(f"Predicted class: {index_to_class_str(predicted_class, val_dataset)}")

plt.show()
