In [27]:
import visdom

import torch
import torch.nn as nn
from torchvision.models import vit_b_16, ViT_B_16_Weights

class ViTSpineClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super(ViTSpineClassifier, self).__init__()

        # 1. Load the pre-trained ViT-B/16
        weights = ViT_B_16_Weights.IMAGENET1K_V1
        self.vit = vit_b_16(weights=weights)

        # 2. Freeze all ViT parameters
        for param in self.vit.parameters():
            param.requires_grad = False

        # 3. Replace the classification head
        in_features = self.vit.heads.head.in_features  # 默认 768
        self.vit.heads.head = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.vit(x)


In [39]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
from PIL import Image
from torch.nn import CrossEntropyLoss
from torch import optim
import pandas as pd
import os
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os

BATCH_SIZE = 32


# ==== path ====
img_dir = 'Dataset Binary'
label_file = 'Dataset_Labels.xlsx'

# ==== read label ====
df = pd.read_excel(label_file)
df.columns = ['Spine_Name', 'Spine_Label']

#To facilitate indexing, we convert the DataFrame to a dict
img_dir = 'Dataset Binary'
train_csv = 'spine_train_split.csv'
test_csv  = 'spine_test_split.csv'

train_df = pd.read_csv(train_csv)
val_df   = pd.read_csv(test_csv)

data_transform = {
    "train": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomRotation(10),
        transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
    ]),
    "val": transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
}

class BinarySpineDataset(Dataset):
    def __init__(self, dataframe, root_dir, transform=None):
        self.data = dataframe.reset_index(drop=True)  # 接收 DataFrame 不再读取 csv
        self.root_dir = root_dir
        self.transform = transform

        self.label_map = {
            "Mushroom": 0,
            "Stubby": 1,
            "Thin": 2
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        img_path = os.path.join(self.root_dir, row['Spine_Name'])
        image = Image.open(img_path).convert("RGB")

        label_str = row['Spine_Label']
        label = self.label_map[label_str]   # ← Here convert the string to 0/1/2

        if self.transform:
            image = self.transform(image)

        return image, label


train_dataset = BinarySpineDataset(train_df, img_dir, transform=data_transform["train"])
val_dataset   = BinarySpineDataset(val_df,   img_dir, transform=data_transform["val"])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False)

In [40]:
import torch
import visdom
from PIL import Image
from torch.nn import CrossEntropyLoss
from torch import optim

BATCH_SIZE = 32
EPOCH = 100                            # Total number of training rounds
save_path = "./Spine_ViT.pth"    # The location for saving model weight parameters

In [41]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")                         # 创建GPU运算环境
print(device)

cpu


In [44]:
import torch
import torch.nn as nn
import visdom

# ---------------------
# 1. Visdom Initialization window
# ---------------------
viz = visdom.Visdom(env="spine_exp")
viz.line([0], [0], win="test_acc", opts=dict(title="Validation Accuracy"))

# ---------------------
# 2. evalute function
# ---------------------
def evalute(model, loader):
    correct = 0
    total = len(loader.dataset)

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).float().sum().item()

    return correct / total


# ---------------------
# 3. Initialize the model and optimizer
# ---------------------

criterion = nn.CrossEntropyLoss()
net = ViTSpineClassifier(num_classes=3).to(device)
net.load_state_dict(torch.load("./Spine_ViT.pth", map_location=device))
net.eval()

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, net.parameters()),
    lr=1e-4
)

# ---------------------
# 4. Define the training state variable# ---------------------
best_acc = 0
best_epoch = 0
global_step = 0


# ---------------------
# 5. cycle training
# ---------------------
for epoch in range(EPOCH):
    net.train()

    for step, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = net(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        global_step += 1

    # Verification for each epoch
    val_acc = evalute(net, val_loader)
    print(f"epoch {epoch+1}/{EPOCH}, val_acc = {val_acc:.4f}")

    # record to visdom
    viz.line([val_acc], [epoch+1], win="test_acc", update="append")

    # keep best model
    if val_acc > best_acc:
        best_acc = val_acc
        best_epoch = epoch + 1
        torch.save(net.state_dict(), save_path)
        print(f"  >>> Best model updated at epoch {best_epoch}! Acc = {best_acc:.4f}")

print("Training Finished!")
print(f"Best epoch = {best_epoch}, best_acc = {best_acc:.4f}")


Setting up a new session...
  net.load_state_dict(torch.load("./Spine_ViT.pth", map_location=device))


epoch 1/100, val_acc = 0.9457
  >>> Best model updated at epoch 1! Acc = 0.9457
epoch 2/100, val_acc = 0.9348
epoch 3/100, val_acc = 0.9348
epoch 4/100, val_acc = 0.9457
epoch 5/100, val_acc = 0.9457
epoch 6/100, val_acc = 0.9457
epoch 7/100, val_acc = 0.9457
epoch 8/100, val_acc = 0.9457
epoch 9/100, val_acc = 0.9457
epoch 10/100, val_acc = 0.9457
epoch 11/100, val_acc = 0.9457
epoch 12/100, val_acc = 0.9457
epoch 13/100, val_acc = 0.9457
epoch 14/100, val_acc = 0.9457
epoch 15/100, val_acc = 0.9565
  >>> Best model updated at epoch 15! Acc = 0.9565
epoch 16/100, val_acc = 0.9457
epoch 17/100, val_acc = 0.9457
epoch 18/100, val_acc = 0.9457
epoch 19/100, val_acc = 0.9457
epoch 20/100, val_acc = 0.9457
epoch 21/100, val_acc = 0.9457
epoch 22/100, val_acc = 0.9457
epoch 23/100, val_acc = 0.9457
epoch 24/100, val_acc = 0.9239
epoch 25/100, val_acc = 0.9239
epoch 26/100, val_acc = 0.9457
epoch 27/100, val_acc = 0.9457
epoch 28/100, val_acc = 0.9457
epoch 29/100, val_acc = 0.9457
epoch 30/