In [1]:
import torch
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 3090'

In [2]:
import torchvision

In [3]:
import cv2

In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader, Dataset

In [5]:
from glob import glob

In [6]:
images = glob("/home/e/Downloads/Dataset/IDC_regular_ps50_idx5/**/*.png", recursive=True)

In [7]:
len(images)

277524

In [8]:
images[0]

'/home/e/Downloads/Dataset/IDC_regular_ps50_idx5/9383/1/9383_idx5_x1801_y651_class1.png'

In [9]:
import fnmatch
negative = fnmatch.filter(images, "*class0.png")
positive = fnmatch.filter(images, "*class1.png")


In [10]:
from tqdm import tqdm

In [11]:
y = []
for img in tqdm(images):
    if img[-5]=="0":
        y.append(0)
    elif img[-5] == "1":
        y.append(1)
    else:
        print(img, img[-5])

100%|██████████████████████████████| 277524/277524 [00:00<00:00, 3094078.09it/s]


In [12]:
len(y)

277524

In [13]:
import pandas as pd

In [14]:
images_dataset = pd.DataFrame()
images_dataset["images"] = images
images_dataset["labels"] = y
images_dataset.head()

Unnamed: 0,images,labels
0,/home/e/Downloads/Dataset/IDC_regular_ps50_idx...,1
1,/home/e/Downloads/Dataset/IDC_regular_ps50_idx...,1
2,/home/e/Downloads/Dataset/IDC_regular_ps50_idx...,1
3,/home/e/Downloads/Dataset/IDC_regular_ps50_idx...,1
4,/home/e/Downloads/Dataset/IDC_regular_ps50_idx...,1


In [15]:
images_dataset.groupby("labels")["labels"].count()

labels
0    198738
1     78786
Name: labels, dtype: int64

In [16]:
from sklearn.model_selection import train_test_split

In [17]:
train, val = train_test_split(images_dataset, stratify=images_dataset.labels, test_size=0.2)
print(len(train), len(val))

222019 55505


In [18]:
class MyDataset(Dataset):
    def __init__(self, df_data,transform=None):
        super().__init__()
        self.df = df_data.values
        
        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path,label = self.df[index]
        
        image = cv2.imread(img_path)
        image = cv2.resize(image, (50,50))
        if self.transform is not None:
            image = self.transform(image)
        return image, label

In [19]:
epochs = 10
classes = 2
batch_size = 128
alpha = 0.002

# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [20]:
trans_train = transforms.Compose([transforms.ToPILImage(),
                                  transforms.Pad(64, padding_mode='reflect'),
                                  transforms.RandomHorizontalFlip(), 
                                  transforms.RandomVerticalFlip(),
                                  transforms.RandomRotation(20), 
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])])

trans_valid = transforms.Compose([transforms.ToPILImage(),
                                  transforms.Pad(64, padding_mode='reflect'),
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])])

dataset_train = MyDataset(df_data=train, transform=trans_train)
dataset_valid = MyDataset(df_data=val,transform=trans_valid)

loader_train = DataLoader(dataset = dataset_train, batch_size=batch_size, shuffle=True, num_workers=0)
loader_valid = DataLoader(dataset = dataset_valid, batch_size=batch_size//2, shuffle=False, num_workers=0)

In [21]:
class CNN(nn.Module):
    def __init__(self):
        # ancestor constructor call
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=2)
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=2)
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=2)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm2d(512)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.avg = nn.AvgPool2d(7)
        self.fc = nn.Linear(512 * 1 * 1, 2) # !!!
        
    def forward(self, x):
        x = self.pool(F.leaky_relu(self.bn1(self.conv1(x)))) # first convolutional layer then batchnorm, then activation then pooling layer.
        x = self.pool(F.leaky_relu(self.bn2(self.conv2(x))))
        x = self.pool(F.leaky_relu(self.bn3(self.conv3(x))))
        x = self.pool(F.leaky_relu(self.bn4(self.conv4(x))))
        x = self.pool(F.leaky_relu(self.bn5(self.conv5(x))))
        x = self.avg(x)
        x = x.view(-1, 512 * 1 * 1) # !!!
        x = self.fc(x)
        return x

In [22]:
model = CNN().to(device)

In [23]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adamax(model.parameters(), lr=alpha)

In [24]:
total_step = len(loader_train)
for epoch in range(epochs):
    for i, (images, labels) in tqdm(enumerate(loader_train)):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, epochs, i+1, total_step, loss.item()))

100it [00:20,  4.97it/s]

Epoch [1/10], Step [100/1735], Loss: 0.3642


200it [00:39,  5.02it/s]

Epoch [1/10], Step [200/1735], Loss: 0.3346


301it [00:59,  4.78it/s]

Epoch [1/10], Step [300/1735], Loss: 0.3387


401it [01:19,  5.03it/s]

Epoch [1/10], Step [400/1735], Loss: 0.2934


501it [01:39,  5.20it/s]

Epoch [1/10], Step [500/1735], Loss: 0.3024


601it [01:59,  4.77it/s]

Epoch [1/10], Step [600/1735], Loss: 0.3755


701it [02:19,  5.14it/s]

Epoch [1/10], Step [700/1735], Loss: 0.3589


800it [02:39,  4.85it/s]

Epoch [1/10], Step [800/1735], Loss: 0.3951


900it [03:00,  4.74it/s]

Epoch [1/10], Step [900/1735], Loss: 0.3603


1001it [03:19,  4.81it/s]

Epoch [1/10], Step [1000/1735], Loss: 0.4008


1101it [03:40,  4.92it/s]

Epoch [1/10], Step [1100/1735], Loss: 0.2157


1201it [04:00,  3.88it/s]

Epoch [1/10], Step [1200/1735], Loss: 0.2762


1301it [04:20,  4.92it/s]

Epoch [1/10], Step [1300/1735], Loss: 0.2803


1401it [04:41,  4.61it/s]

Epoch [1/10], Step [1400/1735], Loss: 0.2808


1501it [05:02,  4.98it/s]

Epoch [1/10], Step [1500/1735], Loss: 0.3425


1600it [05:21,  4.31it/s]

Epoch [1/10], Step [1600/1735], Loss: 0.3527


1701it [05:42,  4.83it/s]

Epoch [1/10], Step [1700/1735], Loss: 0.3205


1735it [05:49,  5.52it/s]
101it [00:17,  5.50it/s]

Epoch [2/10], Step [100/1735], Loss: 0.2495


201it [00:34,  5.47it/s]

Epoch [2/10], Step [200/1735], Loss: 0.4029


301it [00:51,  5.67it/s]

Epoch [2/10], Step [300/1735], Loss: 0.3246


401it [01:08,  5.10it/s]

Epoch [2/10], Step [400/1735], Loss: 0.3275


501it [01:25,  5.60it/s]

Epoch [2/10], Step [500/1735], Loss: 0.3589


601it [01:42,  5.58it/s]

Epoch [2/10], Step [600/1735], Loss: 0.3237


701it [01:58,  5.61it/s]

Epoch [2/10], Step [700/1735], Loss: 0.3964


801it [02:15,  5.70it/s]

Epoch [2/10], Step [800/1735], Loss: 0.3207


901it [02:32,  5.66it/s]

Epoch [2/10], Step [900/1735], Loss: 0.2870


1001it [02:48,  5.61it/s]

Epoch [2/10], Step [1000/1735], Loss: 0.2814


1101it [03:05,  5.73it/s]

Epoch [2/10], Step [1100/1735], Loss: 0.1894


1201it [03:21,  5.65it/s]

Epoch [2/10], Step [1200/1735], Loss: 0.3946


1301it [03:38,  5.64it/s]

Epoch [2/10], Step [1300/1735], Loss: 0.2437


1401it [03:55,  5.69it/s]

Epoch [2/10], Step [1400/1735], Loss: 0.3734


1501it [04:11,  5.70it/s]

Epoch [2/10], Step [1500/1735], Loss: 0.2647


1601it [04:28,  5.73it/s]

Epoch [2/10], Step [1600/1735], Loss: 0.2714


1701it [04:44,  5.61it/s]

Epoch [2/10], Step [1700/1735], Loss: 0.2255


1735it [04:50,  5.97it/s]
101it [00:16,  5.66it/s]

Epoch [3/10], Step [100/1735], Loss: 0.3371


201it [00:33,  5.66it/s]

Epoch [3/10], Step [200/1735], Loss: 0.3564


301it [00:49,  5.68it/s]

Epoch [3/10], Step [300/1735], Loss: 0.2675


401it [01:06,  5.72it/s]

Epoch [3/10], Step [400/1735], Loss: 0.2508


501it [01:22,  5.64it/s]

Epoch [3/10], Step [500/1735], Loss: 0.2961


601it [01:39,  5.72it/s]

Epoch [3/10], Step [600/1735], Loss: 0.2335


701it [01:55,  5.64it/s]

Epoch [3/10], Step [700/1735], Loss: 0.3347


801it [02:12,  5.68it/s]

Epoch [3/10], Step [800/1735], Loss: 0.3462


901it [02:29,  5.68it/s]

Epoch [3/10], Step [900/1735], Loss: 0.2689


1001it [02:45,  5.67it/s]

Epoch [3/10], Step [1000/1735], Loss: 0.2202


1101it [03:02,  5.70it/s]

Epoch [3/10], Step [1100/1735], Loss: 0.2311


1201it [03:19,  5.43it/s]

Epoch [3/10], Step [1200/1735], Loss: 0.2654


1300it [03:38,  4.75it/s]

Epoch [3/10], Step [1300/1735], Loss: 0.2244


1401it [03:57,  4.96it/s]

Epoch [3/10], Step [1400/1735], Loss: 0.2053


1501it [04:15,  5.48it/s]

Epoch [3/10], Step [1500/1735], Loss: 0.2796


1601it [04:32,  5.56it/s]

Epoch [3/10], Step [1600/1735], Loss: 0.2493


1701it [04:48,  5.32it/s]

Epoch [3/10], Step [1700/1735], Loss: 0.3044


1735it [04:54,  5.88it/s]
101it [00:17,  5.31it/s]

Epoch [4/10], Step [100/1735], Loss: 0.2300


201it [00:35,  5.42it/s]

Epoch [4/10], Step [200/1735], Loss: 0.2514


301it [00:52,  5.45it/s]

Epoch [4/10], Step [300/1735], Loss: 0.1899


401it [01:10,  5.39it/s]

Epoch [4/10], Step [400/1735], Loss: 0.2118


501it [01:27,  5.56it/s]

Epoch [4/10], Step [500/1735], Loss: 0.2345


601it [01:43,  5.62it/s]

Epoch [4/10], Step [600/1735], Loss: 0.2553


701it [02:00,  5.53it/s]

Epoch [4/10], Step [700/1735], Loss: 0.3130


801it [02:17,  5.64it/s]

Epoch [4/10], Step [800/1735], Loss: 0.2108


901it [02:34,  5.65it/s]

Epoch [4/10], Step [900/1735], Loss: 0.2585


1001it [02:50,  5.57it/s]

Epoch [4/10], Step [1000/1735], Loss: 0.3305


1101it [03:07,  5.56it/s]

Epoch [4/10], Step [1100/1735], Loss: 0.2080


1201it [03:24,  5.57it/s]

Epoch [4/10], Step [1200/1735], Loss: 0.2753


1301it [03:41,  5.44it/s]

Epoch [4/10], Step [1300/1735], Loss: 0.2894


1401it [03:58,  5.62it/s]

Epoch [4/10], Step [1400/1735], Loss: 0.2437


1501it [04:14,  5.64it/s]

Epoch [4/10], Step [1500/1735], Loss: 0.2115


1601it [04:31,  5.63it/s]

Epoch [4/10], Step [1600/1735], Loss: 0.2742


1701it [04:48,  5.61it/s]

Epoch [4/10], Step [1700/1735], Loss: 0.3177


1735it [04:53,  5.91it/s]
101it [00:16,  5.56it/s]

Epoch [5/10], Step [100/1735], Loss: 0.2497


201it [00:33,  5.55it/s]

Epoch [5/10], Step [200/1735], Loss: 0.2106


301it [00:49,  5.54it/s]

Epoch [5/10], Step [300/1735], Loss: 0.2310


401it [01:06,  5.61it/s]

Epoch [5/10], Step [400/1735], Loss: 0.2257


500it [01:26,  4.47it/s]

Epoch [5/10], Step [500/1735], Loss: 0.3767


601it [01:47,  4.83it/s]

Epoch [5/10], Step [600/1735], Loss: 0.2240


700it [02:06,  4.69it/s]

Epoch [5/10], Step [700/1735], Loss: 0.2569


801it [02:26,  4.93it/s]

Epoch [5/10], Step [800/1735], Loss: 0.2234


901it [02:45,  4.92it/s]

Epoch [5/10], Step [900/1735], Loss: 0.2484


1001it [03:04,  5.02it/s]

Epoch [5/10], Step [1000/1735], Loss: 0.2374


1101it [03:23,  4.92it/s]

Epoch [5/10], Step [1100/1735], Loss: 0.2630


1201it [03:42,  5.02it/s]

Epoch [5/10], Step [1200/1735], Loss: 0.2554


1301it [04:01,  5.01it/s]

Epoch [5/10], Step [1300/1735], Loss: 0.2727


1400it [04:21,  3.94it/s]

Epoch [5/10], Step [1400/1735], Loss: 0.2601


1501it [04:41,  5.00it/s]

Epoch [5/10], Step [1500/1735], Loss: 0.3474


1601it [04:59,  5.03it/s]

Epoch [5/10], Step [1600/1735], Loss: 0.2044


1701it [05:18,  5.17it/s]

Epoch [5/10], Step [1700/1735], Loss: 0.3500


1735it [05:25,  6.05it/s]
101it [00:18,  5.05it/s]

Epoch [6/10], Step [100/1735], Loss: 0.3345


201it [00:37,  5.12it/s]

Epoch [6/10], Step [200/1735], Loss: 0.3461


301it [00:55,  5.03it/s]

Epoch [6/10], Step [300/1735], Loss: 0.2973


401it [01:13,  5.03it/s]

Epoch [6/10], Step [400/1735], Loss: 0.2003


501it [01:32,  5.32it/s]

Epoch [6/10], Step [500/1735], Loss: 0.2110


601it [01:50,  5.01it/s]

Epoch [6/10], Step [600/1735], Loss: 0.2439


701it [02:09,  5.09it/s]

Epoch [6/10], Step [700/1735], Loss: 0.3311


801it [02:27,  5.12it/s]

Epoch [6/10], Step [800/1735], Loss: 0.1965


900it [02:46,  4.81it/s]

Epoch [6/10], Step [900/1735], Loss: 0.3206


1001it [03:05,  5.07it/s]

Epoch [6/10], Step [1000/1735], Loss: 0.2734


1101it [03:23,  5.02it/s]

Epoch [6/10], Step [1100/1735], Loss: 0.2445


1201it [03:42,  5.03it/s]

Epoch [6/10], Step [1200/1735], Loss: 0.1996


1301it [04:00,  5.05it/s]

Epoch [6/10], Step [1300/1735], Loss: 0.2427


1401it [04:19,  5.13it/s]

Epoch [6/10], Step [1400/1735], Loss: 0.2912


1501it [04:38,  4.95it/s]

Epoch [6/10], Step [1500/1735], Loss: 0.2641


1601it [04:56,  5.12it/s]

Epoch [6/10], Step [1600/1735], Loss: 0.2336


1701it [05:15,  4.93it/s]

Epoch [6/10], Step [1700/1735], Loss: 0.2550


1735it [05:22,  5.39it/s]
101it [00:18,  4.97it/s]

Epoch [7/10], Step [100/1735], Loss: 0.2194


201it [00:37,  4.91it/s]

Epoch [7/10], Step [200/1735], Loss: 0.2322


301it [00:56,  4.99it/s]

Epoch [7/10], Step [300/1735], Loss: 0.2529


401it [01:15,  5.14it/s]

Epoch [7/10], Step [400/1735], Loss: 0.2359


501it [01:33,  4.96it/s]

Epoch [7/10], Step [500/1735], Loss: 0.3149


601it [01:52,  5.07it/s]

Epoch [7/10], Step [600/1735], Loss: 0.2488


701it [02:11,  5.08it/s]

Epoch [7/10], Step [700/1735], Loss: 0.1945


800it [02:30,  4.19it/s]

Epoch [7/10], Step [800/1735], Loss: 0.2539


901it [02:48,  5.10it/s]

Epoch [7/10], Step [900/1735], Loss: 0.3014


1001it [03:06,  5.20it/s]

Epoch [7/10], Step [1000/1735], Loss: 0.2168


1100it [03:26,  4.48it/s]

Epoch [7/10], Step [1100/1735], Loss: 0.2350


1201it [03:46,  5.05it/s]

Epoch [7/10], Step [1200/1735], Loss: 0.2068


1301it [04:05,  5.23it/s]

Epoch [7/10], Step [1300/1735], Loss: 0.1962


1400it [04:25,  4.55it/s]

Epoch [7/10], Step [1400/1735], Loss: 0.3555


1501it [04:45,  5.04it/s]

Epoch [7/10], Step [1500/1735], Loss: 0.2954


1601it [05:04,  5.21it/s]

Epoch [7/10], Step [1600/1735], Loss: 0.2547


1701it [05:22,  5.22it/s]

Epoch [7/10], Step [1700/1735], Loss: 0.2420


1735it [05:28,  5.28it/s]
101it [00:19,  5.16it/s]

Epoch [8/10], Step [100/1735], Loss: 0.2117


201it [00:37,  5.24it/s]

Epoch [8/10], Step [200/1735], Loss: 0.2677


301it [00:55,  5.19it/s]

Epoch [8/10], Step [300/1735], Loss: 0.1467


400it [01:13,  4.65it/s]

Epoch [8/10], Step [400/1735], Loss: 0.2683


500it [01:33,  4.63it/s]

Epoch [8/10], Step [500/1735], Loss: 0.2198


600it [01:54,  4.26it/s]

Epoch [8/10], Step [600/1735], Loss: 0.3562


701it [02:16,  4.92it/s]

Epoch [8/10], Step [700/1735], Loss: 0.2396


800it [02:35,  4.40it/s]

Epoch [8/10], Step [800/1735], Loss: 0.2857


901it [02:54,  4.95it/s]

Epoch [8/10], Step [900/1735], Loss: 0.2634


1001it [03:14,  4.93it/s]

Epoch [8/10], Step [1000/1735], Loss: 0.2289


1101it [03:32,  5.12it/s]

Epoch [8/10], Step [1100/1735], Loss: 0.3074


1200it [03:51,  4.46it/s]

Epoch [8/10], Step [1200/1735], Loss: 0.1446


1301it [04:11,  5.08it/s]

Epoch [8/10], Step [1300/1735], Loss: 0.2164


1401it [04:30,  5.13it/s]

Epoch [8/10], Step [1400/1735], Loss: 0.1775


1501it [04:48,  5.18it/s]

Epoch [8/10], Step [1500/1735], Loss: 0.2311


1601it [05:07,  4.52it/s]

Epoch [8/10], Step [1600/1735], Loss: 0.1819


1701it [05:28,  5.01it/s]

Epoch [8/10], Step [1700/1735], Loss: 0.1759


1735it [05:33,  5.20it/s]
101it [00:17,  5.20it/s]

Epoch [9/10], Step [100/1735], Loss: 0.3078


201it [00:36,  5.10it/s]

Epoch [9/10], Step [200/1735], Loss: 0.2786


301it [00:54,  5.13it/s]

Epoch [9/10], Step [300/1735], Loss: 0.2637


401it [01:13,  5.11it/s]

Epoch [9/10], Step [400/1735], Loss: 0.1464


501it [01:32,  4.85it/s]

Epoch [9/10], Step [500/1735], Loss: 0.2154


601it [01:54,  4.44it/s]

Epoch [9/10], Step [600/1735], Loss: 0.1449


701it [02:16,  4.60it/s]

Epoch [9/10], Step [700/1735], Loss: 0.1885


800it [02:36,  4.41it/s]

Epoch [9/10], Step [800/1735], Loss: 0.2558


900it [02:57,  4.22it/s]

Epoch [9/10], Step [900/1735], Loss: 0.2648


1001it [03:17,  4.93it/s]

Epoch [9/10], Step [1000/1735], Loss: 0.2890


1100it [03:39,  3.18it/s]

Epoch [9/10], Step [1100/1735], Loss: 0.2442


1200it [04:01,  4.39it/s]

Epoch [9/10], Step [1200/1735], Loss: 0.1883


1301it [04:22,  4.64it/s]

Epoch [9/10], Step [1300/1735], Loss: 0.2774


1401it [04:42,  5.02it/s]

Epoch [9/10], Step [1400/1735], Loss: 0.2386


1501it [05:00,  5.09it/s]

Epoch [9/10], Step [1500/1735], Loss: 0.1765


1601it [05:19,  4.85it/s]

Epoch [9/10], Step [1600/1735], Loss: 0.2069


1701it [05:36,  5.44it/s]

Epoch [9/10], Step [1700/1735], Loss: 0.1777


1735it [05:42,  5.06it/s]
101it [00:17,  5.38it/s]

Epoch [10/10], Step [100/1735], Loss: 0.2250


201it [00:34,  5.34it/s]

Epoch [10/10], Step [200/1735], Loss: 0.1987


301it [00:52,  5.45it/s]

Epoch [10/10], Step [300/1735], Loss: 0.2059


401it [01:09,  5.46it/s]

Epoch [10/10], Step [400/1735], Loss: 0.2907


501it [01:26,  5.53it/s]

Epoch [10/10], Step [500/1735], Loss: 0.2135


601it [01:43,  5.44it/s]

Epoch [10/10], Step [600/1735], Loss: 0.3682


701it [02:01,  5.41it/s]

Epoch [10/10], Step [700/1735], Loss: 0.2108


801it [02:18,  5.41it/s]

Epoch [10/10], Step [800/1735], Loss: 0.2617


901it [02:35,  5.42it/s]

Epoch [10/10], Step [900/1735], Loss: 0.2233


1001it [02:53,  4.76it/s]

Epoch [10/10], Step [1000/1735], Loss: 0.1762


1101it [03:10,  5.20it/s]

Epoch [10/10], Step [1100/1735], Loss: 0.1885


1201it [03:28,  5.48it/s]

Epoch [10/10], Step [1200/1735], Loss: 0.1430


1301it [03:45,  5.27it/s]

Epoch [10/10], Step [1300/1735], Loss: 0.2180


1401it [04:03,  5.38it/s]

Epoch [10/10], Step [1400/1735], Loss: 0.2340


1501it [04:20,  5.64it/s]

Epoch [10/10], Step [1500/1735], Loss: 0.1373


1601it [04:38,  5.48it/s]

Epoch [10/10], Step [1600/1735], Loss: 0.2851


1701it [04:55,  5.58it/s]

Epoch [10/10], Step [1700/1735], Loss: 0.2264


1735it [05:00,  5.77it/s]


In [25]:
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
confusion_matrix = torch.zeros(2, 2)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in loader_valid:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        for t, p in zip(labels.view(-1), predicted.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
                 
    print('Test Accuracy of the model on the test images: {} %'.format(100 * correct / total))

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

Test Accuracy of the model on the test images: 90.53959102783533 %


In [26]:
print(confusion_matrix)

tensor([[37440.,  2308.],
        [ 2943., 12814.]])
