In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import random
from copy import deepcopy

In [3]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(4242)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
epochs = 45
tau = 0.98  # EMA коэффициент

# Аугментация изображений

In [5]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
])

transform_plain = transforms.ToTensor()  # Для линейной оценки

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform_plain)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Архитектура энкодера

In [6]:
class Encoder(nn.Module):
    def __init__(self, feature_dim=128):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, 2, 1),  # 28x28 -> 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1),  # 14x14 -> 7x7
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, feature_dim)
        )
        
    def forward(self, x):
        x = self.conv(x)
        x = F.normalize(x, dim=-1)
        return x

# Проекционная и предсказательная голова

In [7]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim=128, hidden_dim=256, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x):
        return self.net(x)

class PredictionHead(nn.Module):
    def __init__(self, in_dim=128, hidden_dim=256, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x):
        return self.net(x)

# Определение модели

In [8]:
class BYOL(nn.Module):
    def __init__(self, encoder, tau=0.99):
        super().__init__()
        self.online_encoder = encoder
        self.target_encoder = deepcopy(encoder)
        self.projector_online = ProjectionHead()
        self.projector_target = deepcopy(self.projector_online)
        self.predictor = PredictionHead()
        self.tau = tau
        
        for p in self.target_encoder.parameters():
            p.requires_grad = False
        for p in self.projector_target.parameters():
            p.requires_grad = False
    
    @torch.no_grad()
    def _update_target(self):
        for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_k.data = self.tau * param_k.data + (1 - self.tau) * param_q.data
        for param_q, param_k in zip(self.projector_online.parameters(), self.projector_target.parameters()):
            param_k.data = self.tau * param_k.data + (1 - self.tau) * param_q.data

    def forward(self, x1, x2):
        # online
        q1 = self.predictor(self.projector_online(self.online_encoder(x1)))
        q2 = self.predictor(self.projector_online(self.online_encoder(x2)))
        # target
        with torch.no_grad():
            z1 = self.projector_target(self.target_encoder(x1))
            z2 = self.projector_target(self.target_encoder(x2))
        return q1, q2, z1.detach(), z2.detach()

In [9]:
def loss_fn(q, z):
    q = F.normalize(q, dim=-1)
    z = F.normalize(z, dim=-1)
    return 2 - 2 * (q * z).sum(dim=-1).mean()

# Обучение модели

In [10]:
import matplotlib.pyplot as plt
byol_losses = []
linear_losses = []
linear_accs = []

In [11]:
encoder = Encoder().to(device)
model = BYOL(encoder, tau).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, epochs + 1):
    model.train()
    total_loss = 0
    for (x1, _), (x2, _) in zip(train_loader, train_loader):
        x1, x2 = x1.to(device), x2.to(device)

        q1, q2, z1, z2 = model(x1, x2)
        loss = loss_fn(q1, z2) / 2 + loss_fn(q2, z1) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        model._update_target()

        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    byol_losses.append(avg_loss)
    print(f"Эпоха {epoch}, BYOL Loss = {avg_loss:.4f}")

Эпоха 1, BYOL Loss = 0.7912
Эпоха 2, BYOL Loss = 0.6209
Эпоха 3, BYOL Loss = 0.8714
Эпоха 4, BYOL Loss = 0.6629
Эпоха 5, BYOL Loss = 0.4309
Эпоха 6, BYOL Loss = 0.3961
Эпоха 7, BYOL Loss = 0.4984
Эпоха 8, BYOL Loss = 0.8659
Эпоха 9, BYOL Loss = 0.8070
Эпоха 10, BYOL Loss = 0.7648
Эпоха 11, BYOL Loss = 0.8293
Эпоха 12, BYOL Loss = 0.9720
Эпоха 13, BYOL Loss = 0.9365
Эпоха 14, BYOL Loss = 1.0937
Эпоха 15, BYOL Loss = 1.3475
Эпоха 16, BYOL Loss = 1.0642
Эпоха 17, BYOL Loss = 0.7580
Эпоха 18, BYOL Loss = 0.5762
Эпоха 19, BYOL Loss = 0.4906
Эпоха 20, BYOL Loss = 0.4757
Эпоха 21, BYOL Loss = 0.4628
Эпоха 22, BYOL Loss = 0.4283
Эпоха 23, BYOL Loss = 0.4222
Эпоха 24, BYOL Loss = 0.3697
Эпоха 25, BYOL Loss = 0.3480
Эпоха 26, BYOL Loss = 0.4619
Эпоха 27, BYOL Loss = 0.4077
Эпоха 28, BYOL Loss = 0.4415
Эпоха 29, BYOL Loss = 0.5137
Эпоха 30, BYOL Loss = 0.5171
Эпоха 31, BYOL Loss = 0.6851
Эпоха 32, BYOL Loss = 0.6804


# Заморозка эмбеддингов и линейная голова

In [None]:
class LinearClassifier(nn.Module):
    def __init__(self, encoder, num_classes=10):
        super().__init__()
        self.encoder = encoder
        for p in self.encoder.parameters():
            p.requires_grad = False
        self.fc = nn.Linear(128, num_classes)
    def forward(self, x):
        with torch.no_grad():
            features = self.encoder(x)
        return self.fc(features)

linear_model = LinearClassifier(model.online_encoder).to(device)
optimizer = optim.Adam(linear_model.fc.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(1, 31):
    linear_model.train()
    total_loss = 0
    correct = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        output = linear_model(x)
        loss = criterion(output, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (output.argmax(1) == y).sum().item()

    acc = 100. * correct / len(train_loader.dataset)
    linear_losses.append(total_loss / len(train_loader))
    linear_accs.append(acc)
    print(f"[Linear Eval] Эпоха {epoch}, Loss={total_loss/len(train_loader):.4f}, Точность={acc:.2f}%")


In [None]:
import matplotlib.pyplot as plt

# График BYOL loss
plt.figure(figsize=(8, 4))
plt.plot(byol_losses, marker='o')
plt.title("Сходимость BYOL Loss")
plt.xlabel("Эпоха")
plt.ylabel("BYOL Loss")
plt.grid(True)
plt.show()

# Графики линейной головы
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(linear_losses, marker='o', color='orange')
plt.title("Сходимость Loss линейной головы")
plt.xlabel("Эпоха")
plt.ylabel("Loss")
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(linear_accs, marker='o', color='green')
plt.title("Точность линейной головы")
plt.xlabel("Эпоха")
plt.ylabel("Accuracy (%)")
plt.grid(True)

plt.tight_layout()
plt.show()


# TSNE Распределение эмбеддингов

In [None]:
from sklearn.manifold import TSNE
import torch

# -----------------------------
# 1. Собираем эмбеддинги для тестового датасета
# -----------------------------
model.online_encoder.eval()
features = []
labels_list = []

with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        feats = model.online_encoder(x)  # эмбеддинги размерности 128
        features.append(feats.cpu())
        labels_list.append(y)

features = torch.cat(features).numpy()  # (N, 128)
labels = torch.cat(labels_list).numpy()  # (N,)

# -----------------------------
# 2. t-SNE
# -----------------------------
tsne = TSNE(n_components=2, perplexity=30, random_state=42)
features_2d = tsne.fit_transform(features)

# -----------------------------
# 3. Визуализация
# -----------------------------
plt.figure(figsize=(8, 8))
scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, cmap='tab10', s=5)
plt.legend(*scatter.legend_elements(), title="Digits")
plt.title("t-SNE визуализация эмбеддингов BYOL на MNIST")
plt.show()


50e theta 0.99 -  0.4760
50e theta 0.98 -  0.1897