In [None]:
!if [ -d HSE ]; then rm -Rf HSE; fi
!git clone https://github.com/Tikhon239/HSE
!cp -a /content/HSE/BayesML/DDPM/* .

In [None]:
!pip3 install wandb
!pip3 install ml_collections

In [3]:
import torch
import wandb

from default_mnist_config import create_default_mnist_config
from diffusion import DiffusionRunner
from models.classifier import ResNet, ResidualBlock, ConditionalResNet
from data_generator import DataGenerator

from tqdm.auto import trange

In [4]:
device = torch.device('cuda')
classifier_args = {
    "block": ResidualBlock,
    "layers": [2, 2, 2, 2]
}
model = ResNet(**classifier_args)
model.to(device)

optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_func = torch.nn.CrossEntropyLoss()

In [None]:
datagen = DataGenerator(create_default_mnist_config())
train_generator = datagen.sample_train()

In [6]:
TOTAL_ITERS = 2_000
EVAL_FREQ = 500

### Обучите классификатор только на чистых картинках. Он понадобится нам для классификации условно сгенерированных картинок

In [7]:
model.train()

for iter_idx in trange(1, 1 + TOTAL_ITERS):
    
    (X, y) = next(train_generator)
    X = X.to(device)
    y = y.to(device)

    pred_labels = model(X)
    loss = loss_func(pred_labels, y)

    optim.zero_grad()
    loss.backward()
    optim.step()
    
    if iter_idx % EVAL_FREQ == 0:
        valid_accuracy = 0
        valid_count = 0

        model.eval()
        with torch.no_grad():
            """
            validate classifier
            """
            for (X, y) in datagen.valid_loader:
                X = X.to(device)
                y = y.to(device)
                pred_labels = model(X)

                valid_accuracy += (torch.argmax(pred_labels, dim=1) == y).sum()
                valid_count += X.size(0)

        valid_accuracy = valid_accuracy / valid_count
        print('Clean MNIST classifier\'s accuracy:', valid_accuracy)
        model.train()

model.eval()

  0%|          | 0/2000 [00:00<?, ?it/s]

Clean MNIST classifier's accuracy: tensor(0.9871, device='cuda:0')
Clean MNIST classifier's accuracy: tensor(0.9931, device='cuda:0')
Clean MNIST classifier's accuracy: tensor(0.9885, device='cuda:0')
Clean MNIST classifier's accuracy: tensor(0.9833, device='cuda:0')


ResNet(
  (conv): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): ResidualBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResidualBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=

In [9]:
torch.save(model.state_dict(), './ddpm_checkpoints/clean_classifier.pth')