In [17]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn

# ===================== KAN Layer =====================
class KANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_intervals=10, spline_degree=3, device="cpu"):
        super(KANLayer, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_intervals = num_intervals
        self.spline_degree = spline_degree
        self.device = device

        self.grid = torch.linspace(-1, 1, num_intervals + 1).repeat(input_dim, 1).to(device)
        self.spline_coeffs = nn.Parameter(
            torch.randn(num_intervals + spline_degree, input_dim, output_dim).to(device)
        )
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        processed_features = []
        for i in range(self.input_dim):
            feature = x[:, i]
            spline_output = self.evaluate_splines(feature, self.grid[i], self.spline_coeffs[:, i, :])
            processed_features.append(spline_output)

        preprocessed_data = torch.stack(processed_features, dim=1).sum(dim=1)
        return self.fc(preprocessed_data)

    @staticmethod
    def evaluate_splines(x, grid, coeffs):
        degree = coeffs.size(0) - (grid.size(0) - 1)
        spline_output = torch.zeros(x.size(0), coeffs.size(1), device=x.device)
        for d in range(degree):
            spline_output += coeffs[d] * (x.unsqueeze(1) ** d)
        return spline_output

# ===================== Full KAN Model =====================
class FullKAN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FullKAN, self).__init__()
        self.input_projection = nn.Linear(input_dim, hidden_dim)
        self.layer1 = KANLayer(hidden_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, hidden_dim)
        self.layer3 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = self.input_projection(x)
        x = torch.relu(x)
        x = self.layer1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.layer2(x)
        x = torch.relu(x)
        return self.layer3(x)

# ===================== Image Transformation =====================
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((48, 48)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# ===================== Predict Emotion =====================
def predict_emotion(image_path, model, device):
    """
    Predict the emotion from an input image using the trained model.
    
    Args:
        image_path (str): Path to the input image.
        model (nn.Module): The trained emotion detection model.
        device (torch.device): The device to run the model on (cpu or cuda).
        
    Returns:
        str: Predicted emotion label.
    """
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        output = model(image.view(image.size(0), -1))
        predicted_class = output.argmax(dim=1).item()

    emotion_labels = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sadness', 'Surprise', 'Neutral']
    return emotion_labels[predicted_class]

# ===================== Example Usage =====================
image_path = '/kaggle/input/fer2013/train/sad/Training_10115766.jpg'  # Provide the path to your input image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FullKAN(input_dim=48*48, hidden_dim=256, output_dim=7).to(device)

# Load the trained model weights safely
model_weights_path = "/kaggle/input/emotion-classifier-using-kan/pytorch/default/1/kan_emotion_model.pth"
model.load_state_dict(torch.load(model_weights_path, weights_only=True))

# Predict and print the emotion
predicted_emotion = predict_emotion(image_path, model, device)
print(f"Predicted Emotion: {predicted_emotion}")


Predicted Emotion: Anger
