In [29]:
import random
import matplotlib.pyplot as plt
import torch
import time
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import cv2
from sklearn.model_selection import train_test_split
from collections import Counter
from PIL import Image
import torchvision.transforms as transforms
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import os
import numpy as np
from segmentation_models_pytorch import Unet
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

In [30]:
id_to_label = {0: 'Basophil', 1: 'Eosinophil', 2: 'Lymphocyte', 3: 'Monocyte', 4: 'Neutrophil'}
label_to_id = {'Basophil': 0, 'Eosinophil': 1, 'Lymphocyte': 2, 'Monocyte': 3, 'Neutrophil': 4}

class TrainDataset(Dataset):
    def __init__(self, root_dir, mask_dir, final_transform, mask_transform, color_transform, resize):
        self.root_dir = root_dir
        self.mask_dir = mask_dir
        self.final_transform = final_transform
        self.mask_transform = mask_transform
        self.color_transform = color_transform
        self.images = []
        self.targets = []
        self.resize = resize

        for subdir in os.listdir(root_dir):
            if os.path.isdir(os.path.join(root_dir, subdir)):
                for filename in os.listdir(os.path.join(root_dir, subdir)):
                    if filename.endswith(".jpg"):
                        image_path = os.path.join(root_dir, subdir, filename)
                        mask_path = os.path.join(mask_dir, subdir, filename)
                        if os.path.exists(mask_path):
                            self.images.append((image_path, mask_path))
                        else:
                            # If there's no mask, create a white mask
                            self.images.append((image_path, None))
                        self.targets.append(label_to_id[os.path.basename(os.path.dirname(image_path))])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path, mask_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")
        if mask_path is not None:
            mask = Image.open(mask_path).convert("L")
        else:
            mask = Image.new("L", (self.resize, self.resize), 255)

        image = transforms.ToTensor()(image)
        mask = transforms.ToTensor()(mask)
        seed = torch.randint(2147483647, (1,)).item()
        image = transforms.Resize(int(self.resize))(image)
        mask = transforms.Resize(int(self.resize))(mask)
        torch.manual_seed(seed)
        image = self.mask_transform(image)
        torch.manual_seed(seed)
        mask = self.mask_transform(mask)
        if self.color_transform != None:
            image = self.color_transform(image)
        image = torch.cat((image, mask), dim=0)
        image = self.final_transform(image)

        label = label_to_id[os.path.basename(os.path.dirname(image_path))]
        return image, label
    
class TestDataset(Dataset):
    def __init__(self, root_dir, mask_dir, final_transform, resize):
        self.root_dir = root_dir
        self.mask_dir = mask_dir
        self.final_transform = final_transform
        self.images = []
        self.resize = resize

        for subdir in os.listdir(root_dir):
            if os.path.isdir(os.path.join(root_dir, subdir)):
                for filename in os.listdir(os.path.join(root_dir, subdir)):
                    if filename.endswith(".jpg"):
                        image_path = os.path.join(root_dir, subdir, filename)
                        mask_path = os.path.join(mask_dir, subdir, filename)
                        if os.path.exists(mask_path):
                            self.images.append((image_path, mask_path))
                        else:
                            self.images.append((image_path, None))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path, mask_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")
        if mask_path is not None:
            mask = Image.open(mask_path).convert("L")
        else:
            mask = Image.new("L", (self.resize, self.resize), 255)

        seed = torch.randint(2147483647, (1,)).item()
        image = transforms.ToTensor()(image)
        mask = transforms.ToTensor()(mask)
        image = transforms.Resize(self.resize)(image)
        
        image = torch.cat((image, mask), dim=0)
        image = self.final_transform(image)

        label = os.path.basename(os.path.dirname(image_path))
        label = label_to_id[label]
        return image, label

In [31]:
def init(porportion, resize):    
    WBC_1_train_dir = 'WBC_1/train/data'
    WBC_1_train_mask_dir = 'WBC_1/train/mask'
    WBC_1_train_pred_mask_dir = 'WBC_1/train/pred_mask'
    WBC_10_train_dir = 'WBC_10/train/data'
    WBC_10_train_mask_dir = 'WBC_10/train/mask'
    WBC_10_train_pred_mask_dir = 'WBC_10/train/pred_mask'
    WBC_50_train_dir = 'WBC_50/train/data'
    WBC_50_train_mask_dir = 'WBC_50/train/mask'
    WBC_50_train_pred_mask_dir = 'WBC_50/train/pred_mask'
    WBC_100_train_dir = 'WBC_100/train/data'
    WBC_100_train_mask_dir = 'WBC_100/train/mask'
    WBC_100_train_pred_mask_dir = 'WBC_100/train/pred_mask'
    WBC_100_val_dir = 'WBC_100/val/data'
    WBC_100_mask_dir = 'WBC_100/val/mask'
    CAM16_100_train_dir = 'CAM16_100cls_10mask/train/data'
    CAM16_100_train_mask_dir = 'CAM16_100cls_10mask/train/mask'
    CAM16_100_val_dir = 'CAM16_100cls_10mask/val/data'
    CAM16_100_test_dir = 'CAM16_100cls_10mask/test/data'


    WBC_train_dir = 'WBC_' + str(proportion) + '/train/data'
    WBC_mask_dir = 'WBC_' + str(proportion) + '/train/mask'
    WBC_pred_mask_dir ='WBC_' + str(proportion) + '/train/pred_mask'

    

    final_transform = transforms.Compose([
        transforms.Resize((resize, resize)),
        transforms.Normalize([0.485, 0.456, 0.406, 0.0], [0.229, 0.224, 0.225, 1.0])
    ])

    mask_transform = transforms.Compose([
        transforms.RandomResizedCrop(resize),
        transforms.RandomHorizontalFlip(),
    ])
    
    color_transform = transforms.Compose([
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    ])

    train_dataset = TrainDataset(root_dir=WBC_train_dir, mask_dir=WBC_pred_mask_dir, 
                                 final_transform=final_transform, mask_transform=mask_transform,
                                 color_transform=color_transform, resize=resize)
    val_dataset = TestDataset(root_dir=WBC_100_val_dir, mask_dir=WBC_100_mask_dir, 
                              final_transform=final_transform, resize=resize)


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # if device == torch.device("cuda"):
    #     train_dataset = [(x.to(device), torch.tensor([y]).to(device)) for x, y in train_dataset]
    #     val_dataset = [(x.to(device), torch.tensor([y]).to(device)) for x, y in val_dataset]

    batch_size = 32
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    class_counts = []
    label_counts = Counter(train_dataset.targets)
    for label, count in label_counts.items():
        class_counts.append(count)
    class_weights = [1.0 / count for count in class_counts]
    
    return train_dataloader, val_dataloader, class_weights, device

In [32]:
def train(train_dataloader, val_dataloader, class_weights, model_load_path, model_save_path, device):
    model = models.resnet34(weights=None)
    
    # modify the model
    new_conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
    with torch.no_grad():
        new_conv1.weight[:, :3, :, :] = model.conv1.weight
        new_conv1.weight[:, 3, :, :] = 0.0
    model.conv1 = new_conv1
    model.fc = nn.Linear(model.fc.in_features, 5)

#     criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights)).to(device)

    criterion = nn.CrossEntropyLoss().to(device)

    lr = 0.005
    epochs = int(500 / proportion)
    opt = optim.SGD(model.parameters(), lr=lr, weight_decay=1e-3)
    sch = optim.lr_scheduler.StepLR(opt, int(epochs / 5), 0.5)


    if model_load_path != "":
        model.load_state_dict(torch.load(model_load_path))

    if device == torch.device("cuda"):
        model = model.to(device)

    losses = []
    accuracies = []

    for epoch in range(epochs):
        model.train()
        start_time = time.time()
        io_time = 0
        train_time = 0
        eval_time = 0
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in train_dataloader:
            io_start_time = time.time()
            inputs, labels = inputs.to(device), labels.to(device)
            io_time += time.time() - io_start_time
            opt.zero_grad()
            pred = model(inputs)
            loss = criterion(pred, labels)
            loss.backward()
            opt.step()
            running_loss += loss.item()
            _, predicted = torch.max(pred, 1)
            total += len(labels)
            correct += (predicted == labels).sum().item()
        sch.step()
        train_time = time.time() - start_time
        start_time = time.time()

        print(f'Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_dataloader)}')
        accuracy = 100 * correct / total
        print(f'Train Accuracy: {accuracy:.2f}%')
        losses.append(running_loss / len(train_dataloader))

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_dataloader:
                io_start_time = time.time()
                inputs, labels = inputs.to(device), labels.to(device)
                io_time += time.time() - io_start_time
                pred = model(inputs)
                _, predicted = torch.max(pred, 1)
                total += len(labels)
                correct += (predicted == labels).sum().item()
    #             print(predicted.shape, labels.shape, total, correct)

        eval_time = time.time() - start_time
        accuracy = 100 * correct / total
        print(f'Validation Accuracy: {accuracy:.2f}%')
        print('IO:', io_time, 'Train:', train_time, 'Eval:', eval_time)
        accuracies.append(accuracy)

    torch.save(model.state_dict(), model_save_path)

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6))

    ax1.plot(range(epochs), losses, label='Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss Over Epochs')

    ax2.plot(range(epochs), accuracies, label='Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Validation Accuracy Over Epochs')

    plt.tight_layout()
    plt.show()

In [33]:
# for images, labels in val_dataloader:

#     for i in range(len(images)):
#         image = images[i, :3] 
#         mask = images[i, 3] 
#         label = labels[i]
#         print(f"Label: {label.item()}")

#         plt.imshow(image.permute(1, 2, 0))
#         plt.title("Original Image")
#         plt.show()

#         plt.imshow(mask, cmap='gray')
#         plt.title("Mask")
#         plt.show()

In [None]:
proportion = 100
resize = 256
model_load_path = ""
model_save_path = "ResNet34_WBC" + str(proportion) +"_model.pth"
train_dataloader, val_dataloader, class_weights, device = init(proportion, resize)
train(train_dataloader, val_dataloader, class_weights, model_load_path, model_save_path, device)