# Classifier for scoring

## Inputs

In [1]:
import torch
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from diffusion_score import *

## Read in datasets

In [2]:
transforms = Compose([
    Resize(32),
    ToTensor(),
    Lambda(lambda x:2*x-1)
])

train_dataset = FashionMNIST("fashion_mnist",  train=True, download=True, transform=transforms)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)

test_dataset = FashionMNIST("fashion_mnist",  train=False, download=True, transform=transforms)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)

## Train and save classifier

In [3]:
classifier = fashionMNIST_CNN(img_size=32, n_layers=4, start_channels=32)
optimizer = torch.optim.Adam(classifier.parameters(), lr=3e-4)
classifier = train_classifier(train_dataloader, test_dataloader, classifier, optimizer, epochs=100, device='cuda', verbose=1, early_stop_patience=10)

torch.save(classifier.state_dict(), 'fashion_mnist_classifier.pt')

Start training 1/100


  return F.conv2d(input, weight, bias, self.stride,
100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 202.58it/s]


Epoch 1/100 - training loss: 0.3756 - Accuracy: 0.8937000036239624 - 0m 2.690687s
Start training 2/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 219.32it/s]


Epoch 2/100 - training loss: 0.2348 - Accuracy: 0.8963000178337097 - 0m 2.496536s
Start training 3/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 226.52it/s]


Epoch 3/100 - training loss: 0.1868 - Accuracy: 0.9124000072479248 - 0m 2.409895s
Start training 4/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 212.88it/s]


Epoch 4/100 - training loss: 0.1473 - Accuracy: 0.9117000102996826 - 0m 2.573133s
Start training 5/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 211.76it/s]


Epoch 5/100 - training loss: 0.1176 - Accuracy: 0.9120000004768372 - 0m 2.565457s
Start training 6/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 215.45it/s]


Epoch 6/100 - training loss: 0.0894 - Accuracy: 0.9153000116348267 - 0m 2.537372s
Start training 7/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 214.47it/s]


Epoch 7/100 - training loss: 0.0659 - Accuracy: 0.9154000282287598 - 0m 2.565667s
Start training 8/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 194.50it/s]


Epoch 8/100 - training loss: 0.0493 - Accuracy: 0.916700005531311 - 0m 2.752192s
Start training 9/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 222.99it/s]


Epoch 9/100 - training loss: 0.0340 - Accuracy: 0.9140999913215637 - 0m 2.446674s
Start training 10/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 220.39it/s]


Epoch 10/100 - training loss: 0.0246 - Accuracy: 0.9186999797821045 - 0m 2.495135s
Start training 11/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 215.86it/s]


Epoch 11/100 - training loss: 0.0201 - Accuracy: 0.9128000140190125 - 0m 2.515111s
Start training 12/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 220.11it/s]


Epoch 12/100 - training loss: 0.0170 - Accuracy: 0.9136000275611877 - 0m 2.486109s
Start training 13/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 214.09it/s]


Epoch 13/100 - training loss: 0.0144 - Accuracy: 0.9187999963760376 - 0m 2.547503s
Start training 14/100


100%|████████████████████████████████████████████████████████████████████████████████| 469/469 [00:02<00:00, 225.86it/s]


Epoch 14/100 - training loss: 0.0171 - Accuracy: 0.9136000275611877 - 0m 2.418829s


## Test the saved classifier

In [4]:
classifier = fashionMNIST_CNN(img_size=32, n_layers=4, start_channels=32)
classifier.load_state_dict(torch.load('fashion_mnist_classifier.pt'))

<All keys matched successfully>

In [5]:
pred, truth = evaluate_classifier(test_dataloader, classifier, 'cuda')
(torch.sum(torch.argmax(nn.functional.softmax(torch.cat(pred, dim=0), dim=-1), dim=-1)==torch.cat(truth, dim=0))/len(torch.cat(truth, dim=0))).item()

0.9124000072479248

## Calculate the scores on the test dataset

### Inception score

In [6]:
inception_score(test_dataloader, classifier, splits=1, contains_labels=True)

8.038381

### FID score

In [7]:
train = get_feature_vectors(train_dataloader, classifier, contains_labels=True, device='cuda')
test = get_feature_vectors(test_dataloader, classifier, contains_labels=True, device='cuda')
calculate_fid(train, test)

0.016007007368596185