In [31]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torchvision
from datasets import load_dataset
import matplotlib.pyplot as plt

In [60]:
# prepare dataset
class CatsDataset(Dataset):
    def __init__(self, train=True):
        if train:
            temp_dataset = load_dataset("cifar10")['train']
        else:
            temp_dataset = load_dataset("cifar10")['test']
        
        main_dataset = []
        cats = []
        for i in range(len(temp_dataset)):
            # store cats
            if temp_dataset[i]['label'] == 3:
                img = torchvision.transforms.PILToTensor()(temp_dataset[i]['img']).expand((3, -1, -1)) / 255.0
                cats.append([img, torch.tensor([0])])

            # add other images to dataset
            else:
                img = torchvision.transforms.PILToTensor()(temp_dataset[i]['img']).expand((3, -1, -1)) / 255.0
                main_dataset.append([img, torch.tensor([1])])

        # Repeat dataset2 nine times
        repeated_datasets = [cats] * 9

        # Concatenate dataset1 with repeated dataset2
        self.dataset = ConcatDataset([main_dataset] + repeated_datasets)

    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        image = torchvision.transforms.RandomHorizontalFlip()(image)
        image = torchvision.transforms.RandomResizedCrop((32, 32), scale=(4/5, 5/4), ratio=(4/5, 5/4))(image)
        return (image, label)

In [61]:
train_dataset = CatsDataset(train=True)
len(train_dataset)

90000

In [62]:
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=256,
    shuffle=True
)

In [63]:
dev = torch.device("cpu")
if torch.cuda.is_available():
    dev = torch.device("cuda")
elif torch.backends.mps.is_available():
    dev = torch.device("mps")

print(f"Using device: {dev}")

Using device: mps


In [77]:
model = nn.Sequential(
    nn.Conv2d(3, 64, 3),
    nn.MaxPool2d((3, 3)),
    nn.LeakyReLU(),
    nn.Conv2d(64, 64, 3),
    nn.MaxPool2d((3, 3)),
    nn.LeakyReLU(),
    nn.Flatten(),
    nn.Linear(256, 64),
    nn.LeakyReLU(),
    nn.Linear(64, 2)
).to(dev)

In [78]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [81]:
for epoch in range(32):
    s = 0
    for i, data in enumerate(train_dataloader):
        # Every data instance is an input + label pair
        inputs, labels = data
        inputs = inputs.to(dev)
        labels = labels.to(dev)
        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)
        labels = labels.view(-1)
        # Compute the loss and its gradients
        loss = criterion(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        s += loss.item()
    
    print(f"Epoch {epoch+1} Loss: {s/len(train_dataloader)}")

torch.save(model.to(dev).state_dict(), "model.pt")

Epoch 1 Loss: 0.5127322578972037
Epoch 2 Loss: 0.5030352201482112
Epoch 3 Loss: 0.4938331706110727
Epoch 4 Loss: 0.48747569297186355
Epoch 5 Loss: 0.4799483444711024
Epoch 6 Loss: 0.47321122055026615
Epoch 7 Loss: 0.4705432698299939
Epoch 8 Loss: 0.46476283915002237
Epoch 9 Loss: 0.4587418803606521
Epoch 10 Loss: 0.45601269975304604
Epoch 11 Loss: 0.45132567250931804
Epoch 12 Loss: 0.4479885161430998
Epoch 13 Loss: 0.4442358175292611
Epoch 14 Loss: 0.43917044082825835
Epoch 15 Loss: 0.4390289630232887
Epoch 16 Loss: 0.43416007938371465
Epoch 17 Loss: 0.4320536175404083
Epoch 18 Loss: 0.4281757626343857
Epoch 19 Loss: 0.4269535530527884
Epoch 20 Loss: 0.42138694590804254
Epoch 21 Loss: 0.42259806572374975
Epoch 22 Loss: 0.4194783576002175
Epoch 23 Loss: 0.41586094645952637
Epoch 24 Loss: 0.41359927191991697
Epoch 25 Loss: 0.4129187828776511
Epoch 26 Loss: 0.40978814161975274
Epoch 27 Loss: 0.40799213522537187
Epoch 28 Loss: 0.40723997269841755
Epoch 29 Loss: 0.40505051375790074
Epoch 30

In [67]:
test_dataset = CatsDataset(train=False)

In [68]:
test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=256,
    shuffle=True
)

In [82]:
good = 0
a = 0
for i, data in enumerate(test_dataloader):
    # Every data instance is an input + label pair
    inputs, labels = data
    inputs = inputs.to(dev)
    labels = labels.to(dev)
    outputs = model(inputs)
    labels = labels.view(-1)
    
    for i in range(len(outputs)):
        good += torch.argmax(outputs[i]) == labels[i]
        a += 1

print(good / a)

tensor(0.7975, device='mps:0')
