In [1]:
import torch
from torch.utils.data import DataLoader
torch.multiprocessing.set_start_method('spawn', force=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import torchvision
from torchvision import datasets, transforms

import random
import numpy as np
import os

from models.ResNet import ResNet34

In [2]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

train_transform = [
    transforms.ToTensor()
]
test_transform = [
    transforms.ToTensor()
]
train_transform = transforms.Compose(train_transform)
test_transform = transforms.Compose(test_transform)

clean_train_dataset = datasets.CIFAR10(root='../datasets', train=True, download=True, transform=train_transform)
clean_test_dataset = datasets.CIFAR10(root='../datasets', train=False, download=True, transform=test_transform)

clean_train_loader = DataLoader(dataset=clean_train_dataset, batch_size=512,
                                shuffle=False, pin_memory=True,
                                drop_last=False, num_workers=0)
clean_test_loader = DataLoader(dataset=clean_test_dataset, batch_size=512,
                                shuffle=False, pin_memory=True,
                                drop_last=False, num_workers=0)

_CLASS_ = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

Files already downloaded and verified
Files already downloaded and verified


In [3]:
from tqdm import tqdm

model = ResNet34()
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=0)

In [4]:
epoch = 0
condition = True
while condition:
    model.train(True)
    pbar = tqdm(clean_train_loader, total=len(clean_train_loader))
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        images = images
        model.zero_grad()
        optimizer.zero_grad()
        logits = model(images)
        print(logits)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        
        _, predicted = torch.max(logits.data, 1)
        acc = (predicted == labels).sum().item()/labels.size(0)
        pbar.set_description("Acc %.2f Loss: %.2f" % (acc*100, loss))
    scheduler.step()
    # optimizer.step()

    # Eval
    model.eval()
    model.train(False)
    correct, total = 0, 0
    clear_label_correct = 0
    for i, (images, labels) in enumerate(clean_test_loader):
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            logits = model(images)
            _, predicted = torch.max(logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = correct / total
    tqdm.write('Epoch %d Clean Accuracy %.2f\n' % (epoch, acc*100))
    epoch += 1
    
    if epoch > 20 or acc > 0.8:
        condition = False
        

Acc 46.73 Loss: 1.55: 100%|██████████| 98/98 [00:34<00:00,  2.83it/s]


Epoch 0 Clean Accuracy 38.82



Acc 55.36 Loss: 1.30: 100%|██████████| 98/98 [00:29<00:00,  3.33it/s]


Epoch 1 Clean Accuracy 38.18



Acc 72.92 Loss: 0.72: 100%|██████████| 98/98 [00:29<00:00,  3.33it/s]


Epoch 2 Clean Accuracy 50.52



Acc 86.01 Loss: 0.43: 100%|██████████| 98/98 [00:29<00:00,  3.29it/s]


Epoch 3 Clean Accuracy 65.69



Acc 92.56 Loss: 0.23: 100%|██████████| 98/98 [00:29<00:00,  3.34it/s]


Epoch 4 Clean Accuracy 68.11



Acc 95.24 Loss: 0.17: 100%|██████████| 98/98 [00:29<00:00,  3.33it/s]


Epoch 5 Clean Accuracy 71.19



Acc 94.64 Loss: 0.14: 100%|██████████| 98/98 [00:29<00:00,  3.28it/s]


Epoch 6 Clean Accuracy 68.59



Acc 98.81 Loss: 0.05: 100%|██████████| 98/98 [00:29<00:00,  3.29it/s]


Epoch 7 Clean Accuracy 73.78



Acc 99.40 Loss: 0.03: 100%|██████████| 98/98 [00:29<00:00,  3.29it/s]


Epoch 8 Clean Accuracy 76.75



Acc 99.11 Loss: 0.04: 100%|██████████| 98/98 [00:29<00:00,  3.31it/s]


Epoch 9 Clean Accuracy 73.95



Acc 100.00 Loss: 0.00: 100%|██████████| 98/98 [00:29<00:00,  3.31it/s]


Epoch 10 Clean Accuracy 78.40



Acc 100.00 Loss: 0.00: 100%|██████████| 98/98 [00:29<00:00,  3.28it/s]


Epoch 11 Clean Accuracy 80.49



In [17]:
clean_target = torch.tensor(clean_test_dataset.data[886].astype(np.float32)/255).to(device)
test_image = clean_target.permute([2, 0, 1]).unsqueeze(0)
print(test_image.shape)
logits = model(test_image)
_, pred = torch.max(logits, 1)
print("Original:", _CLASS_[0], "Intended: ", _CLASS_[3], "Prediction:", _CLASS_[int(pred)])
print(logits)

torch.Size([1, 3, 32, 32])
Original: airplane Intended:  cat Prediction: airplane
tensor([[ 7.6116, -0.1931, -2.8654,  0.6235, -4.9685, -5.0432, -0.6924, -2.8483,
          4.5795,  3.8750]], device='cuda:0', grad_fn=<AddmmBackward0>)
