In [None]:
# DATASET download link
# https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

# import tarfile

# with tarfile.open('/path/to/cifar10.tgz', 'r:gz') as tar:
#     tar.extractall(path="./data/cifar10")

In [None]:

from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor

train_ds = ImageFolder("./data/cifar10/train",transform=ToTensor())
test_ds = ImageFolder("./data/cifar10/test", transform=ToTensor())

img_cls = train_ds.classes

In [None]:
from torch.utils.data.dataloader import DataLoader

train_dl = DataLoader(train_ds,100,shuffle=1)
test_dl = DataLoader(test_ds,100)

In [None]:
import torch.nn as nn

class CifarModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            # 3 32 32
            nn.Conv2d(3,16,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            # 16 16 16
            nn.Conv2d(16,32,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            # 32 8 8
            nn.Conv2d(32,64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            #64 4 4
            nn.Flatten(),
            #64*4*4
            nn.Linear(64*4*4, 128),
            nn.ReLU(),
            #128
            nn.Linear(128,10)
            #10
        )
    
    def forward(self,data):
        out = self.network(data)
        out = nn.Softmax(out, dim=-1)
        
        return out
    



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = CifarModel().to(device)
loss_fn = nn.BCELoss()
opt = torch.optim.SGD(model.parameters(), lr = 0.001)

In [None]:

import numpy as np
import torch

def fit(epochs):
    
    for epoch in range(epochs):
        print("Epoch: ",epoch)
        
        for images, labels in train_dl:
            
            targets = []
            for label in labels:
                x = np.zeros(len(img_cls))
                
                x[int(label)] = 1
                targets.append(x.astype(np.float32))
            
            targets = torch.Tensor(targets)

            preds = model(images.to(device))
            
            loss = loss_fn(preds.to(device), targets.to(device))
            loss.backward()  
            
            opt.step()
            opt.zero_grad()
            
        print("Loss: ",round(loss.item(), 4))
        
        with torch.no_grad():
            correct = 0
            for images,labels in test_dl:
                preds = model(images.to(device))
        
                for i in range(len(labels)):
                    if (round(preds[i][labels[i]].item()) == 1):
                        correct+=1
      
        acc = correct/len(test_ds)
        print("Accuracy: ",round(acc*100,2))               

In [None]:

from PIL import Image
import matplotlib.pyplot as plt

def predict(path):
    img = Image.open(path)
    
    if img.size != (32,32):
        img.resize((32,32))
    
    img_arr = np.asarray(img)
    plt.imshow(img_arr)
    
    img_arr = img_arr/255
    img_tsr = torch.Tensor([img_arr])
    img_tsr = img_tsr.permute(0,3,2,1)
    
    pred = model(img_tsr)
    pred = np.array(pred[0])
    pred_index = np.where(pred==max(pred))[0][0]
    
    print("Prediction: ",img_cls[pred_index])
    
    