In [None]:
import os
import zipfile
import numpy as np
import scipy.io
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import mean_squared_error, mean_absolute_error, accuracy_score
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Unzip the Penn Action dataset from Google Drive
data_zip_path = '/content/drive/My Drive/data.zip'  # Path to your data.zip in Google Drive
extract_path = '/content/data'

if not os.path.exists(extract_path):
    os.makedirs(extract_path)

with zipfile.ZipFile(data_zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

Mounted at /content/drive


In [None]:
label_mapping = {
    'tennis_serve': 0, 'golf_swing': 1, 'baseball_pitch': 2, 'bench_press': 3,
    'pullup': 4, 'pushup': 5, 'situp': 6, 'jumping_jacks': 7, 'strum_guitar': 8,
    'bowl': 9, 'tennis_forehand': 10, 'squat': 11, 'jump_rope': 12,
    'clean_and_jerk': 13, 'baseball_swing': 14
}

In [None]:
class PennActionDataset(Dataset):
    def __init__(self, data_path, transform=None, num_frames=32):
        self.data_path = data_path
        self.transform = transform
        self.num_frames = num_frames
        self.frames_dir = os.path.join(data_path, "frames")
        self.labels_dir = os.path.join(data_path, "labels")
        self.video_ids = os.listdir(self.frames_dir)

    def __len__(self):
        return len(self.video_ids)

    def __getitem__(self, idx):
        video_id = self.video_ids[idx]
        video_path = os.path.join(self.frames_dir, video_id)
        label_path = os.path.join(self.labels_dir, video_id + ".mat")
        frame_paths = sorted([os.path.join(video_path, f) for f in os.listdir(video_path)])
        frame_count = len(frame_paths)
        if frame_count > self.num_frames:
            selected_indices = np.linspace(0, frame_count - 1, self.num_frames).astype(int)
            frame_paths = [frame_paths[i] for i in selected_indices]
        elif frame_count < self.num_frames:
            frame_paths += [frame_paths[-1]] * (self.num_frames - frame_count)
        frames = [Image.open(frame_path).convert("RGB") for frame_path in frame_paths]
        if self.transform:
            frames = [self.transform(frame) for frame in frames]
        frames = torch.stack(frames)
        mat = scipy.io.loadmat(label_path)
        action_label = mat["action"][0]
        label = label_mapping[action_label]
        return frames, label


In [None]:
transform = transforms.Compose([
    transforms.Resize((384, 384)),  # Resize to 384x384 for SWAG weights
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize
])

In [None]:
dataset = PennActionDataset('/content/data/data/Penn_Action/Penn_Action', transform=transform, num_frames=32)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)

In [None]:
!git clone https://github.com/KindXiaoming/pykan.git
!pip install ./pykan

Cloning into 'pykan'...
remote: Enumerating objects: 4221, done.[K
remote: Counting objects: 100% (664/664), done.[K
remote: Compressing objects: 100% (245/245), done.[K
remote: Total 4221 (delta 569), reused 419 (delta 419), pack-reused 3557 (from 3)[K
Receiving objects: 100% (4221/4221), 114.76 MiB | 45.71 MiB/s, done.
Resolving deltas: 100% (1580/1580), done.
Processing ./pykan
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pykan
  Building wheel for pykan (setup.py) ... [?25l[?25hdone
  Created wheel for pykan: filename=pykan-0.2.8-py3-none-any.whl size=78235 sha256=a2906fea05b2c84c62f396c529daa16758da6ff9d373eb104df583f18ca662eb
  Stored in directory: /tmp/pip-ephem-wheel-cache-s70g9nqp/wheels/05/9b/6c/6f9f5a9927ba27c99b92cf0cbdd57f190932c31289c49eded1
Successfully built pykan
Installing collected packages: pykan
Successfully installed pykan-0.2.8


In [None]:
vit_weights = models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
pretrained_vit = models.vit_b_16(weights=vit_weights)

import torch.nn.functional as F

# Define KAN Layer
class KANLayer(nn.Module):
    def __init__(self, in_dim, out_dim, activation='relu'):
        super(KANLayer, self).__init__()
        self.weights = nn.Parameter(torch.randn(out_dim, in_dim))
        self.bias = nn.Parameter(torch.zeros(out_dim))
        self.activation = activation

    def forward(self, x):
        x = torch.matmul(x, self.weights.T) + self.bias
        if self.activation == 'relu':
            return F.relu(x)
        elif self.activation == 'tanh':
            return torch.tanh(x)
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)
        return x  # No activation if None

# Freeze pretrained weights
for param in pretrained_vit.parameters():
    param.requires_grad = False

# Update the classification head with custom KAN layers
embedding_dim = 768  # Dimension of ViT embeddings
kan_hidden_units = [64, 32]  # Define hidden layers for KAN

pretrained_vit.heads = nn.Sequential(
    nn.LayerNorm(normalized_shape=embedding_dim),
    KANLayer(in_dim=embedding_dim, out_dim=kan_hidden_units[0], activation='relu'),
    KANLayer(in_dim=kan_hidden_units[0], out_dim=kan_hidden_units[1], activation='relu'),
    KANLayer(in_dim=kan_hidden_units[1], out_dim=len(label_mapping), activation=None)
)

Downloading: "https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16_swag-9ac1b537.pth
100%|██████████| 331M/331M [00:02<00:00, 154MB/s]


In [None]:
import pykan
print(dir(pykan))

['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__']


In [None]:
import pykan.kan
print(dir(pykan.kan))  # If 'layers' exist, check its content


['Abs', 'AccumBounds', 'Add', 'Adjoint', 'AlgebraicField', 'AlgebraicNumber', 'And', 'AppliedPredicate', 'Array', 'AssumptionsContext', 'Atom', 'AtomicExpr', 'BasePolynomialError', 'Basic', 'BlockDiagMatrix', 'BlockMatrix', 'CC', 'CRootOf', 'Catalan', 'Chi', 'Ci', 'Circle', 'CoercionFailed', 'Complement', 'ComplexField', 'ComplexRegion', 'ComplexRootOf', 'Complexes', 'ComputationFailed', 'ConditionSet', 'Contains', 'CosineTransform', 'Curve', 'DeferredVector', 'DenseNDimArray', 'Derivative', 'Determinant', 'DiagMatrix', 'DiagonalMatrix', 'DiagonalOf', 'Dict', 'DiracDelta', 'DisjointUnion', 'Domain', 'DomainError', 'DotProduct', 'Dummy', 'E', 'E1', 'EPath', 'EX', 'EXRAW', 'Ei', 'Eijk', 'Ellipse', 'EmptySequence', 'EmptySet', 'Eq', 'Equality', 'Equivalent', 'EulerGamma', 'EvaluationFailed', 'ExactQuotientFailed', 'Expr', 'ExpressionDomain', 'ExtraneousFactors', 'FF', 'FF_gmpy', 'FF_python', 'FU', 'FallingFactorial', 'FiniteField', 'FiniteSet', 'FlagError', 'Float', 'FourierTransform', 'F

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
pretrained_vit.to(device)

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [None]:
import pykan
from pykan.kan import KANLayer
import torch.nn.functional as F

# Freeze pretrained weights
for param in pretrained_vit.parameters():
    param.requires_grad = False

# Update the classification head with pykan KAN layers
embedding_dim = 768  # Dimension of ViT embeddings
kan_hidden_units = [128, 64, 32]  # Increased model capacity for better learning

class KANHead(nn.Module):
    def __init__(self, embedding_dim, kan_hidden_units, num_classes):
        super(KANHead, self).__init__()
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
        self.kan1 = KANLayer(in_dim=embedding_dim, out_dim=kan_hidden_units[0])
        self.kan2 = KANLayer(in_dim=kan_hidden_units[0], out_dim=kan_hidden_units[1])
        self.kan3 = KANLayer(in_dim=kan_hidden_units[1], out_dim=kan_hidden_units[2])
        self.kan_out = KANLayer(in_dim=kan_hidden_units[2], out_dim=num_classes)

    def forward(self, x):
        x = self.layer_norm(x)
        x = self.kan1(x)[0]  # Extract only the tensor
        x = F.relu(x)
        x = self.kan2(x)[0]
        x = F.relu(x)
        x = self.kan3(x)[0]
        x = F.relu(x)
        x = self.kan_out(x)[0]  # Extract only the tensor
        return x

pretrained_vit.heads = KANHead(embedding_dim, kan_hidden_units, len(label_mapping))
pretrained_vit.to(device)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(pretrained_vit.parameters(), lr=5e-4, weight_decay=1e-5)  # Lower learning rate for stability
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.7)  # Adjusted scheduler for better convergence

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(pretrained_vit.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

In [None]:
epochs = 10
for epoch in range(epochs):
    pretrained_vit.train()
    total_loss = 0
    for batch_idx, (frames, labels) in enumerate(dataloader):
        frames = frames.mean(dim=1).to(device)  # Average frames across the time dimension
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = pretrained_vit(frames)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if (batch_idx + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Batch {batch_idx + 1}/{len(dataloader)}, Loss: {loss.item():.4f}")
    scheduler.step()
    print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {total_loss / len(dataloader):.4f}")


Epoch 1/10, Batch 10/73, Loss: 2.6983
Epoch 1/10, Batch 20/73, Loss: 2.5328
Epoch 1/10, Batch 30/73, Loss: 2.3363
Epoch 1/10, Batch 40/73, Loss: 1.9419
Epoch 1/10, Batch 50/73, Loss: 1.6446
Epoch 1/10, Batch 60/73, Loss: 1.4655
Epoch 1/10, Batch 70/73, Loss: 1.4562
Epoch 1/10, Average Loss: 2.0725
Epoch 2/10, Batch 10/73, Loss: 1.0318
Epoch 2/10, Batch 20/73, Loss: 1.1978
Epoch 2/10, Batch 30/73, Loss: 0.8471
Epoch 2/10, Batch 40/73, Loss: 0.8779
Epoch 2/10, Batch 50/73, Loss: 0.8052
Epoch 2/10, Batch 60/73, Loss: 0.6011
Epoch 2/10, Batch 70/73, Loss: 0.9108
Epoch 2/10, Average Loss: 0.9741
Epoch 3/10, Batch 10/73, Loss: 0.5722
Epoch 3/10, Batch 20/73, Loss: 0.7375
Epoch 3/10, Batch 30/73, Loss: 0.7700
Epoch 3/10, Batch 40/73, Loss: 0.5065
Epoch 3/10, Batch 50/73, Loss: 0.3851
Epoch 3/10, Batch 60/73, Loss: 0.6137
Epoch 3/10, Batch 70/73, Loss: 0.7443
Epoch 3/10, Average Loss: 0.6347
Epoch 4/10, Batch 10/73, Loss: 0.3347
Epoch 4/10, Batch 20/73, Loss: 0.4752
Epoch 4/10, Batch 30/73, Lo

In [None]:
pretrained_vit.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for frames, labels in dataloader:
        frames = frames.mean(dim=1).to(device)
        labels = labels.to(device)
        outputs = pretrained_vit(frames)
        preds = outputs.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())

In [None]:
accuracy = accuracy_score(all_labels, all_preds)
rmse = np.sqrt(mean_squared_error(all_labels, all_preds))
mae = mean_absolute_error(all_labels, all_preds)

print(f"Accuracy: {accuracy * 100:.2f}%")
print(f"RMSE: {rmse:.4f}")
print(f"MAE: {mae:.4f}")

Accuracy: 99.23%
RMSE: 0.6421
MAE: 0.0477
