In [None]:
%load_ext autoreload
%autoreload 2
from custom_trainer import CustomTrainer, CustomClassifierTrainer
from custom_dataset import build_dataloader
from custom_model import ConvTransformerAE, CustomClassifier, AEClassifier, NewCustomClassifier
from torch import nn
import torch
import random
import numpy as np

torch.manual_seed(42)
random.seed(42)

In [None]:
classifier = CustomClassifier(latent_dim=32, num_classes=5, dropout=0.05)
ae = ConvTransformerAE(
    input_dim=32,
    latent_dim=32,
    conv_channels=8,
    d_model=64,
    n_head=4,
    num_layers=10
)
model = AEClassifier(ae_model=ae, classifier_model=classifier)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-4
)

criterion = nn.CrossEntropyLoss()


model.load_state_dict(torch.load("./classifier_model_110000.pth", map_location="cuda"))
for name, module in ae.named_modules():
    print(name, "->", module.__class__.__name__)

for p in model.parameters():
    p.requires_grad = False
model.ae_model.eval()
for p in model.classifier_model.parameters():
    p.requires_grad = True
for p in model.ae_model.to_latent.parameters():
    p.requires_grad = True
# model.classifier_model = NewCustomClassifier(latent_dim=32, num_classes=5, dropout=0.05)
# # 冻结 encoder
# for p in ae.conv.parameters():
#     p.requires_grad = False
# for p in ae.embedding.parameters():
#     p.requires_grad = False
# for p in ae.encoder.parameters():
#     p.requires_grad = False
# for p in ae.to_latent.parameters():
#     p.requires_grad = False

# ae.eval()  # encoder 用 eval

scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=400,   # 每 1 个 epoch 衰减一次
    gamma=0.99       # lr *= 0.5
)

In [None]:
data_path = './resources/features_norm.npy'
label_path = './resources/labels.npy'
train_loader = build_dataloader(
    data_path=data_path,
    label_path=label_path,
    batch_size=256,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    drop_last=True,
    offset=10000
)

val_path = './resources/val_features.npy'
val_label_path = './resources/val_labels.npy'

val_loader = build_dataloader(
    data_path=val_path,
    label_path=val_label_path,
    batch_size=4096,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    drop_last=True,
)

In [None]:
num_classes = 5
labels = np.load(label_path)
counts = np.bincount(labels.astype(np.int64), minlength=num_classes)
freq = counts / counts.sum()

print("Class counts:", counts)
print("Class freq:", freq)

weights = 1.0 / np.sqrt(counts + 1e-6)
weights = weights / weights.mean() + 0.5 # 归一化
class_weights = torch.tensor(weights, dtype=torch.float32).to('cuda')
criterion_cls = nn.CrossEntropyLoss(class_weights)
print(weights)

In [None]:
trainer = CustomClassifierTrainer(model=model,
                                  optimizer=optimizer,
                                  scheduler=scheduler,
                                  criterion=criterion_cls,)
trainer.train(train_loader=train_loader, val_loader=val_loader, epochs=2)