In [None]:
import os
import copy
import time

import torch
import torch.nn as nn
import torchvision

from torchvision import datasets, models
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
# from torchsummary import summary

from sklearn.metrics import classification_report, confusion_matrix

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

In [None]:
# Const variables

# Dataset root
DATASET_DIR_ROOT = "./dataset"

# Number of epochs
n_epochs = 80

# Input images size
image_size = 224

# Number of batches
batch_size = 32

# Number of workers for dataloaders
num_workers = 4

# **DATA LOADING**

In [None]:
# Define transforms for each dataset separately
transform = transforms.Compose([
    transforms.Resize(256),  # 调整大小
    transforms.CenterCrop(224),  # 中心裁剪
    transforms.ToTensor(),  # 转为Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])

In [None]:
# Load datasets
train_set = datasets.ImageFolder(
    root=os.path.join(DATASET_DIR_ROOT, "train"),
    transform=transform
)

validation_set = datasets.ImageFolder(
    root=os.path.join(DATASET_DIR_ROOT, "validation"),
    transform=transform
)

test_set = datasets.ImageFolder(
    root=os.path.join(DATASET_DIR_ROOT, "test"),
    transform=transform
)

In [None]:
# Instantiate DataLoaders

train_loader = DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers
)

validation_loader = DataLoader(
    dataset=validation_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

test_loader = DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

# **DATA VISUALIZATION** 


In [None]:
sns.set_style('darkgrid')

In [None]:
def grid_display(dataloader):
    """Plots a single batch of a dataloader. Denormalizes images for better visualization.

    :param dataloader: a DataLoader object that we want to display its images
    """
    for images, labels in dataloader:
        fig, ax = plt.subplots(figsize = (16,12))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(make_grid(images, nrow=8).permute(1,2,0))
        break
        
# grid_display(train_loader)

In [None]:
def label_distribution(dataset):
    """Counts the number of samples per label(class) in the dataset.
    
    :param dataset: the purpose dataset
    :type dataset: ImageFolder
    """
    encoded_labels = {v:k for k, v in dataset.class_to_idx.items()} # {0: 'cataract', 1: 'diabetic_retinopathy', 2: 'glaucoma', 3: 'normal'}
    labels_count = {k:0 for k in dataset.class_to_idx} # {'cataract': 0, 'diabetic_retinopathy': 0, 'glaucoma': 0, 'normal': 0}

    for label_code in encoded_labels:
        labels_count[encoded_labels[label_code]] = dataset.targets.count(label_code)
    return labels_count

In [None]:
def plot_from_dict(dict_obj: dict, plot_title: str, **kwargs):
    """Plots a bar chart from a dictionry. keys: x_axis, values: y_axis
    
    :param dict_obj: the dictionary that would be plotted
    :param plot_title: title of the plot
    """
    df = pd.DataFrame.from_dict([dict_obj]).melt()
    df.rename(columns={'variable': 'Dataset Labels', 'value': 'Number of samples'}, inplace=True)
    return sns.barplot(
        data=df,
        x="Dataset Labels",
        y="Number of samples",
        hue="Dataset Labels",
        **kwargs
    ).set_title(label=plot_title)

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(10, 20))
plot_from_dict(label_distribution(train_set), plot_title="Train Set", ax=axes[0])
plot_from_dict(label_distribution(validation_set), plot_title="Validation Set", ax=axes[1])
plot_from_dict(label_distribution(test_set), plot_title="Test Set", ax=axes[2])

# **MODEL**

In [None]:
class MyResnet(nn.Module):
    def __init__(self, num_class, frozen=True):
        super(MyResnet, self).__init__()
        self.num_class = num_class
        self.model = models.resnet50(weights='IMAGENET1K_V1')
        num_features = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, self.num_class)
        )
        if frozen:
            for param in self.model.parameters():
                param.requires_grad = False  # 冻结所有层

            # 只训练最后的全连接层
            for param in self.model.fc.parameters():
                param.requires_grad = True

    def forward(self, x):
        return self.model(x)

In [None]:
# Summary of the architecture of RetinalEnsemble

# summary(EyeSeeNet(4), (3, image_size, image_size))

In [None]:
# Define device : GPU, MPS, or CPU

device = "cpu"
if torch.cuda.is_available():
    device = "cuda:0"
elif torch.backends.mps.is_available():
    device = "mps"
    
print(device)

In [None]:
# Instanciate and Transfer model on the device

model = MyResnet(4, frozen=True).to(device)

In [None]:
# Loss & Optimizer

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

# **TRAIN THE MODEL**

In [None]:
# Train model

train_losses = np.zeros(n_epochs)
val_losses = np.zeros(n_epochs)
best_val_acc = 0.0
best_model_wts = copy.deepcopy(model.state_dict())
since = time.time()

for epoch in range(n_epochs):
    train_corrects = 0
    train_loss = 0.0

    model.train()
    for inputs, targets in tqdm(train_loader, desc=f'Training... Epoch: {epoch + 1}/{n_epochs}'):

        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        _, preds = torch.max(outputs, 1)
        train_corrects += torch.sum(preds == targets.data)
        train_loss += loss.item() * inputs.size(0)

        loss.backward()
        optimizer.step()
            
    train_loss = train_loss / len(train_loader.dataset)
    train_acc = train_corrects / len(train_loader.dataset)
    
    with torch.no_grad():
        val_corrects = 0
        val_loss = 0.0
        model.eval()

        for inputs, targets in tqdm(validation_loader, desc=f'Validating... Epoch: {epoch + 1}/{n_epochs}'):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            _, preds = torch.max(outputs, 1)
            val_corrects += torch.sum(preds == targets.data)
            val_loss += loss.item() * inputs.size(0)
        
        val_loss = val_loss / len(validation_loader.dataset)
        val_acc = val_corrects / len(validation_loader.dataset)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
    
    # save epoch losses
    train_losses[epoch] = train_loss
    val_losses[epoch] = val_loss
    
    print(f"Epoch {epoch+1}/{n_epochs}:")
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")
    print('-'*30)

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
model.load_state_dict(best_model_wts)

In [None]:
plt.plot(train_losses, label="train loss")
plt.plot(val_losses, label="validation loss")
plt.legend()
plt.show()

# **TEST & EVALUATION**



In [18]:
# Calculate Train and Validation Accuracy

def cal_accuracy(data_loader):
    n_correct = 0
    n_total = 0
    
    for inputs, targets in data_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        _, predictions = torch.max(outputs, 1)

        n_correct += (predictions == targets).sum().item()
        n_total += targets.shape[0]

    accuracy = n_correct / n_total

    return accuracy


print(f"Train Accuracy: {cal_accuracy(train_loader):.4f}, Validation Accuracy: {cal_accuracy(validation_loader):.4f}, Test Accuracy: {cal_accuracy(test_loader):.4f}")

Train Accuracy: 0.9769, Validation Accuracy: 0.8641, Test Accuracy: 0.8756


In [19]:
y_true_list = []
y_pred_list = []

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, predictions = torch.max(outputs, 1)

        y_true_list.append(targets.cpu().numpy())
        y_pred_list.append(predictions.cpu().numpy())

# flatten data of batches into a 1-d list
y_true_list = list(np.concatenate(y_true_list).flat)
y_pred_list = list(np.concatenate(y_pred_list).flat)

In [20]:
# Classification Report

print(classification_report(y_true_list, y_pred_list))

              precision    recall  f1-score   support

           0       0.91      0.92      0.91       156
           1       0.94      0.94      0.94       165
           2       0.85      0.77      0.81       152
           3       0.81      0.87      0.84       162

    accuracy                           0.88       635
   macro avg       0.88      0.87      0.87       635
weighted avg       0.88      0.88      0.88       635


In [None]:
# Confusion Matrix

print(confusion_matrix(y_true_list, y_pred_list))

In [None]:
# Plot the Confusion Matrix

def plot_confusion_matrix(dataset, y_true_list, y_pred_list):
    index_to_label = {v:k for k, v in dataset.class_to_idx.items()}
    confusion_matrix_df = pd.DataFrame(confusion_matrix(y_true_list, y_pred_list)).rename(columns=index_to_label, index=index_to_label)
    fig, ax = plt.subplots(figsize=(14,10))         
    return sns.heatmap(confusion_matrix_df, annot=True, ax=ax)


plot_confusion_matrix(test_set, y_true_list, y_pred_list)

In [None]:
torch.save(model.state_dict(), "myresnet.pth")

# **MISCLASSIFIED SAMPLES**

In [None]:
# Plot some of misclassified instances

n_missclassified = 25

encoded_labels = {v:k for k, v in train_set.class_to_idx.items()}
misclassified_idx = np.where(np.array(y_true_list) != np.array(y_pred_list))[0]

print(f"{n_missclassified} of Misclassified Images:")
fig, axes = plt.subplots(5, 5, figsize=(15, 15))
for i, mis_index in enumerate(misclassified_idx[:n_missclassified]):
    ax = axes.ravel()[i]
    ax.imshow(test_set[mis_index][0].permute(1, 2, 0))
    ax.set_title(f"True: {encoded_labels[y_true_list[mis_index]]}\nPredicted: {encoded_labels[y_pred_list[mis_index]]}")
    ax.axis('off')
plt.subplots_adjust(wspace=0.5, hspace=0.5)
plt.show()