In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset, Subset
from collections import Counter
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np


def show_example(dataset, index):
    img, label = dataset[index]
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(f"Label: {dataset.dataset.classes[label]}")
    plt.show()

# Трансформации
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Загружаем CIFAR-100
data_dir = './data'
full_trainset = datasets.CIFAR100(root=data_dir, train=True, download=True, transform=data_transforms)
full_testset = datasets.CIFAR100(root=data_dir, train=False, download=True, transform=data_transforms)

# Выбираем только животные
animal_classes = ['beaver', 'dolphin', 'otter', 'seal', 'fox', 'spider', 'elephant', 'bear', 'rabbit', 'tiger']
animal_class_indices = [full_trainset.class_to_idx[cls] for cls in animal_classes]

# Создаем маппинг старых индексов в новые (чтобы метки шли от 0 до 9)
class_mapping = {orig_idx: new_idx for new_idx, orig_idx in enumerate(animal_class_indices)}

# Кастомный датасет с правильными метками
class CustomDataset(Dataset):
    def __init__(self, dataset, class_mapping):
        self.dataset = dataset
        self.class_mapping = class_mapping
        self.indices = [i for i, (_, label) in enumerate(dataset) if label in class_mapping]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        original_idx = self.indices[idx]
        img, label = self.dataset[original_idx]
        new_label = self.class_mapping[label]  # Переиндексация
        return img, new_label

# Создаем датасеты
trainset = CustomDataset(full_trainset, class_mapping)
testset = CustomDataset(full_testset, class_mapping)

# DataLoaders
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=4)

weights = models.ResNet18_Weights.DEFAULT
model = models.resnet18(weights=weights)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(animal_classes))

# Обучение модели
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001,weight_decay=1e-5)
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

def show_example(dataset, index):
    img, label = dataset[index]
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(f"Label: {animal_classes[label]}")
    plt.show()



def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    best_model_wts = model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                dataloader = trainloader
            else:
                model.eval()
                dataloader = testloader

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in tqdm(dataloader, desc=f"{phase} Epoch {epoch+1}/{num_epochs}"):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Сохранение лучшей модели
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()
                torch.save(model.state_dict(), 'best_model_image.pth')

        scheduler.step()

    print(f'Best val Acc: {best_acc:.4f}')

    # Загрузка лучших весов модели
    model.load_state_dict(best_model_wts)
    return model

# Обучение модели
model = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=15)

# Функция для выполнения предсказания
def predict(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Predicting"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return all_preds, all_labels

# Выполнение предсказания на тестовом наборе данных
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
preds, labels = predict(model, testloader, device)

# Вывод первых 10 предсказаний и соответствующих меток
for i in range(10):
    show_example(trainset, )

# Оценка точности модели
accuracy = sum(p == l for p, l in zip(preds, labels)) / len(labels)
print(f'Accuracy: {accuracy:.4f}')