In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import copy
import random
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
# Версии библиотек
print(f"PyTorch: {torch.__version__}")
print(f"Torchvision: {torchvision.__version__}")
print(f"NumPy: {np.__version__}")
print(f"Matplotlib: {plt.matplotlib.__version__}")

## Установка Seed

In [None]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## Аугментация данных

In [3]:
class TwoTransform:
    # Возвращает две разные версии одного изображеня
    def __init__(self, base_transform):
        self.base = base_transform
    def __call__(self, x):
        return self.base(x), self.base(x)

# Сдвиг, кадрирование, поворот и нормализация

# Для обучения BYOL
mnist_aug = transforms.Compose([
    transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),
    transforms.RandomRotation(20),
    transforms.RandomAffine(0, translate=(0.1,0.1), scale=(0.9,1.1), shear=10),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # mean и std пикселей на MNIST
])

# Для проверки линейной головы
linear_aug = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

## Пример исходного и аугментированных изображений

In [5]:
mnist_ds = datasets.MNIST(root='./data', train=True, transform=None, download=True)

In [6]:
img, label = mnist_ds[0]
v1, v2 = TwoTransform(mnist_aug)(img)

def unnormalize(tensor):
    return tensor * 0.3081 + 0.1307

v1_img = unnormalize(v1).squeeze().numpy()
v2_img = unnormalize(v2).squeeze().numpy()
orig_img = transforms.ToTensor()(img).squeeze().numpy() 

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(9,3))
axes[0].imshow(orig_img, cmap='gray')
axes[0].set_title("Оригинальное")
axes[1].imshow(v1_img, cmap='gray')
axes[1].set_title("Аугментировано 1")
axes[2].imshow(v2_img, cmap='gray')
axes[2].set_title("Аугментировано 2")
for ax in axes:
    ax.axis('off')
plt.show()

## Энкодер

In [4]:
class Encoder(nn.Module):
    def __init__(self, rep_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1), # 28x28
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.MaxPool2d(2), # 14x14
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2), # 7x7
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten() # 1x1
        )
        self.fc = nn.Linear(128, rep_dim) # 1x1x128 -> 128
    def forward(self, x): return self.fc(self.net(x))

## Проекционная и предикторные головы

In [5]:
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x): return self.net(x)

## BYOL модель

In [6]:
class BYOL:
    def __init__(self, encoder, projector_dim, pred_hidden, tau):
        # Энкодер ученика
        self.student_encoder = encoder
        # Проекционная голова ученика
        self.student_projector = MLP(encoder.fc.out_features, pred_hidden, projector_dim)
        # Предикторная голова ученика (Без нее не сходится( )

        # Энкодер и голова учителя (одинаковое состояние)
        self.student_predictor = MLP(projector_dim, pred_hidden, projector_dim)
        self.teacher_encoder = copy.deepcopy(encoder)
        self.teacher_projector = copy.deepcopy(self.student_projector)
        # Замороска параметров учителя (обновление только через EMA)
        self._set_requires_grad(self.teacher_encoder, False)
        self._set_requires_grad(self.teacher_projector, False)
        self.teacher_encoder.eval()
        self.teacher_projector.eval()
        self.tau = tau
    @staticmethod
    def _set_requires_grad(model, req):
        # Выключает градиенты для всех параметров модели
        for p in model.parameters(): p.requires_grad = req
    def to(self, device):
        self.student_encoder.to(device)
        self.student_projector.to(device)
        self.student_predictor.to(device)
        self.teacher_encoder.to(device)
        self.teacher_projector.to(device)
    def student_forward(self, x):
        y = self.student_encoder(x)
        z = self.student_projector(y)
        p = self.student_predictor(z)
        return y, z, p
    @torch.no_grad()
    def teacher_forward(self, x):
        y = self.teacher_encoder(x)
        z = self.teacher_projector(y)
        return y, z
    @torch.no_grad()
    def update_teacher(self):
        # Обновление для параметров
        for param_q, param_k in zip(self.student_encoder.parameters(), self.teacher_encoder.parameters()):
            param_k.mul_(self.tau).add_(param_q, alpha=1.0 - self.tau)
        for param_q, param_k in zip(self.student_projector.parameters(), self.teacher_projector.parameters()):
            param_k.mul_(self.tau).add_(param_q, alpha=1.0 - self.tau)

        # Обновление для буферов BatchNorm
        for m_s, m_t in zip(self.student_encoder.modules(), self.teacher_encoder.modules()):
            if isinstance(m_s, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                m_t.running_mean.mul_(self.tau).add_(m_s.running_mean, alpha=1.0 - self.tau)
                m_t.running_var.mul_(self.tau).add_(m_s.running_var, alpha=1.0 - self.tau)

        for m_s, m_t in zip(self.student_projector.modules(), self.teacher_projector.modules()):
            if isinstance(m_s, (nn.BatchNorm1d, nn.BatchNorm2d)):
                m_t.running_mean.mul_(self.tau).add_(m_s.running_mean, alpha=1.0 - self.tau)
                m_t.running_var.mul_(self.tau).add_(m_s.running_var, alpha=1.0 - self.tau)

def byol_loss(p, z_target):
    # Приближение двух аугментаций (косинусное сходство)
    p = F.normalize(p, dim=-1)
    z = F.normalize(z_target, dim=-1)
    return 2 - 2 * (p * z).sum(dim=-1)

## Функции для проверки коллапса

In [7]:
def make_collapse_loader(n_samples=512):
    ds = datasets.MNIST(root='./data', train=True, transform=linear_aug, download=True)
    subset_idx = torch.randperm(len(ds))[:n_samples]
    subset = torch.utils.data.Subset(ds, subset_idx)
    loader = DataLoader(subset, batch_size=256, shuffle=False)
    return loader

@torch.no_grad()
def compute_trace_cov(encoder, loader, device):
    encoder.eval()
    feats = []

    for x, _ in loader:
        x = x.to(device)
        y = encoder(x)
        feats.append(y.cpu())

    feats = torch.cat(feats, dim=0)

    # центрируем
    mu = feats.mean(dim=0, keepdim=True)
    X = feats - mu

    # ковариация
    cov = (X.T @ X) / (X.size(0) - 1)

    return torch.trace(cov).item()


## Цикл обучения BYOL

In [8]:
def train_byol(byol, dataloader, optimizer, device, epochs, tau):
    byol.tau = tau
    byol.to(device)

    loss_list = []
    trace_list = []

    collapse_loader = make_collapse_loader()

    for epoch in range(1, epochs+1):

        byol.student_encoder.train()
        byol.student_projector.train()
        byol.student_predictor.train()

        byol.teacher_encoder.eval()
        byol.teacher_projector.eval()

        current_epoch = 0.0

        for (x1, x2), _ in tqdm(dataloader, desc=f"BYOL epoch {epoch}/{epochs}"):
            x1, x2 = x1.to(device), x2.to(device)
            _, z1, p1 = byol.student_forward(x1)
            _, z2, p2 = byol.student_forward(x2)
            with torch.no_grad():
                _, z1_t = byol.teacher_forward(x1)
                _, z2_t = byol.teacher_forward(x2)
            loss = (byol_loss(p1, z2_t).mean() + byol_loss(p2, z1_t).mean()) * 0.5
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            byol.update_teacher()
            current_epoch += loss.item() * x1.size(0)
        avg_loss = current_epoch / len(dataloader.dataset)
        loss_list.append(avg_loss)

        # Коллапс чек
        trace_val = compute_trace_cov(byol.student_encoder, collapse_loader, device)
        trace_list.append(trace_val)

        print(f"Epoch {epoch}: BYOL loss = {avg_loss:.4f}, trace(cov) = {trace_val:.2f}")
    return loss_list, trace_list

## Цикл обучения линейной головы

In [9]:
def train_linear(encoder, train_loader, test_loader, device, epochs, lr):
    encoder.eval()
    # Заморозка всех параметров энкодера
    for p in encoder.parameters(): p.requires_grad = False
    feat_dim = encoder.fc.out_features
    linear = nn.Linear(feat_dim, 10).to(device)
    opt = torch.optim.Adam(linear.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    acc_list = []
    for epoch in range(1, epochs+1):
        linear.train(); current = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad(): feats = encoder(x)
            logits = linear(feats)
            loss = criterion(logits, y)
            opt.zero_grad(); loss.backward(); opt.step()
            current += loss.item() * x.size(0)
        avg = current / len(train_loader.dataset)
        acc = evaluate_encoder_linear(encoder, linear, test_loader, device)
        acc_list.append(acc)
        print(f"Linear eval epoch {epoch}: loss={avg:.4f}, acc={acc:.2f}%")
    return linear, acc_list

## Оценка точности линейной головы

In [10]:
def evaluate_encoder_linear(encoder, linear, test_loader, device):
    encoder.eval(); linear.eval()
    correct = 0; total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            feats = encoder(x)
            preds = linear(feats).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return 100.0 * correct / total

## Загрузка MNIST и применение аугментаций

In [11]:
class MNISTTwoView(datasets.MNIST):
    def __init__(self, root, train, transform, download):
        super().__init__(root=root, train=train, transform=None, download=download)
        self.twotransform = transform
    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        v1, v2 = self.twotransform(img)
        return (v1, v2), target

def make_dataloaders(batch_size_pretrain=256, batch_size_eval=256, seed=42):
    g = torch.Generator().manual_seed(seed)

    # Для обучения BYOL
    pretrain_ds = MNISTTwoView(root='./data', train=True, transform=TwoTransform(mnist_aug), download=True)
    pretrain_loader = DataLoader(pretrain_ds, batch_size=batch_size_pretrain, shuffle=True, num_workers=0, drop_last=True, generator=g)

    # Для обучения и тестирования линейной головы
    train_ds = datasets.MNIST(root='./data', train=True, transform=linear_aug, download=True)
    test_ds = datasets.MNIST(root='./data', train=False, transform=linear_aug, download=True)
    train_loader = DataLoader(train_ds, batch_size=batch_size_eval, shuffle=True, num_workers=0, generator=g) # Винда
    test_loader = DataLoader(test_ds, batch_size=batch_size_eval, shuffle=False, num_workers=0)

    return pretrain_loader, train_loader, test_loader

## Параметры

In [16]:
pretrain_epochs = 25
linear_epochs = 50
tau = 0.999

In [17]:
pretrain_loader, train_loader, test_loader = make_dataloaders()
encoder = Encoder(rep_dim=128)
byol = BYOL(encoder, projector_dim=64, pred_hidden=128, tau=tau)
student_params = list(byol.student_encoder.parameters()) + list(byol.student_projector.parameters()) + list(byol.student_predictor.parameters())
optimizer = torch.optim.Adam(student_params, lr=1e-3, weight_decay=1e-6)

## Обучение BYOL

In [None]:
loss_list, trace_list = train_byol(byol, pretrain_loader, optimizer, device, epochs=pretrain_epochs, tau=tau)

## Обучение линейной головы

In [None]:
frozen_encoder = byol.student_encoder
_, acc_list = train_linear(frozen_encoder, train_loader, test_loader, device, epochs=linear_epochs, lr=1e-3)

## Обучение на случайно инициализированном энкодере

In [None]:
random_encoder = Encoder(rep_dim=128).to(device)

# Заморазка параметров энкодера (т.к. они случайные и не будут меняться)
for p in random_encoder.parameters(): 
    p.requires_grad = False

linear_random, acc_random_list = train_linear(random_encoder, train_loader, test_loader, device, epochs=linear_epochs, lr=1e-3)

print(f"Точность на случайном энкодере: {acc_random_list[-1]:.2f}%")

## Графики процесса обучения

In [None]:
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(range(1, len(loss_list)+1), loss_list, marker='o')
plt.title("BYOL Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.subplot(1,2,2)
plt.plot(range(1, len(acc_list)+1), acc_random_list, marker='o', color='green')
plt.title("Linear Random Encoder Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.tight_layout()
plt.show()

## Trace Cov

In [None]:
plt.figure(figsize=(6,4))
plt.plot(trace_list, marker='o')
plt.title("Коллапс-чек: trace(cov)")
plt.xlabel("Epoch")
plt.ylabel("Trace")
plt.grid(True)
plt.show()


## TSNE График

In [None]:
from sklearn.manifold import TSNE

def visualize_embeddings_tsne(encoder, dataloader, device, n_samples=2000):
    encoder.eval()
    feats_list = []
    labels_list = []

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            feats = encoder(x)
            feats_list.append(feats.cpu())
            labels_list.append(y.cpu())
            if len(torch.cat(labels_list)) >= n_samples:
                break

    feats = torch.cat(feats_list)[:n_samples]
    labels = torch.cat(labels_list)[:n_samples]

    tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, random_state=42)
    emb_2d = tsne.fit_transform(feats.numpy())

    plt.figure(figsize=(8, 8))
    scatter = plt.scatter(emb_2d[:, 0], emb_2d[:, 1], c=labels, cmap='tab10', s=10, alpha=0.8)
    plt.legend(*scatter.legend_elements(), title="Цифры", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.title("TSNE")
    plt.tight_layout()
    plt.show()

visualize_embeddings_tsne(frozen_encoder, test_loader, device)

## Матрица ошибок

In [23]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [24]:
def plot_confusion_matrix(encoder, linear, test_loader, device):
    encoder.eval()
    linear.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            feats = encoder(x)
            preds = linear(feats).argmax(dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(y.cpu())
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    cm = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(range(10)))
    disp.plot(cmap='Blues', xticks_rotation=45)
    plt.title("Матрица ошибок - BYOL")
    plt.show()


In [None]:
plot_confusion_matrix(frozen_encoder, _, test_loader, device)

## Прогон по нескольким сидам

In [29]:
import json
import os

def sweep_tau(tau_list, seeds=[42, 123, 456], pretrain_epochs=25, linear_epochs=50, save_path="tau_sweep_results.json"):
    os.makedirs("results", exist_ok=True)
    
    results = {}

    for tau in tau_list:
        acc_per_seed = []
        print(f"\n=== Tau = {tau} ===")
        for seed in seeds:
            print(f"\nSeed {seed}...")
            # Установка сида
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

            # Даталоадеры
            pretrain_loader, train_loader, test_loader = make_dataloaders()

            # Модель и BYOL
            encoder = Encoder(rep_dim=128)
            byol = BYOL(encoder, projector_dim=64, pred_hidden=128, tau=tau)
            student_params = list(byol.student_encoder.parameters()) + list(byol.student_projector.parameters()) + list(byol.student_predictor.parameters())
            optimizer = torch.optim.Adam(student_params, lr=1e-3, weight_decay=1e-6)

            # Предобучение BYOL
            _ = train_byol(byol, pretrain_loader, optimizer, device, epochs=pretrain_epochs, tau=tau)

            # Линейная оценка
            frozen_encoder = byol.student_encoder
            _, acc_list = train_linear(frozen_encoder, train_loader, test_loader, device, epochs=linear_epochs, lr=1e-3)
            final_acc = acc_list[-1]  # Берем последнюю эпоху

            acc_per_seed.append(final_acc)
            print(f"Сид {seed} завершен: точность линейной головы = {final_acc:.2f}%")

        # Сохраняем результаты по tau
        mean_acc = float(np.mean(acc_per_seed))
        std_acc = float(np.std(acc_per_seed))
        results[str(tau)] = {
            "seed_accuracies": acc_per_seed,
            "mean": mean_acc,
            "std": std_acc
        }

        print(f"\nTau={tau}: mean={mean_acc:.2f}%, std={std_acc:.2f}%")

    # Сохраняем в JSON
    with open(save_path, "w") as f:
        json.dump(results, f, indent=4)

    return results


In [None]:
tau_values = [0.98]
results = sweep_tau(tau_values, save_path="results/tau_98.json")


In [None]:
tau_values = [0.995]
results = sweep_tau(tau_values, save_path="results/tau_995.json")


In [None]:
tau_values = [0.999]
results = sweep_tau(tau_values, save_path="results/tau_999.json")


## Построение графика по результатам прогонов

In [None]:
# Пути к файлам
files = {
    "0.98": "tau_98.json",
    "0.995": "tau_995.json",
    "0.999": "tau_999.json"
}

taus = []
means = []
stds = []

# Загружаем данные
for tau, path in files.items():
    with open(path, "r") as f:
        data = json.load(f)
        result = data[tau]
        taus.append(float(tau))
        means.append(result["mean"])
        stds.append(result["std"])

# Преобразуем в массивы
taus = np.array(taus)
means = np.array(means)
stds = np.array(stds)

plt.figure(figsize=(8, 5))
plt.errorbar(taus, means, yerr=stds, fmt='o-', capsize=5)
plt.title("Зависимость точности от τ")
plt.xlabel("τ")
plt.ylabel("Точность (%)")
plt.grid(True)
plt.show()

## Абляции для аугментаций

In [21]:
aug_soft = transforms.Compose([
    transforms.RandomResizedCrop(28, scale=(0.95, 1.0)),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

aug_medium = transforms.Compose([
    transforms.RandomResizedCrop(28, scale=(0.5, 1.0)),
    transforms.RandomRotation(30),
    transforms.RandomAffine(0, translate=(0.25,0.25), scale=(0.65,1.25), shear=15),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

aug_hard = transforms.Compose([
    transforms.RandomAffine(
        degrees=40,
        translate=(0.9, 0.9),    # сдвиг цифры за пределы изображения
        scale=(0.4, 1.5),
        shear=30
    ),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [22]:
def ablation_augmentations(tau=0.98, 
                           aug_list=[('soft', aug_soft), ('med', aug_medium), ('hard', aug_hard)], 
                           seeds=[42,43],
                           pretrain_epochs=25,
                           linear_epochs=50):
    table = {}

    for name, aug in aug_list:
        accs = []
        print(f"\n=== Аугментация: {name} ===")
        for s in seeds:
            print(f"Seed {s}...")
            # Установка сидов
            random.seed(s)
            np.random.seed(s)
            torch.manual_seed(s)
            torch.cuda.manual_seed_all(s)

            # Даталоадеры с текущей аугментацией для BYOL
            pretrain_ds = MNISTTwoView(root='./data', train=True, transform=TwoTransform(aug), download=True)
            pretrain_loader = DataLoader(pretrain_ds, batch_size=256, shuffle=True, num_workers=0, drop_last=True)

            # Даталоадеры для линейной головы
            train_ds = datasets.MNIST(root='./data', train=True, transform=linear_aug, download=True)
            test_ds = datasets.MNIST(root='./data', train=False, transform=linear_aug, download=True)
            train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=0)
            test_loader = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=0)

            # Модель и BYOL
            encoder = Encoder(rep_dim=128)
            byol = BYOL(encoder, projector_dim=64, pred_hidden=128, tau=tau)
            student_params = list(byol.student_encoder.parameters()) + list(byol.student_projector.parameters()) + list(byol.student_predictor.parameters())
            optimizer = torch.optim.Adam(student_params, lr=1e-3, weight_decay=1e-6)

            # Предобучение BYOL
            _ = train_byol(byol, pretrain_loader, optimizer, device, epochs=pretrain_epochs, tau=tau)

            # Линейная оценка
            frozen_encoder = byol.student_encoder
            _, acc_list = train_linear(frozen_encoder, train_loader, test_loader, device, epochs=linear_epochs, lr=1e-3)
            final_acc = acc_list[-1]
            accs.append(final_acc)
            print(f"Seed {s} завершен: точность линейной головы = {final_acc:.2f}%")

        table[name] = {'acc_mean': np.mean(accs), 'acc_std': np.std(accs, ddof=1)}

    # Визуализация
    labels = list(table.keys())
    means = [table[k]['acc_mean'] for k in labels]
    stds = [table[k]['acc_std'] for k in labels]

    plt.figure(figsize=(6,4))
    plt.bar(labels, means, yerr=stds, capsize=5, color=['skyblue','orange','green'])
    plt.ylabel("Точность (%)")
    plt.title(f"Абляции: Сила аугментации (tau={tau})")
    plt.ylim(0, 100)
    plt.show()

    return table

In [None]:
results_ablation = ablation_augmentations(tau=0.98)
print(results_ablation)

Результаты {'soft': {'acc_mean': np.float64(96.4), 'acc_std': np.float64(0.6929646455628193)}, 'med': {'acc_mean': np.float64(98.20), 'acc_std': np.float64(0.1979898987322341)}, 'hard': {'acc_mean': np.float64(79.87), 'acc_std': np.float64(5.289158723275378)}}