In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
import pandas as pd
import os
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# 配置参数
CFG = {
	'img_size': 224,
	'batch_size': 16,      # 根据8GB显存调整
	'num_epochs': 30,
	'lr': 3e-5,
	'num_workers': 0,      # Windows系统需要设置为0
	'device': 'cuda' if torch.cuda.is_available() else 'cpu',
	'num_classes': 881,
	'seed': 42
}

# 设置随机种子
torch.manual_seed(CFG['seed'])

# 数据预处理
train_transform = transforms.Compose([
	transforms.RandomResizedCrop(CFG['img_size']),
	transforms.RandomHorizontalFlip(),
	transforms.RandomRotation(15),
	transforms.ToTensor(),
	transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
	transforms.Resize(CFG['img_size']),
	transforms.CenterCrop(CFG['img_size']),
	transforms.ToTensor(),
	transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 自定义数据集类
# 修正后的MedicineDataset类
class MedicineDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)  # 确保索引正确
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'path']
        image = Image.open(img_path).convert('RGB')
        label = self.df.loc[idx, 'label']
        
        if self.transform:
            image = self.transform(image)
            
        # 确保标签为torch.int64类型
        return image, torch.as_tensor(label, dtype=torch.long)

# 加载数据
def load_data(csv_path, root_dir):
    # 改进的CSV读取方式
    df = pd.read_csv(csv_path, header=0)  # 使用第一行作为列头
    df = df.rename(columns={'ID': 'path', 'Label': 'label'})  # 规范列名
    
    # 数据类型转换
    convert_dict = {'path': str, 'label': int}
    df = df.astype(convert_dict)
    
    # 路径修正（根据实际文件结构可能需要调整）
    df['path'] = df['path'].apply(lambda x: os.path.join(root_dir, x))
    
    # 验证数据有效性
    print("正在验证数据完整性...")
    valid_samples = []
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        if os.path.exists(row['path']):
            valid_samples.append(idx)
        else:
            print(f"警告：缺失文件 {row['path']}")
    df = df.loc[valid_samples]
    
    # 过滤无效类别（每个类别至少需要2个样本）
    label_counts = df['label'].value_counts()
    valid_labels = label_counts[label_counts >= 2].index
    df = df[df['label'].isin(valid_labels)]
    
    print(f"有效数据量: {len(df)}")
    print(f"有效类别数: {df['label'].nunique()}")
    
    # 转换为0-based标签
    df['label'] = df['label'] - 1
    
    # 分层划分数据集
    train_df, val_df = train_test_split(
        df,
        test_size=0.2,
        stratify=df['label'],
        random_state=CFG['seed']
    )
    return train_df, val_df
# 创建模型
def create_model():
	model = ViTForImageClassification.from_pretrained(
		'google/vit-base-patch16-224-in21k',
		num_labels=CFG['num_classes'],
		ignore_mismatched_sizes=True
	)
	return model.to(CFG['device'])

# 训练函数
def train_model(model, train_loader, val_loader):
	optimizer = torch.optim.AdamW(model.parameters(), lr=CFG['lr'])
	criterion = nn.CrossEntropyLoss()
	
	best_acc = 0.0
	for epoch in range(CFG['num_epochs']):
		# 训练阶段
		model.train()
		train_loss = 0.0
		progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1} [Train]')
		for images, labels in progress_bar:
			images = images.to(CFG['device'])
			labels = labels.to(CFG['device'])
			
			optimizer.zero_grad()
			
			outputs = model(images)
			loss = criterion(outputs.logits, labels)
			
			loss.backward()
			optimizer.step()
			
			train_loss += loss.item() * images.size(0)
			progress_bar.set_postfix(loss=loss.item())
		
		# 验证阶段
		model.eval()
		val_loss = 0.0
		correct = 0
		total = 0
		with torch.no_grad():
			for images, labels in tqdm(val_loader, desc=f'Epoch {epoch+1} [Val]'):
				images = images.to(CFG['device'])
				labels = labels.to(CFG['device'])
				
				outputs = model(images)
				loss = criterion(outputs.logits, labels)
				
				val_loss += loss.item() * images.size(0)
				_, predicted = torch.max(outputs.logits, 1)
				total += labels.size(0)
				correct += (predicted == labels).sum().item()
		
		# 打印统计信息
		train_loss = train_loss / len(train_loader.dataset)
		val_loss = val_loss / len(val_loader.dataset)
		val_acc = 100 * correct / total
		print(f'Epoch {epoch+1}/{CFG["num_epochs"]}')
		print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
		
		# 保存最佳模型
		if val_acc > best_acc:
			best_acc = val_acc
			torch.save(model.state_dict(), 'best_vit_model.pth')
	
	print(f'Best Validation Accuracy: {best_acc:.2f}%')

# 生成预测结果
def generate_predictions(model, test_dir, transform):
	model.eval()
	predictions = []
	
	# 获取测试图片列表
	test_images = [os.path.join(test_dir, f) for f in os.listdir(test_dir) 
				  if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
	
	# 创建测试数据加载器
	class TestDataset(Dataset):
		def __init__(self, image_paths, transform=None):
			self.image_paths = image_paths
			self.transform = transform
			
		def __len__(self):
			return len(self.image_paths)
		
		def __getitem__(self, idx):
			image = Image.open(self.image_paths[idx]).convert('RGB')
			if self.transform:
				image = self.transform(image)
			return image, os.path.basename(self.image_paths[idx])
	
	test_dataset = TestDataset(test_images, transform=transform)
	test_loader = DataLoader(test_dataset, batch_size=CFG['batch_size'], shuffle=False)
	
	# 进行预测
	with torch.no_grad():
		for images, filenames in tqdm(test_loader, desc='Predicting'):
			images = images.to(CFG['device'])
			outputs = model(images)
			_, preds = torch.max(outputs.logits, 1)
			
			for fn, pred in zip(filenames, preds.cpu().numpy()):
				predictions.append((fn, pred + 1))  # 转换回1-based
	
	# 保存结果
	with open('submission.txt', 'w') as f:
		for fn, pred in predictions:
			f.write(f'{fn}\t{pred}\n')

if __name__ == '__main__':
	# 数据准备
	train_df, val_df = load_data('chinese-medicine-image/train_labels.csv', 'chinese-medicine-image')
	
	# 创建数据集
	train_dataset = MedicineDataset(train_df, train_transform)
	val_dataset = MedicineDataset(val_df, val_transform)
	# 创建数据加载器
	train_loader = DataLoader(train_dataset, batch_size=CFG['batch_size'], shuffle=True)
	val_loader = DataLoader(val_dataset, batch_size=CFG['batch_size'], shuffle=False)
	
	# 初始化模型
	model = create_model()
	
	# 开始训练
	train_model(model, train_loader, val_loader)
	
	# 加载最佳模型进行预测
	model.load_state_dict(torch.load('best_vit_model.pth'))
	
	# 生成提交文件（假设测试图片在test目录下）
	generate_predictions(model, 'chinese-medicine-image/test', val_transform)


正在验证数据完整性...


100%|██████████| 167017/167017 [00:10<00:00, 16399.91it/s]


有效数据量: 167017
有效类别数: 879


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1 [Train]: 100%|██████████| 8351/8351 [1:12:58<00:00,  1.91it/s, loss=3.33]
Epoch 1 [Val]: 100%|██████████| 2088/2088 [13:00<00:00,  2.68it/s]


Epoch 1/30
Train Loss: 5.3754 | Val Loss: 3.7197 | Val Acc: 40.38%


Epoch 2 [Train]: 100%|██████████| 8351/8351 [1:07:48<00:00,  2.05it/s, loss=3.41]
Epoch 2 [Val]: 100%|██████████| 2088/2088 [13:09<00:00,  2.64it/s]


Epoch 2/30
Train Loss: 2.9090 | Val Loss: 2.3762 | Val Acc: 54.54%


Epoch 3 [Train]: 100%|██████████| 8351/8351 [1:08:01<00:00,  2.05it/s, loss=3.04] 
Epoch 3 [Val]: 100%|██████████| 2088/2088 [13:16<00:00,  2.62it/s]


Epoch 3/30
Train Loss: 2.1204 | Val Loss: 2.0144 | Val Acc: 59.57%


Epoch 4 [Train]: 100%|██████████| 8351/8351 [1:09:20<00:00,  2.01it/s, loss=2.94] 
Epoch 4 [Val]: 100%|██████████| 2088/2088 [13:18<00:00,  2.61it/s]


Epoch 4/30
Train Loss: 1.7865 | Val Loss: 1.8679 | Val Acc: 61.91%


Epoch 5 [Train]: 100%|██████████| 8351/8351 [1:08:32<00:00,  2.03it/s, loss=1.59] 
Epoch 5 [Val]: 100%|██████████| 2088/2088 [13:19<00:00,  2.61it/s]


Epoch 5/30
Train Loss: 1.5633 | Val Loss: 1.7679 | Val Acc: 63.35%


Epoch 6 [Train]: 100%|██████████| 8351/8351 [1:08:33<00:00,  2.03it/s, loss=2.19] 
Epoch 6 [Val]: 100%|██████████| 2088/2088 [13:09<00:00,  2.64it/s]


Epoch 6/30
Train Loss: 1.3927 | Val Loss: 1.7139 | Val Acc: 64.18%


Epoch 7 [Train]: 100%|██████████| 8351/8351 [1:07:32<00:00,  2.06it/s, loss=1.71] 
Epoch 7 [Val]: 100%|██████████| 2088/2088 [13:14<00:00,  2.63it/s]


Epoch 7/30
Train Loss: 1.2440 | Val Loss: 1.7077 | Val Acc: 64.43%


Epoch 8 [Train]: 100%|██████████| 8351/8351 [1:07:23<00:00,  2.07it/s, loss=1.22] 
Epoch 8 [Val]: 100%|██████████| 2088/2088 [13:16<00:00,  2.62it/s]


Epoch 8/30
Train Loss: 1.1178 | Val Loss: 1.6964 | Val Acc: 64.71%


Epoch 9 [Train]: 100%|██████████| 8351/8351 [1:07:17<00:00,  2.07it/s, loss=0.781]
Epoch 9 [Val]: 100%|██████████| 2088/2088 [13:18<00:00,  2.61it/s]


Epoch 9/30
Train Loss: 1.0028 | Val Loss: 1.6980 | Val Acc: 64.76%


Epoch 10 [Train]: 100%|██████████| 8351/8351 [1:06:40<00:00,  2.09it/s, loss=0.429] 
Epoch 10 [Val]: 100%|██████████| 2088/2088 [13:18<00:00,  2.61it/s]


Epoch 10/30
Train Loss: 0.9044 | Val Loss: 1.7248 | Val Acc: 64.38%


Epoch 11 [Train]: 100%|██████████| 8351/8351 [1:05:58<00:00,  2.11it/s, loss=0.546] 
Epoch 11 [Val]: 100%|██████████| 2088/2088 [13:17<00:00,  2.62it/s]


Epoch 11/30
Train Loss: 0.8157 | Val Loss: 1.7397 | Val Acc: 64.32%


Epoch 12 [Train]: 100%|██████████| 8351/8351 [1:07:10<00:00,  2.07it/s, loss=1.08]  
Epoch 12 [Val]: 100%|██████████| 2088/2088 [13:30<00:00,  2.58it/s]


Epoch 12/30
Train Loss: 0.7308 | Val Loss: 1.7589 | Val Acc: 64.43%


Epoch 13 [Train]: 100%|██████████| 8351/8351 [1:06:59<00:00,  2.08it/s, loss=0.683] 
Epoch 13 [Val]: 100%|██████████| 2088/2088 [13:33<00:00,  2.57it/s]


Epoch 13/30
Train Loss: 0.6676 | Val Loss: 1.8086 | Val Acc: 63.73%


Epoch 14 [Train]: 100%|██████████| 8351/8351 [1:06:38<00:00,  2.09it/s, loss=0.334] 
Epoch 14 [Val]: 100%|██████████| 2088/2088 [13:32<00:00,  2.57it/s]


Epoch 14/30
Train Loss: 0.6154 | Val Loss: 1.7935 | Val Acc: 64.25%


Epoch 15 [Train]: 100%|██████████| 8351/8351 [1:06:35<00:00,  2.09it/s, loss=0.505] 
Epoch 15 [Val]: 100%|██████████| 2088/2088 [13:28<00:00,  2.58it/s]


Epoch 15/30
Train Loss: 0.5729 | Val Loss: 1.8402 | Val Acc: 63.83%


Epoch 16 [Train]: 100%|██████████| 8351/8351 [1:07:51<00:00,  2.05it/s, loss=0.901] 
Epoch 16 [Val]: 100%|██████████| 2088/2088 [13:24<00:00,  2.60it/s]


Epoch 16/30
Train Loss: 0.5275 | Val Loss: 1.8616 | Val Acc: 63.91%


Epoch 17 [Train]: 100%|██████████| 8351/8351 [1:07:10<00:00,  2.07it/s, loss=0.224] 
Epoch 17 [Val]: 100%|██████████| 2088/2088 [13:28<00:00,  2.58it/s]


Epoch 17/30
Train Loss: 0.5000 | Val Loss: 1.8963 | Val Acc: 63.26%


Epoch 18 [Train]: 100%|██████████| 8351/8351 [1:07:06<00:00,  2.07it/s, loss=0.26]  
Epoch 18 [Val]: 100%|██████████| 2088/2088 [13:35<00:00,  2.56it/s]


Epoch 18/30
Train Loss: 0.4725 | Val Loss: 1.9020 | Val Acc: 63.62%


Epoch 19 [Train]: 100%|██████████| 8351/8351 [1:07:19<00:00,  2.07it/s, loss=0.64]  
Epoch 19 [Val]: 100%|██████████| 2088/2088 [13:30<00:00,  2.58it/s]


Epoch 19/30
Train Loss: 0.4524 | Val Loss: 1.9451 | Val Acc: 63.13%


Epoch 20 [Train]: 100%|██████████| 8351/8351 [1:08:22<00:00,  2.04it/s, loss=0.127] 
Epoch 20 [Val]: 100%|██████████| 2088/2088 [13:48<00:00,  2.52it/s]


Epoch 20/30
Train Loss: 0.4368 | Val Loss: 1.9369 | Val Acc: 63.45%


Epoch 21 [Train]: 100%|██████████| 8351/8351 [1:08:38<00:00,  2.03it/s, loss=1.01]  
Epoch 21 [Val]: 100%|██████████| 2088/2088 [13:46<00:00,  2.53it/s]


Epoch 21/30
Train Loss: 0.4205 | Val Loss: 1.9505 | Val Acc: 63.76%


Epoch 22 [Train]: 100%|██████████| 8351/8351 [1:08:39<00:00,  2.03it/s, loss=0.308] 
Epoch 22 [Val]: 100%|██████████| 2088/2088 [13:36<00:00,  2.56it/s]


Epoch 22/30
Train Loss: 0.4089 | Val Loss: 1.9811 | Val Acc: 63.24%


Epoch 23 [Train]: 100%|██████████| 8351/8351 [1:10:40<00:00,  1.97it/s, loss=0.237] 
Epoch 23 [Val]: 100%|██████████| 2088/2088 [13:41<00:00,  2.54it/s]


Epoch 23/30
Train Loss: 0.4025 | Val Loss: 1.9973 | Val Acc: 63.28%


Epoch 24 [Train]: 100%|██████████| 8351/8351 [1:09:14<00:00,  2.01it/s, loss=0.512]  
Epoch 24 [Val]: 100%|██████████| 2088/2088 [13:49<00:00,  2.52it/s]


Epoch 24/30
Train Loss: 0.3810 | Val Loss: 2.0185 | Val Acc: 62.95%


Epoch 25 [Train]: 100%|██████████| 8351/8351 [1:10:00<00:00,  1.99it/s, loss=0.0655] 
Epoch 25 [Val]: 100%|██████████| 2088/2088 [13:37<00:00,  2.56it/s]


Epoch 25/30
Train Loss: 0.3686 | Val Loss: 2.0183 | Val Acc: 63.13%


Epoch 26 [Train]: 100%|██████████| 8351/8351 [1:09:58<00:00,  1.99it/s, loss=0.334]  
Epoch 26 [Val]: 100%|██████████| 2088/2088 [13:38<00:00,  2.55it/s]


Epoch 26/30
Train Loss: 0.3596 | Val Loss: 2.0213 | Val Acc: 63.73%


Epoch 27 [Train]: 100%|██████████| 8351/8351 [1:09:32<00:00,  2.00it/s, loss=0.576]  
Epoch 27 [Val]: 100%|██████████| 2088/2088 [13:43<00:00,  2.54it/s]


Epoch 27/30
Train Loss: 0.3565 | Val Loss: 2.0578 | Val Acc: 63.15%


Epoch 28 [Train]: 100%|██████████| 8351/8351 [1:09:40<00:00,  2.00it/s, loss=0.156]  
Epoch 28 [Val]: 100%|██████████| 2088/2088 [13:50<00:00,  2.52it/s]


Epoch 28/30
Train Loss: 0.3497 | Val Loss: 2.0725 | Val Acc: 63.04%


Epoch 29 [Train]: 100%|██████████| 8351/8351 [1:09:39<00:00,  2.00it/s, loss=0.18]  
Epoch 29 [Val]: 100%|██████████| 2088/2088 [13:39<00:00,  2.55it/s]


Epoch 29/30
Train Loss: 0.3424 | Val Loss: 2.0424 | Val Acc: 63.53%


Epoch 30 [Train]: 100%|██████████| 8351/8351 [1:10:04<00:00,  1.99it/s, loss=0.747]  
Epoch 30 [Val]: 100%|██████████| 2088/2088 [13:39<00:00,  2.55it/s]


Epoch 30/30
Train Loss: 0.3361 | Val Loss: 2.0567 | Val Acc: 63.36%
Best Validation Accuracy: 64.76%


Predicting: 100%|██████████| 4694/4694 [12:22<00:00,  6.32it/s]
