In [1]:
import torch
import numpy as np
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm
from datasets import load_cifar10_choosen
from attack import attack, test_model
device = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
train_dataloader, test_dataloader, test_dataloader_all = load_cifar10_choosen(
    choosen_classes=[0, 1, 2, 3])


Files already downloaded and verified
Files already downloaded and verified


In [3]:
def train_choosen_classes(train_dataloader, test_dataloader, epochs=10, lr=0.001, model=None):
    if model is None:
        model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        model.fc = nn.Linear(2048, 10)
        model = model.to(device)
    loss_func = nn.CrossEntropyLoss()
    best_acc = -np.inf
    for epoch in range(epochs):
        model.train()
        all_preds = []
        all_labels = []
        for x, y in tqdm(train_dataloader):
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = loss_func(output, y)
            loss.backward()
            for param in model.parameters():
                param.data -= lr * param.grad
                param.grad.zero_()
        model.eval()
        train_correct, train_total = test_model(model, train_dataloader)
        test_correct, test_total = test_model(model, test_dataloader)
        print(f"Epoch {epoch + 1} train acc: {train_correct / train_total}")
        print(f"Epoch {epoch + 1} test acc: {test_correct / test_total}")
        if test_correct / test_total > best_acc:
            best_acc = test_correct / test_total
            torch.save(model.state_dict(), "weights/best_model.pth")


In [4]:
train_choosen_classes(train_dataloader, test_dataloader, epochs=10, lr=0.01)


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 1 train acc: 0.71515
Epoch 1 test acc: 0.6795


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 2 train acc: 0.8259
Epoch 2 test acc: 0.77675


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 3 train acc: 0.88675
Epoch 3 test acc: 0.824


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 4 train acc: 0.90675
Epoch 4 test acc: 0.83975


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 5 train acc: 0.9368
Epoch 5 test acc: 0.8505


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 6 train acc: 0.96075
Epoch 6 test acc: 0.85925


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 7 train acc: 0.96165
Epoch 7 test acc: 0.8605


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 8 train acc: 0.95555
Epoch 8 test acc: 0.8675


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 9 train acc: 0.9828
Epoch 9 test acc: 0.87275


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 10 train acc: 0.981
Epoch 10 test acc: 0.87225


In [5]:
def load_model():
    model = resnet50(num_classes=10)
    model.load_state_dict(torch.load("weights/best_model.pth"))
    model = model.to(device)
    model.eval()
    return model


In [6]:
model = load_model()
correct, all = test_model(model, test_dataloader_all)
print(f"test acc: {correct / all}")


test acc: 0.3491
