In [None]:
import pandas as pd
import os
from pyparsing import C
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch import nn, optim
from sklearn.model_selection import train_test_split
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
import pandas as pd
import os
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# 参数配置
CFG = {
	'data_path': 'chinese-medicine-image',
	'csv_name': 'train_labels.csv',
	'batch_size': 16,        # 根据显存调整
	'num_workers': 0,
	'num_epochs': 30,
	'lr': 3e-4,
	'image_size': 224,
	'num_classes': 881,
	'device': torch.device('cuda'),
	'seed': 42
}

# 数据预处理
train_transform = transforms.Compose([
	transforms.RandomResizedCrop(CFG['image_size']),
	transforms.RandomHorizontalFlip(),
	transforms.ColorJitter(0.2, 0.2, 0.2),
	transforms.ToTensor(),
	transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
	transforms.Resize(256),
	transforms.CenterCrop(CFG['image_size']),
	transforms.ToTensor(),
	transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 自定义数据集
class MedicineDataset(Dataset):
	def __init__(self, df, transform=None):
		self.df = df
		self.transform = transform
		
	def __len__(self):
		return len(self.df)
	
	def __getitem__(self, idx):
		img_path = self.df.iloc[idx, 0]
		label = self.df.iloc[idx, 1] - 1  # 标签转换为0-based
		
		# 关键断言
		assert 0 <= label < CFG['num_classes'], \
			f"无效标签：{label+1} (应为1-{CFG['num_classes']}), 位置：{idx}"
		img = Image.open(img_path).convert('RGB')
		if self.transform:
			img = self.transform(img)
		label = torch.tensor(label, dtype=torch.long)  # 确保label为long类型
		return img, label

# 准备数据
def load_data(csv_path, root_dir):
	# 改进的CSV读取方式
	df = pd.read_csv(csv_path, header=0)  # 使用第一行作为列头
	df = df.rename(columns={'ID': 'path', 'Label': 'label'})  # 规范列名
	
	# 数据类型转换
	convert_dict = {'path': str, 'label': int}
	df = df.astype(convert_dict)
	
	# 路径修正（根据实际文件结构可能需要调整）
	df['path'] = df['path'].apply(lambda x: os.path.join(root_dir, x))
	
	# 验证数据有效性
	print("正在验证数据完整性...")
	valid_samples = []
	for idx, row in tqdm(df.iterrows(), total=len(df)):
		if os.path.exists(row['path']):
			valid_samples.append(idx)
		else:
			print(f"警告：缺失文件 {row['path']}")
	df = df.loc[valid_samples]
	
	# 过滤无效类别（每个类别至少需要2个样本）
	label_counts = df['label'].value_counts()
	valid_labels = label_counts[label_counts >= 2].index
	df = df[df['label'].isin(valid_labels)]
	
	print(f"有效数据量: {len(df)}")
	print(f"有效类别数: {df['label'].nunique()}")
	
	# 转换为0-based标签
	
	# 分层划分数据集
	train_df, val_df = train_test_split(
		df,
		test_size=0.2,
		stratify=df['label'],
		random_state=CFG['seed']
	)
	return train_df, val_df

# 训练函数
def train_model(model, train_loader, test_loader):
	model = model.to(CFG['device'])
	criterion = nn.CrossEntropyLoss()
	optimizer = optim.Adam(model.parameters(), lr=CFG['lr'])
	
	best_acc = 0.0
	for epoch in range(CFG['num_epochs']):
		# 训练阶段
		model.train()
		running_loss = 0.0
		correct = 0
		total = 0
		
		pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{CFG["num_epochs"]} [Train]')
		for inputs, labels in pbar:
			inputs = inputs.to(CFG['device'])
			labels = labels.to(CFG['device'])
			
			optimizer.zero_grad()
			outputs = model(inputs)
			loss = criterion(outputs, labels)
			loss.backward()
			optimizer.step()
			
			running_loss += loss.item() * inputs.size(0)
			_, predicted = torch.max(outputs.data, 1)
			total += labels.size(0)
			correct += (predicted == labels).sum().item()
			
			pbar.set_postfix({
				'loss': loss.item(),
				'acc': f'{correct/total:.3f}'
			})
		
		# 验证阶段
		model.eval()
		test_correct = 0
		test_total = 0
		test_loss = 0.0
		
		with torch.no_grad():
			pbar = tqdm(test_loader, desc=f'Epoch {epoch+1}/{CFG["num_epochs"]} [Test]')
			for inputs, labels in pbar:
				inputs = inputs.to(CFG['device'])
				labels = labels.to(CFG['device'])
				
				outputs = model(inputs)
				loss = criterion(outputs, labels)
				
				test_loss += loss.item() * inputs.size(0)
				_, predicted = torch.max(outputs.data, 1)
				test_total += labels.size(0)
				test_correct += (predicted == labels).sum().item()
				
				pbar.set_postfix({
					'acc': f'{test_correct/test_total:.3f}'
				})
		
		# 保存最佳模型
		epoch_acc = test_correct / test_total
		if epoch_acc > best_acc:
			best_acc = epoch_acc
			torch.save(model.state_dict(), 'best_model.pth')
			print(f'New best model saved with acc: {best_acc:.4f}')

# 预测并生成结果
def predict(model, test_loader):
	model.load_state_dict(torch.load('best_model.pth'))
	model.eval()
	
	filenames = []
	predictions = []
	
	with torch.no_grad():
		for inputs, _ in tqdm(test_loader, desc='Predicting'):
			inputs = inputs.to(CFG['device'])
			outputs = model(inputs)
			_, preds = torch.max(outputs, 1)
			predictions.extend((preds + 1).cpu().numpy())  # 转回1-based
	
	# 获取文件名
	test_df = test_loader.dataset.df
	filenames = test_df.iloc[:, 0].apply(lambda x: x.split('/')[-1]).tolist()
	
	# 写入文件
	with open('submission.txt', 'w') as f:
		for fn, pred in zip(filenames, predictions):
			f.write(f"{fn}\t{pred}\n")

if __name__ == '__main__':
	# 准备数据
	train_df, test_df = load_data(CFG['csv_name'], CFG['data_path'])
	
	# 创建数据集和数据加载器
	train_dataset = MedicineDataset(train_df, train_transform)
	test_dataset = MedicineDataset(test_df, test_transform)
	
	train_loader = DataLoader(
		train_dataset,
		batch_size=CFG['batch_size'],
		shuffle=True,
		num_workers=CFG['num_workers']
	)
	
	test_loader = DataLoader(
		test_dataset,
		batch_size=CFG['batch_size'],
		shuffle=False,
		num_workers=CFG['num_workers']
	)
	
	# 初始化模型
	model = models.resnet50(pretrained=True)
	model.fc = nn.Linear(model.fc.in_features, CFG['num_classes'])
	
	# 训练模型
	train_model(model, train_loader, test_loader)
	
	# 生成预测结果
	predict(model, test_loader)


正在验证数据完整性...


100%|██████████| 167017/167017 [00:10<00:00, 16271.12it/s]


有效数据量: 167017
有效类别数: 879


Epoch 1/30 [Train]: 100%|██████████| 8351/8351 [41:04<00:00,  3.39it/s, loss=6.38, acc=0.013]  
Epoch 1/30 [Test]: 100%|██████████| 2088/2088 [12:44<00:00,  2.73it/s, acc=0.044]


New best model saved with acc: 0.0436


Epoch 2/30 [Train]: 100%|██████████| 8351/8351 [41:21<00:00,  3.37it/s, loss=5.88, acc=0.070]  
Epoch 2/30 [Test]: 100%|██████████| 2088/2088 [10:38<00:00,  3.27it/s, acc=0.143]


New best model saved with acc: 0.1430


Epoch 3/30 [Train]: 100%|██████████| 8351/8351 [45:54<00:00,  3.03it/s, loss=5.22, acc=0.144]  
Epoch 3/30 [Test]: 100%|██████████| 2088/2088 [13:19<00:00,  2.61it/s, acc=0.201]


New best model saved with acc: 0.2014


Epoch 4/30 [Train]: 100%|██████████| 8351/8351 [47:22<00:00,  2.94it/s, loss=4.64, acc=0.202]  
Epoch 4/30 [Test]: 100%|██████████| 2088/2088 [13:16<00:00,  2.62it/s, acc=0.257]


New best model saved with acc: 0.2573


Epoch 5/30 [Train]: 100%|██████████| 8351/8351 [45:44<00:00,  3.04it/s, loss=3.64, acc=0.248]  
Epoch 5/30 [Test]: 100%|██████████| 2088/2088 [10:40<00:00,  3.26it/s, acc=0.318]


New best model saved with acc: 0.3182


Epoch 6/30 [Train]: 100%|██████████| 8351/8351 [45:32<00:00,  3.06it/s, loss=5.02, acc=0.287]  
Epoch 6/30 [Test]: 100%|██████████| 2088/2088 [13:20<00:00,  2.61it/s, acc=0.360]


New best model saved with acc: 0.3604


Epoch 7/30 [Train]: 100%|██████████| 8351/8351 [47:21<00:00,  2.94it/s, loss=3.87, acc=0.322]  
Epoch 7/30 [Test]: 100%|██████████| 2088/2088 [13:17<00:00,  2.62it/s, acc=0.386]


New best model saved with acc: 0.3858


Epoch 8/30 [Train]: 100%|██████████| 8351/8351 [45:40<00:00,  3.05it/s, loss=2.6, acc=0.351]   
Epoch 8/30 [Test]: 100%|██████████| 2088/2088 [10:38<00:00,  3.27it/s, acc=0.398]


New best model saved with acc: 0.3979


Epoch 9/30 [Train]: 100%|██████████| 8351/8351 [45:43<00:00,  3.04it/s, loss=2.41, acc=0.375]  
Epoch 9/30 [Test]: 100%|██████████| 2088/2088 [13:27<00:00,  2.59it/s, acc=0.430]


New best model saved with acc: 0.4300


Epoch 10/30 [Train]: 100%|██████████| 8351/8351 [45:54<00:00,  3.03it/s, loss=3.79, acc=0.397]  
Epoch 10/30 [Test]: 100%|██████████| 2088/2088 [10:38<00:00,  3.27it/s, acc=0.443]


New best model saved with acc: 0.4433


Epoch 11/30 [Train]: 100%|██████████| 8351/8351 [46:52<00:00,  2.97it/s, loss=1.47, acc=0.417]  
Epoch 11/30 [Test]: 100%|██████████| 2088/2088 [13:25<00:00,  2.59it/s, acc=0.454]


New best model saved with acc: 0.4543


Epoch 12/30 [Train]: 100%|██████████| 8351/8351 [46:54<00:00,  2.97it/s, loss=2.6, acc=0.432]   
Epoch 12/30 [Test]: 100%|██████████| 2088/2088 [13:21<00:00,  2.61it/s, acc=0.470]


New best model saved with acc: 0.4701


Epoch 13/30 [Train]: 100%|██████████| 8351/8351 [47:27<00:00,  2.93it/s, loss=3.38, acc=0.449]  
Epoch 13/30 [Test]: 100%|██████████| 2088/2088 [13:39<00:00,  2.55it/s, acc=0.486]


New best model saved with acc: 0.4856


Epoch 14/30 [Train]: 100%|██████████| 8351/8351 [48:09<00:00,  2.89it/s, loss=1.67, acc=0.464]   
Epoch 14/30 [Test]: 100%|██████████| 2088/2088 [13:40<00:00,  2.55it/s, acc=0.491]


New best model saved with acc: 0.4907


Epoch 15/30 [Train]: 100%|██████████| 8351/8351 [47:51<00:00,  2.91it/s, loss=4.59, acc=0.477]  
Epoch 15/30 [Test]: 100%|██████████| 2088/2088 [13:41<00:00,  2.54it/s, acc=0.492]


New best model saved with acc: 0.4925


Epoch 16/30 [Train]: 100%|██████████| 8351/8351 [47:47<00:00,  2.91it/s, loss=2.54, acc=0.490]  
Epoch 16/30 [Test]: 100%|██████████| 2088/2088 [13:46<00:00,  2.53it/s, acc=0.505]


New best model saved with acc: 0.5053


Epoch 17/30 [Train]: 100%|██████████| 8351/8351 [47:55<00:00,  2.90it/s, loss=2.31, acc=0.501]  
Epoch 17/30 [Test]: 100%|██████████| 2088/2088 [13:43<00:00,  2.54it/s, acc=0.505]
Epoch 18/30 [Train]: 100%|██████████| 8351/8351 [47:34<00:00,  2.93it/s, loss=2.37, acc=0.514]  
Epoch 18/30 [Test]: 100%|██████████| 2088/2088 [13:43<00:00,  2.54it/s, acc=0.515]


New best model saved with acc: 0.5151


Epoch 19/30 [Train]: 100%|██████████| 8351/8351 [48:06<00:00,  2.89it/s, loss=2.17, acc=0.524]   
Epoch 19/30 [Test]: 100%|██████████| 2088/2088 [13:19<00:00,  2.61it/s, acc=0.523]


New best model saved with acc: 0.5227


Epoch 20/30 [Train]: 100%|██████████| 8351/8351 [49:30<00:00,  2.81it/s, loss=2.02, acc=0.532]   
Epoch 20/30 [Test]: 100%|██████████| 2088/2088 [13:17<00:00,  2.62it/s, acc=0.522]
Epoch 21/30 [Train]: 100%|██████████| 8351/8351 [49:04<00:00,  2.84it/s, loss=1.95, acc=0.544]   
Epoch 21/30 [Test]: 100%|██████████| 2088/2088 [13:29<00:00,  2.58it/s, acc=0.526]


New best model saved with acc: 0.5262


Epoch 22/30 [Train]: 100%|██████████| 8351/8351 [48:38<00:00,  2.86it/s, loss=2.04, acc=0.553]   
Epoch 22/30 [Test]: 100%|██████████| 2088/2088 [13:41<00:00,  2.54it/s, acc=0.527]


New best model saved with acc: 0.5269


Epoch 23/30 [Train]: 100%|██████████| 8351/8351 [48:12<00:00,  2.89it/s, loss=2.03, acc=0.560]   
Epoch 23/30 [Test]: 100%|██████████| 2088/2088 [13:38<00:00,  2.55it/s, acc=0.538]


New best model saved with acc: 0.5380


Epoch 24/30 [Train]: 100%|██████████| 8351/8351 [50:28<00:00,  2.76it/s, loss=3.72, acc=0.572]   
Epoch 24/30 [Test]: 100%|██████████| 2088/2088 [13:34<00:00,  2.56it/s, acc=0.535]
Epoch 25/30 [Train]: 100%|██████████| 8351/8351 [47:03<00:00,  2.96it/s, loss=1.59, acc=0.578]   
Epoch 25/30 [Test]: 100%|██████████| 2088/2088 [13:39<00:00,  2.55it/s, acc=0.530]
Epoch 26/30 [Train]: 100%|██████████| 8351/8351 [47:16<00:00,  2.94it/s, loss=1.6, acc=0.589]    
Epoch 26/30 [Test]: 100%|██████████| 2088/2088 [13:45<00:00,  2.53it/s, acc=0.534]
Epoch 27/30 [Train]: 100%|██████████| 8351/8351 [49:57<00:00,  2.79it/s, loss=2.54, acc=0.595]   
Epoch 27/30 [Test]: 100%|██████████| 2088/2088 [13:34<00:00,  2.56it/s, acc=0.544]


New best model saved with acc: 0.5436


Epoch 28/30 [Train]: 100%|██████████| 8351/8351 [48:04<00:00,  2.90it/s, loss=1.16, acc=0.604]   
Epoch 28/30 [Test]: 100%|██████████| 2088/2088 [13:36<00:00,  2.56it/s, acc=0.549]


New best model saved with acc: 0.5487


Epoch 29/30 [Train]: 100%|██████████| 8351/8351 [48:47<00:00,  2.85it/s, loss=0.869, acc=0.608]  
Epoch 29/30 [Test]: 100%|██████████| 2088/2088 [12:33<00:00,  2.77it/s, acc=0.543]
Epoch 30/30 [Train]: 100%|██████████| 8351/8351 [46:21<00:00,  3.00it/s, loss=1.67, acc=0.617]   
Epoch 30/30 [Test]: 100%|██████████| 2088/2088 [10:52<00:00,  3.20it/s, acc=0.545]


AttributeError: type object 'tqdm' has no attribute 'tqdm'

In [11]:
def generate_predictions(model, test_dir, transform):
	model.eval()
	predictions = []
	
	# 获取测试图片列表
	test_images = [os.path.join(test_dir, f) for f in os.listdir(test_dir) 
				  if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
	
	# 创建测试数据加载器
	class TestDataset(Dataset):
		def __init__(self, image_paths, transform=None):
			self.image_paths = image_paths
			self.transform = transform
			
		def __len__(self):
			return len(self.image_paths)
		
		def __getitem__(self, idx):
			image = Image.open(self.image_paths[idx]).convert('RGB')
			if self.transform:
				image = self.transform(image)
			return image, os.path.basename(self.image_paths[idx])
	
	test_dataset = TestDataset(test_images, transform=transform)
	test_loader = DataLoader(test_dataset, batch_size=CFG['batch_size'], shuffle=False)
	
	# 进行预测
	with torch.no_grad():
		for images, filenames in tqdm(test_loader, desc='Predicting'):
			images = images.to(CFG['device'])
			outputs = model(images)
			_, preds = torch.max(outputs.logits, 1)
			
			for fn, pred in zip(filenames, preds.cpu().numpy()):
				predictions.append((fn, pred + 1))  # 转换回1-based
	
	# 保存结果
	with open('submission.txt', 'w') as f:
		for fn, pred in predictions:
			f.write(f'{fn}\t{pred}\n')


In [13]:

# 生成提交文件（假设测试图片在test目录下）
# 修正 outputs.logits 为 outputs
def generate_predictions(model, test_dir, transform):
	model.eval()
	predictions = []
	
	# 获取测试图片列表
	test_images = [os.path.join(test_dir, f) for f in os.listdir(test_dir) 
				  if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
	
	# 创建测试数据加载器
	class TestDataset(Dataset):
		def __init__(self, image_paths, transform=None):
			self.image_paths = image_paths
			self.transform = transform
			
		def __len__(self):
			return len(self.image_paths)
		
		def __getitem__(self, idx):
			image = Image.open(self.image_paths[idx]).convert('RGB')
			if self.transform:
				image = self.transform(image)
			return image, os.path.basename(self.image_paths[idx])
	
	test_dataset = TestDataset(test_images, transform=transform)
	test_loader = DataLoader(test_dataset, batch_size=CFG['batch_size'], shuffle=False)
	
	# 进行预测
	with torch.no_grad():
		for images, filenames in tqdm(test_loader, desc='Predicting'):
			images = images.to(CFG['device'])
			outputs = model(images)
			_, preds = torch.max(outputs, 1)
			
			for fn, pred in zip(filenames, preds.cpu().numpy()):
				predictions.append((fn, pred + 1))  # 转换回1-based
	
	# 保存结果
	with open('submission.txt', 'w') as f:
		for fn, pred in predictions:
			f.write(f'{fn}\t{pred}\n')

generate_predictions(model, 'chinese-medicine-image/test', test_transform)


Predicting: 100%|██████████| 4694/4694 [11:41<00:00,  6.69it/s]
