In [None]:
import torch
import wandb

from default_mnist_config import create_default_mnist_config
from diffusion import DiffusionRunner
from models.classifier import ResNet, ResidualBlock, ConditionalResNet
from data_generator import DataGenerator

from tqdm.auto import trange

import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [None]:
device = torch.device('cuda')
classifier_args = {
    "block": ResidualBlock,
    "layers": [2, 2, 2, 2]
}
model = ResNet(**classifier_args)
model.to(device)

optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_func = torch.nn.CrossEntropyLoss()

In [None]:
datagen = DataGenerator(create_default_mnist_config())
train_generator = datagen.sample_train()

In [None]:
TOTAL_ITERS = 2_000
EVAL_FREQ = 500

### Обучите классификатор только на чистых картинках. Он понадобится нам для классификации условно сгенерированных картинок

In [None]:
wandb.init(project='sde', name='clear_classifier')
model.train()

for iter_idx in trange(1, 1 + TOTAL_ITERS):
    X, y = next(train_generator)
    X, y = X.to(device), y.to(device)
    preds = model(X)
    loss = loss_func(preds, y)
    loss.backward()
    optim.step()
    optim.zero_grad()
    acc = (preds.argmax(-1)==y).sum() / y.shape[0]

    wandb.log({'accuracy/train': acc}, step=iter_idx)
    wandb.log({'loss/train': loss.item()}, step=iter_idx)
    
    if iter_idx % EVAL_FREQ == 0:
        valid_loss = 0
        valid_accuracy = 0
        valid_count = 0
        model.eval()
        with torch.no_grad():
            for X, y in datagen.valid_loader:
                bs = y.shape[0]
                X, y = X.to(device), y.to(device)
                preds = model(X)
                valid_loss += loss_func(preds, y).item() * bs
                valid_accuracy += (preds.argmax(-1)==y).sum()
                valid_count += bs
        valid_accuracy = valid_accuracy / valid_count
        valid_loss /= valid_count
        model.train()
        print('Clean MNIST classifier\'s accuracy:', valid_accuracy)
        wandb.log({'accuracy/valid': valid_accuracy}, step=iter_idx)
        wandb.log({'loss/valid': valid_loss}, step=iter_idx)
model.eval()
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mdlishudi[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/2000 [00:00<?, ?it/s]

Clean MNIST classifier's accuracy: tensor(0.9899, device='cuda:0')
Clean MNIST classifier's accuracy: tensor(0.9910, device='cuda:0')
Clean MNIST classifier's accuracy: tensor(0.9929, device='cuda:0')
Clean MNIST classifier's accuracy: tensor(0.9891, device='cuda:0')


VBox(children=(Label(value='0.001 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.142710…

0,1
accuracy/train,▁▆▇▆█▇▇███▇███▇███▇██▇█████████████████▇
accuracy/valid,▂▅█▁
loss/train,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/valid,█▅▁█

0,1
accuracy/train,0.98828
accuracy/valid,0.9891
loss/train,0.02542
loss/valid,0.03494


In [None]:
torch.save(model.state_dict(), './ddpm_checkpoints/clean_classifier.pth')