In [1]:
from split import split_data, is_dir_empty

SRC_MINE_DIR = './dataset/mine/images'
TRAINING_DIR = "./dataset/mine/training/"
VALIDATION_DIR = "./dataset/mine/validation/"
TEST_DIR = "./dataset/mine/test/"

# 检查训练集和验证集目录是否为空
train_dir_empty = is_dir_empty(TRAINING_DIR)
val_dir_empty = is_dir_empty(VALIDATION_DIR)

if train_dir_empty and val_dir_empty:
    print("Training and Validation directories are empty. Starting data split...\n")
    split_data(SRC_MINE_DIR, TRAINING_DIR, VALIDATION_DIR,
               test_dir=None, include_test_split=False, split_ratio=0.8)
else:
    # 提示目录不为空的原因
    if not train_dir_empty:
        print(f"❌ Training directory ({TRAINING_DIR}) is not empty!")
    if not val_dir_empty:
        print(f"❌ Validation directory ({VALIDATION_DIR}) is not empty!")
    print("\nSkip data split. Please empty the directories first if you want to re-run the split.")
    

❌ Training directory (./dataset/mine/training/) is not empty!
❌ Validation directory (./dataset/mine/validation/) is not empty!

Skip data split. Please empty the directories first if you want to re-run the split.


In [2]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
validation_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = ImageFolder(TRAINING_DIR, transform=train_transform)
validation_dataset = ImageFolder(VALIDATION_DIR,transform=validation_transform)

BATCH_SIZE = 64
NUM_WORKERS = 15
PIN_MEMORY = True

train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS, 
    pin_memory=PIN_MEMORY
)

validation_loader = DataLoader(
    validation_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS, 
    pin_memory=PIN_MEMORY
)

In [3]:
import matplotlib.pyplot as plt
import numpy as np
import math

print("=== 数据集详情 ===")
print(f"训练集数量: {len(train_dataset)}")
print(f"验证集数量: {len(validation_dataset)}")
print(f"分类类别: {train_dataset.classes}")
print(f"类别索引: {train_dataset.class_to_idx}")

def show_batch_with_labels(loader, dataset_classes, title="Batch"):
    # 从 DataLoader 获取一个 batch 的数据
    images, labels = next(iter(loader))
    batch_size = len(images)
    
    print(f"\n--- {title} 信息 ---")
    print(f"图片 Tensor 形状: {images.shape}")
    print(f"标签 Tensor 形状: {labels.shape}")
    
    cols = 8
    rows = math.ceil(batch_size / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 2.5, rows * 3))
    fig.suptitle(title, fontsize=16)
    
    axes = axes.flatten() if batch_size > 1 else [axes]
        
    for i in range(len(axes)):
        ax = axes[i]
        if i < batch_size:
            img = images[i].numpy().transpose((1, 2, 0))
            
            ax.imshow(img)
            
            label_name = dataset_classes[labels[i]]
            
            ax.set_xlabel(label_name, fontsize=11, fontweight='bold')
            ax.set_xticks([])
            ax.set_yticks([])
        else:
            ax.axis('off')

    plt.tight_layout()
    plt.subplots_adjust(top=0.92) 
    plt.show()

# show_batch_with_labels(train_loader, train_dataset.classes, "Training")
# show_batch_with_labels(validation_loader, validation_dataset.classes, "Validation")

=== 数据集详情 ===
训练集数量: 4510
验证集数量: 1130
分类类别: ['biotite', 'bornite', 'chrysocolla', 'malachite', 'muscovite', 'pyrite', 'quartz']
类别索引: {'biotite': 0, 'bornite': 1, 'chrysocolla': 2, 'malachite': 3, 'muscovite': 4, 'pyrite': 5, 'quartz': 6}


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from model import MineralCNN, fit_model

# 1. 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 2. 准备数据信息
class_names = train_dataset.classes
num_classes = len(class_names)

# 3. 实例化模型
cnn_model = MineralCNN(num_classes=num_classes).to(device)

# 4. 训练配置
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_model.parameters(), lr=0.001)

# 5. 开始训练 (注意新增加了 class_names 参数)
history = fit_model(
    cnn_model, 
    train_loader, 
    validation_loader, 
    criterion, 
    optimizer, 
    num_epochs=50, # 可以根据需要调整
    device=device,
    class_names=class_names # 用于生成混淆矩阵标签
)

Using device: cuda
Training started. Results will be saved to: runs/train6


                                                                    

Epoch [1/50] | T_Loss: 1.5338 T_Acc: 43.61% | V_Loss: 1.3462 V_Acc: 49.56% | LR: 0.001000


                                                                    

Epoch [2/50] | T_Loss: 1.3006 T_Acc: 53.66% | V_Loss: 1.5110 V_Acc: 46.02% | LR: 0.001000


                                                                    

Epoch [3/50] | T_Loss: 1.2615 T_Acc: 54.94% | V_Loss: 1.3807 V_Acc: 48.58% | LR: 0.001000


                                                                    

Epoch [4/50] | T_Loss: 1.1837 T_Acc: 58.05% | V_Loss: 1.2699 V_Acc: 55.66% | LR: 0.001000


                                                                    

Epoch [5/50] | T_Loss: 1.1320 T_Acc: 59.96% | V_Loss: 1.1698 V_Acc: 58.50% | LR: 0.001000


                                                                    

Epoch [6/50] | T_Loss: 1.0813 T_Acc: 62.37% | V_Loss: 1.5471 V_Acc: 51.15% | LR: 0.001000


                                                                     

Epoch [7/50] | T_Loss: 0.9848 T_Acc: 65.52% | V_Loss: 1.1849 V_Acc: 58.23% | LR: 0.001000


                                                                     

Epoch [8/50] | T_Loss: 0.9277 T_Acc: 67.94% | V_Loss: 1.1865 V_Acc: 60.62% | LR: 0.001000


                                                                     

Epoch [9/50] | T_Loss: 0.8085 T_Acc: 72.66% | V_Loss: 0.9809 V_Acc: 67.17% | LR: 0.001000


                                                                     

Epoch [10/50] | T_Loss: 0.7006 T_Acc: 76.85% | V_Loss: 1.2352 V_Acc: 56.81% | LR: 0.001000


                                                                     

Epoch [11/50] | T_Loss: 0.5914 T_Acc: 80.04% | V_Loss: 0.9630 V_Acc: 68.32% | LR: 0.001000


                                                                     

Epoch [12/50] | T_Loss: 0.5254 T_Acc: 82.26% | V_Loss: 1.2632 V_Acc: 60.09% | LR: 0.001000


                                                                     

Epoch [13/50] | T_Loss: 0.4456 T_Acc: 85.23% | V_Loss: 1.0295 V_Acc: 68.85% | LR: 0.001000


                                                                     

Epoch [14/50] | T_Loss: 0.3855 T_Acc: 87.54% | V_Loss: 0.5797 V_Acc: 82.74% | LR: 0.001000


                                                                     

Epoch [15/50] | T_Loss: 0.3246 T_Acc: 89.29% | V_Loss: 0.7004 V_Acc: 77.08% | LR: 0.001000


                                                                     

Epoch [16/50] | T_Loss: 0.2644 T_Acc: 90.80% | V_Loss: 1.1507 V_Acc: 64.96% | LR: 0.001000


                                                                     

Epoch [17/50] | T_Loss: 0.2664 T_Acc: 91.13% | V_Loss: 0.5473 V_Acc: 85.22% | LR: 0.001000


                                                                     

Epoch [18/50] | T_Loss: 0.1937 T_Acc: 93.77% | V_Loss: 0.5449 V_Acc: 85.66% | LR: 0.001000


                                                                     

Epoch [19/50] | T_Loss: 0.1686 T_Acc: 94.24% | V_Loss: 0.6255 V_Acc: 86.28% | LR: 0.001000


                                                                     

Epoch [20/50] | T_Loss: 0.1682 T_Acc: 94.50% | V_Loss: 0.4985 V_Acc: 87.61% | LR: 0.001000


                                                                     

Epoch [21/50] | T_Loss: 0.1501 T_Acc: 95.03% | V_Loss: 0.4944 V_Acc: 88.85% | LR: 0.001000


                                                                     

Epoch [22/50] | T_Loss: 0.1491 T_Acc: 95.25% | V_Loss: 0.7678 V_Acc: 78.23% | LR: 0.001000


                                                                      

Epoch [23/50] | T_Loss: 0.1214 T_Acc: 96.12% | V_Loss: 0.5008 V_Acc: 88.41% | LR: 0.001000


                                                                      

Epoch [24/50] | T_Loss: 0.0988 T_Acc: 96.76% | V_Loss: 0.6099 V_Acc: 86.81% | LR: 0.001000


                                                                      

Epoch [25/50] | T_Loss: 0.0794 T_Acc: 97.69% | V_Loss: 0.4862 V_Acc: 89.82% | LR: 0.001000


                                                                      

Epoch [26/50] | T_Loss: 0.0819 T_Acc: 97.56% | V_Loss: 0.6341 V_Acc: 85.84% | LR: 0.001000


                                                                      

Epoch [27/50] | T_Loss: 0.0686 T_Acc: 97.83% | V_Loss: 0.5889 V_Acc: 88.23% | LR: 0.001000


                                                                      

Epoch [28/50] | T_Loss: 0.0877 T_Acc: 97.01% | V_Loss: 0.6501 V_Acc: 85.66% | LR: 0.001000


                                                                     

Epoch [29/50] | T_Loss: 0.1195 T_Acc: 95.81% | V_Loss: 0.5671 V_Acc: 87.08% | LR: 0.001000


                                                                     

Epoch [30/50] | T_Loss: 0.1318 T_Acc: 95.43% | V_Loss: 1.0078 V_Acc: 76.19% | LR: 0.001000


                                                                      

Epoch [31/50] | T_Loss: 0.0941 T_Acc: 96.81% | V_Loss: 0.6453 V_Acc: 85.84% | LR: 0.000500


                                                                      

Epoch [32/50] | T_Loss: 0.0600 T_Acc: 97.83% | V_Loss: 0.5118 V_Acc: 88.58% | LR: 0.000500


                                                                      

Epoch [33/50] | T_Loss: 0.0386 T_Acc: 98.58% | V_Loss: 0.4969 V_Acc: 88.94% | LR: 0.000500


                                                                      

Epoch [34/50] | T_Loss: 0.0361 T_Acc: 98.56% | V_Loss: 0.4733 V_Acc: 89.73% | LR: 0.000500


                                                                      

Epoch [35/50] | T_Loss: 0.0364 T_Acc: 98.40% | V_Loss: 0.5104 V_Acc: 89.47% | LR: 0.000500


                                                                      

Epoch [36/50] | T_Loss: 0.0317 T_Acc: 98.71% | V_Loss: 0.5017 V_Acc: 90.27% | LR: 0.000500


                                                                      

Epoch [37/50] | T_Loss: 0.0326 T_Acc: 98.65% | V_Loss: 0.5067 V_Acc: 89.73% | LR: 0.000500


                                                                      

Epoch [38/50] | T_Loss: 0.0316 T_Acc: 98.82% | V_Loss: 0.5093 V_Acc: 89.82% | LR: 0.000500


                                                                      

Epoch [39/50] | T_Loss: 0.0296 T_Acc: 98.60% | V_Loss: 0.5250 V_Acc: 89.47% | LR: 0.000500


                                                                       

Epoch [40/50] | T_Loss: 0.0325 T_Acc: 98.69% | V_Loss: 0.5208 V_Acc: 89.65% | LR: 0.000250


                                                                      

Epoch [41/50] | T_Loss: 0.0273 T_Acc: 98.78% | V_Loss: 0.5099 V_Acc: 89.91% | LR: 0.000250


                                                                      

Epoch [42/50] | T_Loss: 0.0277 T_Acc: 98.71% | V_Loss: 0.5074 V_Acc: 90.09% | LR: 0.000250


                                                                      

Epoch [43/50] | T_Loss: 0.0267 T_Acc: 98.71% | V_Loss: 0.5110 V_Acc: 90.00% | LR: 0.000250


                                                                      

Epoch [44/50] | T_Loss: 0.0271 T_Acc: 98.71% | V_Loss: 0.5104 V_Acc: 90.18% | LR: 0.000250


                                                                      

Epoch [45/50] | T_Loss: 0.0266 T_Acc: 98.76% | V_Loss: 0.5029 V_Acc: 89.91% | LR: 0.000250


                                                                      

Epoch [46/50] | T_Loss: 0.0283 T_Acc: 98.80% | V_Loss: 0.5162 V_Acc: 90.00% | LR: 0.000125


                                                                      

Epoch [47/50] | T_Loss: 0.0260 T_Acc: 98.63% | V_Loss: 0.5074 V_Acc: 90.09% | LR: 0.000125


                                                                      

Epoch [48/50] | T_Loss: 0.0242 T_Acc: 98.89% | V_Loss: 0.5055 V_Acc: 90.35% | LR: 0.000125


                                                                      

Epoch [49/50] | T_Loss: 0.0243 T_Acc: 98.80% | V_Loss: 0.5172 V_Acc: 89.73% | LR: 0.000125


                                                                      

Epoch [50/50] | T_Loss: 0.0264 T_Acc: 98.78% | V_Loss: 0.5065 V_Acc: 90.62% | LR: 0.000125

 Training finished. Generatring reports...
Loading best model for evaluation...
Saving confusion matrix...
Saving prediction samples...
 All done! Check directory: runs/train6


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import importlib
import model
importlib.reload(model) # 重新加载以应用更改
from modelCBAM import MineralCBAMResNet, fit_model

# 1. 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 2. 准备数据信息
class_names = train_dataset.classes
num_classes = len(class_names)

# 3. 实例化创新模型 (CBAM-ResNet)
# 注意这里调用的是 MineralCBAMResNet
cbam_model = MineralCBAMResNet(num_classes=num_classes).to(device)

# 4. 训练配置
criterion = nn.CrossEntropyLoss()
# 学习率可以保持 0.001
optimizer = optim.Adam(cbam_model.parameters(), lr=0.001)

# 5. 开始训练
print(">>> 开始训练 CBAM-ResNet 创新架构 <<<")
# 建议先跑 50 个 Epoch 进行对比
history_cbam = fit_model(
    cbam_model, 
    train_loader, 
    validation_loader, 
    criterion, 
    optimizer, 
    num_epochs=50, 
    device=device,
    class_names=class_names
)

Using device: cuda
>>> 开始训练 CBAM-ResNet 创新架构 <<<
Training started (CBAM-ResNet). Results will be saved to: runs/train_cbam


                                                                    

Epoch [1/50] | T_Loss: 1.3937 T_Acc: 47.85% | V_Loss: 1.6074 V_Acc: 46.37% | LR: 0.001000


                                                                    

Epoch [2/50] | T_Loss: 1.1708 T_Acc: 57.80% | V_Loss: 1.9923 V_Acc: 38.32% | LR: 0.001000


                                                                    

Epoch [3/50] | T_Loss: 1.0325 T_Acc: 63.15% | V_Loss: 1.4878 V_Acc: 41.77% | LR: 0.001000


                                                                     

Epoch [4/50] | T_Loss: 0.8609 T_Acc: 70.00% | V_Loss: 2.4926 V_Acc: 34.78% | LR: 0.001000


                                                                     

Epoch [5/50] | T_Loss: 0.7290 T_Acc: 75.65% | V_Loss: 1.8637 V_Acc: 45.84% | LR: 0.001000


                                                                     

Epoch [6/50] | T_Loss: 0.5603 T_Acc: 81.37% | V_Loss: 1.0500 V_Acc: 68.23% | LR: 0.001000


                                                                     

Epoch [7/50] | T_Loss: 0.4633 T_Acc: 84.75% | V_Loss: 0.6254 V_Acc: 81.95% | LR: 0.001000


                                                                     

Epoch [8/50] | T_Loss: 0.3716 T_Acc: 88.03% | V_Loss: 0.8395 V_Acc: 74.78% | LR: 0.001000


                                                                     

Epoch [9/50] | T_Loss: 0.3082 T_Acc: 89.49% | V_Loss: 0.6973 V_Acc: 79.82% | LR: 0.001000


                                                                     

Epoch [10/50] | T_Loss: 0.2475 T_Acc: 91.46% | V_Loss: 0.9393 V_Acc: 70.88% | LR: 0.001000


                                                                     

Epoch [11/50] | T_Loss: 0.2148 T_Acc: 92.97% | V_Loss: 0.9597 V_Acc: 74.42% | LR: 0.001000


                                                                     

Epoch [12/50] | T_Loss: 0.1613 T_Acc: 94.63% | V_Loss: 0.9943 V_Acc: 76.11% | LR: 0.000500


                                                                     

Epoch [13/50] | T_Loss: 0.1014 T_Acc: 96.83% | V_Loss: 0.4074 V_Acc: 90.35% | LR: 0.000500


                                                                      

Epoch [14/50] | T_Loss: 0.0596 T_Acc: 98.12% | V_Loss: 0.4230 V_Acc: 90.18% | LR: 0.000500


                                                                      

Epoch [15/50] | T_Loss: 0.0486 T_Acc: 98.38% | V_Loss: 0.4612 V_Acc: 89.12% | LR: 0.000500


                                                                      

Epoch [16/50] | T_Loss: 0.0457 T_Acc: 98.63% | V_Loss: 0.4364 V_Acc: 90.18% | LR: 0.000500


                                                                      

Epoch [17/50] | T_Loss: 0.0500 T_Acc: 98.12% | V_Loss: 0.4445 V_Acc: 90.18% | LR: 0.000500


                                                                      

Epoch [18/50] | T_Loss: 0.0468 T_Acc: 98.18% | V_Loss: 0.4561 V_Acc: 89.91% | LR: 0.000250


                                                                      

Epoch [19/50] | T_Loss: 0.0394 T_Acc: 98.49% | V_Loss: 0.4555 V_Acc: 90.35% | LR: 0.000250


                                                                      

Epoch [20/50] | T_Loss: 0.0341 T_Acc: 98.56% | V_Loss: 0.4281 V_Acc: 90.71% | LR: 0.000250


                                                                      

Epoch [21/50] | T_Loss: 0.0380 T_Acc: 98.47% | V_Loss: 0.4478 V_Acc: 90.44% | LR: 0.000250


                                                                      

Epoch [22/50] | T_Loss: 0.0309 T_Acc: 98.80% | V_Loss: 0.4313 V_Acc: 90.97% | LR: 0.000250


                                                                      

Epoch [23/50] | T_Loss: 0.0308 T_Acc: 98.58% | V_Loss: 0.4400 V_Acc: 90.44% | LR: 0.000125


                                                                      

Epoch [24/50] | T_Loss: 0.0281 T_Acc: 98.67% | V_Loss: 0.4332 V_Acc: 90.44% | LR: 0.000125


                                                                      

Epoch [25/50] | T_Loss: 0.0269 T_Acc: 98.87% | V_Loss: 0.4437 V_Acc: 90.71% | LR: 0.000125


                                                                      

Epoch [26/50] | T_Loss: 0.0264 T_Acc: 98.89% | V_Loss: 0.4540 V_Acc: 90.71% | LR: 0.000125


                                                                      

Epoch [27/50] | T_Loss: 0.0279 T_Acc: 98.80% | V_Loss: 0.4401 V_Acc: 90.53% | LR: 0.000125


                                                                      

Epoch [28/50] | T_Loss: 0.0268 T_Acc: 98.87% | V_Loss: 0.4377 V_Acc: 90.09% | LR: 0.000063


                                                                      

Epoch [29/50] | T_Loss: 0.0261 T_Acc: 98.87% | V_Loss: 0.4615 V_Acc: 90.80% | LR: 0.000063


                                                                      

Epoch [30/50] | T_Loss: 0.0258 T_Acc: 98.82% | V_Loss: 0.4613 V_Acc: 90.88% | LR: 0.000063


                                                                      

Epoch [31/50] | T_Loss: 0.0237 T_Acc: 98.96% | V_Loss: 0.4393 V_Acc: 90.62% | LR: 0.000063


                                                                      

Epoch [32/50] | T_Loss: 0.0256 T_Acc: 98.78% | V_Loss: 0.4433 V_Acc: 90.35% | LR: 0.000063


                                                                      

Epoch [33/50] | T_Loss: 0.0263 T_Acc: 98.80% | V_Loss: 0.4407 V_Acc: 90.44% | LR: 0.000031


                                                                      

Epoch [34/50] | T_Loss: 0.0242 T_Acc: 98.85% | V_Loss: 0.4554 V_Acc: 90.53% | LR: 0.000031


                                                                      

Epoch [35/50] | T_Loss: 0.0235 T_Acc: 98.94% | V_Loss: 0.4412 V_Acc: 90.71% | LR: 0.000031


                                                                      

Epoch [36/50] | T_Loss: 0.0247 T_Acc: 98.98% | V_Loss: 0.4561 V_Acc: 90.35% | LR: 0.000031


                                                                      

Epoch [37/50] | T_Loss: 0.0240 T_Acc: 98.78% | V_Loss: 0.4520 V_Acc: 90.62% | LR: 0.000031


                                                                      

Epoch [38/50] | T_Loss: 0.0255 T_Acc: 98.80% | V_Loss: 0.4449 V_Acc: 90.53% | LR: 0.000016


                                                                      

Epoch [39/50] | T_Loss: 0.0250 T_Acc: 98.96% | V_Loss: 0.4539 V_Acc: 90.27% | LR: 0.000016


                                                                      

Epoch [40/50] | T_Loss: 0.0238 T_Acc: 98.87% | V_Loss: 0.4652 V_Acc: 90.62% | LR: 0.000016


                                                                      

Epoch [41/50] | T_Loss: 0.0240 T_Acc: 98.85% | V_Loss: 0.4568 V_Acc: 90.44% | LR: 0.000016


                                                                      

Epoch [42/50] | T_Loss: 0.0233 T_Acc: 99.05% | V_Loss: 0.4618 V_Acc: 90.71% | LR: 0.000016


                                                                      

Epoch [43/50] | T_Loss: 0.0239 T_Acc: 98.89% | V_Loss: 0.4538 V_Acc: 90.62% | LR: 0.000008


                                                                      

Epoch [44/50] | T_Loss: 0.0236 T_Acc: 98.87% | V_Loss: 0.4532 V_Acc: 90.80% | LR: 0.000008


                                                                       

Epoch [45/50] | T_Loss: 0.0226 T_Acc: 99.05% | V_Loss: 0.4532 V_Acc: 90.80% | LR: 0.000008


                                                                      

Epoch [46/50] | T_Loss: 0.0226 T_Acc: 99.00% | V_Loss: 0.4585 V_Acc: 90.62% | LR: 0.000008


                                                                      

Epoch [47/50] | T_Loss: 0.0243 T_Acc: 98.80% | V_Loss: 0.4527 V_Acc: 90.53% | LR: 0.000008


                                                                      

Epoch [48/50] | T_Loss: 0.0235 T_Acc: 99.00% | V_Loss: 0.4684 V_Acc: 90.71% | LR: 0.000004


                                                                      

Epoch [49/50] | T_Loss: 0.0231 T_Acc: 98.89% | V_Loss: 0.4506 V_Acc: 90.71% | LR: 0.000004


                                                                      

Epoch [50/50] | T_Loss: 0.0233 T_Acc: 98.89% | V_Loss: 0.4568 V_Acc: 90.53% | LR: 0.000004

 Training finished. Generatring reports...
Loading best model for evaluation...
Saving confusion matrix...
 All done! Check directory: runs/train_cbam
