In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import random
import torchvision





import torch
import torch.nn as nn
import torch.nn.functional as F


def label_smoothing_loss(inputs, targets, alpha):
    log_probs = torch.nn.functional.log_softmax(inputs, dim=1, _stacklevel=5)
    kl = -log_probs.mean(dim=1)
    xent = torch.nn.functional.nll_loss(log_probs, targets, reduction="none")
    loss = (1 - alpha) * xent + alpha * kl
    return loss


class GhostBatchNorm(nn.BatchNorm2d):
    def __init__(self, num_features, num_splits, **kw):
        super().__init__(num_features, **kw)

        running_mean = torch.zeros(num_features * num_splits)
        running_var = torch.ones(num_features * num_splits)

        self.weight.requires_grad = False
        self.num_splits = num_splits
        self.register_buffer("running_mean", running_mean)
        self.register_buffer("running_var", running_var)

    def train(self, mode=True):
        if (self.training is True) and (mode is False):
            # lazily collate stats when we are going to use them
            self.running_mean = torch.mean(
                self.running_mean.view(self.num_splits, self.num_features), dim=0
            ).repeat(self.num_splits)
            self.running_var = torch.mean(
                self.running_var.view(self.num_splits, self.num_features), dim=0
            ).repeat(self.num_splits)
        return super().train(mode)

    def forward(self, input):
        n, c, h, w = input.shape
        if self.training or not self.track_running_stats:
            assert n % self.num_splits == 0, f"Batch size ({n}) must be divisible by num_splits ({self.num_splits}) of GhostBatchNorm"
            return F.batch_norm(
                input.view(-1, c * self.num_splits, h, w),
                self.running_mean,
                self.running_var,
                self.weight.repeat(self.num_splits),
                self.bias.repeat(self.num_splits),
                True,
                self.momentum,
                self.eps,
            ).view(n, c, h, w)
        else:
            return F.batch_norm(
                input,
                self.running_mean[: self.num_features],
                self.running_var[: self.num_features],
                self.weight,
                self.bias,
                False,
                self.momentum,
                self.eps,
            )


def conv_bn_relu(c_in, c_out, kernel_size=(3, 3), padding=(1, 1)):
    return nn.Sequential(
        nn.Conv2d(c_in, c_out, kernel_size=kernel_size, padding=padding, bias=False),
        GhostBatchNorm(c_out, num_splits=16),
        nn.CELU(alpha=0.3),
    )


def conv_pool_norm_act(c_in, c_out):
    return nn.Sequential(
        nn.Conv2d(c_in, c_out, kernel_size=(3, 3), padding=(1, 1), bias=False),
        nn.MaxPool2d(kernel_size=2, stride=2),
        GhostBatchNorm(c_out, num_splits=16),
        nn.CELU(alpha=0.3),
    )


def patch_whitening(data, patch_size=(3, 3)):
    # Compute weights from data such that
    # torch.std(F.conv2d(data, weights), dim=(2, 3))
    # is close to 1.
    h, w = patch_size
    c = data.size(1)
    patches = data.unfold(2, h, 1).unfold(3, w, 1)
    patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32)

    n, c, h, w = patches.shape
    X = patches.reshape(n, c * h * w)
    X = X / (X.size(0) - 1) ** 0.5
    covariance = X.t() @ X

    eigenvalues, eigenvectors = torch.linalg.eigh(covariance)

    eigenvalues = eigenvalues.flip(0)

    eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0)

    return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1)


class ResNetBagOfTricks(nn.Module):
    def __init__(self, first_layer_weights, c_in, c_out, scale_out):
        super().__init__()

        c = first_layer_weights.size(0)

        conv1 = nn.Conv2d(c_in, c, kernel_size=(3, 3), padding=(1, 1), bias=False)
        conv1.weight.data = first_layer_weights
        conv1.weight.requires_grad = False

        self.conv1 = conv1
        self.conv2 = conv_bn_relu(c, 64, kernel_size=(1, 1), padding=0)
        self.conv3 = conv_pool_norm_act(64, 128)
        self.conv4 = conv_bn_relu(128, 128)
        self.conv5 = conv_bn_relu(128, 128)
        self.conv6 = conv_pool_norm_act(128, 256)
        self.conv7 = conv_pool_norm_act(256, 512)
        self.conv8 = conv_bn_relu(512, 512)
        self.conv9 = conv_bn_relu(512, 512)
        self.pool10 = nn.MaxPool2d(kernel_size=4, stride=4)
        self.linear11 = nn.Linear(512, c_out, bias=False)
        self.scale_out = scale_out

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x + self.conv5(self.conv4(x))
        x = self.conv6(x)
        x = self.conv7(x)
        x = x + self.conv9(self.conv8(x))
        x = self.pool10(x)
        x = x.reshape(x.size(0), x.size(1))
        x = self.linear11(x)
        x = self.scale_out * x
        return x

Model = ResNetBagOfTricks









# Параметры
learning_rate = 0.001  # Уменьшена для CIFAR-10
batch_size = 128
epochs = 40  # CIFAR-10 требует больше эпох
num_seeds = 1  # Количество запусков с разными сидами
c_1 = 1
c_11 = 0.1

# Устройство
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Трансформации для CIFAR-10
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Датасет
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

def preprocess_data(data, device, dtype):
    # Convert to torch float16 tensor
    data = torch.tensor(data, device=device).to(dtype)

    # Normalize
    mean = torch.tensor([125.31, 122.95, 113.87], device=device).to(dtype)
    std = torch.tensor([62.99, 62.09, 66.70], device=device).to(dtype)
    data = (data - mean) / std

    # Permute data from NHWC to NCHW format
    data = data.permute(0, 3, 1, 2)

    return data

def load_cifar10(device, dtype, data_dir="~/data"):
    train = torchvision.datasets.CIFAR10(root=data_dir, download=True)
    valid = torchvision.datasets.CIFAR10(root=data_dir, train=False)

    train_data = preprocess_data(train.data, device, dtype)
    valid_data = preprocess_data(valid.data, device, dtype)

    train_targets = torch.tensor(train.targets).to(device)
    valid_targets = torch.tensor(valid.targets).to(device)

    # Pad 32x32 to 40x40
    train_data = nn.ReflectionPad2d(4)(train_data)

    return train_data, train_targets, valid_data, valid_targets

# Функция для вычисления accuracy
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    return 100 * correct / total

# Функция предсказания
def predict(model, parameters, inputs):
    with torch.no_grad():
        for param, saved_param in zip([p for p in model.parameters() if p.grad is not None], parameters):
            param.data.copy_(saved_param)
        outputs = model(inputs)
    return outputs

# Функция обучения
def train(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    dtype = torch.float16 if device.type != "cpu" else torch.float32
    train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype)
    weights = patch_whitening(train_data[:10000, :, 4:-4, 4:-4])

    model = Model(weights, c_in=3, c_out=10, scale_out=0.125).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    epoch_losses = []
    epoch_gradients = []
    epoch_params = []
    inputs_saved = []
    targets_saved = []

    for epoch in range(epochs):
        model.train()
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()

            if batch_idx == len(train_loader) - 1:
                gradients = [param.grad.clone() for param in model.parameters() if param.grad is not None]
                epoch_gradients.append(gradients)
                epoch_losses.append(loss.item())
                params = [param.clone().detach() for param in model.parameters() if param.grad is not None]
                epoch_params.append(params)
                inputs_saved.append(inputs)
                targets_saved.append(targets)

            optimizer.step()

        # Выводим accuracy на валидации после каждой эпохи
        valid_acc = evaluate(model, test_loader)
        print(f"Epoch {epoch+1}/{epochs}, Validation Accuracy: {valid_acc:.2f}%")

    # Вычисление скалярных произведений
    x_star = [param.clone().detach() for param in epoch_params[-1]]
    inner_products = []
    inner_products2 = []
    for epoch in range(epochs):
        grad_vector = torch.cat([g.view(-1) for g in epoch_gradients[epoch]])
        param_vector = torch.cat([p.view(-1) for p in epoch_params[epoch]])
        x_star_vector = torch.cat([x.view(-1) for x in x_star])
        inner_product = torch.dot(grad_vector, param_vector - x_star_vector) - c_1 * epoch_losses[epoch] + c_1 * criterion(predict(model, epoch_params[epoch], inputs_saved[epoch]), targets_saved[epoch])
        inner_products.append(inner_product.item())
        inner_product2 = torch.dot(grad_vector, param_vector - x_star_vector) - c_11 * torch.norm(grad_vector) ** 2
        inner_products2.append(inner_product2.item())

    return inner_products, inner_products2

# Запуск обучения для нескольких сидов
all_inner_products = []
all_inner_products2 = []
for seed in range(num_seeds):
    seed += 10
    print(f"Training with seed {seed}")
    inner_products, inner_products2 = train(seed)
    all_inner_products.append(inner_products)
    all_inner_products2.append(inner_products2)
    
import numpy as np
import matplotlib.pyplot as plt

# Convert tensors in inner products to numpy arrays if necessary
all_inner_products = np.array([[inner_product.cpu().numpy() if isinstance(inner_product, torch.Tensor) else inner_product
                                for inner_product in inner_list]
                               for inner_list in all_inner_products])
all_inner_products2 = np.array([[inner_product2.cpu().numpy() if isinstance(inner_product2, torch.Tensor) else inner_product2
                                 for inner_product2 in inner_list]
                                for inner_list in all_inner_products2])

# Compute statistics
mean_inner_products = np.mean(all_inner_products, axis=0)
min_inner_products = np.min(all_inner_products, axis=0)
max_inner_products = np.max(all_inner_products, axis=0)

mean_inner_products2 = np.mean(all_inner_products2, axis=0)
min_inner_products2 = np.min(all_inner_products2, axis=0)
max_inner_products2 = np.max(all_inner_products2, axis=0)

print(min_inner_products)
print(min_inner_products2)

RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.