In [1]:
import numpy as np
import random
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
import os
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import roc_curve, auc, log_loss, accuracy_score
from tqdm import tqdm

In [12]:
def gray_world_assumption(img):
    # Calculate the average color of the entire image
    mean_color = img.mean(dim=[1, 2])
    # Calculate the correction factors for each channel
    correction_factors = torch.tensor([0.5, 0.5, 0.5]) / mean_color
    # Apply the correction to the image
    corrected_img = correction_factors.view(3, 1, 1) * img
    #corrected_img = img - mean_color.view(3,1,1) + 0.5
    return corrected_img


data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((280,280)),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((230,230)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = "./cell_images"
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
all_dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=64, shuffle=True) for x in ['train', 'val']}


In [3]:
image_datasets['train'].class_to_idx

{'Parasitized': 0, 'Uninfected': 1}

In [4]:
image_datasets['val'].class_to_idx

{'Parasitized': 0, 'Uninfected': 1}

In [5]:
model = models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False



In [20]:
class CustomCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(nn.Conv2d(3, 3, kernel_size=7, padding='same'),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2,2),
                                   nn.ReLU(),
                                   nn.Flatten(),
                                   nn.Linear(112*112*3, 2))
    
    def forward(self, X):
        return self.model(X)

In [21]:
# num_classes = 2
# model.fc = nn.Linear(model.fc.in_features, num_classes)
model = CustomCNN()
model.to('mps') #pytorch metal for M2

CustomCNN(
  (model): Sequential(
    (0): Conv2d(3, 3, kernel_size=(7, 7), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): ReLU()
    (4): Flatten(start_dim=1, end_dim=-1)
    (5): Linear(in_features=37632, out_features=2, bias=True)
  )
)

In [22]:
criterion = nn.CrossEntropyLoss()
#optimizer = optim.Adam(model.fc.parameters(), lr=0.001) #change LR
optimizer = optim.Adam(model.parameters(), lr=0.001) #change LR

In [23]:
num_epochs = 1
device = 'mps'

# Training loop
for epoch in range(num_epochs):
    print('-' * 30)
    print(f"Epoch {epoch+1}/{num_epochs}")
    print('-' * 30)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()
            dataloader = all_dataloaders['train']
        else:
            model.eval()
            dataloader = all_dataloaders['val']

        running_loss = 0.0
        running_corrects = 0
                
        # Iterate over data
        for inputs, labels in tqdm(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # Backward pass and optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(dataloader.dataset)
        epoch_acc = running_corrects / len(dataloader.dataset)

        print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

print("Training complete!")

------------------------------
Epoch 1/1
------------------------------


100%|██████████| 343/343 [00:50<00:00,  6.80it/s]


train Loss: 0.7131 Acc: 0.6292


100%|██████████| 88/88 [00:08<00:00, 10.51it/s]

val Loss: 0.6622 Acc: 0.6139
Training complete!





In [None]:
'''
./cell_images
    train
        Uninfected
        Parasitized
    val
        Uninfected
        Parasitized
'''

In [None]:
# uninfected_data = [os.path.join('./cell_images/Uninfected', x) for x in os.listdir('./cell_images/Uninfected')]
# parasitized_data = [os.path.join('./cell_images/Parasitized', x) for x in os.listdir('./cell_images/Parasitized')]

# try:
#     os.mkdir('./cell_images/train')
#     os.mkdir('./cell_images/val')
#     os.mkdir('./cell_images/train/Uninfected')
#     os.mkdir('./cell_images/train/Parasitized')
#     os.mkdir('./cell_images/val/Uninfected')
#     os.mkdir('./cell_images/val/Parasitized')
# except Exception: 
#     pass

# train_fraction = 0.8
# val_fraction = 0.2

# all_data = {'Uninfected':uninfected_data, 'Parasitized':parasitized_data}
# for key in all_data:
#     for image in all_data[key]:
#         #sort as train
#         if random.random() < train_fraction:
#             shutil.copy(image, f'./cell_images/train/{key}')
#         else:
#             shutil.copy(image, f'./cell_images/val/{key}')

In [24]:
# You can then proceed with evaluation, saving the model, etc.
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    progress_bar_eval = tqdm(all_dataloaders['val'], desc="Evaluating", ncols=1000)
    for inputs, labels in all_dataloaders['val']:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())


accuracy = accuracy_score(all_labels, all_preds)
print(f"Accuracy: {accuracy:.4f}")

# Classification report
report = classification_report(all_labels, all_preds, target_names=["Uninfected", "Parasitized"])
print(report)

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
print("Confusion Matrix:")
print(cm)

Evaluating:   0%|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                | 0/88 [36:49<?, ?it/s]

Accuracy: 0.6139
              precision    recall  f1-score   support

  Uninfected       0.58      0.79      0.67      2807
 Parasitized       0.68      0.44      0.53      2805

    accuracy                           0.61      5612
   macro avg       0.63      0.61      0.60      5612
weighted avg       0.63      0.61      0.60      5612

Confusion Matrix:
[[2218  589]
 [1578 1227]]


In [19]:
torch.save(model, 'classifier3.pth')