<a href="https://colab.research.google.com/github/AoShuang92/PhD_tutorial/blob/main/all_cnn_cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# All CNN for CIFAR10
src:https://github.com/huyvnphan/PyTorch_CIFAR10/tree/641cac24371b17052b9bb6e56af1c83b5e97cd7f <br>
Download Pretrained Weights and Architectures Repository

In [None]:
import gdown
url = 'https://drive.google.com/uc?id=17fmN8eQdLpq2jIMQ_X0IXDPXfI9oVWgq'
gdown.download(url,'state_dicts.zip',quiet=True)
!unzip -q state_dicts.zip

In [None]:
!git clone https://github.com/huyvnphan/PyTorch_CIFAR10.git

Cloning into 'PyTorch_CIFAR10'...
remote: Enumerating objects: 690, done.[K
remote: Counting objects: 100% (66/66), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 690 (delta 52), reused 36 (delta 36), pack-reused 624[K
Receiving objects: 100% (690/690), 6.58 MiB | 18.66 MiB/s, done.
Resolving deltas: 100% (269/269), done.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import models
import torchvision.transforms as transforms
# from PyTorch_CIFAR10.cifar10_models.densenet import densenet121
from PyTorch_CIFAR10.cifar10_models.densenet import densenet121, densenet161, densenet169
from PyTorch_CIFAR10.cifar10_models.googlenet import googlenet
from PyTorch_CIFAR10.cifar10_models.inception import inception_v3
from PyTorch_CIFAR10.cifar10_models.mobilenetv2 import mobilenet_v2
from PyTorch_CIFAR10.cifar10_models.resnet import resnet18, resnet34, resnet50
from PyTorch_CIFAR10.cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn



device = 'cuda' if torch.cuda.is_available() else 'cpu'

all_classifiers = {
    "vgg11_bn": vgg11_bn(),
    "vgg13_bn": vgg13_bn(),
    "vgg16_bn": vgg16_bn(),
    "vgg19_bn": vgg19_bn(),
    "resnet18": resnet18(),
    "resnet34": resnet34(),
    "resnet50": resnet50(),
    "densenet121": densenet121(),
    "densenet161": densenet161(),
    "densenet169": densenet169(),
    "mobilenet_v2": mobilenet_v2(),
    "googlenet": googlenet(),
    "inception_v3": inception_v3(),
}

def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return correct / total

def main(arch= 'densenet121', num_classes=10, ckpt=None, test_loader=None):
    # model = densenet121(num_classes=num_classes)
    model = all_classifiers[arch]
    model.load_state_dict(torch.load(ckpt))
    model.to(device)
    acc = test(model, test_loader)
    return acc

mean_cifar, std_cifar = (0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)
transform_test = transforms.Compose([transforms.ToTensor(),
    transforms.Normalize(mean_cifar, std_cifar),])

test_dataset10 = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
test_loader10 = torch.utils.data.DataLoader(test_dataset10, batch_size=2048, shuffle=False, num_workers=2)

arch= 'densenet121'
acc = main(arch= arch, num_classes=10, ckpt='state_dicts/densenet121.pt', test_loader=test_loader10)
print(arch,':',acc)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:07<00:00, 21782221.42it/s]


Extracting data/cifar-10-python.tar.gz to data
densenet121 : 0.9406


In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
pytorch_total_params = sum(p.numel() for p in model.parameters())