<a href="https://colab.research.google.com/github/ArghyaPal/STIC/blob/main/Synthesize_It_Classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction
In this work, we show the generative capability of an
image classifier network by synthesizing high-resolution,
photo-realistic, and diverse images at scale. The overall
methodology, called Synthesize-It-Classifier (STIC), does
not require an explicit generator network to estimate the
density of the data distribution and sample images from
that, but instead uses the classifier’s knowledge of the
boundary to perform gradient ascent w.r.t. class logits and
then synthesizes images using Langevin Algorithm by drawing on a
blank canvas. During training, the classifier iteratively uses
these synthesized images as fake samples and re-estimates
the class boundary in a recurrent fashion to improve both
the classification accuracy and quality of synthetic images.
The STIC shows that mixing of the hard fake samples (i.e.
those synthesized by the one hot class conditioning), and the
soft fake samples (which are synthesized as a convex combination of classes, i.e. a mixup of classes) improves
class interpolation. In this basic version we will demonstrate the work on CIFAR 10. 

**Note**:
Please note that this basic version doesn't include (i) Gram Matrix corrected Langaving Dynamics, (ii) Attentive-STIC, (iii) Score-STIC. Enjoy the basic STIC!

In [None]:
import torch as t, torch.nn as nn, torch.nn.functional as tnnF, torch.distributions as tdist
from torch.utils.data import DataLoader, Dataset
import torchvision as tv, torchvision.transforms as tr
import os
import sys
import argparse
from stic import *
from dataset import *
from train import *
import sampling
import numpy as np

t.backends.cudnn.benchmark = True
t.backends.cudnn.enabled = True

# Sample from the Classifier

In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser("Synthesize It Classifier")
    parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "svhn"])
    parser.add_argument("--data_root", type=str, default="./data")
    # optimization
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--decay_epochs", nargs="+", type=int, default=[160, 180],
                        help="decay learning rate by decay_rate at these epochs")
    parser.add_argument("--decay_rate", type=float, default=.3,
                        help="learning rate decay multiplier")
    parser.add_argument("--clf_only", action="store_true", help="If set, then only train the classifier")
    parser.add_argument("--labels_per_class", type=int, default=-1,
                        help="number of labeled examples per class, if zero then use all labels")
    parser.add_argument("--optimizer", choices=["adam", "sgd"], default="adam")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--n_epochs", type=int, default=200)
    parser.add_argument("--warmup_iters", type=int, default=-1,
                        help="number of iters to linearly increase learning rate, if -1 then no warmmup")
    # loss weighting
    parser.add_argument("--p_x_weight", type=float, default=1.)
    parser.add_argument("--p_y_given_x_weight", type=float, default=1.)
    parser.add_argument("--p_x_y_weight", type=float, default=0.)
    # regularization
    parser.add_argument("--dropout_rate", type=float, default=0.0)
    parser.add_argument("--sigma", type=float, default=3e-2,
                        help="stddev of gaussian noise to add to input, .03 works but .1 is more stable")
    parser.add_argument("--weight_decay", type=float, default=0.0)
    # network
    parser.add_argument("--norm", type=str, default=None, choices=[None, "norm", "batch", "instance", "layer", "act"],
                        help="norm to add to weights, none works fine")
    # EBM specific
    parser.add_argument("--n_steps", type=int, default=20,
                        help="number of steps of SGLD per iteration, 100 works for short-run, 20 works for PCD")
    parser.add_argument("--width", type=int, default=10, help="WideResNet Width")
    parser.add_argument("--depth", type=int, default=28, help="WideResNet Depth")
    parser.add_argument("--uncond", action="store_true", help="Classifier Parameter")
    parser.add_argument("--class_cond_p_x_sample", action="store_true",
                        help="Sample from STIC")
    parser.add_argument("--buffer_size", type=int, default=10000)
    parser.add_argument("--reinit_freq", type=float, default=.05)
    parser.add_argument("--sgld_lr", type=float, default=1.0)
    parser.add_argument("--sgld_std", type=float, default=1e-2)
    # logging + evaluation
    parser.add_argument("--save_dir", type=str, default='./experiment')
    parser.add_argument("--ckpt_every", type=int, default=10, help="Epochs between checkpoint save")
    parser.add_argument("--eval_every", type=int, default=1, help="Epochs between evaluation")
    parser.add_argument("--print_every", type=int, default=100, help="Iterations between print")
    parser.add_argument("--load_path", type=str, default=None)
    parser.add_argument("--print_to_log", action="store_true", help="If true, directs std-out to log file")
    parser.add_argument("--plot_cond", action="store_true", help="If set, save class-conditional samples")
    parser.add_argument("--plot_uncond", action="store_true", help="If set, save unconditional samples")
    parser.add_argument("--n_valid", type=int, default=5000)

    args = parser.parse_args(args=[])
    args.n_classes = 10
    main(args)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
| Wide-Resnet 28x10


0it [00:00, ?it/s]

P(x) | 0:0 f(x_p_d)=   2.306620121 f(x_q)=   2.308447361 d=  -0.001827240
P(y|x) 0:0 loss=   2.285281181, acc=   0.187500000


100it [06:00,  3.60s/it]

P(x) | 0:100 f(x_p_d)=  -9.885334015 f(x_q)=  -8.246473312 d=  -1.638860703
P(y|x) 0:100 loss=   2.368155003, acc=   0.171875000


200it [12:00,  3.60s/it]

P(x) | 0:200 f(x_p_d)= -10.812567711 f(x_q)=  -8.611837387 d=  -2.200730324
P(y|x) 0:200 loss=   2.221562624, acc=   0.187500000


300it [18:00,  3.60s/it]

P(x) | 0:300 f(x_p_d)=  -3.785247326 f(x_q)=  -3.218774796 d=  -0.566472530
P(y|x) 0:300 loss=   2.140653133, acc=   0.140625000


400it [24:00,  3.60s/it]

P(x) | 0:400 f(x_p_d)=  -2.511241913 f(x_q)=  -2.373368263 d=  -0.137873650
P(y|x) 0:400 loss=   2.007229328, acc=   0.218750000


500it [29:59,  3.60s/it]

P(x) | 0:500 f(x_p_d)=  -0.438657910 f(x_q)=  -0.175983787 d=  -0.262674123
P(y|x) 0:500 loss=   1.996925473, acc=   0.281250000


600it [35:59,  3.60s/it]

P(x) | 0:600 f(x_p_d)=  -0.998685300 f(x_q)=  -0.986736059 d=  -0.011949241
P(y|x) 0:600 loss=   1.984986544, acc=   0.250000000


700it [41:59,  3.60s/it]

P(x) | 0:700 f(x_p_d)=   1.795664907 f(x_q)=   1.679900646 d=   0.115764260
P(y|x) 0:700 loss=   1.934644341, acc=   0.343750000


703it [42:10,  3.60s/it]


Epoch 0: Valid Loss 1.9492214918136597, Valid Acc 0.2718000113964081
Best Valid!: 0.2718000113964081
Epoch 0: Test Loss 1.9351226091384888, Test Acc 0.2702000141143799


97it [05:49,  3.60s/it]

P(x) | 1:97 f(x_p_d)=   0.620563030 f(x_q)=   0.660329223 d=  -0.039766192
P(y|x) 1:800 loss=   1.873410463, acc=   0.281250000


197it [11:49,  3.60s/it]

P(x) | 1:197 f(x_p_d)=   1.029497504 f(x_q)=   1.015585423 d=   0.013912082
P(y|x) 1:900 loss=   1.822833419, acc=   0.296875000


297it [17:49,  3.60s/it]

P(x) | 1:297 f(x_p_d)=   1.441894054 f(x_q)=   1.532715321 d=  -0.090821266
P(y|x) 1:1000 loss=   1.858728886, acc=   0.328125000


368it [22:04,  3.60s/it]

KeyboardInterrupt: ignored