In [11]:
import logging
import time
import torch 
import torch.nn as nn
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import json

In [12]:
# models
class BasicResidualBlock(nn.Module):
    """Базовый Residual блок: 2 conv слоя."""
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = nn.functional.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return nn.functional.relu(out)

class BottleneckResidualBlock(nn.Module):
    """Bottleneck Residual блок: 1x1 -> 3x3 -> 1x1."""
    def __init__(self, in_channels, out_channels, stride=1, bottleneck_ratio=4):
        super().__init__()
        bottleneck_channels = out_channels // bottleneck_ratio
        self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(bottleneck_channels)
        self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(bottleneck_channels)
        self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = nn.functional.relu(self.bn1(self.conv1(x)))
        out = nn.functional.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        return nn.functional.relu(out)

class WideResidualBlock(nn.Module):
    """Wide Residual блок: увеличенное число каналов."""
    def __init__(self, in_channels, out_channels, stride=1, widen_factor=3):
        super().__init__()
        mid_channels = out_channels * widen_factor
        self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = nn.functional.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return nn.functional.relu(out)

class BasicResNet(nn.Module):
    """ResNet с базовыми блоками."""
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.layer1 = BasicResidualBlock(16, 32, stride=2)
        self.layer2 = BasicResidualBlock(32, 64, stride=2)
        self.layer3 = BasicResidualBlock(64, 128, stride=2)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = nn.functional.relu(self.conv(x))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class BottleneckResNet(nn.Module):
    """ResNet с bottleneck блоками."""
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.layer1 = BottleneckResidualBlock(16, 32, stride=2)
        self.layer2 = BottleneckResidualBlock(32, 64, stride=2)
        self.layer3 = BottleneckResidualBlock(64, 128, stride=2)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = nn.functional.relu(self.conv(x))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class WideResNet(nn.Module):
    """ResNet с wide residual блоками."""
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.layer1 = WideResidualBlock(16, 32, stride=2, widen_factor=3)
        self.layer2 = WideResidualBlock(32, 64, stride=2, widen_factor=3)
        self.layer3 = WideResidualBlock(64, 128, stride=2, widen_factor=3)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = nn.functional.relu(self.conv(x))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [13]:
# utils
def get_data_loaders(dataset_name, batch_size=128):
    """Загружает и возвращает загрузчики данных для указанного датасета."""
    if dataset_name == 'MNIST':
        transform = transforms.Compose([transforms.ToTensor()])
        train = datasets.MNIST('data', train=True, download=True, transform=transform)
        test = datasets.MNIST('data', train=False, download=True, transform=transform)
        train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)
        class_names = [str(i) for i in range(10)]
        return train_loader, test_loader, class_names
    elif dataset_name == 'CIFAR10':
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        ])
        train = datasets.CIFAR10('data', train=True, download=True, transform=transform_train)
        test = datasets.CIFAR10('data', train=False, download=True, transform=transform_test)
        train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
        test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=2)
        class_names = train.classes
        return train_loader, test_loader, class_names
    else:
        raise ValueError(f"Неизвестный датасет: {dataset_name}")

def plot_grad_flow(named_parameters, save_path):
    """Визуализирует средние значения градиентов по слоям."""
    ave_grads = []
    layers = []
    for n, p in named_parameters:
        if p.requires_grad and p.grad is not None and "bias" not in n:
            layers.append(n)
            ave_grads.append(p.grad.abs().mean().item())
    plt.figure(figsize=(8,4))
    plt.plot(ave_grads, alpha=0.7, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical", fontsize=8)
    plt.xlim(xmin=0, xmax=len(ave_grads))
    plt.xlabel("Слои")
    plt.ylabel("Средний градиент")
    plt.title("Градиентный поток")
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def train(model, train_loader, test_loader, epochs=20, log_prefix="", grad_flow_plot=False, grad_flow_path=None):
    device = next(model.parameters()).device
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = torch.nn.functional.cross_entropy(output, target)
            loss.backward()
            if grad_flow_plot and grad_flow_path and epoch % 5 == 0:
                plot_grad_flow(model.named_parameters(), grad_flow_path.replace('.png', f'_epoch{epoch}.png'))
            optimizer.step()
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
        avg_loss = total_loss / len(train_loader)
        train_acc = correct / total
        test_loss, test_acc = test(model, test_loader)
        history['train_loss'].append(avg_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
    return history

def test(model, test_loader):
    device = next(model.parameters()).device
    model.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    loss /= len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)
    return loss, acc

def compare_models(histories, labels, save_path):
    """Сохраняет сравнение историй обучения разных моделей в файл."""
    results = {}
    for label, history in zip(labels, histories):
        results[label] = {
            'train_loss': history['train_loss'],
            'test_loss': history['test_loss'],
            'test_acc': history['test_acc'],
        }
    with open(save_path, 'w') as f:
        json.dump(results, f, indent=2)

def count_parameters(model):
    """Возвращает количество обучаемых параметров модели."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def plot_metrics(histories, labels, save_dir):
    """Строит и сохраняет графики метрик обучения для нескольких моделей."""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    metrics = ['train_loss', 'train_acc', 'test_loss', 'test_acc']
    for metric in metrics:
        plt.figure()
        for history, label in zip(histories, labels):
            plt.plot(history[metric], label=label)
        plt.title(metric)
        plt.xlabel('Epoch')
        plt.ylabel(metric)
        plt.legend()
        plt.savefig(os.path.join(save_dir, f"{metric}.png"))
        plt.close()

# Задание 3: Кастомные слои и эксперименты

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader, _ = get_data_loaders('CIFAR10', batch_size=128)

models = [
    ("BasicResNet", BasicResNet().to(device)),
    ("BottleneckResNet", BottleneckResNet().to(device)),
    ("WideResNet", WideResNet().to(device)),
]

histories = []
param_counts = []
train_times = []
test_accs = []

for name, model in models:
    params = count_parameters(model)
    param_counts.append((name, params))

    start_time = time.time()
    history = train(model, train_loader, test_loader, epochs=20, log_prefix=name)
    train_duration = time.time() - start_time
    train_times.append((name, train_duration))

    _, test_acc = test(model, test_loader)
    test_accs.append((name, test_acc))

    histories.append(history)

compare_models(histories, [n for n, _ in models], save_path='PyTorch4/Homework_4/results/custom_layers/metrics_resblocks.json')
plot_metrics(histories, [n for n, _ in models], save_dir='PyTorch4/plots/custom_layers/')

print("\nСравнение числа параметров:")
for name, params in param_counts:
    print(f"{name}: {params}")

print("\nВремя обучения (сек):")
for name, t_time in train_times:
    print(f"{name}: {t_time:.2f}")

print("\nТочность на тестовом множестве:")
for name, acc in test_accs:
    print(f"{name}: {acc:.4f}")