In [1]:
import torch
import wandb

from default_mnist_config import create_default_mnist_config
from diffusion import DiffusionRunner
from models.classifier import ResNet, ResidualBlock, ConditionalResNet
from data_generator import DataGenerator

from tqdm.auto import trange

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
%cd ..

/home/pasha/ml/mmp/neuralbayes/n4


In [3]:
device = torch.device('cuda')
classifier_args = {
    "block": ResidualBlock,
    "layers": [2, 2, 2, 2]
}
model = ResNet(**classifier_args)
model.to(device)

optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_func = torch.nn.CrossEntropyLoss()

In [4]:
datagen = DataGenerator(create_default_mnist_config())
train_generator = datagen.sample_train()

In [5]:
TOTAL_ITERS = 2_000
EVAL_FREQ = 500

### Обучите классификатор только на чистых картинках. Он понадобится нам для классификации условно сгенерированных картинок

In [6]:
model.train()

for iter_idx in trange(1, 1 + TOTAL_ITERS):
    
    """
    train
    """
    x, y = next(train_generator)
    x = x.to(device)
    y = y.to(device)
    
    logits = model(x)
    loss = loss_func(logits, y)
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    if iter_idx % EVAL_FREQ == 0:
        valid_accuracy = 0
        """
        validate
        """
        model.eval()
        valid_count = 0
        valid_good = 0
        with torch.no_grad():
            for (x, y) in datagen.valid_loader:
                x = x.to(device)
                y = y.to(device)
                
                logits = model(x)
                pred = torch.max(logits, dim=1)[1]
                
                valid_count += x.shape[0]
                valid_good += torch.sum(pred == y).item()

        valid_accuracy = valid_good/ valid_count
        model.train()
        print('Clean MNIST classifier\'s accuracy:', valid_accuracy)
        
model.eval()

  0%|          | 0/2000 [00:00<?, ?it/s]

Clean MNIST classifier's accuracy: 0.9433
Clean MNIST classifier's accuracy: 0.9766
Clean MNIST classifier's accuracy: 0.994
Clean MNIST classifier's accuracy: 0.9893


ResNet(
  (conv): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): ResidualBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResidualBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=

In [7]:
torch.save(model.state_dict(), './ddpm_checkpoints/clean_classifier.pth')