In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix, balanced_accuracy_score

device = torch.device("cpu")

In [None]:
BATCH_SIZE = 64
EPOCHS = 5
LR = 1e-3
IMG_SIZE = 128  

In [None]:
transform_rgb = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [None]:
train_dataset = datasets.CIFAR100(
    root="./data",
    train=True,
    download=True,
    transform=transform_rgb
)

test_dataset = datasets.CIFAR100(
    root="./data",
    train=False,
    download=True,
    transform=transform_rgb
)

print("Classes:", train_dataset.classes[:10])

In [None]:
bicycle_id = train_dataset.class_to_idx["bicycle"]
motorcycle_id = train_dataset.class_to_idx["motorcycle"]

def remap_labels(target):
    if target == bicycle_id:
        return 0
    elif target == motorcycle_id:
        return 1
    else:
        return 2

train_dataset.targets = [remap_labels(t) for t in train_dataset.targets]
test_dataset.targets = [remap_labels(t) for t in test_dataset.targets]

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)

for param in model.features.parameters():
    param.requires_grad = False

in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, 3)

model = model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr=LR)

In [None]:
from tqdm import tqdm

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    for images, labels in tqdm(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(train_loader):.4f}")

In [None]:
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)

        all_preds.extend(preds.numpy())
        all_labels.extend(labels.numpy())

print(classification_report(
    all_labels,
    all_preds,
    target_names=["bicycle", "motorcycle", "background"]
))

print("Balanced accuracy:",
      balanced_accuracy_score(all_labels, all_preds))

In [None]:
cm = confusion_matrix(all_labels, all_preds)

cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)

plt.figure(figsize=(6,5))
sns.heatmap(
    cm_norm,
    annot=True,
    fmt=".3f",
    cmap="Blues",
    xticklabels=["bicycle","motorcycle","background"],
    yticklabels=["bicycle","motorcycle","background"]
)

plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Normalized Confusion Matrix (Recall per class)")
plt.tight_layout()
plt.show()

### greyscale

In [None]:
transform_gray = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [None]:
model_gray = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT)

old_conv = model_gray.features[0][0]
old_weights = old_conv.weight.data

new_conv = nn.Conv2d(
    1,
    old_conv.out_channels,
    kernel_size=old_conv.kernel_size,
    stride=old_conv.stride,
    padding=old_conv.padding,
    bias=False
)

new_conv.weight.data = old_weights.mean(dim=1, keepdim=True)

model_gray.features[0][0] = new_conv

# меняем классификатор
in_features = model_gray.classifier[1].in_features
model_gray.classifier[1] = nn.Linear(in_features, 3)

model_gray = model_gray.to(device)

In [None]:
test_dataset_gray = datasets.CIFAR100(
    root="./data",
    train=False,
    download=False,
    transform=transform_gray
)

test_dataset_gray.targets = [remap_labels(t) for t in test_dataset_gray.targets]
test_loader_gray = DataLoader(test_dataset_gray, batch_size=BATCH_SIZE)

model_gray.eval()
all_preds_gray = []
all_labels_gray = []

with torch.no_grad():
    for images, labels in tqdm(test_loader_gray):
        images = images.to(device)
        outputs = model_gray(images)
        preds = torch.argmax(outputs, dim=1)

        all_preds_gray.extend(preds.numpy())
        all_labels_gray.extend(labels.numpy())

print("Grayscale balanced accuracy:",
      balanced_accuracy_score(all_labels_gray, all_preds_gray))