In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import torchvision
import torchvision.transforms as transforms
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from datetime import datetime

In [None]:
train_transformer = torchvision.transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    # transforms.ColorJitter(
    #     brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2
    # ),
    transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomRotation(degrees=15),
    transforms.RandomAffine(0, translate=(0.1, 0.1)),
    # transforms.RandomPerspective(),
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.CIFAR10(
    root=".",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

test_dataset = torchvision.datasets.CIFAR10(
    root=".",
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

In [None]:
train_dataset.data

In [None]:
train_dataset.data.shape

In [None]:
train_dataset.targets

In [None]:
k = len(set(train_dataset.targets))
print("number of classes: ", k)

In [None]:
batch_size = 128
train_loader = torch.utils.data.Dataloader(
    dataset=train_dataset,
    batch=batch_size,
    shuffle=True
)

test_loader = torch.utils.data.Dataloader(
    dataset=test_dataset,
    batch=batch_size,
    shuffle=False
)

In [None]:
# the data transformer mapped the data to (0, 1)
# and also moved the color channel before height/width
tmp_loader = torch.utils.data.Dataloader(
    dataset=train_dataset, batch_size=1, shuffle=True
)

for x, y in tmp_loader:
    print(x)
    print(x.shape)
    print(y)
    break

In [None]:
class CNN(nn.Module):
    def __init__(self, k):
        super(CNN, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
        )
        
        self.fc1 = nn.Linear(128 * 4 * 4, 1024)
        self.fc2 = nn.Linear(1024, k)
        
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.conv3(x)
            
            x = x.view(x.size(0), -1)
            
            x = f.dropout(x, p = 0.5)
            x = f.relu(self.fc1(x))
            x = f.dropout(x, p = 0.2)
            x = self.fc2(x)      
            
            return x

In [None]:
model = CNN(k)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
def batch_gd(model, criterion, optimizer, train_loader, test_loader, epochs):
    train_losses = []
    test_losses = []
    
    for i in range(epochs):
        model.train()
        t0 = datetime.now()
        train_loss = []
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            
            output = model(inputs)
            loss = criterion(output, targets)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
            
        train_loss = np.mean(train_loss)
        
        model.eval()
        test_loss = []
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            output = model(inputs)
            loss = criterion(output, targets)
            test_loss.append(loss.item())
        
        test_loss = np.mean(test_loss)
        
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        
        dt = datetime.now() - t0
        
        print(f"epoch {i + 1}/{epochs}, train_loss: {train_loss:.4f}, test_loss: {test_loss:.4f}, duration: {dt}")
        
    return train_losses, test_losses

In [None]:
train_losses, test_losses = batch_gd(model, criterion, optimizer, train_loader, test_loader, epochs=80)

In [None]:
plt.plot(train_losses, label="train_loss")
plt.plot(test_losses, label="test_loss")
plt.legend()
plt.show()

In [None]:
model.eval()

def calculate_acc(data_loader):
    n_correct = 0
    n_total = 0

    for inputs, targets in data_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        outputs = model(inputs)
        _, predictions = torch.max(outputs, 1)
        
        n_correct += (predictions == targets).sum().item()
        n_total += targets.shape[0]
        
    acc = n_correct / n_total
    return acc

In [None]:
def plot_confusion_matrix_sns(cm,
                          labels,  
                          normalized=False, 
                          title="confusion_matrix", 
                          cmap=plt.cm.Blues):
    if normalized:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("normalized confusion matrix")
    else:
        print("confusion_matrix without normalization")
        
    print(cm)
    
    fmt = '.2f' if normalized else 'd'
    
    sns.heatmap(cm, cmap=cmap, annot=True, fmt=fmt, xticklabels=labels, yticklabels=labels)
    plt.title(title)
    plt.tight_layout()
    # plt.xticks(classes, labels, rotation=90)
    # plt.yticks(classes, labels, rotation=0)
    plt.ylabel("True label")
    plt.xlabel("predicted label")
    plt.show()

In [None]:
labels = '''airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck'''.split("\n")

In [None]:
x_test = test_dataset.data.numpy()
y_test = test_dataset.targets.numpy()
p_test = np.array([])

for inputs, targets in test_loader:
    inputs = inputs.to(device)
    
    outputs = model(inputs)
    
    _, predictions = torch.max(outputs, 1)
    
    p_test = np.concatenate((p_test, predictions.cpu().numpy()))
    
# classes = np.sort(np.unique(y_test))
cm = confusion_matrix(y_test, p_test, labels=np.sort(np.unique(y_test)))
plot_confusion_matrix_sns(cm, labels)

In [None]:
misclassified_index = np.where(p_test != y_test)[0]
i = np.random.choice(misclassified_index)
plt.imshow(x_test[i].reshape(28, 28), cmap="gray")
plt.title(f"true label: {labels[y_test[i]]}, predicted: {labels[int(p_test[i])]}")
plt.show()

In [None]:
summary(model, (3, 32, 32))