In [2]:
# Imports
import os, sys, subprocess, shutil, random, warnings, timm
import numpy as np, librosa, soundfile as sf
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt, seaborn as sns
from tqdm import tqdm
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

def pip_install(pkgs):
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", "--upgrade"] + pkgs, check=True)

# Deps
pip_install(["librosa==0.10.1","soundfile","scikit-learn","matplotlib","pandas","tqdm","seaborn","timm"])
rc = subprocess.run([sys.executable, "-m", "pip", "install", "-q",
                     "torch","torchvision","torchaudio",
                     "--index-url","https://download.pytorch.org/whl/cu121"]).returncode
if rc != 0:
    rc = subprocess.run([sys.executable, "-m", "pip", "install", "-q",
                         "torch","torchvision","torchaudio",
                         "--index-url","https://download.pytorch.org/whl/cu118"]).returncode
    if rc != 0:
        pip_install(["torch","torchvision","torchaudio"])

# Dataset
if os.path.exists("Keystroke-Datasets"):
    shutil.rmtree("Keystroke-Datasets")
subprocess.run(["git","clone","https://github.com/JBFH-Dev/Keystroke-Datasets.git"], check=True)

# Seeds & determinism
SEED=42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=False

DATASET_DIRS = ['Keystroke-Datasets/Zoom']
KEYS = list("1234567890qwertyuiopasdfghjklzxcvbnm")

# Audio slice & mels
TARGET_SR = 44100
TARGET_LEN_S = 0.28
TARGET_SAMPLES = int(TARGET_SR * TARGET_LEN_S)
N_MELS, N_FFT, HOP = 64, 1024, 255
FMIN, FMAX = 50, 8000
IMG_H = IMG_W = 64

# Train setup
EPOCHS = 120
BATCH = 64
LR_MAX = 7e-4
WD = 1e-4
LABEL_SMOOTH = 0.0   # (changed from 0.02)
OUT_DIR = "ckpts_convnext_s_zoom"
os.makedirs(OUT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Keystroke isolator
def isolator(signal, sr, n_fft=48, hop=24, before=2400, after=12000, threshold=0.06, min_gap_sec=0.1):
    fft = librosa.stft(signal, n_fft=n_fft, hop_length=hop)
    energy = np.abs(np.sum(fft, axis=0)).astype(float)
    peaks = np.where(energy > threshold)[0]
    strokes, prev_end = [], -int(min_gap_sec * sr)
    for p in peaks:
        ts = (p * hop) + n_fft // 2
        if ts > prev_end + int(min_gap_sec * sr):
            start = max(0, ts - before)
            end   = min(len(signal), ts + after)
            strokes.append(signal[start:end])
            prev_end = end
    return strokes

def build_strokes(dataset_dirs, keys=KEYS, target_per_key=25, sr_target=TARGET_SR):
    rows=[]
    for ds in dataset_dirs:
        base=os.path.basename(ds)
        for ch in tqdm(keys, desc=f"Extracting from {base}"):
            wav_path=os.path.join(ds,f"{ch}.wav")
            if not os.path.exists(wav_path): continue
            y, sr = librosa.load(wav_path, sr=sr_target)
            y = y[int(0.8*sr):]  # drop head
            prom, step, tries = 0.06, 0.005, 0
            strokes=[]
            while len(strokes)!=target_per_key and tries<60:
                strokes = isolator(y, sr, threshold=prom)
                if len(strokes)<target_per_key: prom=max(1e-5, prom-step)
                elif len(strokes)>target_per_key: prom+=step
                step*=0.98; tries+=1
            if not strokes: continue
            if len(strokes)>target_per_key: strokes=strokes[:target_per_key]
            for s in strokes: rows.append((ch,s,sr))
    return rows

print(">> Building strokes ...")
strokes=build_strokes(DATASET_DIRS)
print("Total strokes:", len(strokes))

# Center-crop around RMS peak
def center_crop_seg(y, sr):
    rms = librosa.feature.rms(y=y, frame_length=256, hop_length=64).flatten()
    pk = int(np.argmax(rms))*64
    start = max(0, pk - TARGET_SAMPLES//2)
    end = start + TARGET_SAMPLES
    if end>len(y):
        start = max(0, len(y)-TARGET_SAMPLES); end=len(y)
    seg = y[start:end]
    if len(seg)<TARGET_SAMPLES: seg=np.pad(seg,(0,TARGET_SAMPLES-len(seg)))
    amax = np.max(np.abs(seg))
    if amax>0: seg = seg/(amax+1e-8)
    return seg

def mel_3ch(seg, sr):
    S = librosa.feature.melspectrogram(y=seg, sr=sr, n_fft=N_FFT, hop_length=HOP,
                                       n_mels=N_MELS, fmin=FMIN, fmax=FMAX, power=2.0)
    Sdb = librosa.power_to_db(S, ref=np.max).astype(np.float32)
    D1 = librosa.feature.delta(Sdb)
    D2 = librosa.feature.delta(Sdb, order=2)
    X = np.stack([Sdb, D1, D2], axis=0)  # (3, n_mels, T)
    T = X.shape[2]
    if T < IMG_W:
        X = np.pad(X, ((0,0),(0,0),(0,IMG_W-T)), mode="edge")
    else:
        X = X[:, :, :IMG_W]
    for c in range(3):
        mu, sd = X[c].mean(), X[c].std()+1e-6
        X[c] = (X[c]-mu)/sd
    return X.astype(np.float32)

# Build arrays
X, Y = [], []
for label, audio, sr in tqdm(strokes, desc="Center-crop + 3ch mel"):
    seg = center_crop_seg(audio, sr)
    X.append(mel_3ch(seg, sr))
    Y.append(label)
X = np.array(X, dtype=np.float32)
Y = np.array(Y)
print("X shape:", X.shape, "num classes:", len(np.unique(Y)))

# Encode/split
le = LabelEncoder().fit(Y)
y_int = le.transform(Y)
num_classes = len(le.classes_)
X_train, X_test, y_train, y_test = train_test_split(
    X, y_int, test_size=0.25, stratify=y_int, random_state=SEED
)
print("Train:", X_train.shape, " Test:", X_test.shape)

# Light SpecAugment
def specaug_light(x):
    x=x.copy()
    if random.random()<0.25:
        t = random.randint(3,7); t0 = random.randint(0,64-t)
        x[:, :, t0:t0+t] = 0.0
    if random.random()<0.25:
        f = random.randint(3,7); f0 = random.randint(0,64-f)
        x[:, f0:f0+f, :] = 0.0
    if random.random()<0.20:
        sh = random.randint(-3,3)
        x = np.roll(x, sh, axis=2)
    return x

class KSDataset(Dataset):
    def __init__(self, X,y,augment=False):
        self.X, self.y, self.augment = X,y,augment
    def __len__(self): return len(self.X)
    def __getitem__(self,i):
        x = self.X[i]
        if self.augment: x = specaug_light(x)
        return torch.from_numpy(x).float(), torch.tensor(self.y[i], dtype=torch.long)

train_ds = KSDataset(X_train, y_train, augment=True)
val_ds   = KSDataset(X_test,  y_test,  augment=False)

counts = np.bincount(y_train)
class_w = 1.0/np.clip(counts,1,None)
sample_w = class_w[y_train]
sampler = WeightedRandomSampler(sample_w, num_samples=len(sample_w), replacement=True)

train_loader = DataLoader(train_ds, batch_size=BATCH, sampler=sampler, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)

# Model: ConvNeXt-Small
model = timm.create_model("convnext_small", pretrained=True, in_chans=3, num_classes=num_classes).to(device)

# Loss/opt/sched
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR_MAX, weight_decay=WD)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=LR_MAX, epochs=EPOCHS, steps_per_epoch=len(train_loader),
    pct_start=0.15, anneal_strategy='cos', div_factor=10.0, final_div_factor=1e3
)
scaler = torch.amp.GradScaler('cuda', enabled=(device.type=="cuda"))
best_acc, patience, no_imp = 0.0, 20, 0
for epoch in range(1, EPOCHS+1):
    model.train()
    tot, corr, runloss = 0, 0, 0.0
    for xb, yb in tqdm(train_loader, desc=f"Train E{epoch:03d}"):
        xb, yb = xb.to(device, dtype=torch.float32), yb.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda'):
            out = model(xb)
            loss = criterion(out, yb)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        runloss += loss.item()*xb.size(0)
        with torch.no_grad():
            preds = out.argmax(1); corr += (preds==yb).sum().item(); tot += yb.size(0)
    tr_loss, tr_acc = runloss/tot, corr/tot

    # val
    model.eval()
    v_tot, v_corr, v_loss = 0, 0, 0.0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device, dtype=torch.float32), yb.to(device)
            out = model(xb); loss = criterion(out, yb)
            v_loss += loss.item()*xb.size(0)
            v_corr += (out.argmax(1)==yb).sum().item(); v_tot += yb.size(0)
    val_loss, val_acc = v_loss/v_tot, v_corr/v_tot
    print(f"Epoch {epoch:03d} | Train {tr_loss:.4f}/{tr_acc:.4f} | Val {val_loss:.4f}/{val_acc:.4f}")

    if val_acc > best_acc:
        best_acc = val_acc; no_imp = 0
        torch.save({'model_state': model.state_dict(), 'classes': list(map(str, le.classes_))},
                   os.path.join(OUT_DIR, "best.pt"))
        print("  -> Saved best:", best_acc)
    else:
        no_imp += 1
        if no_imp >= patience:
            print("Early stopping."); break

# Final eval (with TTA)
ck = torch.load(os.path.join(OUT_DIR, "best.pt"), map_location=device)
model.load_state_dict(ck['model_state'])
model.eval()

def tta_logits(model, xb, shifts=(-4, -2, -1, 0, 1, 2, 4)):
    with torch.no_grad():
        acc_logits = None
        for sh in shifts:
            x = xb if sh==0 else torch.roll(xb, shifts=sh, dims=3)  # roll time-axis
            out = model(x)
            acc_logits = out if acc_logits is None else (acc_logits + out)
        return acc_logits / len(shifts)

all_p, all_t = [], []
with torch.no_grad():
    for xb, yb in val_loader:
        xb = xb.to(device, dtype=torch.float32)
        logits = tta_logits(model, xb)      # TTA inference
        pred = logits.argmax(1).cpu().numpy().tolist()
        all_p.extend(pred); all_t.extend(yb.numpy().tolist())

acc = accuracy_score(all_t, all_p)
print(f"\nTTA accuracy: {acc:.4f}\n")
print("TTA classification report:\n", classification_report(all_t, all_p, target_names=list(map(str, le.classes_))))

cm = confusion_matrix(all_t, all_p)
plt.figure(figsize=(12,10))
sns.heatmap(cm, annot=False, cmap='Blues',
            xticklabels=list(map(str, le.classes_)),
            yticklabels=list(map(str, le.classes_)))
plt.xlabel("Predicted"); plt.ylabel("True"); plt.title("Confusion Matrix (ConvNeXt-S, Zoom-only, TTA)")
plt.tight_layout()
cm_path = os.path.join(OUT_DIR, "cm_tta.png")
plt.savefig(cm_path, dpi=140)
print("Saved confusion matrix image to:", cm_path)
print("Checkpoints in:", OUT_DIR)

Device: cuda
>> Building strokes ...


Extracting from Zoom: 100%|██████████| 36/36 [00:02<00:00, 14.05it/s]


Total strokes: 900


Center-crop + 3ch mel: 100%|██████████| 900/900 [00:08<00:00, 109.85it/s]


X shape: (900, 3, 64, 64) num classes: 36
Train: (675, 3, 64, 64)  Test: (225, 3, 64, 64)


Train E001: 100%|██████████| 11/11 [00:01<00:00,  6.91it/s]


Epoch 001 | Train 3.8996/0.0341 | Val 3.6685/0.0267
  -> Saved best: 0.02666666666666667


Train E002: 100%|██████████| 11/11 [00:01<00:00, 10.28it/s]


Epoch 002 | Train 3.6368/0.0430 | Val 3.6131/0.0267


Train E003: 100%|██████████| 11/11 [00:01<00:00, 10.49it/s]


Epoch 003 | Train 3.5902/0.0444 | Val 3.6194/0.0267


Train E004: 100%|██████████| 11/11 [00:01<00:00, 10.46it/s]


Epoch 004 | Train 3.6079/0.0311 | Val 3.6038/0.0356
  -> Saved best: 0.035555555555555556


Train E005: 100%|██████████| 11/11 [00:01<00:00, 10.38it/s]


Epoch 005 | Train 3.5491/0.0400 | Val 3.6074/0.0444
  -> Saved best: 0.044444444444444446


Train E006: 100%|██████████| 11/11 [00:01<00:00, 10.43it/s]


Epoch 006 | Train 3.5863/0.0385 | Val 3.5345/0.0356


Train E007: 100%|██████████| 11/11 [00:01<00:00,  9.33it/s]


Epoch 007 | Train 3.4155/0.0622 | Val 3.3686/0.0444


Train E008: 100%|██████████| 11/11 [00:01<00:00,  7.82it/s]


Epoch 008 | Train 2.8943/0.1496 | Val 2.7971/0.1244
  -> Saved best: 0.12444444444444444


Train E009: 100%|██████████| 11/11 [00:01<00:00, 10.22it/s]


Epoch 009 | Train 2.6615/0.1822 | Val 2.4930/0.2444
  -> Saved best: 0.24444444444444444


Train E010: 100%|██████████| 11/11 [00:01<00:00, 10.25it/s]


Epoch 010 | Train 2.1085/0.3126 | Val 2.0743/0.3511
  -> Saved best: 0.3511111111111111


Train E011: 100%|██████████| 11/11 [00:01<00:00, 10.50it/s]


Epoch 011 | Train 1.9551/0.3570 | Val 1.9000/0.4044
  -> Saved best: 0.40444444444444444


Train E012: 100%|██████████| 11/11 [00:01<00:00, 10.40it/s]


Epoch 012 | Train 1.4461/0.5141 | Val 1.5892/0.4933
  -> Saved best: 0.49333333333333335


Train E013: 100%|██████████| 11/11 [00:01<00:00,  9.19it/s]


Epoch 013 | Train 1.5528/0.4948 | Val 1.4307/0.5511
  -> Saved best: 0.5511111111111111


Train E014: 100%|██████████| 11/11 [00:01<00:00,  8.04it/s]


Epoch 014 | Train 1.0351/0.6919 | Val 1.0891/0.6667
  -> Saved best: 0.6666666666666666


Train E015: 100%|██████████| 11/11 [00:01<00:00, 10.52it/s]


Epoch 015 | Train 0.8524/0.7452 | Val 0.9633/0.6844
  -> Saved best: 0.6844444444444444


Train E016: 100%|██████████| 11/11 [00:01<00:00, 10.30it/s]


Epoch 016 | Train 0.5661/0.8119 | Val 0.8751/0.7422
  -> Saved best: 0.7422222222222222


Train E017: 100%|██████████| 11/11 [00:01<00:00, 10.40it/s]


Epoch 017 | Train 0.5361/0.8444 | Val 0.7571/0.7867
  -> Saved best: 0.7866666666666666


Train E018: 100%|██████████| 11/11 [00:01<00:00,  9.92it/s]


Epoch 018 | Train 0.4082/0.8815 | Val 0.6859/0.7600


Train E019: 100%|██████████| 11/11 [00:01<00:00,  7.98it/s]


Epoch 019 | Train 0.3169/0.8933 | Val 0.7334/0.7911
  -> Saved best: 0.7911111111111111


Train E020: 100%|██████████| 11/11 [00:01<00:00, 10.49it/s]


Epoch 020 | Train 0.2425/0.9289 | Val 0.5242/0.8089
  -> Saved best: 0.8088888888888889


Train E021: 100%|██████████| 11/11 [00:01<00:00, 10.13it/s]


Epoch 021 | Train 0.2167/0.9437 | Val 0.5819/0.8489
  -> Saved best: 0.8488888888888889


Train E022: 100%|██████████| 11/11 [00:01<00:00,  9.48it/s]


Epoch 022 | Train 0.1307/0.9600 | Val 0.3886/0.8889
  -> Saved best: 0.8888888888888888


Train E023: 100%|██████████| 11/11 [00:01<00:00, 10.15it/s]


Epoch 023 | Train 0.1341/0.9644 | Val 0.4473/0.8667


Train E024: 100%|██████████| 11/11 [00:01<00:00, 10.44it/s]


Epoch 024 | Train 0.1507/0.9585 | Val 0.4766/0.8444


Train E025: 100%|██████████| 11/11 [00:01<00:00,  8.15it/s]


Epoch 025 | Train 0.1466/0.9556 | Val 0.4822/0.8622


Train E026: 100%|██████████| 11/11 [00:01<00:00,  9.40it/s]


Epoch 026 | Train 0.1445/0.9570 | Val 0.5564/0.8578


Train E027: 100%|██████████| 11/11 [00:01<00:00, 10.57it/s]


Epoch 027 | Train 0.1428/0.9585 | Val 0.4434/0.8533


Train E028: 100%|██████████| 11/11 [00:01<00:00, 10.46it/s]


Epoch 028 | Train 0.1152/0.9615 | Val 0.5921/0.8578


Train E029: 100%|██████████| 11/11 [00:01<00:00, 10.42it/s]


Epoch 029 | Train 0.1250/0.9556 | Val 0.4351/0.8756


Train E030: 100%|██████████| 11/11 [00:01<00:00, 10.46it/s]


Epoch 030 | Train 0.0760/0.9748 | Val 0.5985/0.8533


Train E031: 100%|██████████| 11/11 [00:01<00:00, 10.31it/s]


Epoch 031 | Train 0.1036/0.9763 | Val 0.5026/0.8489


Train E032: 100%|██████████| 11/11 [00:01<00:00, 10.40it/s]


Epoch 032 | Train 0.0650/0.9719 | Val 0.2818/0.9200
  -> Saved best: 0.92


Train E033: 100%|██████████| 11/11 [00:01<00:00,  8.82it/s]


Epoch 033 | Train 0.0446/0.9867 | Val 0.2063/0.9289
  -> Saved best: 0.9288888888888889


Train E034: 100%|██████████| 11/11 [00:01<00:00,  9.92it/s]


Epoch 034 | Train 0.0909/0.9674 | Val 0.4219/0.8711


Train E035: 100%|██████████| 11/11 [00:01<00:00, 10.27it/s]


Epoch 035 | Train 0.0625/0.9793 | Val 0.4239/0.8578


Train E036: 100%|██████████| 11/11 [00:01<00:00, 10.61it/s]


Epoch 036 | Train 0.0571/0.9807 | Val 0.4784/0.8756


Train E037: 100%|██████████| 11/11 [00:01<00:00, 10.51it/s]


Epoch 037 | Train 0.0927/0.9733 | Val 0.4474/0.8711


Train E038: 100%|██████████| 11/11 [00:01<00:00, 10.22it/s]


Epoch 038 | Train 0.0504/0.9822 | Val 0.3274/0.8933


Train E039: 100%|██████████| 11/11 [00:01<00:00,  6.53it/s]


Epoch 039 | Train 0.0252/0.9911 | Val 0.3407/0.8978


Train E040: 100%|██████████| 11/11 [00:01<00:00,  7.50it/s]


Epoch 040 | Train 0.0467/0.9881 | Val 0.3608/0.9067


Train E041: 100%|██████████| 11/11 [00:01<00:00, 10.14it/s]


Epoch 041 | Train 0.0499/0.9867 | Val 0.3127/0.9244


Train E042: 100%|██████████| 11/11 [00:01<00:00, 10.35it/s]


Epoch 042 | Train 0.0560/0.9852 | Val 0.3710/0.8800


Train E043: 100%|██████████| 11/11 [00:01<00:00,  9.90it/s]


Epoch 043 | Train 0.0374/0.9881 | Val 0.3893/0.8978


Train E044: 100%|██████████| 11/11 [00:01<00:00, 10.19it/s]


Epoch 044 | Train 0.0660/0.9793 | Val 0.6549/0.8711


Train E045: 100%|██████████| 11/11 [00:01<00:00, 10.19it/s]


Epoch 045 | Train 0.0781/0.9733 | Val 0.5190/0.8711


Train E046: 100%|██████████| 11/11 [00:01<00:00, 10.11it/s]


Epoch 046 | Train 0.0640/0.9719 | Val 0.4750/0.8756


Train E047: 100%|██████████| 11/11 [00:01<00:00,  9.56it/s]


Epoch 047 | Train 0.0498/0.9837 | Val 0.6237/0.8133


Train E048: 100%|██████████| 11/11 [00:01<00:00,  7.14it/s]


Epoch 048 | Train 0.0495/0.9837 | Val 0.4606/0.8844


Train E049: 100%|██████████| 11/11 [00:01<00:00, 10.18it/s]


Epoch 049 | Train 0.0213/0.9941 | Val 0.4475/0.8933


Train E050: 100%|██████████| 11/11 [00:01<00:00, 10.14it/s]


Epoch 050 | Train 0.0128/0.9970 | Val 0.4989/0.8978


Train E051: 100%|██████████| 11/11 [00:01<00:00, 10.16it/s]


Epoch 051 | Train 0.0192/0.9926 | Val 0.5763/0.8800


Train E052: 100%|██████████| 11/11 [00:01<00:00, 10.11it/s]


Epoch 052 | Train 0.0389/0.9926 | Val 0.4276/0.8978


Train E053: 100%|██████████| 11/11 [00:01<00:00, 10.04it/s]


Epoch 053 | Train 0.0442/0.9867 | Val 0.3954/0.8933
Early stopping.

TTA accuracy: 0.8933

TTA classification report:
               precision    recall  f1-score   support

           0       0.86      1.00      0.92         6
           1       0.86      1.00      0.92         6
           2       0.86      1.00      0.92         6
           3       0.83      0.83      0.83         6
           4       0.62      0.83      0.71         6
           5       1.00      0.67      0.80         6
           6       1.00      0.86      0.92         7
           7       1.00      0.57      0.73         7
           8       1.00      0.83      0.91         6
           9       1.00      0.83      0.91         6
           a       1.00      0.86      0.92         7
           b       1.00      1.00      1.00         6
           c       1.00      0.83      0.91         6
           d       0.86      1.00      0.92         6
           e       1.00      0.86      0.92         7
           f    