In [None]:
# 1. 环境设置和GPU配置
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import SwinForImageClassification, SwinConfig
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from tqdm import tqdm
import json
import glob

# 检查GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
if torch.cuda.is_available():
    print(f"GPU名称: {torch.cuda.get_device_name(0)}")
    print(f"GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)


In [None]:
# 2. 数据路径和参数设置
DATA_PATH = "/remote-home/cs_acmis_hby/Galaxy-Zoo-Classification/Contrast_experiment/Modern CNNs/Galaxy-Classification-Using-CNN/output_dataset"

# 实验参数
IMG_SIZE = 224
BATCH_SIZE = 16  # Swin Transformer推荐较小batch size
NUM_CLASSES = 8
EPOCHS = 50  # 根据论文表格中的设置
LEARNING_RATE = 5e-5  # 预训练模型推荐学习率
WEIGHT_DECAY = 0.05

# 类别名称
CLASS_NAMES = [
    'barred_spirals',
    'cigar_shaped_elliptical', 
    'edge_on',
    'in_between_elliptical',
    'irregular',
    'merger',
    'round_elliptical',
    'unbarred_spirals'
]

print(f"🎯 实验配置:")
print(f"   数据路径: {DATA_PATH}")
print(f"   图像大小: {IMG_SIZE}x{IMG_SIZE}")
print(f"   批次大小: {BATCH_SIZE}")
print(f"   类别数: {NUM_CLASSES}")
print(f"   学习率: {LEARNING_RATE}")
print(f"   最大轮数: {EPOCHS}")


In [None]:
# 3. 数据集类和数据加载
class GalaxyDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.images = []
        self.labels = []
        
        # 加载所有图像路径和标签
        for class_idx, class_name in enumerate(CLASS_NAMES):
            class_dir = os.path.join(data_dir, class_name)
            if os.path.exists(class_dir):
                for img_file in glob.glob(os.path.join(class_dir, '*.jpg')):
                    self.images.append(img_file)
                    self.labels.append(class_idx)
        
        print(f"📁 {data_dir.split('/')[-1]} 数据加载:")
        print(f"   总样本数: {len(self.images)}")
        
        # 统计各类别样本数
        for class_idx, class_name in enumerate(CLASS_NAMES):
            count = self.labels.count(class_idx)
            print(f"   {class_name}: {count}")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        
        # 加载图像
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# 数据变换
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomRotation(180),  # 天文图像可任意旋转
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
])

val_test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 创建数据集
print("🔄 创建数据集...")
train_dataset = GalaxyDataset(os.path.join(DATA_PATH, 'train'), transform=train_transform)
val_dataset = GalaxyDataset(os.path.join(DATA_PATH, 'val'), transform=val_test_transform)
test_dataset = GalaxyDataset(os.path.join(DATA_PATH, 'test'), transform=val_test_transform)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

print(f"\n✅ 数据加载器创建完成:")
print(f"   训练批次: {len(train_loader)}")
print(f"   验证批次: {len(val_loader)}")
print(f"   测试批次: {len(test_loader)}")


In [None]:
# 4. 加载官方预训练Swin Transformer模型
print("🔧 加载官方Swin Transformer模型...")

# 使用Hugging Face官方预训练模型
model = SwinForImageClassification.from_pretrained(
    "microsoft/swin-base-patch4-window7-224",
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True  # 忽略分类头大小不匹配
).to(device)

# 模型信息
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✅ 模型加载成功!")
print(f"   模型: microsoft/swin-base-patch4-window7-224")
print(f"   总参数量: {total_params:,}")
print(f"   可训练参数: {trainable_params:,}")
print(f"   预训练: ImageNet-22K → ImageNet-1K")
print(f"   论文引用: Liu et al. (2021) - Swin Transformer V1")

# 设置优化器和调度器
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-7)
criterion = nn.CrossEntropyLoss()

print(f"\n📋 训练配置:")
print(f"   优化器: AdamW")
print(f"   学习率调度: Cosine Annealing")
print(f"   损失函数: CrossEntropyLoss")


In [None]:
# 5. 训练和验证函数
def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(train_loader, desc="Training")
    
    for batch_idx, (inputs, targets) in enumerate(progress_bar):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs).logits
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        progress_bar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    return running_loss/len(train_loader), 100.*correct/total

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs).logits
            loss = criterion(outputs, targets)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    return running_loss/len(val_loader), 100.*correct/total

print("✅ 训练函数定义完成!")


In [None]:
# 6. 开始训练
print(f"🚀 开始训练官方Swin Transformer...")
print("=" * 60)

best_val_acc = 0.0
patience = 10
patience_counter = 0
train_losses = []
train_accs = []
val_losses = []
val_accs = []

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}:")
    
    # 训练
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    
    # 验证
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    
    # 更新学习率
    scheduler.step()
    
    # 记录历史
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    print(f"LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    # 保存最佳模型
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), 'swin_transformer_official_best.pth')
        print(f"✅ 新的最佳验证准确率: {best_val_acc:.2f}%")
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= patience:
        print(f"⏹️ Early stopping after {patience} epochs without improvement")
        break

print(f"\n✅ 训练完成!")
print(f"   最佳验证准确率: {best_val_acc:.2f}%")
print(f"   实际训练轮数: {epoch + 1}")


In [None]:
# 7. 测试集评估
print("📊 测试集评估...")

# 加载最佳模型
model.load_state_dict(torch.load('swin_transformer_classification_best.pth'))
model.eval()

test_loss = 0.0
correct = 0
total = 0
all_predictions = []
all_labels = []

with torch.no_grad():
    for inputs, targets in tqdm(test_loader, desc="Testing"):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs).logits
        loss = criterion(outputs, targets)
        
        test_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(targets.cpu().numpy())

test_accuracy = 100. * correct / total
test_loss = test_loss / len(test_loader)

print(f"\n🎯 测试结果:")
print(f"   测试准确率: {test_accuracy:.2f}%")
print(f"   测试损失: {test_loss:.4f}")

# 计算F1分数
f1_macro = f1_score(all_labels, all_predictions, average='macro')
f1_weighted = f1_score(all_labels, all_predictions, average='weighted')

print(f"   Macro F1-Score: {f1_macro:.4f}")
print(f"   Weighted F1-Score: {f1_weighted:.4f}")

# 详细分类报告
print(f"\n📋 详细分类报告:")
print(classification_report(all_labels, all_predictions, 
                          target_names=CLASS_NAMES, digits=4))
