In [None]:
from dataset import *
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import models
from torchvision.models import ResNet50_Weights
from torchvision import transforms
import torchvision
import dataset
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import f1_score
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score

import random
import seaborn as sns

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
random.seed(0)
np.random.seed(0)

In [None]:
TRANSFORM = transforms.Compose([
        transforms.Resize((232, 232)),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

test_images = COVIDDataset(txt_file='dataset//encoded_test.txt', root_dir='dataset//test', transform=TRANSFORM)

#test_truncated_idxs = random.sample(range(0, len(test_images)), 500)

#test_truncated = Subset(test_images, test_truncated_idxs)
    
test_loader = DataLoader(test_images, batch_size=128, shuffle=True)




model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
num_ftrs = model.fc.in_features
# Change the final layer of ResNet50 Model for binary classification
#model.fc = nn.Linear(num_ftrs, 2) 
model.fc = nn.Sequential(
      nn.Linear(num_ftrs, 512),
      nn.ReLU(),
      nn.Dropout(0.25),
      nn.Linear(512, 2))



#model.load_state_dict(torch.load('.\\10KN_10KP\\1E4LR_BS128\\LINEAR_RELU_NODROPOUT\\covid_model_178_RESNET_LINEAR_RELU_BS128_1E4LR_NODROPOUT_acc.pt'))
model.load_state_dict(torch.load('.\\INSERTMODELNAME.pt'))
model.to(DEVICE)
model.eval()
    # Test the model

correct = 0
total = 0
all_pred = []
all_labels = []
with torch.no_grad():
    for data, labels in test_loader:
        data, labels = data.to(DEVICE), labels.to(DEVICE)
        outputs = model(data)
        predicted = torch.argmax(outputs, 1)
        #predicted = torch.round(outputs.cpu().detach())
        total += labels.size(0)
        correct += (predicted.cpu() == labels.cpu()).sum().item()
        all_pred.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())


acc = accuracy_score(all_labels, all_pred)*100
cm = confusion_matrix(all_labels, all_pred, labels=[0, 1])
prec = precision_score(all_labels, all_pred)
rec = recall_score(all_labels, all_pred)

#total values in confusion matrix
total = np.sum(cm)
print(total)

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['negative', 'positive'])
disp.plot()
#plt.title('Default ResNet50 Model')
plt.xlabel(f'Predicted Label\nAccuracy: {acc:.3f}%')
plt.tight_layout()
#plt.savefig('.//figs//confusion_matrix_5KP_10KN.png')