<div style="background-color:rgb(0, 55, 207); padding: 30px; border-radius: 20px; box-shadow: 0 4px 15px rgba(105, 195, 255, 0.3); color:rgb(187, 201, 248); font-family: 'Times New Roman', serif;">

<h1 style="text-align: center; font-size: 38px; color: white; font-weight: bold;">Fine Tuning St-GCN</h1>

<h3 style="font-size: 22px; color: white; font-weight: bold;">Libraries</h3>

In [None]:
import sys
import yaml
import torch
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
from feeder.feeder import Feeder
from main import init_recognition
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score, precision_score, recall_score, f1_score


<h3 style="font-size: 22px; color: white; font-weight: bold;">Create Dataset</h3>

In [None]:
# Load config
with open('config/st_gcn/mediapipe-asl.yaml', 'r', encoding='utf-8') as f:
    cfg = yaml.safe_load(f)

# Create dataset
dataset = Feeder(**cfg['train_feeder_args'])
loader = DataLoader(dataset, batch_size=cfg['train_batch_size'], shuffle=False)

print(f"🧪 Dataset contains {len(dataset)} samples")
print(f"📦 Each batch contains {cfg['train_batch_size']} samples")
print(f"🔁 Total batches = {len(loader)}")

<h3 style="font-size: 22px; color: white; font-weight: bold;">Fine Tuning the Model</h3>

In [None]:
config_path = 'config/st_gcn/mediapipe-asl.yaml'
processor = init_recognition(config_path)

# Save your fine-tuned model
torch.save(processor.model.state_dict(), "./weights/fine_tuned_model.pt")

<h3 style="font-size: 22px; color: white; font-weight: bold;">Model Evaluation</h3>

In [None]:
# === 1. Set paths ===
STGCN_PATH = r"YOUR_MAIN_DIRECTORY"
MODEL_PATH = "./weights/fine_tuned_model.pt"
VAL_DATA_PATH = './data/mediapipe_asl/val_data.npy'
VAL_LABEL_PATH = './data/mediapipe_asl/val_label.pkl' 
EXCEL_PATH = r"LABELS.xlsx"
MAP_LABELS_PATH = "./data/mediapipe_asl/label_mapping.pkl"

# === 2. Load class names and video_id to gloss mapping ===
df = pd.read_excel(EXCEL_PATH)
video_id_to_gloss = {str(row['video_id']).zfill(5): row['gloss'] for _, row in df.iterrows()}
CLASS_NAMES = sorted(df['gloss'].unique())
NUM_CLASSES = len(CLASS_NAMES)
print(f"✅ Loaded {NUM_CLASSES} classes.")

# === 3. Set device ===
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"✅ Using device: {device}")

# === 4. Import model ===
sys.path.append(STGCN_PATH)
from model.st_gcn import Model

model = Model(
    in_channels=6,
    num_class=NUM_CLASSES,
    num_point=21,
    num_person=1,
    graph="graph.mediapipe_asl.Graph",
    graph_args={"layout": "mediapipe_asl", "strategy": "spatial"}
)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model = model.to(device)
model.eval()
print("✅ Model loaded successfully.")

# === 5. Load validation data and labels ===
val_data = np.load(VAL_DATA_PATH)

with open(VAL_LABEL_PATH, 'rb') as f:
    val_label_data = pickle.load(f)

if isinstance(val_label_data, tuple):
    label_list = val_label_data[0]
else:
    label_list = val_label_data

assert len(val_data) == len(label_list), "Mismatch between val_data and label_list lengths"
print(f"✅ Loaded {len(label_list)} validation samples.")

# === 6. Predict all samples ===
y_true = []
y_pred = []

print("Predicting...")

with torch.no_grad():
    for idx in range(len(val_data)):
        data_sample = val_data[idx]
        data_sample = np.squeeze(data_sample, axis=-1)  # (3, T, V)

        # Create velocity
        if data_sample.shape[1] > 1:
            velocity = data_sample[:, 1:, :] - data_sample[:, :-1, :]
            position = data_sample[:, :-1, :]
            data_sample = np.concatenate((position, velocity), axis=0)
        else:
            velocity = np.zeros_like(data_sample)
            data_sample = np.concatenate((data_sample, velocity), axis=0)

        data_sample = torch.tensor(data_sample, dtype=torch.float32).unsqueeze(0).to(device)

        filename = label_list[idx]             # e.g., '69206.npy'
        video_id = filename.replace('.npy', '')  # '69206'
        true_gloss = video_id_to_gloss[video_id]  # lookup correct gloss
        true_index = CLASS_NAMES.index(true_gloss)  # map gloss to index

        output = model(data_sample)
        _, predicted = torch.max(output, dim=1)

        y_true.append(true_index)
        y_pred.append(predicted.item())

print("✅ Prediction complete.")

# === 7. Plot confusion matrix ===
print("Plotting confusion matrix...")
cm = confusion_matrix(y_true, y_pred)

fig, ax = plt.subplots(figsize=(18, 18))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=CLASS_NAMES)
disp.plot(ax=ax, cmap='Blues', xticks_rotation=90)
plt.title('Confusion Matrix with Class Names')
plt.tight_layout()
plt.show()
print("✅ Confusion matrix plotted.")

In [None]:
# === Calculate metrics ===
acc = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, average='weighted')
rec = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')

# === Print metrics ===
print(f"✅ Test Accuracy : {acc*100:.2f}%")
print(f"✅ Test Precision: {prec*100:.2f}%")
print(f"✅ Test Recall   : {rec*100:.2f}%")
print(f"✅ Test F1 Score : {f1*100:.2f}%")