In [1]:
import torch
import wandb

from default_mnist_config import create_default_mnist_config
from diffusion import DiffusionRunner
from models.classifier import ResNet, ResidualBlock, ConditionalResNet
from data_generator import DataGenerator

from tqdm.auto import trange

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [2]:
device = torch.device('cuda')
classifier_args = {
    "block": ResidualBlock,
    "layers": [2, 2, 2, 2]
}
model = ResNet(**classifier_args)
model.to(device)

optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_func = torch.nn.CrossEntropyLoss()

In [3]:
datagen = DataGenerator(create_default_mnist_config())
train_generator = datagen.sample_train()

In [4]:
TOTAL_ITERS = 2_000
EVAL_FREQ = 500

### Обучите классификатор только на чистых картинках. Он понадобится нам для классификации условно сгенерированных картинок

In [5]:
model.train()

for iter_idx in trange(1, 1 + TOTAL_ITERS):
    
    """
    train
    """
    ####
    (X, y) = next(train_generator)
    X = X.to(device)
    y = y.to(device)

    pred_labels = model(X)
    loss = loss_func(pred_labels, y)

    optim.zero_grad()
    loss.backward()
    optim.step()
    
    ####
    if iter_idx % EVAL_FREQ == 0:
        valid_accuracy = 0
        """
        validate
        """
        ####
        valid_count = 0

        model.eval()
        with torch.no_grad():
            """
            validate classifier
            """
            for (X, y) in datagen.valid_loader:
                X = X.to(device)
                y = y.to(device)
                pred_labels = model(X)

                valid_accuracy += (torch.argmax(pred_labels, dim=1) == y).sum()
                valid_count += X.size(0)

        valid_accuracy = valid_accuracy / valid_count
        print('Clean MNIST classifier\'s accuracy:', valid_accuracy)
        model.train()
        
      ###
    
        #print('Clean MNIST classifier\'s accuracy:', valid_accuracy)
model.eval()

  0%|          | 0/2000 [00:00<?, ?it/s]

Clean MNIST classifier's accuracy: 0.101
Clean MNIST classifier's accuracy: 0.9741
Clean MNIST classifier's accuracy: 0.9939
Clean MNIST classifier's accuracy: 0.9904


In [7]:
torch.save(model.state_dict(), './ddpm_checkpoints/clean_classifier.pth')