In [240]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from datasets import load_dataset
import matplotlib.pyplot as plt

In [241]:
# Load the dataset
dataset = load_dataset("cifar10")

In [303]:
# prepare dataset
class CatsDataset(Dataset):
    def __init__(self, train=True):
        if train:
            temp_dataset = load_dataset("cifar10")['train']
        else:
            temp_dataset = load_dataset("cifar10")['test']
        
        self.dataset = []
        for i in range(len(temp_dataset)):
            if temp_dataset[i]['label'] == 3:
                img = torchvision.transforms.PILToTensor()(temp_dataset[i]['img']) / 255.0
                self.dataset.append([img, torch.tensor([0])])

            elif torch.randint(0, 90, (1, )) < 10:
                img = torchvision.transforms.PILToTensor()(temp_dataset[i]['img']) / 255.0
                self.dataset.append([img, torch.tensor([1])])
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        return (image, label)

In [304]:
train_dataset = CatsDataset()

In [305]:
dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=64,
    shuffle=True
)

In [306]:
model = nn.Sequential(
    nn.Conv2d(3, 16, 5),
    nn.MaxPool2d((5, 5)),
    nn.LeakyReLU(),
    nn.Flatten(),
    nn.Linear(400, 64),
    nn.LeakyReLU(),
    nn.Linear(64, 2)
)

In [307]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [308]:
for epoch in range(40):
    s = 0
    for i, data in enumerate(dataloader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)
        labels = labels.view(-1)
        # Compute the loss and its gradients
        loss = criterion(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        s += loss.item()
    
    print(f"Epoch {epoch+1} Loss: {s/len(dataloader)}")

torch.save(model.state_dict(), "model.pt")

Epoch 1 Loss: 0.6152230309452981
Epoch 2 Loss: 0.5393341550022174
Epoch 3 Loss: 0.5180603446094854
Epoch 4 Loss: 0.5040903991195047
Epoch 5 Loss: 0.4942670022226443
Epoch 6 Loss: 0.48011620276293177
Epoch 7 Loss: 0.46641658350920223
Epoch 8 Loss: 0.4589494303533226
Epoch 9 Loss: 0.4604894281572597
Epoch 10 Loss: 0.4480480821269333
Epoch 11 Loss: 0.43683091622249337
Epoch 12 Loss: 0.430411545524172
Epoch 13 Loss: 0.42318472645844624
Epoch 14 Loss: 0.4164942047398561
Epoch 15 Loss: 0.40561415871996787
Epoch 16 Loss: 0.39270503183079375
Epoch 17 Loss: 0.3869354738171693
Epoch 18 Loss: 0.3814699295789573
Epoch 19 Loss: 0.3723717454322584
Epoch 20 Loss: 0.36314954309706476
Epoch 21 Loss: 0.3544482316370982
Epoch 22 Loss: 0.3511567739354577
Epoch 23 Loss: 0.34058864775356973
Epoch 24 Loss: 0.3296292709886648
Epoch 25 Loss: 0.3203633677238112
Epoch 26 Loss: 0.3105700579798146
Epoch 27 Loss: 0.30285927852627575
Epoch 28 Loss: 0.2879538503801747
Epoch 29 Loss: 0.2892795727131473
Epoch 30 Loss: 

In [309]:
test_dataset = CatsDataset(train=False)

In [310]:
dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=64,
    shuffle=True
)

In [311]:
good = 0
a = 0
for i, data in enumerate(dataloader):
    # Every data instance is an input + label pair
    inputs, labels = data
    outputs = model(inputs)
    labels = labels.view(-1)
    
    for i in range(len(outputs)):
        good += torch.argmax(outputs[i]) == labels[i]
        a += 1

print(good / a)

tensor(0.9447)
