In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from data import get_data
from model import InceptionNet, InceptionNet_simplified, Mynet_v1, Mynet_v2
import torchvision.transforms as transforms
from PIL import Image, ImageDraw, ImageFont

In [23]:
# 计算模型参数量
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

models = {
    "InceptionNet": InceptionNet(),
    "InceptionNet_simplified": InceptionNet_simplified(),
    "Mynet_v1": Mynet_v1(),
    "Mynet_v2": Mynet_v2(),
}

for name, model in models.items():
    print(f"{name} Params: {count_params(model):,}")

InceptionNet Params: 6,275,687
InceptionNet_simplified Params: 2,872,375
Mynet_v1 Params: 488,039
Mynet_v2 Params: 625,111


In [None]:
# 模型训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Mynet_v2().to(device)
model.load_state_dict(torch.load('model/Mynet_v2_acc0.6857.pth', map_location=device, weights_only=True))
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-8)
train_loader, test_loader = get_data(batch_size=32)

num_epochs = 30
best_val_acc = 0.0

torch.set_float32_matmul_precision('high')
for epoch in range(num_epochs):
    model.train()

    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            outputs = model(images)
            loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        if (batch_idx + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], LR: {scheduler.get_last_lr()[0]:.6f}, Batch [{batch_idx+1}/{len(train_loader)}], Train Loss: {loss.item():.4f}")

    scheduler.step()
    
    model.eval()
    val_loss = 0.0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                outputs = model(images)
                loss = loss_fn(outputs, labels)
            val_loss += loss.item()
            predicted = outputs.data.argmax(dim=1)
            all_predictions.append(predicted)
            all_labels.append(labels)
    
    avg_val_loss = val_loss / len(test_loader)

    all_predictions = torch.cat(all_predictions)
    all_labels = torch.cat(all_labels)
    val_acc = (all_predictions == all_labels).float().mean().item()

    print("="*70)
    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"Val Loss: {avg_val_loss:.4f}")
    print(f"Val Accuracy: {val_acc:.4f}")

    if val_acc > best_val_acc: 
        best_val_acc = val_acc
        torch.save(model.state_dict(), f'model/{model._get_name()}_acc{best_val_acc:.4f}.pth')
        print("-"*70)
        print("Best model saved with accuracy: {:.4f}".format(best_val_acc))
        print("-"*70)

    print("="*70)

In [36]:
# 模型测试
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model = Mynet_v2().to(device)
# model.load_state_dict(torch.load('model/MyNet_v2_acc0.6857.pth', map_location=device, weights_only=True))

model = InceptionNet_simplified().to(device)
model.load_state_dict(torch.load('model/InceptionNet_simplified_acc0.6959.pth', map_location=device, weights_only=True))
image_path = "test_images/surprise.png"
ori_image = Image.open(image_path)

test_transforms = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]) 
])

image = test_transforms(ori_image.convert("L")).unsqueeze(0).to(device)

model.eval()
with torch.no_grad():
    output = model(image)
    prediction = output.argmax(dim=1)

emotion_dict = {
            'angry': 0,
            'disgust': 1,
            'fear': 2,
            'happy': 3,
            'neutral': 4,
            'sad': 5,
            'surprise': 6
        }

predicted_emotion = [key for key, value in emotion_dict.items() if value == prediction.item()][0]
print(f"Predicted Emotion: {predicted_emotion}")

draw = ImageDraw.Draw(ori_image)
font = ImageFont.load_default()
text = f"Predicted: {predicted_emotion}"
draw.text((10, 10), text, fill=(50, 0, 0), font=font)

ori_image.show()
ori_image.save(f"test_images/{predicted_emotion}_result.png")


Predicted Emotion: surprise
