In [43]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix

# Проверим доступность GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cpu


### Шаг 2. Загрузка CIFAR-100 и фильтрация классов

In [44]:
# Определение классов
target_classes = {
    'bicycle': 0,
    'motorcycle': 1,
}

# Загрузка CIFAR-100
train_dataset = torchvision.datasets.CIFAR100(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

test_dataset = torchvision.datasets.CIFAR100(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

# Получим индексы классов в оригинальном датасете CIFAR-100
class_to_idx = {name: idx for idx, name in enumerate(train_dataset.classes)}
background_classes = set(class_to_idx.values()) - {class_to_idx['bicycle'], class_to_idx['motorcycle']}

def custom_label_mapping(label):
    """Переприсваиваем классы."""
    if label == class_to_idx['bicycle']:
        return 0  # bicycle
    elif label == class_to_idx['motorcycle']:
        return 1  # motorcycle
    else:
        return 2  # background

# Преобразуем датасеты, отметим только нужные метки
train_dataset.targets = list(map(custom_label_mapping, train_dataset.targets))
test_dataset.targets = list(map(custom_label_mapping, test_dataset.targets))

Files already downloaded and verified
Files already downloaded and verified


### Шаг 3. DataLoader'ы

In [45]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

### Шаг 4. Обучение модели

In [46]:
from torchvision import models

# Загружаем предобученную ResNet18
model = models.resnet18(pretrained=True)

# Изменяем последний слой, чтобы предсказывать 3 класса
model.fc = nn.Linear(model.fc.in_features, 3)
model = model.to(device)

# Определим функцию потерь и оптимизатор
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)



### Шаг 5. Функция для обучения

In [47]:

def train_model(model, train_loader, epochs=10):
    """Основной цикл обучения."""
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader):.4f}")

    print("Finished Training")

# Учиться будет долго, 1 эпоха = 17 минут на v2-8 TPU в колабе. Можно уменшить эпохи до 1-2, в общем даст неплохую картину тоже
train_model(model, train_loader, epochs=10)

KeyboardInterrupt: 

### Шаг 6. Оценка модели

In [None]:
def evaluate_model(model, loader):
    """Оценка на тестовом наборе."""
    model.eval()
    correct = 0
    total = 0

    all_labels = []
    all_preds = []

    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy:.4f}")

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Bicycle', 'Motorcycle', 'Background'], yticklabels=['Bicycle', 'Motorcycle', 'Background'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

    print(classification_report(all_labels, all_preds, target_names=['Bicycle', 'Motorcycle', 'Background']))

# Оценка модели
evaluate_model(model, test_loader)

### Шаг 7. Анализ работы на Grayscale

In [None]:
# Преобразуем входные изображения в Grayscale
grayscale_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Преобразуем в 3-канальный Grayscale
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset.transform = grayscale_transform
test_dataset.transform = grayscale_transform

grayscale_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
grayscale_test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

# Оценка производительности модели на Grayscale
evaluate_model(model, grayscale_test_loader)