In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [4]:
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:02<00:00, 80921827.11it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
bicycle_class = 19
motorcycle_class = 22
background_classes = [i for i in range(100) if i not in [bicycle_class, motorcycle_class]]

def transform_labels(y):
    if y == bicycle_class:
        return 0  # bicycle_class
    elif y == motorcycle_class:
        return 1  # motorcycle_class
    else:
        return 2  # background

In [6]:
trainset.targets = [transform_labels(y) for y in trainset.targets if y in [bicycle_class, motorcycle_class] + background_classes]
trainset.data = trainset.data[[i for i in range(len(trainset.targets)) if trainset.targets[i] in [0, 1, 2]]]

testset.targets = [transform_labels(y) for y in testset.targets if y in [bicycle_class, motorcycle_class] + background_classes]
testset.data = testset.data[[i for i in range(len(testset.targets)) if testset.targets[i] in [0, 1, 2]]]

In [7]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)

In [8]:
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 3)
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 169MB/s] 


In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10

In [10]:
def train(num_epochs, trainloader, model, optimizer, criterion):
    
    for epoch in tqdm(range(num_epochs)):
        model.train()
        running_loss = 0.0
        
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
    
        print(f'Epochs {epoch + 1}, Loss: {running_loss / len(trainloader)}')

In [11]:
train(num_epochs, trainloader, model, optimizer, criterion)

 10%|█         | 1/10 [00:36<05:25, 36.22s/it]

Epochs 1, Loss: 0.11977897435742516


 20%|██        | 2/10 [01:11<04:43, 35.49s/it]

Epochs 2, Loss: 0.12131107340299274


 30%|███       | 3/10 [01:46<04:08, 35.44s/it]

Epochs 3, Loss: 0.11733444864007805


 40%|████      | 4/10 [02:22<03:33, 35.56s/it]

Epochs 4, Loss: 0.1189348922553219


 50%|█████     | 5/10 [02:57<02:57, 35.56s/it]

Epochs 5, Loss: 0.1006656582413631


 60%|██████    | 6/10 [03:33<02:22, 35.57s/it]

Epochs 6, Loss: 0.09627858376975384


 70%|███████   | 7/10 [04:09<01:46, 35.60s/it]

Epochs 7, Loss: 0.08814508800817294


 80%|████████  | 8/10 [04:44<01:11, 35.61s/it]

Epochs 8, Loss: 0.09271421253511959


 90%|█████████ | 9/10 [05:20<00:35, 35.59s/it]

Epochs 9, Loss: 0.08845755943292019


100%|██████████| 10/10 [05:55<00:00, 35.60s/it]

Epochs 10, Loss: 0.08569533542252633





In [15]:
def valid(testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f'Accuracy on test set: {100 * correct / total}%')

In [16]:
valid(testloader)

Accuracy on test set: 97.9%
