In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets, models
import torch.nn.functional as F
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
batch_size = 128

# Определим аугментации для self-supervised learning
transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.3, 0.1)], p=0.5),
        transforms.RandomGrayscale(p=0.5),
        transforms.RandomApply([transforms.GaussianBlur(3)], p=0.5),
    ]
)

# Датасет CIFAR-10 без разметки (только для обучения)
trainset = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor()]),
)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

Files already downloaded and verified


In [4]:
def EMA(new, alpha=0.99, old=None):
    if old is None:
        return new
    else:
        return old * alpha + (1 - alpha) * new


def loss_fn(x, y):
    # L2 normalization
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)


class MLP(nn.Module):
    def __init__(self, input_dim=512) -> None:  # Изменено на 512
        super(MLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(),
            nn.Linear(4096, 256),
        )

    def forward(self, x):
        return self.net(x)


class TargetModel(nn.Module):
    def __init__(self) -> None:
        super(TargetModel, self).__init__()
        self.encoder = torchvision.models.resnet18(weights=None)  # Используем ResNet-18
        self.encoder.conv1 = torch.nn.Conv2d(
            3, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.encoder.fc = nn.Identity()  # Убираем последний полносвязный слой
        self.encoder.maxpool = torch.nn.Identity()  # Убираем maxpool слой

        self.represent = MLP()  # Используем MLP для проекции

    def forward(self, x):
        x = self.encoder(x)
        x = self.represent(x)  # Передаем через MLP
        return x


class OnlineModel(nn.Module):
    def __init__(self) -> None:
        super(OnlineModel, self).__init__()
        self.encoder = torchvision.models.resnet18(weights=None)  # Используем ResNet-18
        self.encoder.conv1 = torch.nn.Conv2d(
            3, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.encoder.fc = nn.Identity()  # Убираем последний полносвязный слой
        self.encoder.maxpool = torch.nn.Identity()  # Убираем maxpool слой

        self.represent = MLP()  # Используем MLP для проекции

    def forward(self, x):
        x = self.encoder(x)
        x = self.represent(x)  # Передаем через MLP
        return x


class BYOL(nn.Module):
    def __init__(self, moving_average_decay=0.99) -> None:
        super(BYOL, self).__init__()

        self.student_model = OnlineModel()
        self.teacher_model = TargetModel()
        self.moving_average_decay = moving_average_decay
        self.student_predictor = MLP(input_dim=256)

    @torch.no_grad()
    def update_moving_average(self):
        assert self.teacher_model is not None, "Target model has not been created yet"
        for student_params, teacher_params in zip(
            self.student_model.parameters(), self.teacher_model.parameters()
        ):
            old_weight, up_weight = teacher_params.data, student_params.data
            teacher_params.data = EMA(
                old=old_weight, new=up_weight, alpha=self.moving_average_decay
            )

    def forward(self, image1, image2):
        # Проекции студента: спинальные слои + MLP
        student_proj_one = self.student_model(image1)
        student_proj_two = self.student_model(image2)

        # Дополнительный слой MLP для предсказаний
        student_pred_one = self.student_predictor(student_proj_one)
        student_pred_two = self.student_predictor(student_proj_two)

        with torch.no_grad():
            # Модель учителя обрабатывает изображения и делает проекции
            teacher_proj_one = self.teacher_model(image1).detach_()
            teacher_proj_two = self.teacher_model(image2).detach_()

        # Вычисление потерь
        loss_one = loss_fn(student_pred_one, teacher_proj_one)
        loss_two = loss_fn(student_pred_two, teacher_proj_two)

        return (loss_one + loss_two).mean()


In [5]:
torch.cuda.empty_cache()

In [6]:
byol = BYOL().to(device)
opt = torch.optim.Adam(byol.parameters(), lr=0.003)
epochs = 20
byol.train()

BYOL(
  (student_model): OnlineModel(
    (encoder): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): Identity()
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runnin

In [9]:
losses = []
for epoch in range(epochs):
    print("Epoch: %d" % (epoch + 1))
    count = 0
    train_loss = 0
    for x, y in tqdm.tqdm(trainloader):
        count += 1
        opt.zero_grad()

        x1 = transform(x).to(device).float()
        x2 = transform(x).to(device).float()

        loss = byol(x1, x2)
        loss.backward()
        opt.step()
        byol.update_moving_average()

        train_loss += loss.item()

    losses.append(train_loss / count)
    print(f"train_loss: {train_loss / count}")
    torch.save(
        byol.student_model.encoder.state_dict(),
        f"./pretrained_feature_extractors/feature_extractor_{epoch + 1}",
    )

Epoch: 1


100%|██████████| 391/391 [02:28<00:00,  2.63it/s]


train_loss: 0.017667719834696147
Epoch: 2


100%|██████████| 391/391 [02:28<00:00,  2.63it/s]


train_loss: 0.01698606448781574
Epoch: 3


100%|██████████| 391/391 [02:29<00:00,  2.62it/s]


train_loss: 0.0031795616939311365
Epoch: 4


100%|██████████| 391/391 [02:25<00:00,  2.69it/s]


train_loss: 0.004426081967361443
Epoch: 5


100%|██████████| 391/391 [02:26<00:00,  2.68it/s]


train_loss: 0.003942917112994682
Epoch: 6


100%|██████████| 391/391 [02:26<00:00,  2.66it/s]


train_loss: 0.00722458845485583
Epoch: 7


100%|██████████| 391/391 [02:25<00:00,  2.69it/s]


train_loss: 0.004818015003009983
Epoch: 8


100%|██████████| 391/391 [02:24<00:00,  2.70it/s]


train_loss: 0.004555188498133436
Epoch: 9


100%|██████████| 391/391 [02:25<00:00,  2.69it/s]


train_loss: 0.00645019577490404
Epoch: 10


100%|██████████| 391/391 [02:26<00:00,  2.67it/s]


train_loss: 0.004670449006168738
Epoch: 11


100%|██████████| 391/391 [02:26<00:00,  2.68it/s]


train_loss: 0.0032598842018762664
Epoch: 12


100%|██████████| 391/391 [02:24<00:00,  2.70it/s]


train_loss: 0.006401929038498179
Epoch: 13


100%|██████████| 391/391 [02:25<00:00,  2.69it/s]


train_loss: 0.006062440081116031
Epoch: 14


100%|██████████| 391/391 [02:24<00:00,  2.71it/s]


train_loss: 0.01102415006607771
Epoch: 15


100%|██████████| 391/391 [02:24<00:00,  2.71it/s]


train_loss: 0.004931472762323478
Epoch: 16


100%|██████████| 391/391 [02:24<00:00,  2.70it/s]


train_loss: 0.005376499880443487
Epoch: 17


100%|██████████| 391/391 [02:26<00:00,  2.66it/s]


train_loss: 0.002896547212105845
Epoch: 18


100%|██████████| 391/391 [02:27<00:00,  2.65it/s]


train_loss: 0.012474577937363302
Epoch: 19


100%|██████████| 391/391 [02:24<00:00,  2.70it/s]


train_loss: 0.0017250548801658785
Epoch: 20


100%|██████████| 391/391 [02:26<00:00,  2.66it/s]

train_loss: 0.0022341918124033668



