In [24]:
!pip install --upgrade --no-deps pytorchvideo==0.1.5 timm==1.0.19 decord==0.6.0 opencv-python==4.12.0.88 transformers==4.53.3 sentencepiece==0.2.0 sentence_transformers==2.2.2


Collecting pytorchvideo==0.1.5
  Downloading pytorchvideo-0.1.5.tar.gz (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.7/132.7 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting sentence_transformers==2.2.2
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytorchvideo, sentence_transformers
  Building wheel for pytorchvideo (setup.py) ... [?25l[?25hdone
  Created wheel for pytorchvideo: filename=pytorchvideo-0.1.5-py3-none-any.whl size=188686 sha256=d9341cb3991308cc4bc20c40919d5cb1147420d61070402b99f6967dbb7f6920
  Stored in directory: /root/.cache/pip/wheels/a4/6d/ae/d016375a73be141a0e11bb42289e2d0b046c35687fc8010ecc
  Building wheel for sentence_transformers (setup.py

In [25]:
!pip uninstall -y scikit-learn sklearn
!pip install --no-cache-dir scikit-learn==1.5.2 numpy==1.26.4 scipy==1.13.1


Found existing installation: scikit-learn 1.2.2
Uninstalling scikit-learn-1.2.2:
  Successfully uninstalled scikit-learn-1.2.2
[0mCollecting scikit-learn==1.5.2
  Downloading scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting scipy==1.13.1
  Downloading scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
Downloading scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.3/13.3 MB[0m [31m276.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (38.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.6/38.6 MB[0m [31m265.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25

In [26]:
!pip install --upgrade --no-cache-dir numpy==1.26.4 scipy==1.13.1 scikit-learn==1.5.2




In [27]:


# MViT imports
from transformers import AutoModelForVideoClassification, AutoConfig
# Clinical embeddings imports
from transformers import AutoTokenizer, AutoModel
import json

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [28]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score, precision_score, recall_score
import cv2

In [29]:
!pip install decord==0.6.0 --quiet


In [30]:
import decord
from decord import VideoReader, cpu

In [31]:
import torch
import torch.nn as nn
import timm
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [32]:
import os
from glob import glob
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

In [33]:
class GaitDataset(Dataset):
    def __init__(self, video_paths, labels, num_frames=16, frame_size=224):
        self.video_paths = video_paths
        self.labels = labels
        self.num_frames = num_frames
        self.frame_size = frame_size

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        try:
            vr = VideoReader(self.video_paths[idx], ctx=cpu(0))
            total_frames = len(vr)

            if total_frames <= self.num_frames:
                frame_indices = list(range(total_frames))
                while len(frame_indices) < self.num_frames:
                    frame_indices.append(frame_indices[-1])
            else:
                frame_indices = np.linspace(0, total_frames-1, self.num_frames, dtype=int)

            frames = vr.get_batch(frame_indices).asnumpy()

            # Resize frames
            resized_frames = []
            for frame in frames:
                resized_frame = cv2.resize(frame, (self.frame_size, self.frame_size))
                resized_frames.append(resized_frame)
            frames = np.array(resized_frames)

            # Convert to tensor and normalize
            frames = torch.from_numpy(frames).permute(3, 0, 1, 2).float()
            if frames.max() > 1.0:
                frames = frames / 255.0

            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1, 1)
            frames = (frames - mean) / std

            return frames, self.labels[idx]

        except Exception as e:
            print(f"Error loading {self.video_paths[idx]}: {e}")
            dummy_frames = torch.randn(3, self.num_frames, self.frame_size, self.frame_size)
            return dummy_frames, self.labels[idx]

In [34]:
def load_and_split_dataset(base_path, test_size=0.2, val_size=0.1):
    # Find all video files and extract labels from folder names
    video_paths = []
    labels = []
    label_to_idx = {}
    idx_counter = 0

    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith(('.mp4', '.MOV', '.mov')):
                video_path = os.path.join(root, file)
                # Use immediate parent folder as label
                label_name = os.path.basename(root)

                if label_name not in label_to_idx:
                    label_to_idx[label_name] = idx_counter
                    idx_counter += 1

                video_paths.append(video_path)
                labels.append(label_to_idx[label_name])

    # 70:20:10 split (train:test:val)
    train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
        video_paths, labels, test_size=test_size, random_state=42, stratify=labels
    )

    val_ratio = val_size / (1 - test_size)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        train_val_paths, train_val_labels, test_size=val_ratio, random_state=42, stratify=train_val_labels
    )

    print(f"Dataset loaded: {len(video_paths)} videos")
    print(f"Train: {len(train_paths)}, Val: {len(val_paths)}, Test: {len(test_paths)}")
    print(f"Classes: {label_to_idx}")

    return train_paths, train_labels, val_paths, val_labels, test_paths, test_labels, label_to_idx

In [None]:
from google.colab import drive

# Mount your Google Drive
drive.mount('/content/drive', force_remount=True)

# To access shared drives, you might need to authorize it after mounting
# You can then access shared drives under /content/drive/Shared with me/

In [None]:
!pip freeze requiremets.txt
print("Requirements file made")

In [None]:
import zipfile
import os

# Load your dataset

# base_path = "/content/drive/MyDrive/GiatLabDatset.zip/" # Original path, commented out

zip_path = "/content/drive/MyDrive/GiatLabDatset.zip"
unzip_path = "/content/GaitLabDataset/"

if os.path.exists(zip_path):
    print(f"Unzipping {zip_path} to {unzip_path}")
    os.makedirs(unzip_path, exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(unzip_path)
    print("Unzipping complete.")
    # Update base_path to the unzipped directory
    base_path = unzip_path
else:
    print(f"Error: Zip file not found at {zip_path}")
    # Keep the original base_path if zip file is not found
    base_path = "/content/drive/MyDrive/GiatLabDatset.zip/" # Or handle this case as appropriate for your workflow


# Verify the contents of the unzipped directory
if os.path.exists(base_path):
    print(f"Contents of {base_path}:")
    try:
        for item in os.listdir(base_path):
            print(f"  - {item}")
    except Exception as e:
        print(f"Error listing contents: {e}")
else:
    print(f"Error: Unzip path not found at {base_path}")

In [35]:
!ls /kaggle/input/gaitlabdataset

GiatLabDatset


In [37]:
# Load your dataset after unzipping
# Moved the load_and_split_dataset call here
base_path="/kaggle/input/gaitlabdataset/GiatLabDatset"
if os.path.exists(base_path):
    train_paths, train_labels, val_paths, val_labels, test_paths, test_labels, class_mapping = load_and_split_dataset(base_path)
else:
    print("Cannot load dataset as the base path does not exist.")
    train_paths, train_labels, val_paths, val_labels, test_paths, test_labels, class_mapping = [], [], [], [], [], [], {}

Dataset loaded: 230 videos
Train: 161, Val: 23, Test: 46
Classes: {'Normal': 0, 'Assistive': 1, 'NonAssistive': 2, 'PD_Mild': 3, 'PD_Early': 4, 'PD_Severe': 5, 'KOA_Early': 6, 'KOA_Mild': 7, 'KOA_Severe': 8}


In [38]:
# Create datasets
train_dataset = GaitDataset(train_paths, train_labels)
val_dataset = GaitDataset(val_paths, val_labels)
test_dataset = GaitDataset(test_paths, test_labels)

# Create data loaders
batch_size = 12
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Data loaders created with batch size {batch_size}")

Data loaders created with batch size 12


In [39]:
class MViTTeacher(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        # Use a pre-trained MViT from timm and adapt it for video
        self.backbone = timm.create_model('mvitv2_base', pretrained=pretrained)

        # Modify the head for video classification
        # Remove the original head
        self.backbone.head = nn.Identity()

        # Determine the number of features from the backbone's output
        # We need a dummy forward pass to get the shape
        # Assuming an input shape like [1, 3, 16, 224, 224] (batch=1, channels=3, frames=16, height=224, width=224)
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 16, 224, 224)
            # Reshape for the MViT backbone (treating frames as batch)
            dummy_input = dummy_input.permute(0, 2, 1, 3, 4).contiguous()
            dummy_input = dummy_input.view(1 * 16, 3, 224, 224)
            num_ftrs = self.backbone(dummy_input).shape[-1]


        # Add a new classification head that expects flattened features [batch_size, num_frames * features_per_frame]
        self.classifier = nn.Linear(num_ftrs * 16, num_classes) # Assuming 16 frames as defined in GaitDataset

    def forward(self, x):
        # x is [batch, channels, frames, height, width]
        batch_size, channels, frames, height, width = x.size()

        # Reshape for the MViT backbone (treating frames as batch)
        # The timm MViT expects [batch*frames, channels, height, width]
        x = x.permute(0, 2, 1, 3, 4).contiguous()
        x = x.view(batch_size * frames, channels, height, width)

        # Pass through the MViT backbone (without the original head)
        visual_features = self.backbone(x) # Shape: [batch*frames, num_ftrs]

        # Reshape back to separate batch and frames
        visual_features = visual_features.view(batch_size, frames, -1) # Shape: [batch, frames, num_ftrs]

        # Flatten the temporal and feature dimensions for the new classifier
        visual_features_flat = visual_features.view(batch_size, -1) # Shape: [batch, frames * num_ftrs]

        # Pass through the new classifier
        logits = self.classifier(visual_features_flat)

        return logits

In [40]:
num_classes = len(class_mapping)
teacher_model = MViTTeacher(num_classes=num_classes)
teacher_model = teacher_model.to(device)
print(f"MViT Teacher initialized with {num_classes} classes")

model.safetensors:   0%|          | 0.00/206M [00:00<?, ?B/s]

MViT Teacher initialized with 9 classes


In [41]:
class ClinicalEnhancedStudent(nn.Module):
    def __init__(self, num_classes, clinical_dim=768): # Changed clinical_dim to 768
        super().__init__()

        # Visual backbone
        self.visual_encoder = nn.Sequential(
            nn.Conv3d(3, 16, kernel_size=(3, 3, 3), padding=1),
            nn.BatchNorm3d(16), nn.ReLU(), nn.MaxPool3d((1, 2, 2)),

            nn.Conv3d(16, 32, kernel_size=(3, 3, 3), padding=1),
            nn.BatchNorm3d(32), nn.ReLU(), nn.MaxPool3d((1, 2, 2)),

            nn.Conv3d(32, 64, kernel_size=(3, 3, 3), padding=1),
            nn.BatchNorm3d(64), nn.ReLU(), nn.MaxPool3d((1, 2, 2)),

            nn.AdaptiveAvgPool3d((None, 7, 7))
        )

        # Clinical embeddings (from language model)
        self.clinical_proj = nn.Linear(clinical_dim, 128)

        # Fusion classifier
        # Need to determine the output size of visual_encoder to calculate the input size for the classifier
        # Let's perform a dummy forward pass
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 16, 224, 224) # Assuming input shape [batch, channels, frames, height, width]
            visual_output_dummy = self.visual_encoder(dummy_input)
            visual_flat_size = visual_output_dummy.view(1, -1).shape[-1]

        self.classifier = nn.Sequential(
            nn.Linear(visual_flat_size + 128, 256), # Updated input size for classifier
            nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x, clinical_embeds=None):
        # Visual features
        visual_features = self.visual_encoder(x)
        batch_size = visual_features.size(0)
        visual_flat = visual_features.view(batch_size, -1)

        # Clinical features (if provided)
        if clinical_embeds is not None:
            clinical_proj = self.clinical_proj(clinical_embeds)
            fused_features = torch.cat([visual_flat, clinical_proj], dim=1)
        else:
            fused_features = visual_flat

        return self.classifier(fused_features)

In [42]:
student_model = ClinicalEnhancedStudent(num_classes=num_classes)
student_model = student_model.to(device)
print("Clinical-enhanced student model initialized")

Clinical-enhanced student model initialized


In [65]:
!pip install --upgrade transformers huggingface-hub


Collecting transformers
  Downloading transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub
  Downloading huggingface_hub-1.0.0-py3-none-any.whl.metadata (13 kB)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Downloading tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Downloading transformers-4.57.1-py3-none-any.whl (12.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m102.9 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hDownloading huggingface_hub-0.36.0-py3-none-any.whl (566 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m566.1/566.1 kB[0m [31m35.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.ma

In [69]:
from transformers import AutoTokenizer, AutoModel
import torch

In [70]:
class ClinicalEmbedder:
    def __init__(self):
        # Use a more stable model that doesn't have chat template issues
        model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModel.from_pretrained(model_name)
            self.model.eval()
            print(f"Clinical embedder loaded: {model_name}")
        except Exception as e:
            print(f"Failed to load clinical model: {e}")
            print("Using fallback embedding method...")
            self.model = None
            self.tokenizer = None
    
    def get_embedding(self, text):
        if self.model is None:
            # Fallback: return random embeddings of correct dimension
            return torch.randn(1, 768)
        
        try:
            inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
            with torch.no_grad():
                outputs = self.model(**inputs)
            return outputs.last_hidden_state.mean(dim=1)
        except Exception as e:
            print(f" Embedding generation failed: {e}")
            return torch.randn(1, 768)


In [71]:
# Initialize clinical embedder
clinical_embedder = ClinicalEmbedder()
print("Clinical embedder initialized")

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Failed to load clinical model: 404 Client Error. (Request ID: Root=1-6900a3ac-7f5dae1a69c954e32353b54e;be341991-24e6-484b-8f7e-bf651359e9b2)

Entry Not Found for url: https://huggingface.co/api/models/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract/tree/main/additional_chat_templates?recursive=false&expand=false.
additional_chat_templates does not exist on "main"
Using fallback embedding method...
Clinical embedder initialized


In [72]:
# Clinical descriptions for each gait condition
clinical_descriptions = {
    "Normal": "Normal symmetrical gait pattern with balanced stride length and cadence",
    "KOA_Early": "Early knee osteoarthritis with mild gait modifications and reduced knee flexion",
    "KOA_Mild": "Mild knee osteoarthritis showing limping gait and asymmetric weight bearing",
    "KOA_Severe": "Severe knee osteoarthritis with significant antalgic gait and reduced mobility",
    "PD_Early": "Early Parkinson's disease showing slight shuffling gait and reduced arm swing",
    "PD_Mild": "Mild Parkinson's disease with festinating gait and postural instability",
    "PD_Severe": "Severe Parkinson's disease showing freezing of gait and significant bradykinesia",
    "Disabled_Assistive": "Disabled gait using assistive devices with modified weight distribution",
    "Disabled_NonAssistive": "Disabled gait without assistive devices showing compensatory movements"
}

In [73]:
clinical_embedder = ClinicalEmbedder()
print("Clinical embedder initialized")

Failed to load clinical model: 404 Client Error. (Request ID: Root=1-6900a3bb-7446ab1d1ea7d2667c08a221;86ca3719-0b2d-4ce0-8ade-3111f3a6c176)

Entry Not Found for url: https://huggingface.co/api/models/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract/tree/main/additional_chat_templates?recursive=false&expand=false.
additional_chat_templates does not exist on "main"
Using fallback embedding method...
Clinical embedder initialized


In [74]:
class KnowledgeDistillationTrainer:
    def __init__(self, teacher, student, temperature=3.0, alpha=0.7):
        self.teacher = teacher
        self.student = student
        self.temperature = temperature
        self.alpha = alpha
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')
        self.ce_loss = nn.CrossEntropyLoss()

    def compute_loss(self, student_logits, teacher_logits, labels, clinical_embeds=None):
        # Knowledge distillation loss
        soft_loss = self.kl_loss(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2)

        # Classification loss
        hard_loss = self.ce_loss(student_logits, labels)

        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss, soft_loss, hard_loss

In [75]:
distillation_trainer = KnowledgeDistillationTrainer(teacher_model, student_model)
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4, weight_decay=1e-5)
print("Distillation trainer and optimizer setup complete")

Distillation trainer and optimizer setup complete


In [76]:
def train_epoch():
    teacher_model.eval()
    student_model.train()

    total_loss = 0
    all_preds = []
    all_labels = []

    for batch_idx, (videos, labels) in enumerate(train_loader):
        videos, labels = videos.to(device), labels.to(device)

        # Get clinical embeddings for this batch
        clinical_batch = []
        for label in labels:
            class_name = list(class_mapping.keys())[list(class_mapping.values()).index(label.item())]
            desc = clinical_descriptions.get(class_name) # Use .get to handle missing keys
            if desc:
                clinical_emb = clinical_embedder.get_embedding(desc).to(device)
                clinical_batch.append(clinical_emb)
            else:
                # Append a zero tensor if description is missing
                # Assuming clinical embedding size is 768 based on ClinicalEmbedder
                clinical_batch.append(torch.zeros(1, 768, device=device))


        clinical_batch = torch.cat(clinical_batch, dim=0) if clinical_batch else None

        optimizer.zero_grad()

        with torch.no_grad():
            teacher_outputs = teacher_model(videos)

        student_outputs = student_model(videos, clinical_batch)

        loss, soft_loss, hard_loss = distillation_trainer.compute_loss(
            student_outputs, teacher_outputs, labels, clinical_batch
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = torch.argmax(student_outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        if batch_idx % 10 == 0:
            print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    return total_loss / len(train_loader), accuracy, precision, recall, f1

In [77]:
def validate():
    student_model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for videos, labels in val_loader:
            videos, labels = videos.to(device), labels.to(device)

            # Get clinical embeddings
            clinical_batch = []
            for label in labels:
                class_name = list(class_mapping.keys())[list(class_mapping.values()).index(label.item())]
                desc = clinical_descriptions.get(class_name) # Use .get to handle missing keys
                if desc:
                    clinical_emb = clinical_embedder.get_embedding(desc).to(device)
                    clinical_batch.append(clinical_emb)
                else:
                    # Append a zero tensor if description is missing
                    # Assuming clinical embedding size is 768 based on ClinicalEmbedder
                    clinical_batch.append(torch.zeros(1, 768, device=device))


            clinical_batch = torch.cat(clinical_batch, dim=0) if clinical_batch else None
            outputs = student_model(videos, clinical_batch)
            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    class_names = list(class_mapping.keys())
    # Pass all possible labels to classification_report
    report = classification_report(all_labels, all_preds, target_names=class_names, zero_division=0, labels=list(range(len(class_mapping))))

    print("VALIDATION RESULTS")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")
    print(f"\nClassification Report:\n{report}")


    return accuracy, precision, recall, f1

## Weighted Sampling for Training

### Subtask:
Implement weighted random sampling for the training data loader to address class imbalance during training.

**Reasoning**:
Implement weighted random sampling for the training DataLoader to ensure that batches during training have a more balanced representation of classes. This helps the model learn from minority classes more effectively. Weighted random sampling is typically applied only to the training set to avoid biasing the evaluation metrics on the validation and test sets.

In [79]:
from torch.utils.data import WeightedRandomSampler
from collections import Counter

In [80]:
# Calculate sample weights for the training data
label_counts = Counter(train_labels)
total_samples = len(train_labels)

# Create a list of weights for each sample in the training set
sample_weights = [1.0 / label_counts[label] for label in train_labels]

# Create a WeightedRandomSampler for the training set
train_sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=total_samples, # Sample with replacement for the size of the dataset
    replacement=True
)

# Update the train_loader to use the weighted sampler
# Keep validation and test loaders without sampling (shuffle=False is standard)
batch_size = 12
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Data loaders updated with weighted sampling for training.")
print(f"Training dataset with augmentation: {len(train_dataset)} videos")
print(f"Validation dataset: {len(val_dataset)} videos")
print(f"Test dataset: {len(test_dataset)} videos")

Data loaders updated with weighted sampling for training.
Training dataset with augmentation: 161 videos
Validation dataset: 23 videos
Test dataset: 46 videos


In [81]:
distillation_trainer = KnowledgeDistillationTrainer(teacher_model, student_model)
print("Distillation trainer setup complete")

Distillation trainer setup complete


In [82]:
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4, weight_decay=1e-5)
print("Optimizer setup complete")

Optimizer setup complete


In [83]:
# Training loop
num_epochs = 5

print("Starting training...")
for epoch in range(num_epochs):
    train_loss, train_acc, train_prec, train_rec, train_f1 = train_epoch()
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, Prec: {train_prec:.4f}, Rec: {train_rec:.4f}, F1: {train_f1:.4f}")

    # Validation
    val_acc, val_prec, val_rec, val_f1 = validate()
    print(f"Epoch {epoch+1}/{num_epochs} - Val Acc: {val_acc:.4f}, Prec: {val_prec:.4f}, Rec: {val_rec:.4f}, F1: {val_f1:.4f}")

print("Training and validation complete.")


Starting training...


[aac @ 0x20de76c0] Input buffer exhausted before END element found
[aac @ 0x20463940] Input buffer exhausted before END element found


Batch 0, Loss: 0.7187


[aac @ 0x24cce400] Input buffer exhausted before END element found
[aac @ 0x2f7dc340] Input buffer exhausted before END element found
[aac @ 0x24ccb1c0] Input buffer exhausted before END element found
[aac @ 0x24d3f100] Input buffer exhausted before END element found
[aac @ 0x24d3f100] Input buffer exhausted before END element found
[aac @ 0x2f7dc340] Input buffer exhausted before END element found
[aac @ 0x2f7d6900] Input buffer exhausted before END element found
[aac @ 0x24d5cb00] Input buffer exhausted before END element found
[aac @ 0x2231b840] Input buffer exhausted before END element found
[aac @ 0x2f7d61c0] Input buffer exhausted before END element found
[aac @ 0x24d3f100] Input buffer exhausted before END element found
[aac @ 0x24d3b8c0] Input buffer exhausted before END element found
[aac @ 0x25442a80] Input buffer exhausted before END element found
[aac @ 0x2230e9c0] Input buffer exhausted before END element found
[aac @ 0x24ccddc0] Input buffer exhausted before END element f

Batch 10, Loss: 0.6917


[aac @ 0x2f7d1080] Input buffer exhausted before END element found
[aac @ 0x2f7d1080] Input buffer exhausted before END element found
[aac @ 0x2c03a7c0] Input buffer exhausted before END element found
[aac @ 0x24fa4740] Input buffer exhausted before END element found
[aac @ 0x2f7d1080] Input buffer exhausted before END element found
[aac @ 0x22fe8b40] Input buffer exhausted before END element found


Epoch 1/5 - Train Loss: 0.9636, Acc: 0.2857, Prec: 0.2822, Rec: 0.2857, F1: 0.2802


[aac @ 0x22fe9f00] Input buffer exhausted before END element found
[aac @ 0x2f7d1080] Input buffer exhausted before END element found
[aac @ 0x24ced4c0] Input buffer exhausted before END element found
[aac @ 0x2780c240] Input buffer exhausted before END element found
[aac @ 0x2780c240] Input buffer exhausted before END element found
[aac @ 0x24d04980] Input buffer exhausted before END element found
[aac @ 0x2542d280] Input buffer exhausted before END element found


VALIDATION RESULTS
Accuracy: 0.1739
Precision: 0.0717
Recall: 0.1739
F1-Score: 0.0975

Classification Report:
              precision    recall  f1-score   support

      Normal       0.00      0.00      0.00         6
   Assistive       0.30      1.00      0.46         3
NonAssistive       0.00      0.00      0.00         4
     PD_Mild       0.00      0.00      0.00         1
    PD_Early       0.00      0.00      0.00         1
   PD_Severe       0.00      0.00      0.00         0
   KOA_Early       0.25      0.33      0.29         3
    KOA_Mild       0.00      0.00      0.00         4
  KOA_Severe       0.00      0.00      0.00         1

    accuracy                           0.17        23
   macro avg       0.06      0.15      0.08        23
weighted avg       0.07      0.17      0.10        23

Epoch 1/5 - Val Acc: 0.1739, Prec: 0.0717, Rec: 0.1739, F1: 0.0975


[aac @ 0x2542e2c0] Input buffer exhausted before END element found
[aac @ 0x2f748180] Input buffer exhausted before END element found
[aac @ 0x24d1e4c0] Input buffer exhausted before END element found
[aac @ 0x25a99bc0] Input buffer exhausted before END element found


Batch 0, Loss: 0.6540


[aac @ 0x25d162c0] Input buffer exhausted before END element found
[aac @ 0x2f79d880] Input buffer exhausted before END element found
[aac @ 0x21027680] Input buffer exhausted before END element found
[aac @ 0x2f73e540] Input buffer exhausted before END element found
[aac @ 0x25a96980] Input buffer exhausted before END element found
[aac @ 0x2c6f8300] Input buffer exhausted before END element found
[aac @ 0x2f743bc0] Input buffer exhausted before END element found
[aac @ 0x2f79cc80] Input buffer exhausted before END element found
[aac @ 0x25a999c0] Input buffer exhausted before END element found
[aac @ 0x25a999c0] Input buffer exhausted before END element found
[aac @ 0x2f73fc00] Input buffer exhausted before END element found
[aac @ 0x2f7a0600] Input buffer exhausted before END element found
[aac @ 0x25d12500] Input buffer exhausted before END element found
[aac @ 0x25d12500] Input buffer exhausted before END element found
[aac @ 0x2f7a0600] Input buffer exhausted before END element f

Batch 10, Loss: 0.6126


[aac @ 0x24d1d680] Input buffer exhausted before END element found
[aac @ 0x2c925640] Input buffer exhausted before END element found
[aac @ 0x2f78c440] Input buffer exhausted before END element found
[aac @ 0x2f73fc00] Input buffer exhausted before END element found
[aac @ 0x2f79d0c0] Input buffer exhausted before END element found
[aac @ 0x2c6fe200] Input buffer exhausted before END element found
[aac @ 0x2f79d0c0] Input buffer exhausted before END element found
[aac @ 0x2f7a0600] Input buffer exhausted before END element found


Epoch 2/5 - Train Loss: 0.6143, Acc: 0.4410, Prec: 0.4036, Rec: 0.4410, F1: 0.4040


[aac @ 0x2f76b880] Input buffer exhausted before END element found
[aac @ 0x2f76b880] Input buffer exhausted before END element found
[aac @ 0x24d021c0] Input buffer exhausted before END element found
[aac @ 0x25af7640] Input buffer exhausted before END element found
[aac @ 0x24d02cc0] Input buffer exhausted before END element found
[aac @ 0x2f76b880] Input buffer exhausted before END element found
[aac @ 0x24d02cc0] Input buffer exhausted before END element found


VALIDATION RESULTS
Accuracy: 0.4348
Precision: 0.5290
Recall: 0.4348
F1-Score: 0.4478

Classification Report:
              precision    recall  f1-score   support

      Normal       0.75      0.50      0.60         6
   Assistive       1.00      0.33      0.50         3
NonAssistive       0.67      1.00      0.80         4
     PD_Mild       0.00      0.00      0.00         1
    PD_Early       0.00      0.00      0.00         1
   PD_Severe       0.00      0.00      0.00         0
   KOA_Early       0.00      0.00      0.00         3
    KOA_Mild       0.50      0.50      0.50         4
  KOA_Severe       0.00      0.00      0.00         1

    accuracy                           0.43        23
   macro avg       0.32      0.26      0.27        23
weighted avg       0.53      0.43      0.45        23

Epoch 2/5 - Val Acc: 0.4348, Prec: 0.5290, Rec: 0.4348, F1: 0.4478


[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x30d80b80] Input buffer exhausted before END element found


Batch 0, Loss: 0.5793


[aac @ 0x2c7ad540] Input buffer exhausted before END element found
[aac @ 0x2c7ad540] Input buffer exhausted before END element found
[aac @ 0x2c7ad540] Input buffer exhausted before END element found
[aac @ 0x24ddd480] Input buffer exhausted before END element found
[aac @ 0x2f7416c0] Input buffer exhausted before END element found
[aac @ 0x2f726fc0] Input buffer exhausted before END element found
[aac @ 0x2f73a080] Input buffer exhausted before END element found
[aac @ 0x2f73a040] Input buffer exhausted before END element found
[mov,mp4,m4a,3gp,3g2,mj2 @ 0x2f73aa80] moov atom not found
[11:19:14] /github/workspace/src/video/video_reader.cc:83: ERROR opening: /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV, Invalid data found when processing input


Error loading /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV: Error reading /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV...


[mov,mp4,m4a,3gp,3g2,mj2 @ 0x24d59d40] moov atom not found
[11:19:31] /github/workspace/src/video/video_reader.cc:83: ERROR opening: /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV, Invalid data found when processing input


Error loading /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV: Error reading /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV...


[aac @ 0x2c8d4d80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x2c7ad440] Input buffer exhausted before END element found
[aac @ 0x2c7ad440] Input buffer exhausted before END element found
[aac @ 0x25a954c0] Input buffer exhausted before END element found
[aac @ 0x2f7416c0] Input buffer exhausted before END element found
[aac @ 0x2f7416c0] Input buffer exhausted before END element found
[aac @ 0x24d20e80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x25ae4e00] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x2c038ec0] Input buffer exhausted before END element found
[aac @ 0x25d10b40] Input buffer exhausted before END element found
[aac @ 0x2f79e200] Input buffer exhausted before END element f

Batch 10, Loss: 0.5498


[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x2c6fde80] Input buffer exhausted before END element found


Epoch 3/5 - Train Loss: 0.5839, Acc: 0.6149, Prec: 0.6513, Rec: 0.6149, F1: 0.6023


[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x25af8080] Input buffer exhausted before END element found
[aac @ 0x24d02780] Input buffer exhausted before END element found
[aac @ 0x2c0389c0] Input buffer exhausted before END element found
[aac @ 0x24fa3940] Input buffer exhausted before END element found
[aac @ 0x25a920c0] Input buffer exhausted before END element found
[aac @ 0x24cef240] Input buffer exhausted before END element found


VALIDATION RESULTS
Accuracy: 0.6087
Precision: 0.6304
Recall: 0.6087
F1-Score: 0.6078

Classification Report:
              precision    recall  f1-score   support

      Normal       1.00      0.83      0.91         6
   Assistive       0.75      1.00      0.86         3
NonAssistive       1.00      0.75      0.86         4
     PD_Mild       1.00      1.00      1.00         1
    PD_Early       0.50      1.00      0.67         1
   PD_Severe       0.00      0.00      0.00         0
   KOA_Early       0.25      0.33      0.29         3
    KOA_Mild       0.00      0.00      0.00         4
  KOA_Severe       0.00      0.00      0.00         1

   micro avg       0.61      0.61      0.61        23
   macro avg       0.50      0.55      0.51        23
weighted avg       0.63      0.61      0.61        23

Epoch 3/5 - Val Acc: 0.6087, Prec: 0.6304, Rec: 0.6087, F1: 0.6078


[aac @ 0x2102af40] Input buffer exhausted before END element found
[aac @ 0x2c925d00] Input buffer exhausted before END element found
[aac @ 0x2f726f80] Input buffer exhausted before END element found


Batch 0, Loss: 0.5650


[aac @ 0x25a98340] Input buffer exhausted before END element found
[aac @ 0x25ac0c40] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x24d81e00] Input buffer exhausted before END element found
[aac @ 0x25a98340] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x25af5540] Input buffer exhausted before END element found
[aac @ 0x25a98340] Input buffer exhausted before END element found
[aac @ 0x4b47d140] Input buffer exhausted before END element found
[aac @ 0x2f73f880] Input buffer exhausted before END element found
[aac @ 0x25a98340] Input buffer exhausted before END element found
[aac @ 0x25a98340] Input buffer exhausted before END element found
[aac @ 0x24d0f740] Input buffer exhausted before END element found
[aac @ 0x25aea500] Input buffer exhausted before END element found
[aac @ 0x2f738b00] Input buffer exhausted before END element f

Batch 10, Loss: 0.5415


[aac @ 0x2f72e5c0] Input buffer exhausted before END element found
[aac @ 0x24d02d00] Input buffer exhausted before END element found
[aac @ 0x2c036880] Input buffer exhausted before END element found
[aac @ 0x25ae3c80] Input buffer exhausted before END element found
[aac @ 0x25a95580] Input buffer exhausted before END element found


Epoch 4/5 - Train Loss: 0.5373, Acc: 0.6832, Prec: 0.6928, Rec: 0.6832, F1: 0.6792


[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x2c03b4c0] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x2c03b4c0] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x2c03b4c0] Input buffer exhausted before END element found


VALIDATION RESULTS
Accuracy: 0.4348
Precision: 0.5543
Recall: 0.4348
F1-Score: 0.4373

Classification Report:
              precision    recall  f1-score   support

      Normal       1.00      0.50      0.67         6
   Assistive       1.00      0.33      0.50         3
NonAssistive       0.67      1.00      0.80         4
     PD_Mild       0.33      1.00      0.50         1
    PD_Early       0.00      0.00      0.00         1
   PD_Severe       0.00      0.00      0.00         0
   KOA_Early       0.25      0.33      0.29         3
    KOA_Mild       0.00      0.00      0.00         4
  KOA_Severe       0.00      0.00      0.00         1

   micro avg       0.43      0.43      0.43        23
   macro avg       0.36      0.35      0.31        23
weighted avg       0.55      0.43      0.44        23

Epoch 4/5 - Val Acc: 0.4348, Prec: 0.5543, Rec: 0.4348, F1: 0.4373


[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found


Batch 0, Loss: 0.5736


[aac @ 0x214b5880] Input buffer exhausted before END element found
[aac @ 0x2e705340] Input buffer exhausted before END element found
[aac @ 0x25d12500] Input buffer exhausted before END element found
[aac @ 0x25af6480] Input buffer exhausted before END element found
[aac @ 0x2c7ada00] Input buffer exhausted before END element found
[aac @ 0x24cf6940] Input buffer exhausted before END element found
[aac @ 0x24cf6940] Input buffer exhausted before END element found
[aac @ 0x2e705340] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24d02980] Input buffer exhausted before END element found
[aac @ 0x25af8080] Input buffer exhausted before END element found
[aac @ 0x2c7aafc0] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element f

Batch 10, Loss: 0.4926


[aac @ 0x2f72d5c0] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x2c7ada00] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x214b5880] Input buffer exhausted before END element found
[aac @ 0x214b5880] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x2f741ac0] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found


Epoch 5/5 - Train Loss: 0.5263, Acc: 0.7764, Prec: 0.8019, Rec: 0.7764, F1: 0.7592


[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x25af7840] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found


VALIDATION RESULTS
Accuracy: 0.4783
Precision: 0.5109
Recall: 0.4783
F1-Score: 0.4859

Classification Report:
              precision    recall  f1-score   support

      Normal       0.80      0.67      0.73         6
   Assistive       1.00      0.67      0.80         3
NonAssistive       0.80      1.00      0.89         4
     PD_Mild       0.00      0.00      0.00         1
    PD_Early       0.00      0.00      0.00         1
   PD_Severe       0.00      0.00      0.00         0
   KOA_Early       0.25      0.33      0.29         3
    KOA_Mild       0.00      0.00      0.00         4
  KOA_Severe       0.00      0.00      0.00         1

   micro avg       0.48      0.48      0.48        23
   macro avg       0.32      0.30      0.30        23
weighted avg       0.51      0.48      0.49        23

Epoch 5/5 - Val Acc: 0.4783, Prec: 0.5109, Rec: 0.4783, F1: 0.4859
Training and validation complete.


In [85]:
num_epochs = 5

print("Starting training...")
for epoch in range(num_epochs):
    train_loss, train_acc, train_prec, train_rec, train_f1 = train_epoch()
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, Prec: {train_prec:.4f}, Rec: {train_rec:.4f}, F1: {train_f1:.4f}")

    # Validation
    val_acc, val_prec, val_rec, val_f1 = validate()
    print(f"Epoch {epoch+1}/{num_epochs} - Val Acc: {val_acc:.4f}, Prec: {val_prec:.4f}, Rec: {val_rec:.4f}, F1: {val_f1:.4f}")

print("Training and validation complete.")


Starting training...


[aac @ 0x25af5d40] Input buffer exhausted before END element found
[aac @ 0x25af5d40] Input buffer exhausted before END element found
[aac @ 0x25af5d40] Input buffer exhausted before END element found
[aac @ 0x25af5d40] Input buffer exhausted before END element found


Batch 0, Loss: 0.4647


[aac @ 0x214b5880] Input buffer exhausted before END element found
[aac @ 0x25af7840] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x25af7840] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x25aead00] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24cf6900] Input buffer exhausted before END element found
[aac @ 0x214b5880] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x214b5880] Input buffer exhausted before END element found
[aac @ 0x25af7840] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x2f72d980] Input buffer exhausted before END element f

Batch 10, Loss: 0.5095


[aac @ 0x2f76b880] Input buffer exhausted before END element found
[aac @ 0x2542e040] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found


Epoch 1/5 - Train Loss: 0.5094, Acc: 0.7826, Prec: 0.8068, Rec: 0.7826, F1: 0.7676


[aac @ 0x4b47eac0] Input buffer exhausted before END element found
[aac @ 0x2f76b880] Input buffer exhausted before END element found
[aac @ 0x21029240] Input buffer exhausted before END element found
[aac @ 0x2f76b880] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x2f76b880] Input buffer exhausted before END element found
[aac @ 0x4b47eac0] Input buffer exhausted before END element found


VALIDATION RESULTS
Accuracy: 0.6087
Precision: 0.5693
Recall: 0.6087
F1-Score: 0.5805

Classification Report:
              precision    recall  f1-score   support

      Normal       0.86      1.00      0.92         6
   Assistive       1.00      0.67      0.80         3
NonAssistive       0.80      1.00      0.89         4
     PD_Mild       1.00      1.00      1.00         1
    PD_Early       0.00      0.00      0.00         1
   PD_Severe       0.00      0.00      0.00         0
   KOA_Early       0.25      0.33      0.29         3
    KOA_Mild       0.00      0.00      0.00         4
  KOA_Severe       0.00      0.00      0.00         1

   micro avg       0.61      0.61      0.61        23
   macro avg       0.43      0.44      0.43        23
weighted avg       0.57      0.61      0.58        23

Epoch 1/5 - Val Acc: 0.6087, Prec: 0.5693, Rec: 0.6087, F1: 0.5805


[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x2f76b880] Input buffer exhausted before END element found


Batch 0, Loss: 0.4917


[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x4b47eac0] Input buffer exhausted before END element found
[aac @ 0x2f737480] Input buffer exhausted before END element found
[aac @ 0x24fa4700] Input buffer exhausted before END element found
[aac @ 0x2f70c600] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x24ceef80] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25af7840] Input buffer exhausted before END element found
[aac @ 0x2f722fc0] Input buffer exhausted before END element found
[aac @ 0x2f722fc0] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x24ceef80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element f

Error loading /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV: Error reading /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV...
Batch 10, Loss: 0.5368


[aac @ 0x2f76bac0] Input buffer exhausted before END element found
[aac @ 0x25aead00] Input buffer exhausted before END element found
[aac @ 0x2c8cf880] Input buffer exhausted before END element found
[aac @ 0x25671d40] Input buffer exhausted before END element found


Epoch 2/5 - Train Loss: 0.5223, Acc: 0.8012, Prec: 0.8233, Rec: 0.8012, F1: 0.7918


[aac @ 0x2f76b880] Input buffer exhausted before END element found
[aac @ 0x2f76b880] Input buffer exhausted before END element found
[aac @ 0x2542e040] Input buffer exhausted before END element found
[aac @ 0x2f72d500] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x2542e040] Input buffer exhausted before END element found
[aac @ 0x2542e040] Input buffer exhausted before END element found


VALIDATION RESULTS
Accuracy: 0.6957
Precision: 0.7671
Recall: 0.6957
F1-Score: 0.7133

Classification Report:
              precision    recall  f1-score   support

      Normal       0.86      1.00      0.92         6
   Assistive       1.00      1.00      1.00         3
NonAssistive       1.00      1.00      1.00         4
     PD_Mild       0.00      0.00      0.00         1
    PD_Early       0.00      0.00      0.00         1
   PD_Severe       0.00      0.00      0.00         0
   KOA_Early       0.50      0.33      0.40         3
    KOA_Mild       1.00      0.50      0.67         4
  KOA_Severe       0.00      0.00      0.00         1

   micro avg       0.70      0.70      0.70        23
   macro avg       0.48      0.43      0.44        23
weighted avg       0.77      0.70      0.71        23

Epoch 2/5 - Val Acc: 0.6957, Prec: 0.7671, Rec: 0.6957, F1: 0.7133


[aac @ 0x24db3780] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found


Batch 0, Loss: 0.4832


[aac @ 0x25a992c0] Input buffer exhausted before END element found
[aac @ 0x25a992c0] Input buffer exhausted before END element found
[aac @ 0x25a992c0] Input buffer exhausted before END element found
[aac @ 0x25a970c0] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x24d216c0] Input buffer exhausted before END element found
[aac @ 0x273ec900] Input buffer exhausted before END element found
[aac @ 0x25671d40] Input buffer exhausted before END element found
[aac @ 0x2c737c00] Input buffer exhausted before END element found
[aac @ 0x25a992c0] Input buffer exhausted before END element found
[aac @ 0x2f735880] Input buffer exhausted before END element found
[aac @ 0x25671d40] Input buffer exhausted before END element found
[aac @ 0x25671d40] Input buffer exhausted before END element found
[aac @ 0x1ca436c0] Input buffer exhausted before END element found
[aac @ 0x2f79fc40] Input buffer exhausted before END element f

Error loading /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV: Error reading /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV...


[aac @ 0x25d24780] Input buffer exhausted before END element found


Batch 10, Loss: 0.5070


[aac @ 0x25671d40] Input buffer exhausted before END element found
[aac @ 0x25af7840] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found


Epoch 3/5 - Train Loss: 0.5206, Acc: 0.7205, Prec: 0.7291, Rec: 0.7205, F1: 0.7016


[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x2f736d40] Input buffer exhausted before END element found
[aac @ 0x25a97c00] Input buffer exhausted before END element found
[aac @ 0x25a97c00] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found


VALIDATION RESULTS
Accuracy: 0.6522
Precision: 0.5975
Recall: 0.6522
F1-Score: 0.6212

Classification Report:
              precision    recall  f1-score   support

      Normal       0.86      1.00      0.92         6
   Assistive       1.00      1.00      1.00         3
NonAssistive       1.00      1.00      1.00         4
     PD_Mild       1.00      1.00      1.00         1
    PD_Early       0.00      0.00      0.00         1
   PD_Severe       0.00      0.00      0.00         0
   KOA_Early       0.20      0.33      0.25         3
    KOA_Mild       0.00      0.00      0.00         4
  KOA_Severe       0.00      0.00      0.00         1

   micro avg       0.65      0.65      0.65        23
   macro avg       0.45      0.48      0.46        23
weighted avg       0.60      0.65      0.62        23

Epoch 3/5 - Val Acc: 0.6522, Prec: 0.5975, Rec: 0.6522, F1: 0.6212


[aac @ 0x2f72e140] Input buffer exhausted before END element found


Batch 0, Loss: 0.5250


[aac @ 0x2f726900] Input buffer exhausted before END element found
[aac @ 0x2f70fdc0] Input buffer exhausted before END element found
[aac @ 0x25af6480] Input buffer exhausted before END element found
[aac @ 0x25ae7d40] Input buffer exhausted before END element found
[aac @ 0x25aa8ac0] Input buffer exhausted before END element found
[aac @ 0x25aa8ac0] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25aa8ac0] Input buffer exhausted before END element found
[aac @ 0x25aa8ac0] Input buffer exhausted before END element found
[aac @ 0x25aa8ac0] Input buffer exhausted before END element found
[aac @ 0x24dad200] Input buffer exhausted before END element found
[aac @ 0x25d12500] Input buffer exhausted before END element f

Batch 10, Loss: 0.4542


[aac @ 0x24ddd300] Input buffer exhausted before END element found
[aac @ 0x1ca436c0] Input buffer exhausted before END element found
[aac @ 0x25d12500] Input buffer exhausted before END element found
[aac @ 0x2f7416c0] Input buffer exhausted before END element found
[aac @ 0x25af7840] Input buffer exhausted before END element found


Epoch 4/5 - Train Loss: 0.4963, Acc: 0.8137, Prec: 0.8091, Rec: 0.8137, F1: 0.8008


[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x2f725cc0] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x2f725cc0] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x25af7840] Input buffer exhausted before END element found


VALIDATION RESULTS
Accuracy: 0.7391
Precision: 0.6957
Recall: 0.7391
F1-Score: 0.7101

Classification Report:
              precision    recall  f1-score   support

      Normal       1.00      1.00      1.00         6
   Assistive       1.00      1.00      1.00         3
NonAssistive       1.00      1.00      1.00         4
     PD_Mild       1.00      1.00      1.00         1
    PD_Early       1.00      1.00      1.00         1
   PD_Severe       0.00      0.00      0.00         0
   KOA_Early       0.33      0.67      0.44         3
    KOA_Mild       0.00      0.00      0.00         4
  KOA_Severe       0.00      0.00      0.00         1

   micro avg       0.74      0.74      0.74        23
   macro avg       0.59      0.63      0.60        23
weighted avg       0.70      0.74      0.71        23

Epoch 4/5 - Val Acc: 0.7391, Prec: 0.6957, Rec: 0.7391, F1: 0.7101


[aac @ 0x25af7840] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found
[aac @ 0x25af7840] Input buffer exhausted before END element found
[aac @ 0x2102ae00] Input buffer exhausted before END element found


Batch 0, Loss: 0.4475


[aac @ 0x25c7d7c0] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x25d11100] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[mov,mp4,m4a,3gp,3g2,mj2 @ 0x25d12500] moov atom not found
[12:05:04] /github/workspace/src/video/video_reader.cc:83: ERROR opening: /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV, Invalid data found when processing input


Error loading /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV: Error reading /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV...


[aac @ 0x24fa3fc0] Input buffer exhausted before END element found
[aac @ 0x24fa3fc0] Input buffer exhausted before END element found
[aac @ 0x24fa4a00] Input buffer exhausted before END element found
[aac @ 0x24cf6900] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25af8480] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x4b47eac0] Input buffer exhausted before END element found
[aac @ 0x4b47eac0] Input buffer exhausted before END element found
[aac @ 0x4b47eac0] Input buffer exhausted before END element found
[aac @ 0x4b47eac0] Input buffer exhausted before END element found
[aac @ 0x25af8480] Input buffer exhausted before END element found
[aac @ 0x2f76b880] Input buffer exhausted before END element found
[aac @ 0x25aa7140] Input buffer exhausted before END element found
[aac @ 0x4b47eac0] Input buffer exhausted before END element f

Batch 10, Loss: 0.4641


[aac @ 0x2c7370c0] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25ae8b80] Input buffer exhausted before END element found
[aac @ 0x25a95740] Input buffer exhausted before END element found
[mov,mp4,m4a,3gp,3g2,mj2 @ 0x25d12500] moov atom not found
[12:08:19] /github/workspace/src/video/video_reader.cc:83: ERROR opening: /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV, Invalid data found when processing input


Error loading /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV: Error reading /kaggle/input/gaitlabdataset/GiatLabDatset/Normal/015_NM_02.MOV...


[aac @ 0x4b47eac0] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x25aeb200] Input buffer exhausted before END element found


Epoch 5/5 - Train Loss: 0.4880, Acc: 0.8571, Prec: 0.8598, Rec: 0.8571, F1: 0.8475


[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25ae8b80] Input buffer exhausted before END element found
[aac @ 0x25a95740] Input buffer exhausted before END element found
[aac @ 0x25a95740] Input buffer exhausted before END element found


VALIDATION RESULTS
Accuracy: 0.6957
Precision: 0.6149
Recall: 0.6957
F1-Score: 0.6466

Classification Report:
              precision    recall  f1-score   support

      Normal       0.86      1.00      0.92         6
   Assistive       1.00      1.00      1.00         3
NonAssistive       1.00      1.00      1.00         4
     PD_Mild       1.00      1.00      1.00         1
    PD_Early       0.00      0.00      0.00         1
   PD_Severe       0.00      0.00      0.00         0
   KOA_Early       0.33      0.67      0.44         3
    KOA_Mild       0.00      0.00      0.00         4
  KOA_Severe       0.00      0.00      0.00         1

   micro avg       0.70      0.70      0.70        23
   macro avg       0.47      0.52      0.49        23
weighted avg       0.61      0.70      0.65        23

Epoch 5/5 - Val Acc: 0.6957, Prec: 0.6149, Rec: 0.6957, F1: 0.6466
Training and validation complete.


In [86]:
def test():
    student_model.eval()
    all_preds = []
    all_labels = []
    all_probabilities = []

    with torch.no_grad():
        for videos, labels in test_loader:
            videos, labels = videos.to(device), labels.to(device)

            # Get clinical embeddings
            clinical_batch = []
            for label in labels:
                class_name = list(class_mapping.keys())[list(class_mapping.values()).index(label.item())]
                desc = clinical_descriptions.get(class_name) # Use .get to handle missing keys
                if desc:
                    clinical_emb = clinical_embedder.get_embedding(desc).to(device)
                    clinical_batch.append(clinical_emb)
                else:
                    # Append a zero tensor if description is missing
                    # Assuming clinical embedding size is 768 based on ClinicalEmbedder
                    clinical_batch.append(torch.zeros(1, 768, device=device))


            clinical_batch = torch.cat(clinical_batch, dim=0) if clinical_batch else None
            outputs = student_model(videos, clinical_batch)
            probabilities = F.softmax(outputs, dim=1)
            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_labels, average='weighted', zero_division=0) # Corrected recall calculation
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)

    # Classification report
    class_names = list(class_mapping.keys())
    # Pass all possible labels to classification_report
    report = classification_report(all_labels, all_preds, target_names=class_names, zero_division=0, labels=list(range(len(class_mapping))))


    print("TEST RESULTS")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")
    print(f"\nConfusion Matrix:\n{cm}")
    print(f"\nClassification Report:\n{report}")

    return accuracy, precision, recall, f1, cm, all_probabilities

print("Testing function defined - ready for final evaluation")

Testing function defined - ready for final evaluation


In [87]:
test()

[aac @ 0x25af6480] Input buffer exhausted before END element found
[aac @ 0x25af7840] Input buffer exhausted before END element found
[aac @ 0x24dabac0] Input buffer exhausted before END element found
[aac @ 0x25aa6c40] Input buffer exhausted before END element found
[aac @ 0x25aa6c40] Input buffer exhausted before END element found
[aac @ 0x2f7219c0] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x25af7840] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x24ceff80] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found
[aac @ 0x25d24780] Input buffer exhausted before END element found


TEST RESULTS
Accuracy: 0.6957
Precision: 0.6913
Recall: 1.0000
F1-Score: 0.6768

Confusion Matrix:
[[11  0  0  0  1  0  0  0  0]
 [ 0  5  0  0  0  0  0  0  0]
 [ 0  0  8  0  0  0  0  0  0]
 [ 2  0  0  1  0  0  0  0  0]
 [ 2  0  0  0  0  1  0  0  0]
 [ 0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  4  2  0]
 [ 0  0  0  0  0  0  4  3  1]
 [ 0  0  0  0  0  0  0  1  0]]

Classification Report:
              precision    recall  f1-score   support

      Normal       0.73      0.92      0.81        12
   Assistive       1.00      1.00      1.00         5
NonAssistive       1.00      1.00      1.00         8
     PD_Mild       1.00      0.33      0.50         3
    PD_Early       0.00      0.00      0.00         3
   PD_Severe       0.00      0.00      0.00         0
   KOA_Early       0.50      0.67      0.57         6
    KOA_Mild       0.50      0.38      0.43         8
  KOA_Severe       0.00      0.00      0.00         1

    accuracy                           0.70        46
   macro a

(0.6956521739130435,
 0.6913043478260869,
 1.0,
 0.6768461007591443,
 array([[11,  0,  0,  0,  1,  0,  0,  0,  0],
        [ 0,  5,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  8,  0,  0,  0,  0,  0,  0],
        [ 2,  0,  0,  1,  0,  0,  0,  0,  0],
        [ 2,  0,  0,  0,  0,  1,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  4,  2,  0],
        [ 0,  0,  0,  0,  0,  0,  4,  3,  1],
        [ 0,  0,  0,  0,  0,  0,  0,  1,  0]]),
 [array([0.06266162, 0.08269726, 0.45133704, 0.05008547, 0.08282047,
         0.0586655 , 0.05075288, 0.08210236, 0.0788774 ], dtype=float32),
  array([0.12175764, 0.08469369, 0.07969507, 0.2161927 , 0.13128613,
         0.10403188, 0.07667983, 0.09298921, 0.09267383], dtype=float32),
  array([0.22710061, 0.063676  , 0.08536635, 0.1270154 , 0.13825499,
         0.10287557, 0.08026368, 0.0901201 , 0.08532731], dtype=float32),
  array([0.22539927, 0.06776547, 0.08220093, 0.10919493, 0.17442921,
         0.111343

In [88]:
def export_model(model, filename="student_model.pth"):
    """Saves the trained student model."""
    torch.save(model.state_dict(), filename)
    print(f"Model saved to {filename}")

In [89]:
# Export the trained student model
export_model(student_model)

Model saved to student_model.pth


In [None]:
def infer_with_model(video_input, clinical_description, model_path, clinical_embedder, class_mapping, num_frames=16, frame_size=224, device='cuda'):
    """
    Performs inference on a single video input (path or data) and clinical description using the student model.

    Args:
        video_input (str or np.ndarray or torch.Tensor): Path to the video file or the video data directly.
                                                          If video data is provided, it should be a NumPy array
                                                          or a PyTorch tensor with shape (frames, height, width, channels)
                                                          or (channels, frames, height, width).
        clinical_description (str): Text description of the clinical condition.
        model_path (str): Path to the saved student model state dictionary (.pth file).
        clinical_embedder (ClinicalEmbedder): The clinical embedder model.
        class_mapping (dict): Dictionary mapping class names to indices.
        num_frames (int): Number of frames to sample from the video (only used if video_input is a path).
        frame_size (int): Size to resize frames to.
        device (str): Device to run inference on ('cuda' or 'cpu').

    Returns:
        tuple: A tuple containing:
            - logits (torch.Tensor): Raw output logits from the model.
            - probabilities (torch.Tensor): Softmax probabilities over classes.
            - predicted_class_index (int): Index of the predicted class.
            - predicted_class_name (str): Name of the predicted class.
    """
    # Initialize the student model architecture
    num_classes = len(class_mapping)
    student_model = ClinicalEnhancedStudent(num_classes=num_classes)
    student_model.to(device)

    # Load the saved model state dictionary
    try:
        student_model.load_state_dict(torch.load(model_path, map_location=device))
        print(f"Student model loaded successfully from {model_path}")
    except FileNotFoundError:
        print(f"Error: Model file not found at {model_path}")
        return None, None, None, "Error: Model not loaded"
    except Exception as e:
        print(f"Error loading model from {model_path}: {e}")
        return None, None, None, "Error: Model not loaded"


    student_model.eval()

    try:
        if isinstance(video_input, str):
            # Process video from path
            vr = VideoReader(video_input, ctx=cpu(0))
            total_frames = len(vr)

            if total_frames <= num_frames:
                frame_indices = list(range(total_frames))
                while len(frame_indices) < num_frames:
                    frame_indices.append(frame_indices[-1])
            else:
                frame_indices = np.linspace(0, total_frames-1, num_frames, dtype=int)

            frames = vr.get_batch(frame_indices).asnumpy()

            # Resize frames
            resized_frames = []
            for frame in frames:
                resized_frame = cv2.resize(frame, (frame_size, frame_size))
                resized_frames.append(resized_frame)
            frames = np.array(resized_frames)

            # Convert to tensor and normalize
            videos = torch.from_numpy(frames).permute(3, 0, 1, 2).float()
            if videos.max() > 1.0:
                videos = videos / 255.0

            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1, 1)
            videos = (videos - mean) / std

        elif isinstance(video_input, (np.ndarray, torch.Tensor)):
            # Process video from data
            videos = video_input
            if isinstance(videos, np.ndarray):
                videos = torch.from_numpy(videos).float()

            # Ensure the tensor has the correct shape [channels, frames, height, width]
            # Assume input is either [frames, height, width, channels] or [channels, frames, height, width]
            if videos.ndim == 4:
                if videos.shape[-1] == 3: # Assuming last dim is channel if size is 3
                    videos = videos.permute(3, 0, 1, 2) # Convert from [frames, h, w, c] to [c, frames, h, w]
                elif videos.shape[0] != 3: # Assuming first dim is channel if not frames
                     # This case is ambiguous, might need more specific checks or require a fixed input format
                     # For now, assume it's already [c, frames, h, w] if first dim is not 3 but ndim is 4
                     pass # Already in [c, frames, h, w] format

            # Resize frames if necessary (assuming the input data might not be the target size)
            current_frame_size = videos.shape[2] # Assuming shape is [c, frames, h, w]
            if current_frame_size != frame_size:
                 # Need to resize each frame. This is more complex for a batched tensor.
                 # For simplicity, let's assume the input video data is already pre-processed
                 # to the correct frame_size. If not, this would require iterating or using
                 # torchvision.transforms.functional.resize, which might not be ideal for 4D tensors.
                 # Add a warning or error if resizing is needed but not implemented for tensor input.
                 if current_frame_size != frame_size:
                     print(f"Warning: Input video data has frame size {current_frame_size}, expected {frame_size}. Resizing not implemented for tensor input.")


            # Normalize if necessary (check max value as a simple heuristic)
            if videos.max() > 1.0:
                videos = videos / 255.0

            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1, 1)
            videos = (videos - mean) / std

        else:
            raise TypeError("video_input must be a file path (str) or video data (np.ndarray or torch.Tensor)")


        videos = videos.unsqueeze(0).to(device) # Add batch dimension and move to device

        # Get clinical embedding
        clinical_embeds = clinical_embedder.get_embedding(clinical_description).to(device)


        # Perform inference
        with torch.no_grad():
            logits = student_model(videos, clinical_embeds)
            probabilities = F.softmax(logits, dim=1)
            predicted_class_index = torch.argmax(logits, dim=1).item()

        # Get predicted class name
        idx_to_class = {v: k for k, v in class_mapping.items()}
        predicted_class_name = idx_to_class.get(predicted_class_index, "Unknown")

        return logits, probabilities, predicted_class_index, predicted_class_name

    except Exception as e:
        print(f"Error during inference: {e}")
        return None, None, None, "Error"