In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from pytorch_metric_learning.losses import SupConLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------------- DATASET ----------------
class EMGDataset(Dataset):
    def __init__(self, x_path, y_path):
        self.X = np.load(x_path)
        self.y = np.argmax(np.load(y_path), axis=1)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (
            torch.tensor(self.X[idx], dtype=torch.float32),
            torch.tensor(self.y[idx], dtype=torch.long)
        )

# ---------------- ENCODER ----------------
class CNNEncoder(nn.Module):
    def __init__(self, emb_dim=256):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv1d(6, 64, kernel_size=10),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=10),
            nn.ReLU(),
            nn.MaxPool1d(3),

            nn.Conv1d(64, 256, kernel_size=10),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=10),
            nn.ReLU(),

            nn.AdaptiveAvgPool1d(1)
        )

        self.fc = nn.Linear(256, emb_dim)

    def forward(self, x):
        x = x.permute(0, 2, 1)      # (B, C, T)
        x = self.features(x).squeeze(-1)
        z = self.fc(x)
        return z                   # (B, emb_dim)

# ---------------- CLASSIFIER ----------------
class Classifier(nn.Module):
    def __init__(self, emb_dim=256, num_classes=62):
        super().__init__()
        self.fc = nn.Linear(emb_dim, num_classes)

    def forward(self, z):
        return self.fc(z)

# ---------------- FULL MODEL ----------------
class FullModel(nn.Module):
    def __init__(self, emb_dim=256, num_classes=62):
        super().__init__()
        self.encoder = CNNEncoder(emb_dim)
        self.classifier = Classifier(emb_dim, num_classes)

    def forward(self, x):
        z = self.encoder(x)
        logits = self.classifier(z)
        return z, logits

# ---------------- EARLY STOPPING ----------------
class EarlyStopping:
    def __init__(self, patience=10):
        self.patience = patience
        self.best_acc = -1
        self.counter = 0
        self.best_state = None

    def step(self, acc, model):
        if acc > self.best_acc:
            self.best_acc = acc
            self.best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            self.counter = 0
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience

    def restore(self, model):
        model.load_state_dict(self.best_state)

# ---------------- TRAIN + EVAL ----------------
def train_and_eval(model, train_loader, test_loader,
                   supcon_loss, alpha=1.0,
                   epochs=500, patience=10):

    ce_loss = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=5e-4)
    stopper = EarlyStopping(patience)

    for epoch in range(epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            z, logits = model(x)

            z_norm = F.normalize(z, dim=1)   # ðŸ”¥ REQUIRED for SupCon

            loss = ce_loss(logits, y) + alpha * supcon_loss(z_norm, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # ---- EVAL ----
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                _, logits = model(x)
                pred = logits.argmax(1)
                correct += (pred == y).sum().item()
                total += y.size(0)

        acc = correct / total
        print(f"Epoch {epoch+1:02d} | Val Acc: {acc:.4f}")

        if stopper.step(acc, model):
            print("Early stopping")
            break

    stopper.restore(model)
    return stopper.best_acc

# ---------------- MAIN LOOP ----------------
BASE = "models/Data/Data/62_classes/UserDependenet"
supcon_loss = SupConLoss(temperature=0.1)

all_accs = []

for split in range(1, 11):
    print(f"\n===== SPLIT {split} =====")

    train_ds = EMGDataset(
        f"{BASE}/Train/X_train_{split}.npy",
        f"{BASE}/Train/y_train_{split}.npy"
    )
    test_ds = EMGDataset(
        f"{BASE}/Test/X_test_{split}.npy",
        f"{BASE}/Test/y_test_{split}.npy"
    )

    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=128, shuffle=False)

    model = FullModel(emb_dim=256, num_classes=62).to(device)

    acc = train_and_eval(
        model,
        train_loader,
        test_loader,
        supcon_loss,
        alpha=1.0,
        epochs=500,
        patience=10
    )

    print(f"Split {split} Accuracy: {acc:.4f}")
    all_accs.append(acc)

print("\nFINAL AVERAGE ACCURACY:", np.mean(all_accs))

Device: cuda

===== SPLIT 1 =====
Epoch 01 | Val Acc: 0.0331
Epoch 02 | Val Acc: 0.2548
Epoch 03 | Val Acc: 0.4024
Epoch 04 | Val Acc: 0.5000
Epoch 05 | Val Acc: 0.5976
Epoch 06 | Val Acc: 0.6266
Epoch 07 | Val Acc: 0.6460
Epoch 08 | Val Acc: 0.6984
Epoch 09 | Val Acc: 0.7185
Epoch 10 | Val Acc: 0.7024
Epoch 11 | Val Acc: 0.7105
Epoch 12 | Val Acc: 0.7427
Epoch 13 | Val Acc: 0.7589
Epoch 14 | Val Acc: 0.7589
Epoch 15 | Val Acc: 0.7476
Epoch 16 | Val Acc: 0.7605
Epoch 17 | Val Acc: 0.7645
Epoch 18 | Val Acc: 0.7766
Epoch 19 | Val Acc: 0.7839
Epoch 20 | Val Acc: 0.7742
Epoch 21 | Val Acc: 0.7782
Epoch 22 | Val Acc: 0.7710
Epoch 23 | Val Acc: 0.7581
Epoch 24 | Val Acc: 0.7694
Epoch 25 | Val Acc: 0.7790
Epoch 26 | Val Acc: 0.7782
Epoch 27 | Val Acc: 0.7815
Epoch 28 | Val Acc: 0.7798
Epoch 29 | Val Acc: 0.7895
Epoch 30 | Val Acc: 0.7677
Epoch 31 | Val Acc: 0.7790
Epoch 32 | Val Acc: 0.7790
Epoch 33 | Val Acc: 0.7766
Epoch 34 | Val Acc: 0.7782
Epoch 35 | Val Acc: 0.7782
Epoch 36 | Val Acc: 0

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from pytorch_metric_learning.losses import SupConLoss

# ---------------- Setup ----------------
torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------------- DATASET ----------------
class EMGDataset(Dataset):
    def __init__(self, x_path, y_path):
        X = np.load(x_path)
        y = np.argmax(np.load(y_path), axis=1)

        # channel-wise normalization (MATCHES KERAS)
        for i in range(X.shape[0]):
            for c in range(6):
                X[i, :, c] = (X[i, :, c] - X[i, :, c].mean()) / (X[i, :, c].std() + 1e-8)

        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# ---------------- CNN + BiLSTM ENCODER ----------------
class CNN_BiLSTM_Encoder(nn.Module):
    def __init__(self, emb_dim=256):
        super().__init__()

        # CNN (same as your Keras)
        self.cnn = nn.Sequential(
            nn.Conv1d(6, 64, kernel_size=10),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=10),
            nn.ReLU(),
            nn.MaxPool1d(3),

            nn.Conv1d(64, 256, kernel_size=10),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=10),
            nn.ReLU()
        )

        # BiLSTM (same idea as Keras)
        self.bilstm = nn.LSTM(
            input_size=256,
            hidden_size=128,
            batch_first=True,
            bidirectional=True
        )

        self.fc = nn.Linear(256, emb_dim)

    def forward(self, x):
        x = x.permute(0, 2, 1)        # (B, C, T)
        x = self.cnn(x)               # (B, 256, T')
        x = x.permute(0, 2, 1)        # (B, T', 256)

        lstm_out, _ = self.bilstm(x)  # (B, T', 256)
        feat = lstm_out[:, -1, :]     # last timestep

        z = self.fc(feat)             # (B, emb_dim)
        return z

# ---------------- CLASSIFIER ----------------
class Classifier(nn.Module):
    def __init__(self, emb_dim=256, num_classes=62):
        super().__init__()
        self.fc = nn.Linear(emb_dim, num_classes)

    def forward(self, z):
        return self.fc(z)

# ---------------- FULL MODEL ----------------
class FullModel(nn.Module):
    def __init__(self, emb_dim=256, num_classes=62):
        super().__init__()
        self.encoder = CNN_BiLSTM_Encoder(emb_dim)
        self.classifier = Classifier(emb_dim, num_classes)

    def forward(self, x):
        z = self.encoder(x)
        logits = self.classifier(z)
        return z, logits

# ---------------- EARLY STOPPING ----------------
class EarlyStopping:
    def __init__(self, patience=10):
        self.patience = patience
        self.best_acc = -1
        self.counter = 0
        self.best_state = None

    def step(self, acc, model):
        if acc > self.best_acc:
            self.best_acc = acc
            self.best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            self.counter = 0
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience

    def restore(self, model):
        model.load_state_dict(self.best_state)

# ---------------- TRAIN + EVAL ----------------
def train_and_eval(model, train_loader, test_loader,
                   supcon_loss, alpha=1.0,
                   epochs=500, patience=10):

    ce_loss = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=5e-4)
    stopper = EarlyStopping(patience)

    for epoch in range(epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            z, logits = model(x)
            z = F.normalize(z, dim=1)  # REQUIRED FOR SUPCON

            loss = ce_loss(logits, y) + alpha * supcon_loss(z, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # -------- Validation --------
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                _, logits = model(x)
                pred = logits.argmax(1)
                correct += (pred == y).sum().item()
                total += y.size(0)

        acc = correct / total
        print(f"Epoch {epoch+1:03d} | Val Acc: {acc:.4f}")

        if stopper.step(acc, model):
            print("Early stopping")
            break

    stopper.restore(model)
    return stopper.best_acc

# ---------------- MAIN LOOP ----------------
BASE = "models/Data/Data/62_classes/UserDependenet"
supcon_loss = SupConLoss(temperature=0.1)

all_accs = []

for split in range(1, 11):
    print(f"\n===== SPLIT {split} =====")

    train_ds = EMGDataset(
        f"{BASE}/Train/X_train_{split}.npy",
        f"{BASE}/Train/y_train_{split}.npy"
    )
    test_ds = EMGDataset(
        f"{BASE}/Test/X_test_{split}.npy",
        f"{BASE}/Test/y_test_{split}.npy"
    )

    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=128)

    model = FullModel(emb_dim=256, num_classes=62).to(device)

    acc = train_and_eval(
        model,
        train_loader,
        test_loader,
        supcon_loss,
        alpha=1.0,
        epochs=500,
        patience=10
    )

    print(f"Split {split} Accuracy: {acc:.4f}")
    all_accs.append(acc)

    del model
    torch.cuda.empty_cache()

print("\nFINAL AVERAGE ACCURACY:", np.mean(all_accs))

Device: cuda

===== SPLIT 1 =====
Epoch 001 | Val Acc: 0.0218
Epoch 002 | Val Acc: 0.0605
Epoch 003 | Val Acc: 0.1000
Epoch 004 | Val Acc: 0.1282
Epoch 005 | Val Acc: 0.1589
Epoch 006 | Val Acc: 0.1992
Epoch 007 | Val Acc: 0.2710
Epoch 008 | Val Acc: 0.2839
Epoch 009 | Val Acc: 0.3847
Epoch 010 | Val Acc: 0.4298
Epoch 011 | Val Acc: 0.4677
Epoch 012 | Val Acc: 0.5185
Epoch 013 | Val Acc: 0.6024
Epoch 014 | Val Acc: 0.6048
Epoch 015 | Val Acc: 0.6226
Epoch 016 | Val Acc: 0.6492
Epoch 017 | Val Acc: 0.6492
Epoch 018 | Val Acc: 0.6669
Epoch 019 | Val Acc: 0.6532
Epoch 020 | Val Acc: 0.6863
Epoch 021 | Val Acc: 0.6790
Epoch 022 | Val Acc: 0.6798
Epoch 023 | Val Acc: 0.6887
Epoch 024 | Val Acc: 0.6855
Epoch 025 | Val Acc: 0.6710
Epoch 026 | Val Acc: 0.7048
Epoch 027 | Val Acc: 0.7000
Epoch 028 | Val Acc: 0.7097
Epoch 029 | Val Acc: 0.7097
Epoch 030 | Val Acc: 0.7113
Epoch 031 | Val Acc: 0.7089
Epoch 032 | Val Acc: 0.6944
Epoch 033 | Val Acc: 0.7218
Epoch 034 | Val Acc: 0.7177
Epoch 035 | Va