In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import datetime
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from torch.utils.data import Dataset, DataLoader
from PIL import Image

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load CSV
data_csv = "selected_for_annotation_mt_strcture_20250310_MA.csv"
image_dir = "/mnt/d/lding/CLS/mousumiLiuDinner/set1to5_processed_results/Microtubule_GUV-Liu-20250106T211105Z-001/processed_MT/GUV-MT_obj_tiff_selected_std15/improved_png"
df = pd.read_csv(data_csv)
df = df.dropna(subset=["filename", "label"])
df["label"] = df["label"].astype(str)

# Split into Train (80%) & Validation (20%)
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df["label"], random_state=42)

# Define Transforms
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

val_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Custom Dataset Class
class CustomDataset(Dataset):
    def __init__(self, dataframe, root_dir, transform=None):
        self.data = dataframe.copy()
        self.root_dir = root_dir
        self.transform = transform
        self.label_to_index = {label: idx for idx, label in enumerate(sorted(self.data["label"].unique()))}
        self.data["label"] = self.data["label"].map(self.label_to_index)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.data.iloc[idx, 0])
        image = Image.open(img_name).convert("RGB")
        label = int(self.data.iloc[idx, 1])
        if self.transform:
            image = self.transform(image)
        return image, label

# Create Datasets & DataLoaders
train_dataset = CustomDataset(train_df, image_dir, transform=train_transform)
val_dataset = CustomDataset(val_df, image_dir, transform=val_transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Function to initialize different models
def initialize_model(model_name, num_classes):
    if model_name == "efficientnet":
        model = models.efficientnet_b0(pretrained=True)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    elif model_name == "resnet":
        model = models.resnet18(pretrained=True)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
    elif model_name == "convnext":
        model = models.convnext_tiny(pretrained=True)
        num_ftrs = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
    else:
        raise ValueError("Unsupported model type")
    return model.to(device)

# Function to save confusion matrix and loss curves
def save_plots(train_losses, val_losses, train_accuracies, val_accuracies, model_dir):
    plt.figure(figsize=(10,5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss Curve')
    plt.savefig(f"{model_dir}/loss_curve.png")
    plt.close()
    
    plt.figure(figsize=(10,5))
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy Curve')
    plt.savefig(f"{model_dir}/accuracy_curve.png")
    plt.close()
    
# Function to plot confusion matrices
def plot_confusion_matrix(model, dataloader, dataset_name, model_dir):
    model.eval()
    true_labels, pred_labels = [], []
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            true_labels.extend(labels.cpu().numpy())
            pred_labels.extend(preds.cpu().numpy())
    
    cm = confusion_matrix(true_labels, pred_labels)
    cm_percentage = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100  # Convert to percentages
    labels = list(train_dataset.label_to_index.keys())
    cm_df = pd.DataFrame(cm_percentage, index=labels, columns=labels)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_df, annot=True, fmt=".1f", cmap="Blues", linewidths=0.5)
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title(f"Confusion Matrix ({dataset_name}) - Percentage")
    plt.savefig(f"{model_dir}/confusion_matrix_{dataset_name}_percent.png")
    plt.close()

    cm_df = pd.DataFrame(cm, index=labels, columns=labels)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues", linewidths=0.5)
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title(f"Confusion Matrix ({dataset_name})")
    plt.savefig(f"{model_dir}/confusion_matrix_{dataset_name}_no.png")
    plt.close()

# Run training function with added confusion matrix plotting
def train_model(model_name, learning_rate=0.0003, weight_decay=1e-4, num_epochs=50, patience=10):
    model = initialize_model(model_name, len(train_dataset.label_to_index))
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
    patience_counter = 0

    best_val_acc = 0
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    model_dir = f"saved_models/{model_name}_lr{learning_rate}_wd{weight_decay}_{timestamp}"
    os.makedirs(model_dir, exist_ok=True)
    
    train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []
    
    for epoch in range(num_epochs):
        model.train()
        running_loss, correct_train, total_train = 0.0, 0, 0
        
        for images, labels in train_dataloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct_train += (preds == labels).sum().item()
            total_train += labels.size(0)
        
        train_loss = running_loss / len(train_dataloader)
        train_acc = correct_train / total_train
        
        # Validation Step
        model.eval()
        running_val_loss, correct_val, total_val = 0.0, 0, 0
        
        with torch.no_grad():
            for images, labels in val_dataloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                running_val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                correct_val += (preds == labels).sum().item()
                total_val += labels.size(0)
        
        val_loss = running_val_loss / len(val_dataloader)
        val_acc = correct_val / total_val
        
        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Train Acc = {train_acc:.4f}, Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.4f}")
    
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        scheduler.step()

        # Early Stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), f"{model_dir}/best_model.pth")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered. Best Validation Accuracy:", best_val_acc)
                break
    
    save_plots(train_losses, val_losses, train_accuracies, val_accuracies, model_dir)
    plot_confusion_matrix(model, train_dataloader, "Training Set", model_dir)
    plot_confusion_matrix(model, val_dataloader, "Validation Set", model_dir)
    return model_dir

# Run training
model_dir = train_model("convnext")




Epoch 1: Train Loss = 1.4339, Train Acc = 0.3191, Val Loss = 1.2435, Val Acc = 0.4930
Epoch 2: Train Loss = 1.0348, Train Acc = 0.5638, Val Loss = 1.2071, Val Acc = 0.5352
Epoch 3: Train Loss = 0.9820, Train Acc = 0.5922, Val Loss = 1.2539, Val Acc = 0.3944
Epoch 4: Train Loss = 0.8200, Train Acc = 0.6915, Val Loss = 0.9965, Val Acc = 0.5634
Epoch 5: Train Loss = 0.6368, Train Acc = 0.7482, Val Loss = 1.0345, Val Acc = 0.6056
Epoch 6: Train Loss = 0.5026, Train Acc = 0.8298, Val Loss = 1.3209, Val Acc = 0.5493
Epoch 7: Train Loss = 0.4328, Train Acc = 0.8333, Val Loss = 1.1639, Val Acc = 0.6197
Epoch 8: Train Loss = 0.2395, Train Acc = 0.9326, Val Loss = 1.2754, Val Acc = 0.6479
Epoch 9: Train Loss = 0.2202, Train Acc = 0.9220, Val Loss = 1.7112, Val Acc = 0.6056
Epoch 10: Train Loss = 0.2343, Train Acc = 0.9113, Val Loss = 1.6136, Val Acc = 0.6056
Epoch 11: Train Loss = 0.1195, Train Acc = 0.9645, Val Loss = 1.7290, Val Acc = 0.5915
Epoch 12: Train Loss = 0.0613, Train Acc = 0.9858, V