In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import pandas as pd
from sklearn.model_selection import train_test_split

from models import CNN, RNN
from tools import LABEL_ENCODER, transforms_train, transforms_test
from dataset import CnnDataset
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

import matplotlib.pyplot as plt


In [None]:
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
torch.backends.cudnn.benchmark = True
print("Using {}: {}".format(device, torch.cuda.get_device_name(0)))

In [None]:
IMAGES_DIR = 'input/images/'
TRAIN_CSV = 'input/split/train.csv'
VALID_CSV = 'input/split/valid.csv'
BATCH_SIZE = 256




In [None]:
def get_predictions(model, iterator, device):

    model.eval()

    images = []
    labels = []
    probs = []

    with torch.no_grad():

        for (x, y) in iterator:

            x = x.to(device)

            y_pred = model(x)

            y_prob = F.softmax(y_pred, dim = -1)
            top_pred = y_prob.argmax(1, keepdim = True)

            images.append(x.cpu())
            labels.append(y.cpu())
            probs.append(y_prob.cpu())

    images = torch.cat(images, dim = 0)
    labels = torch.cat(labels, dim = 0)
    probs = torch.cat(probs, dim = 0)

    return images, labels, probs

def plot_confusion_matrix(labels, pred_labels, classes):
    
    fig = plt.figure(figsize = (15, 15))
    ax = fig.add_subplot(1, 1, 1)
    cm = confusion_matrix(labels, pred_labels)


    for i in range(27):
        total = sum(cm[i])

        for j in range(27):
            cm[i][j] = (cm[i][j] / total) * 100
        # if cm[i][i] > 15:
        #     cm[i][i] = 15



    cm = ConfusionMatrixDisplay(cm, display_labels = classes)
    cm.plot(values_format = 'd', cmap = 'Blues', ax = ax)
    plt.xticks(rotation = 20)

# Load CNN

In [None]:
cnn_model = CNN()
cnn_model.load_state_dict(torch.load('./output/cnn-model.pt'))
cnn_model.cuda()

def create_cnn_iterator(csv_path):
    df = pd.read_csv(csv_path)

    X = list(df['id'])
    y = list(df['category'])

    X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=11)

    y_train, y_valid = LABEL_ENCODER.transform(y_train), LABEL_ENCODER.transform(y_valid)


    train_dataset = CnnDataset(transforms_train, X_train, y_train)
    test_dataset = CnnDataset(transforms_test, X_valid, y_valid)

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
    return train_dataloader, test_dataloader

_, cnn_test_dataloader = create_cnn_iterator("input/train.csv")

images, labels, probs = get_predictions(cnn_model, cnn_test_dataloader, device)




In [None]:
pred_labels = torch.argmax(probs, 1)

plot_confusion_matrix(labels, pred_labels, LABEL_ENCODER.classes_) 

# Load RNN