# Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import os
import glob
from PIL import Image
from sklearn.model_selection import train_test_split

# Hyperparameters

In [None]:
input_dim = 784      # Omniglot images are 28x28 (flattened to 784)
D = 64               # Hidden dimension for attributes
z_dim = 64           # Task representation dimension
mu = 0.1             # Regularization parameter
lambda_ = 0.01       # Soft thresholding parameter
num_inner_iter = 5   # Inner loop iterations
lr = 1e-3            # Learning rate
num_epochs = 10      # Training epochs
support_size = 10    # Support set size per task
query_size = 10      # Query set size per task


# Data


In [None]:
# Paths
data_root = " "
train_dir = os.path.join(data_root, "images_background")

# Preprocessing
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.flatten())
])


# Loading Data
def load_omniglot_tasks(root_dir, num_tasks=1000):
    """Load Omniglot tasks without using classes."""
    # Get all character classes
    alphabets = glob.glob(os.path.join(root_dir, "*"))
    all_classes = []
    for alphabet in alphabets:
        all_classes.extend(glob.glob(os.path.join(alphabet, "*")))

    # Generating tasks
    tasks = []
    for _ in range(num_tasks):
        # selecting normal and anomaly classes
        normal_idx = np.random.randint(len(all_classes))
        normal_class = all_classes[normal_idx]
        anomaly_classes = [
            c for c in all_classes
            if c != normal_class
        ][:5]

        # Loading support set (normal class)
        support_files = glob.glob(os.path.join(normal_class, "*"))[:support_size]
        support_set = [transform(Image.open(f).convert('L')) for f in support_files]

        # Loading query set
        query_normal_files = glob.glob(os.path.join(normal_class, "*"))[support_size:support_size + query_size//2]
        query_normal = [transform(Image.open(f).convert('L')) for f in query_normal_files]

        query_anomaly = []
        for c in anomaly_classes:
            files = glob.glob(os.path.join(c, "*"))[:query_size//(2*len(anomaly_classes))]
            query_anomaly.extend([transform(Image.open(f).convert('L')) for f in files])

        tasks.append((
            torch.stack(support_set),
            torch.stack(query_anomaly),
            torch.stack(query_normal)
        ))
    return tasks

tasks = load_omniglot_tasks(train_dir, num_tasks=1000)



# Defining functions

In [None]:
# Task encoder (f)
task_encoder = nn.Sequential(
    nn.Linear(input_dim, 256),
    nn.ReLU(),
    nn.Linear(256, z_dim)
)

# Network h
h_network = nn.Sequential(
    nn.Linear(input_dim + z_dim, 256),
    nn.ReLU(),
    nn.Linear(256, D)
)

# Network v
v_network = nn.Sequential(
    nn.Linear(input_dim + z_dim, 256),
    nn.ReLU(),
    nn.Linear(256, D)
)

# Learnable parameters
mu_param = nn.Parameter(torch.tensor(mu))
lambda_param = nn.Parameter(torch.tensor(lambda_))

# Combining all parameters
params = list(task_encoder.parameters()) + \
         list(h_network.parameters()) + \
         list(v_network.parameters()) + \
         [mu_param, lambda_param]

optimizer = optim.Adam(params, lr=lr)

def compute_scores(support_set, query_set):
    # Computing task representation z
    embeddings = torch.stack([task_encoder(x) for x in support_set])
    z = torch.mean(embeddings, dim=0)

    # Initial attribute A0
    A = torch.stack([v_network(torch.cat([x, z])) for x in support_set])

    W = None
    for _ in range(num_inner_iter):
        H = torch.stack([h_network(torch.cat([x, z])) for x in support_set])
        X = torch.stack(support_set)
        N_S = len(support_set)

        # Updating W
        HtH = H.T @ H / N_S
        reg_matrix = mu_param * torch.eye(H.shape[1])
        inv = torch.inverse(HtH + reg_matrix)
        W = inv @ (H.T @ X) / N_S

        # Updating A
        residual = X - W @ H.T
        threshold = lambda_param / mu_param
        A = torch.sign(residual) * torch.relu(torch.abs(residual) - threshold)

    # Computing scores
    scores = []
    for x in query_set:
        h_xz = h_network(torch.cat([x, z]))
        recon = W @ h_xz
        scores.append(torch.norm(x - recon, p=2)**2)
    return torch.stack(scores)

# Training


In [None]:
for epoch in range(num_epochs):
    total_loss = 0
    for support, query_anomaly, query_normal in tasks:
        optimizer.zero_grad()

        # Computing scores
        anomaly_scores = compute_scores(support, query_anomaly)
        normal_scores = compute_scores(support, query_normal)

        # Smoothed AUC
        diff = anomaly_scores.unsqueeze(1) - normal_scores.unsqueeze(0)
        auc = torch.sigmoid(diff).mean()
        loss = 1 - auc

        # Backpropagation
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch}, Loss: {total_loss / len(tasks)}")