In [1]:
# Helper: single-number accuracy on a loader
def evaluate(loader):
    model.eval()
    correct = 0
    total   = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(2)
            mask  = y != -1
            correct += ((preds == y) & mask).sum().item()
            total   += mask.sum().item()
    return correct / total if total > 0 else 0.0


In [2]:
import torch, numpy as np, matplotlib.pyplot as plt
from torch import nn

class MultiLineMLP(nn.Module):
    def __init__(self, input_dim=6144, hidden_dims=[1024,1024,512,512,256],
                 num_lines=60, num_classes=6, p_drop=0.3):
        super().__init__()
        self.bn_in = nn.BatchNorm1d(input_dim)
        layers, dims = [], [input_dim]+hidden_dims
        for a,b in zip(dims,dims[1:]):
            layers += [nn.Linear(a,b), nn.BatchNorm1d(b), nn.ReLU(), nn.Dropout(p_drop)]
        self.shared = nn.Sequential(*layers)
        self.classifier = nn.Linear(hidden_dims[-1], num_lines*num_classes)
    def forward(self,x):
        x = self.bn_in(x)
        x = self.shared(x)
        return self.classifier(x).view(x.size(0),-1,6)

device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt_path = "best_resampled.pt"        # choose file
model     = MultiLineMLP().to(device)
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()
print("Loaded", ckpt_path)

# define evaluate(), macro_precision_recall() here or import them


RuntimeError: Error(s) in loading state_dict for MultiLineMLP:
	Missing key(s) in state_dict: "bn_in.weight", "bn_in.bias", "bn_in.running_mean", "bn_in.running_var", "shared.0.weight", "shared.0.bias", "shared.1.weight", "shared.1.bias", "shared.1.running_mean", "shared.1.running_var", "shared.4.weight", "shared.4.bias", "shared.5.weight", "shared.5.bias", "shared.5.running_mean", "shared.5.running_var", "shared.8.weight", "shared.8.bias", "shared.9.weight", "shared.9.bias", "shared.9.running_mean", "shared.9.running_var", "shared.12.weight", "shared.12.bias", "shared.13.weight", "shared.13.bias", "shared.13.running_mean", "shared.13.running_var", "shared.16.weight", "shared.16.bias", "shared.17.weight", "shared.17.bias", "shared.17.running_mean", "shared.17.running_var", "classifier.weight", "classifier.bias". 
	Unexpected key(s) in state_dict: "epoch", "model_state_dict", "optimizer_state_dict", "history". 

In [None]:
# 11. Post‑training Analysis Helpers (for paper‑style figures)
import matplotlib.pyplot as plt

## 11a. Training curves (stored during loop)
# Make sure train_loss_hist and val_acc_hist were populated above
# If not, uncomment these two lists at the very top of the loop:
# train_loss_hist, val_acc_hist = [], []
# inside epoch loop: train_loss_hist.append(avg_loss); val_acc_hist.append(val_acc)

if 'train_loss_hist' in globals() and 'val_acc_hist' in globals():
    epochs = range(1, len(train_loss_hist)+1)
    fig, ax1 = plt.subplots()
    ax1.plot(epochs, train_loss_hist, label='Train Loss')
    ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss')
    ax2 = ax1.twinx()
    ax2.plot(epochs, val_acc_hist, color='tab:red', label='Val Acc')
    ax2.set_ylabel('Accuracy')
    fig.legend(loc='upper right'); plt.title('Training vs Validation'); plt.show()
else:
    print('Run training loop with train_loss_hist / val_acc_hist lists to get curve plot')


In [None]:
## 11b. Table‑2 Precision/Recall (macro within‑one)
prec, rec = macro_precision_recall(model, val_loader, device)
print("
Table 2 – Precision & Recall (Exact / ±1)")
print("Class  ExactP  ±1P  ExactR  ±1R")
for c in range(6):
    print(f"  {c}    {prec[c,0]:.3f}  {prec[c,1]:.3f}  {rec[c,0]:.3f}  {rec[c,1]:.3f}")


In [None]:
## 11c. Figure‑1 distribution of activity levels
true_cnt = np.zeros(6, dtype=int)
pred_cnt = np.zeros(6, dtype=int)
model.eval()
with torch.no_grad():
    for x, y in val_loader:
        x = x.to(device)
        p = model(x).argmax(2).cpu().numpy(); y = y.numpy()
        m = y != -1
        for c in range(6):
            true_cnt[c] += ((y==c)&m).sum()
            pred_cnt[c] += ((p==c)&m).sum()
levels = ['Inactive','Weak','Mild','Active','Potent','Super']
idx = np.arange(6)
plt.figure();
plt.bar(idx-0.2, pred_cnt, width=0.4, label='Predicted')
plt.bar(idx+0.2, true_cnt, width=0.4, label='True')
plt.xticks(idx, levels, rotation=45); plt.ylabel('Count');
plt.title('Figure 1 – Activity‑level distribution (Val set)'); plt.legend(); plt.tight_layout(); plt.show()
