In [19]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.utils
import torch
from torch import optim
import torch.nn.functional as F
from PIL import Image

from torch.autograd import Variable
from tqdm import tqdm 
import torch.nn as nn

In [20]:
# import imgshow and plot show
#impot class SiameseNetworkDataset 
from utils.plot_helpers import imshow, show_plot
import os

In [21]:
from torchvision.models import resnet50, ResNet50_Weights

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

In [22]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


## Preparing the data

In [23]:
class COVIDDataset(Dataset):
    """
    This class is used to create a dataset for the Siamese Network.

    Args:
        Dataset (torch.utils.data.Dataset): Pytorch Dataset class
    Returns:
        image0, image1, label    
    
    """
    def __init__(self,root_path,transform=None):
        """
        Initialize the dataset
        Args:
            imageFolderDataset (torchvision.datasets.ImageFolder): Pytorch ImageFolder class
            transform (torchvision.transforms): Pytorch transforms
        """
        self.root_path = root_path    
        self.transform = transform
        self.classes = os.listdir(root_path)
        self.classes_to_idx = {class_name:i for i, class_name in enumerate(self.classes)} 
        self.img_paths, self.labels = self._make_samples()

    def _make_samples(self):
        samples_images = []
        samples_classes = []
        for cls_name in self.classes:
            class_dir = os.path.join(self.root_path, cls_name)
            for file_name in os.listdir(class_dir):
                sample_path = os.path.join(class_dir, file_name)
                samples_images.append(sample_path)
                samples_classes.append(self.classes_to_idx[cls_name])
        return samples_images, samples_classes 

    def __getitem__(self,idx):
        path_img0 = self.img_paths[idx] 
        label0 = self.labels[idx]

        img0 = Image.open(path_img0)

        if self.transform is not None:
            img0 = self.transform(img0)
        
        return img0, torch.tensor(label0)
    
    def __len__(self):
        return len(self.img_paths)

In [24]:
# Load the training dataset

folder_datatrain = COVIDDataset(root_path=os.path.join('dataset_jpg1','train'), 
                                         transform=preprocess)

train_size = int(0.8 * len(folder_datatrain))
val_size = len(folder_datatrain) - train_size

siamese_datatrain, siamese_dataval = torch.utils.data.random_split(folder_datatrain, [train_size, val_size])


# Create a simple dataloader just for simple visualization-train
train_loader = DataLoader(siamese_datatrain,
                        shuffle=True,
                        num_workers=0,
                        batch_size=8)

# Create a simple dataloader just for simple visualization-val
val_loader = DataLoader(siamese_dataval,
                        shuffle=True,
                        num_workers=0,
                        batch_size=8)

In [38]:
def train(model, train_loader, val_loader, optimizer, criterion, num_epochs):
    for epoch in range(num_epochs):
        # Initialize metrics
        train_loss = 0.0
        train_acc = 0.0
        val_loss = 0.0
        val_acc = 0.0
        num_train_batches = len(train_loader)
        num_val_batches = len(val_loader)
        min_val = 10000

        # Training loop
        model.train()
        for i, (inputs, labels) in tqdm(enumerate(train_loader), total=num_train_batches, desc=f'Epoch {epoch + 1}/{num_epochs} - Training'):
            inputs = Variable(inputs).to(device)
            labels = Variable(labels).to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Update metrics
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_acc += (predicted == labels).sum().item()

        # Validation loop
        model.eval()
        with torch.no_grad():
            for i, (inputs, labels) in tqdm(enumerate(val_loader), total=num_val_batches, desc=f'Epoch {epoch + 1}/{num_epochs} - Validation'):
                inputs = Variable(inputs).to(device)
                labels = Variable(labels).to(device)
                
                outputs = model(inputs)
            
                loss = criterion(outputs, labels)

                # Update metrics
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_acc += (predicted == labels).sum().item()

        if val_loss < min_val:
            torch.save(model.state_dict(), os.path.join('ckpt','best_model.pth'))

        # Calculate metrics
        train_loss /= num_train_batches
        train_acc /= len(train_loader.dataset)
        val_loss /= num_val_batches
        val_acc /= len(val_loader.dataset)

        # Print metrics and progress bar
        tqdm.write(f'Epoch {epoch + 1}/{num_epochs} - Training accuracy: {train_acc:.4f} - Training loss: {train_loss:.4f} - Validation accuracy: {val_acc:.4f} - Validation loss: {val_loss:.4f}')

In [39]:
lr = 0.0001
epochs = 50

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss().to(device)


In [40]:
model.fc = nn.Linear(in_features = model.fc.in_features, out_features=4)

In [41]:
model.to(device);

In [42]:
train(model, train_loader, val_loader, optimizer, criterion, epochs)

Epoch 1/50 - Training: 100%|██████████| 147/147 [00:10<00:00, 14.59it/s]
Epoch 1/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.74it/s]


Epoch 1/50 - Training accuracy: 0.8481 - Training loss: 0.5972 - Validation accuracy: 0.9693 - Validation loss: 0.1202


Epoch 2/50 - Training: 100%|██████████| 147/147 [00:08<00:00, 16.43it/s]
Epoch 2/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.85it/s]


Epoch 2/50 - Training accuracy: 0.9838 - Training loss: 0.0953 - Validation accuracy: 0.9795 - Validation loss: 0.0784


Epoch 3/50 - Training: 100%|██████████| 147/147 [00:08<00:00, 16.39it/s]
Epoch 3/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.15it/s]


Epoch 3/50 - Training accuracy: 0.9778 - Training loss: 0.0782 - Validation accuracy: 0.9795 - Validation loss: 0.0728


Epoch 4/50 - Training: 100%|██████████| 147/147 [00:08<00:00, 16.42it/s]
Epoch 4/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.33it/s]


Epoch 4/50 - Training accuracy: 0.9855 - Training loss: 0.0583 - Validation accuracy: 0.9727 - Validation loss: 0.0843


Epoch 5/50 - Training: 100%|██████████| 147/147 [00:08<00:00, 16.39it/s]
Epoch 5/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.00it/s]


Epoch 5/50 - Training accuracy: 0.9829 - Training loss: 0.0701 - Validation accuracy: 0.9795 - Validation loss: 0.0878


Epoch 6/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.23it/s]
Epoch 6/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.54it/s]


Epoch 6/50 - Training accuracy: 0.9915 - Training loss: 0.0379 - Validation accuracy: 0.9829 - Validation loss: 0.0574


Epoch 7/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.16it/s]
Epoch 7/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.10it/s]


Epoch 7/50 - Training accuracy: 0.9932 - Training loss: 0.0268 - Validation accuracy: 0.9761 - Validation loss: 0.1171


Epoch 8/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.18it/s]
Epoch 8/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.38it/s]


Epoch 8/50 - Training accuracy: 0.9957 - Training loss: 0.0125 - Validation accuracy: 0.9795 - Validation loss: 0.1096


Epoch 9/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.20it/s]
Epoch 9/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.36it/s]


Epoch 9/50 - Training accuracy: 0.9966 - Training loss: 0.0100 - Validation accuracy: 0.9795 - Validation loss: 0.0585


Epoch 10/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.16it/s]
Epoch 10/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.52it/s]


Epoch 10/50 - Training accuracy: 0.9974 - Training loss: 0.0123 - Validation accuracy: 0.9829 - Validation loss: 0.0596


Epoch 11/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.13it/s]
Epoch 11/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.70it/s]


Epoch 11/50 - Training accuracy: 0.9940 - Training loss: 0.0116 - Validation accuracy: 0.9761 - Validation loss: 0.0958


Epoch 12/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.19it/s]
Epoch 12/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.34it/s]


Epoch 12/50 - Training accuracy: 0.9923 - Training loss: 0.0209 - Validation accuracy: 0.9727 - Validation loss: 0.1632


Epoch 13/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.13it/s]
Epoch 13/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.24it/s]


Epoch 13/50 - Training accuracy: 0.9949 - Training loss: 0.0117 - Validation accuracy: 0.9659 - Validation loss: 0.1441


Epoch 14/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.01it/s]
Epoch 14/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.42it/s]


Epoch 14/50 - Training accuracy: 0.9949 - Training loss: 0.0248 - Validation accuracy: 0.9795 - Validation loss: 0.1491


Epoch 15/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.17it/s]
Epoch 15/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.19it/s]


Epoch 15/50 - Training accuracy: 0.9889 - Training loss: 0.0383 - Validation accuracy: 0.9761 - Validation loss: 0.0819


Epoch 16/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 15.59it/s]
Epoch 16/50 - Validation: 100%|██████████| 37/37 [00:02<00:00, 15.25it/s]


Epoch 16/50 - Training accuracy: 0.9855 - Training loss: 0.0608 - Validation accuracy: 0.9795 - Validation loss: 0.1076


Epoch 17/50 - Training: 100%|██████████| 147/147 [00:12<00:00, 12.21it/s]
Epoch 17/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.42it/s]


Epoch 17/50 - Training accuracy: 0.9957 - Training loss: 0.0281 - Validation accuracy: 0.9659 - Validation loss: 0.0951


Epoch 18/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.11it/s]
Epoch 18/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.28it/s]


Epoch 18/50 - Training accuracy: 0.9983 - Training loss: 0.0146 - Validation accuracy: 0.9761 - Validation loss: 0.1171


Epoch 19/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.05it/s]
Epoch 19/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.66it/s]


Epoch 19/50 - Training accuracy: 0.9966 - Training loss: 0.0094 - Validation accuracy: 0.9795 - Validation loss: 0.0674


Epoch 20/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.05it/s]
Epoch 20/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.04it/s]


Epoch 20/50 - Training accuracy: 0.9974 - Training loss: 0.0056 - Validation accuracy: 0.9795 - Validation loss: 0.1180


Epoch 21/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.15it/s]
Epoch 21/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 28.87it/s]


Epoch 21/50 - Training accuracy: 0.9983 - Training loss: 0.0059 - Validation accuracy: 0.9829 - Validation loss: 0.0956


Epoch 22/50 - Training: 100%|██████████| 147/147 [00:15<00:00,  9.24it/s]
Epoch 22/50 - Validation: 100%|██████████| 37/37 [00:02<00:00, 18.07it/s]


Epoch 22/50 - Training accuracy: 0.9974 - Training loss: 0.0058 - Validation accuracy: 0.9829 - Validation loss: 0.1222


Epoch 23/50 - Training: 100%|██████████| 147/147 [00:15<00:00,  9.39it/s]
Epoch 23/50 - Validation: 100%|██████████| 37/37 [00:02<00:00, 18.11it/s]


Epoch 23/50 - Training accuracy: 0.9957 - Training loss: 0.0126 - Validation accuracy: 0.9795 - Validation loss: 0.1327


Epoch 24/50 - Training: 100%|██████████| 147/147 [00:14<00:00, 10.02it/s]
Epoch 24/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.92it/s]


Epoch 24/50 - Training accuracy: 0.9940 - Training loss: 0.0143 - Validation accuracy: 0.9795 - Validation loss: 0.0817


Epoch 25/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.11it/s]
Epoch 25/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.98it/s]


Epoch 25/50 - Training accuracy: 0.9940 - Training loss: 0.0220 - Validation accuracy: 0.9795 - Validation loss: 0.0700


Epoch 26/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.01it/s]
Epoch 26/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.00it/s]


Epoch 26/50 - Training accuracy: 0.9949 - Training loss: 0.0147 - Validation accuracy: 0.9829 - Validation loss: 0.0739


Epoch 27/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.07it/s]
Epoch 27/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.88it/s]


Epoch 27/50 - Training accuracy: 0.9974 - Training loss: 0.0101 - Validation accuracy: 0.9795 - Validation loss: 0.0971


Epoch 28/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.03it/s]
Epoch 28/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.24it/s]


Epoch 28/50 - Training accuracy: 0.9974 - Training loss: 0.0058 - Validation accuracy: 0.9795 - Validation loss: 0.1066


Epoch 29/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.07it/s]
Epoch 29/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.57it/s]


Epoch 29/50 - Training accuracy: 0.9974 - Training loss: 0.0046 - Validation accuracy: 0.9795 - Validation loss: 0.1040


Epoch 30/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.03it/s]
Epoch 30/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.91it/s]


Epoch 30/50 - Training accuracy: 0.9983 - Training loss: 0.0034 - Validation accuracy: 0.9795 - Validation loss: 0.1158


Epoch 31/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.01it/s]
Epoch 31/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.91it/s]


Epoch 31/50 - Training accuracy: 0.9974 - Training loss: 0.0031 - Validation accuracy: 0.9795 - Validation loss: 0.1403


Epoch 32/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.03it/s]
Epoch 32/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.56it/s]


Epoch 32/50 - Training accuracy: 0.9974 - Training loss: 0.0030 - Validation accuracy: 0.9795 - Validation loss: 0.1351


Epoch 33/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.07it/s]
Epoch 33/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 29.76it/s]


Epoch 33/50 - Training accuracy: 0.9966 - Training loss: 0.0034 - Validation accuracy: 0.9795 - Validation loss: 0.1354


Epoch 34/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.12it/s]
Epoch 34/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.11it/s]


Epoch 34/50 - Training accuracy: 0.9974 - Training loss: 0.0032 - Validation accuracy: 0.9795 - Validation loss: 0.1387


Epoch 35/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.22it/s]
Epoch 35/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.37it/s]


Epoch 35/50 - Training accuracy: 0.9966 - Training loss: 0.0032 - Validation accuracy: 0.9795 - Validation loss: 0.1358


Epoch 36/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.24it/s]
Epoch 36/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 29.85it/s]


Epoch 36/50 - Training accuracy: 0.9974 - Training loss: 0.0029 - Validation accuracy: 0.9795 - Validation loss: 0.1769


Epoch 37/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.27it/s]
Epoch 37/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.78it/s]


Epoch 37/50 - Training accuracy: 0.9966 - Training loss: 0.0029 - Validation accuracy: 0.9795 - Validation loss: 0.1853


Epoch 38/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.05it/s]
Epoch 38/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.16it/s]


Epoch 38/50 - Training accuracy: 0.9974 - Training loss: 0.0029 - Validation accuracy: 0.9795 - Validation loss: 0.1748


Epoch 39/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.08it/s]
Epoch 39/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 31.33it/s]


Epoch 39/50 - Training accuracy: 0.9966 - Training loss: 0.0031 - Validation accuracy: 0.9795 - Validation loss: 0.1649


Epoch 40/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.06it/s]
Epoch 40/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.97it/s]


Epoch 40/50 - Training accuracy: 0.9983 - Training loss: 0.0029 - Validation accuracy: 0.9795 - Validation loss: 0.1694


Epoch 41/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.06it/s]
Epoch 41/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.68it/s]


Epoch 41/50 - Training accuracy: 0.9966 - Training loss: 0.0073 - Validation accuracy: 0.9727 - Validation loss: 0.1463


Epoch 42/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 15.94it/s]
Epoch 42/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.62it/s]


Epoch 42/50 - Training accuracy: 0.9787 - Training loss: 0.0944 - Validation accuracy: 0.9829 - Validation loss: 0.0688


Epoch 43/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.04it/s]
Epoch 43/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.23it/s]


Epoch 43/50 - Training accuracy: 0.9898 - Training loss: 0.0347 - Validation accuracy: 0.8498 - Validation loss: 1.1950


Epoch 44/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.03it/s]
Epoch 44/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.18it/s]


Epoch 44/50 - Training accuracy: 0.9957 - Training loss: 0.0103 - Validation accuracy: 0.9795 - Validation loss: 0.0660


Epoch 45/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.14it/s]
Epoch 45/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 29.57it/s]


Epoch 45/50 - Training accuracy: 0.9983 - Training loss: 0.0087 - Validation accuracy: 0.9795 - Validation loss: 0.1144


Epoch 46/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.17it/s]
Epoch 46/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 29.70it/s]


Epoch 46/50 - Training accuracy: 0.9974 - Training loss: 0.0054 - Validation accuracy: 0.9795 - Validation loss: 0.0881


Epoch 47/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.09it/s]
Epoch 47/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 29.74it/s]


Epoch 47/50 - Training accuracy: 0.9983 - Training loss: 0.0049 - Validation accuracy: 0.9795 - Validation loss: 0.1169


Epoch 48/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.07it/s]
Epoch 48/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.15it/s]


Epoch 48/50 - Training accuracy: 0.9974 - Training loss: 0.0052 - Validation accuracy: 0.8191 - Validation loss: 3.8430


Epoch 49/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.21it/s]
Epoch 49/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.21it/s]


Epoch 49/50 - Training accuracy: 0.9974 - Training loss: 0.0037 - Validation accuracy: 0.8294 - Validation loss: 2.0143


Epoch 50/50 - Training: 100%|██████████| 147/147 [00:09<00:00, 16.13it/s]
Epoch 50/50 - Validation: 100%|██████████| 37/37 [00:01<00:00, 30.27it/s]


Epoch 50/50 - Training accuracy: 0.9983 - Training loss: 0.0033 - Validation accuracy: 0.9761 - Validation loss: 0.1684


In [43]:
# Load the training dataset

folder_datatest = COVIDDataset(root_path=os.path.join('dataset_jpg1','test'), 
                                         transform=preprocess)

# Create a simple dataloader just for simple visualization-train
test_loader = DataLoader(folder_datatest,
                        shuffle=True,
                        num_workers=0,
                        batch_size=8)

In [45]:
def test(model, loader):
    # Initialize metrics
    val_acc = 0.0
    
    model.eval()
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(loader):
            inputs = Variable(inputs).to(device)
            labels = Variable(labels).to(device)
            
            outputs = model(inputs)
        
            _, predicted = torch.max(outputs.data, 1)
            val_acc += (predicted == labels).sum().item()

    val_acc /= len(val_loader.dataset)
    return val_acc


In [51]:
model.load_state_dict(torch.load(os.path.join('ckpt','best_model.pth')));

In [50]:
test_accuracy = test(model, test_loader)

print(f'Test accuracy = {round(test_accuracy,4)}')

Test accuracy = 0.5904


### Data Visualization