In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import math
import os
from PIL import Image
import torch.optim as optim
#Confusion Matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns  # For a nicer confusion matrix visualization
#Test data visualization
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import DataLoader, SubsetRandomSampler


# Device configuration (choose CPU/GPU)
device = torch.device('cuda' if torch.cuda.is_available()else 'cpu')
print(device)
# Here, using CPU

In [None]:
#Hyperparameter
batch_size = 60
learning_rate = 1e-4
num_epochs = 15
trainingEpoch_loss = []
validationEpoch_loss = []
trainingEpoch_accuracy = []
validationEpoch_accuracy = []

In [None]:
# Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images
    transforms.ToTensor(),  # Convert images to Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize images
])

# Create dataset instances
dataset = datasets.ImageFolder(root='D:/Purdue/2024 Spring/BME 450/Project/Data_mix', transform=transform)

#Given categoreis
categories = ('glioma','meningioma','notumor','pituitary')

In [None]:
#Split dataset
num_train, num_val,num_test = 0.8*len(dataset), 0.1*len(dataset),0.1*len(dataset)
num_train = int(num_train)
num_val = int((num_val) + 1)
num_test = int(num_test)

trainset,valset,test_dataset = torch.utils.data.random_split(dataset,[num_train, num_val,num_test])
len(trainset), len(valset), len(test_dataset)


print(f'The traning data set size is {len(trainset)}\nThe size validate dataset is {len(valset)}\nThe size of test dataset is {len(test_dataset)}')

In [None]:
#Create Dataloader
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False)

In [None]:
#Simple Version of model
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels= 8, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels= 16, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels= 32, kernel_size=3)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(in_features = 32*26*26,out_features=120) # num of feature * size * size
        #self.drop1 = nn.Dropout(p = 0.3)


        self.out = nn.Linear(in_features=120, out_features = 4)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.pool1(x))

        x = F.relu(self.conv2(x))
        x = F.relu(self.pool2(x))

        x = F.relu(self.conv3(x))
        x = F.relu(self.pool3(x))

        x = self.flatten(x)

        x = F.relu(self.fc1(x))
        #x = self.drop1(x)

        x = self.out(x)

        return x

In [None]:
# Define model, loss function, and optimizer
model = ConvNet().to(device)
loss_fn = nn.CrossEntropyLoss() # If do not end with softmax, use this
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate) #stochastic gradient descent --> Adam

In [None]:
#Training
def train_one_epoch():
    model.train()  # Set the model to training mode
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)

        # Backward and optimize
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = 100 * correct / total
    trainingEpoch_loss.append(np.array(epoch_loss))
    trainingEpoch_accuracy.append(np.array(epoch_accuracy))
    print(f'Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%')


In [None]:
# Validation
def validate_one_epoch():
    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():  # Disable gradient computation
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)

            # Statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(val_loader)
    epoch_accuracy = 100 * correct / total
    validationEpoch_loss.append(np.array(epoch_loss))
    validationEpoch_accuracy.append(np.array(epoch_accuracy))
    print(f'Validation Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%')
    print('***************************************************')

In [None]:
# Train CNN
for epoch_index in range(num_epochs):
    print(f'Epoch: {epoch_index + 1}\n')
    
    train_one_epoch()
    validate_one_epoch()
    
print('Finished Training')

In [None]:
#Plot training history
Epochs = [i+1 for i in range(len(trainingEpoch_accuracy))]
print(len(Epochs))
plt.subplot(1,2,1)
plt.plot(Epochs,trainingEpoch_loss, label='Tranining Loss')
plt.plot(Epochs,validationEpoch_loss,label='Validated Loss')
plt.legend()

print(len(trainingEpoch_accuracy))
plt.subplot(1,2,2)
plt.plot(Epochs,trainingEpoch_accuracy, label='Tranining Accuracy')
plt.plot(Epochs,validationEpoch_accuracy,label='Validated Accuracy')

plt.legend()
plt.show

In [None]:
#Test
model.eval()
correct = 0
total = 0
# Lists to store all labels and predictions
all_labels = []
all_predictions = []

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        # Get output
        outputs = model(images)
        # Choose prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        # Append labels
        all_labels.extend(labels)
        all_predictions.extend(predicted)

# Convert lists into tensors for confusion matrix computation
all_labels = torch.stack(all_labels).cpu().numpy()
all_predictions = torch.stack(all_predictions).cpu().numpy()

print(f'Accuracy of the network on the 1310 test images: {100 * correct // total} %')

In [None]:
# Compute the confusion matrix
cm = confusion_matrix(all_labels, all_predictions, labels = range(len(categories)))
print(cm)
# Plotting the confusion matrix
sns.heatmap(cm, fmt='d', cmap='Blues',xticklabels=categories, yticklabels=categories)
#sns.heatmap(cm, annot= True, fmt='d', cmap='Blues',xticklabels=categories, yticklabels=categories,annot_kws={"size": 12})   
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
#plt.title('Confusion Matrix')
plt.show()

In [None]:
# Plot random 12 figures
num_samples = 12
indices = torch.randperm(len(test_dataset))[:num_samples]
sampler = SubsetRandomSampler(indices)
test_loader_plot = DataLoader(test_dataset, batch_size=num_samples, sampler=sampler)

# Get one batch of data
images, labels = next(iter(test_loader_plot))

# Move the batch to the same device as the model
images = images.to(device)
labels = labels.to(device)

# Make predictions
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)

# Plot the results
fig, axes = plt.subplots(3, 4, figsize=(12, 9))
axes = axes.ravel()
for i, (img, label, pred) in enumerate(zip(images.cpu(), labels.cpu(), predicted.cpu())):
    img_pil = to_pil_image(img)  # Convert tensor to PIL Image
    axes[i].imshow(img_pil)
    axes[i].set_title(f'True: {categories[label]}\nPred: {categories[pred]}')
    axes[i].axis('off')

plt.tight_layout()
plt.show()