In [24]:
import time
import torch
import json 
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import os

In [25]:
# layers
class CustomConv2d(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, scale=1.0):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = scale
    def forward(self, x):
        return self.conv(x) * self.scale

    def forward(self, x):
        out = self.conv(x)
        # Дополнительная логика: масштабируем выход
        return out * self.scale

class SimpleSpatialAttention(torch.nn.Module):
    """Attention механизм для CNN (простая spatial attention карта)."""
    def __init__(self, in_channels):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels, 1, kernel_size=1)

    def forward(self, x):
        # Получаем attention map через 1x1 conv, softmax по spatial
        attn_map = torch.sigmoid(self.conv(x))
        return x * attn_map

class CustomActivation(torch.nn.Module):
    """Кастомная функция активации: Swish с learnable beta."""
    def __init__(self):
        super().__init__()
        self.beta = torch.nn.Parameter(torch.tensor(1.0))

    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)

class CustomMaxAbsPool2d(torch.nn.Module):
    """Кастомный pooling: максимальный по абсолютному значению."""
    def __init__(self, kernel_size, stride=None, padding=0):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride or kernel_size
        self.padding = padding

    def forward(self, x):
        # Берём max по abs(x), но возвращаем исходный знак
        abs_x = torch.abs(x)
        max_abs = torch.nn.functional.max_pool2d(abs_x, self.kernel_size, self.stride, self.padding)
        mask = (abs_x == torch.nn.functional.interpolate(max_abs, size=x.shape[2:], mode='nearest'))
        # Восстанавливаем знак
        out = x * mask.float()
        # Агрегируем по пулу
        out = torch.nn.functional.max_pool2d(out, self.kernel_size, self.stride, self.padding)
        return out

In [26]:
# models
class ResidualBlock(torch.nn.Module):
    """Residual блок для ResNet."""
    def __init__(self, in_channels, out_channels, stride=1, dropout=0.0, use_bn=True):
        super().__init__()
        self.use_bn = use_bn
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(out_channels) if use_bn else torch.nn.Identity()
        self.relu = torch.nn.ReLU(inplace=True)
        self.dropout = torch.nn.Dropout2d(dropout) if dropout > 0 else torch.nn.Identity()
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False)
        self.bn2 = torch.nn.BatchNorm2d(out_channels) if use_bn else torch.nn.Identity()
        self.shortcut = torch.nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = torch.nn.Sequential(
                torch.nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                torch.nn.BatchNorm2d(out_channels) if use_bn else torch.nn.Identity()
            )

    def forward(self, x):
        """Выполняет прямое распространение через residual блок."""
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return self.relu(out)

class ShallowCNN(torch.nn.Module):
    """Неглубокая CNN: 2 conv слоя."""
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, padding=1)
        self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.fc = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(self.conv1(x))
        x = torch.nn.functional.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class MediumCNN(torch.nn.Module):
    """Средняя CNN: 4 conv слоя."""
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = torch.nn.Conv2d(64, 128, 3, padding=1)
        self.conv4 = torch.nn.Conv2d(128, 128, 3, padding=1)
        self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.fc = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(self.conv1(x))
        x = torch.nn.functional.relu(self.conv2(x))
        x = torch.nn.functional.relu(self.conv3(x))
        x = torch.nn.functional.relu(self.conv4(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class DeepCNN(torch.nn.Module):
    """Глубокая CNN: 6 conv слоев."""
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = torch.nn.Conv2d(64, 128, 3, padding=1)
        self.conv4 = torch.nn.Conv2d(128, 128, 3, padding=1)
        self.conv5 = torch.nn.Conv2d(128, 128, 3, padding=1)
        self.conv6 = torch.nn.Conv2d(128, 128, 3, padding=1)
        self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.fc = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(self.conv1(x))
        x = torch.nn.functional.relu(self.conv2(x))
        x = torch.nn.functional.relu(self.conv3(x))
        x = torch.nn.functional.relu(self.conv4(x))
        x = torch.nn.functional.relu(self.conv5(x))
        x = torch.nn.functional.relu(self.conv6(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class ResidualCNN(torch.nn.Module):
    """CNN с Residual связями (4 residual блока)."""
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 32, 3, padding=1)
        self.res1 = ResidualBlock(32, 64)
        self.res2 = ResidualBlock(64, 128)
        self.res3 = ResidualBlock(128, 128)
        self.res4 = ResidualBlock(128, 128)
        self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.fc = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(self.conv(x))
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.res4(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [27]:
# utils
def get_data_loaders(dataset_name, batch_size=128):
    """Загружает и возвращает загрузчики данных для указанного датасета."""
    if dataset_name == 'MNIST':
        transform = transforms.Compose([transforms.ToTensor()])
        train = datasets.MNIST('data', train=True, download=True, transform=transform)
        test = datasets.MNIST('data', train=False, download=True, transform=transform)
        train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)
        class_names = [str(i) for i in range(10)]
        return train_loader, test_loader, class_names
    elif dataset_name == 'CIFAR10':
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        ])
        train = datasets.CIFAR10('data', train=True, download=True, transform=transform_train)
        test = datasets.CIFAR10('data', train=False, download=True, transform=transform_test)
        train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
        test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=2)
        class_names = train.classes
        return train_loader, test_loader, class_names
    else:
        raise ValueError(f"Неизвестный датасет: {dataset_name}")

def plot_grad_flow(named_parameters, save_path):
    """Визуализирует средние значения градиентов по слоям."""
    ave_grads = []
    layers = []
    for n, p in named_parameters:
        if p.requires_grad and p.grad is not None and "bias" not in n:
            layers.append(n)
            ave_grads.append(p.grad.abs().mean().item())
    plt.figure(figsize=(8,4))
    plt.plot(ave_grads, alpha=0.7, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical", fontsize=8)
    plt.xlim(xmin=0, xmax=len(ave_grads))
    plt.xlabel("Слои")
    plt.ylabel("Средний градиент")
    plt.title("Градиентный поток")
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def train(model, train_loader, test_loader, epochs=20, log_prefix="", grad_flow_plot=False, grad_flow_path=None):
    device = next(model.parameters()).device
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = torch.nn.functional.cross_entropy(output, target)
            loss.backward()
            if grad_flow_plot and grad_flow_path and epoch % 5 == 0:
                plot_grad_flow(model.named_parameters(), grad_flow_path.replace('.png', f'_epoch{epoch}.png'))
            optimizer.step()
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
        avg_loss = total_loss / len(train_loader)
        train_acc = correct / total
        test_loss, test_acc = test(model, test_loader)
        history['train_loss'].append(avg_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
    return history

def test(model, test_loader):
    device = next(model.parameters()).device
    model.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    loss /= len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)
    return loss, acc

def get_feature_maps(model, img_tensor):
    """Возвращает feature maps первого conv-слоя для одного изображения."""
    model.eval()
    with torch.no_grad():
        for module in model.modules():
            if isinstance(module, torch.nn.Conv2d):
                return module(img_tensor).cpu().numpy()


def compare_models(histories, labels, save_path):
    """Сохраняет сравнение историй обучения разных моделей в файл."""
    results = {}
    for label, history in zip(labels, histories):
        results[label] = {
            'train_loss': history['train_loss'],
            'test_loss': history['test_loss'],
            'test_acc': history['test_acc'],
        }
    with open(save_path, 'w') as f:
        json.dump(results, f, indent=2)

def count_parameters(model):
    """Возвращает количество обучаемых параметров модели."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# import matplotlib
# matplotlib.use('Agg')
def plot_metrics(histories, labels, save_dir):
    """Строит и сохраняет графики метрик обучения для нескольких моделей."""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    metrics = ['train_loss', 'train_acc', 'test_loss', 'test_acc']
    for metric in metrics:
        plt.figure()
        for history, label in zip(histories, labels):
            plt.plot(history[metric], label=label)
        plt.title(metric)
        plt.xlabel('Epoch')
        plt.ylabel(metric)
        plt.legend()
        plt.savefig(os.path.join(save_dir, f"{metric}.png"))
        plt.close()

def plot_feature_maps(feature_maps, save_path):
    """Визуализирует feature maps первого conv-слоя."""
    fmap = feature_maps[0]
    n = min(fmap.shape[0], 8)
    fig, axes = plt.subplots(1, n, figsize=(2*n, 2))
    for i in range(n):
        axes[i].imshow(fmap[i], cmap='magma')
        axes[i].axis('off')
        axes[i].set_title(f'Channel {i}')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

# Задание 1: Сравнение CNN и полносвязных сетей

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

models = [
    ("ShallowCNN", ShallowCNN().to(device)),
    ("MediumCNN", MediumCNN().to(device)),
    ("DeepCNN", DeepCNN().to(device)),
    ("ResidualCNN", ResidualCNN().to(device)),
]

histories = []
train_times = []
param_counts = []
test_accs = []

## 1.1 Сравнение на MNIST

In [None]:
train_loader, test_loader, _ = get_data_loaders('MNIST', batch_size=128)

for name, model in models:
    params = count_parameters(model)
    param_counts.append((name, params))

    start_time = time.time()
    history = train(model, train_loader, test_loader, epochs=10, log_prefix=name,
        grad_flow_plot=True, grad_flow_path=f'PyTorch4/plots/architecture_analysis/grad_{name}.png')
    train_duration = time.time() - start_time
    train_times.append((name, train_duration))

    _, test_acc = test(model, test_loader)
    test_accs.append((name, test_acc))

    histories.append(history)

    images, _ = next(iter(test_loader))
    img = images[0:1].to(device)
    fmap = get_feature_maps(model, img)
    plot_feature_maps(fmap, save_path=f'PyTorch4/plots/architecture_analysis/fmap_{name}.png')

compare_models(histories, [n for n, _ in models], save_path='PyTorch4/results/architecture_analysis/metrics_depth.json')
plot_metrics(histories, [n for n, _ in models], save_dir='PyTorch4/plots/architecture_analysis/')

print("\nСравнение числа параметров:")
for name, params in param_counts:
    print(f"{name}: {params}")

print("\nВремя обучения (с):")
for name, t_time in train_times:
    print(f"{name}: {t_time:.2f}")

print("\nТочность на тестовом:")
for name, acc in test_accs:
    print(f"{name}: {acc:.4f}")

histories.clear()
train_times.clear()
param_counts.clear()
test_accs.clear()

## 1.2 Сравнение на MNIST

In [None]:
train_loader, test_loader, _ = get_data_loaders('CIFAR10', batch_size=128)

for name, model in models:
    params = count_parameters(model)
    param_counts.append((name, params))
    print(f"\nОбучение {name} параметров: {params}")

    start_time = time.time()
    history = train(model, train_loader, test_loader, epochs=10, log_prefix=name,
        grad_flow_plot=True, grad_flow_path=f'PyTorch4/plots/architecture_analysis/grad_{name}.png')
    train_duration = time.time() - start_time
    train_times.append((name, train_duration))
    print(f"Время обучения: {train_duration:.2f} сек")

    _, test_acc = test(model, test_loader)
    test_accs.append((name, test_acc))
    print(f"Точность тестовой: {test_acc:.4f}")

    histories.append(history)

    images, _ = next(iter(test_loader))
    img = images[0:1].to(device)
    fmap = get_feature_maps(model, img)
    plot_feature_maps(fmap, save_path=f'PyTorch4/plots/architecture_analysis/fmap_{name}.png')

compare_models(histories, [n for n, _ in models], save_path='PyTorch4/results/architecture_analysis/metrics_depth.json')
plot_metrics(histories, [n for n, _ in models], save_dir='PyTorch4/plots/architecture_analysis/')

print("\nСравнение числа параметров:")
for name, params in param_counts:
    print(f"{name}: {params}")

print("\nВремя обучения (с):")
for name, t_time in train_times:
    print(f"{name}: {t_time:.2f}")

print("\nТочность на тестовом:")
for name, acc in test_accs:
    print(f"{name}: {acc:.4f}")