In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import os
import glob
from PIL import Image

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [3]:
class WBCDataset(Dataset):
    def __init__(self, root_dir, transform=None, mode='train', subset=False):
        assert mode in ['train', 'val', 'test'], "Mode should be 'train', 'val', or 'test'"

        self.root_dir = root_dir
        self.transform = transform
        self.mode = mode
        self.subset = subset

        self.cell_types = ['Basophil', 'Eosinophil', 'Lymphocyte', 'Monocyte', 'Neutrophil']
        self.image_paths = []
        self.labels = []

        if self.subset:
            data_dir = 'data'
        else:
            data_dir = '' if mode == 'val' else 'data'

        for idx, cell_type in enumerate(self.cell_types):
            type_image_paths = glob.glob(os.path.join(root_dir, data_dir, cell_type, '*.jpg'))
            self.image_paths.extend(type_image_paths)
            self.labels.extend([idx] * len(type_image_paths))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')

        if self.transform:
            img = self.transform(img)

        label = self.labels[idx]
        return img, label

In [4]:
wbc_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

wbc_100_val_dataset = WBCDataset(root_dir='WBC_100/val', transform=wbc_transform, mode='val')
val_loader_100 = torch.utils.data.DataLoader(wbc_100_val_dataset, batch_size=32, shuffle=False, drop_last=True)

print(len(wbc_100_val_dataset))

1728


In [5]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=5):
        super(SimpleCNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.fc1 = nn.Linear(64 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, num_classes)

        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))

        x = x.view(-1, 64 * 16 * 16)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x
    
simple_cnn = SimpleCNN()

In [6]:
def test_model(model, val_loader):
    val_accuracy_history = []

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_corrects = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            val_corrects += torch.sum(preds == labels.data)

    val_accuracy = val_corrects.double() / len(val_loader.dataset)

    return val_accuracy

In [7]:
dataset_name = ['WBC_100', 'WBC_50', 'WBC_10', 'WBC_1']

In [8]:
for name in dataset_name:
    model_random = SimpleCNN().to(device)
    model_random.load_state_dict(torch.load("weights/"+name+"_random.pth", map_location=torch.device('cpu')))
    model_pretrained = SimpleCNN().to(device)
    model_pretrained.load_state_dict(torch.load("weights/"+name+"_pretrained.pth", map_location=torch.device('cpu')))
    
    random_acc = test_model(model_random, val_loader_100)
    print("Accuracy of random-weight for "+name+": ", random_acc.item())
    pretrained_acc = test_model(model_pretrained, val_loader_100)
    print("Accuracy of pretrained-weight for "+name+": ", pretrained_acc.item())

Accuracy of random-weight for WBC_100:  0.9623842592592593
Accuracy of pretrained-weight for WBC_100:  0.9652777777777778
Accuracy of random-weight for WBC_50:  0.9577546296296297
Accuracy of pretrained-weight for WBC_50:  0.9606481481481481
Accuracy of random-weight for WBC_10:  0.9207175925925926
Accuracy of pretrained-weight for WBC_10:  0.9253472222222222
Accuracy of random-weight for WBC_1:  0.6128472222222222
Accuracy of pretrained-weight for WBC_1:  0.6458333333333334
