In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
import os
import functools

# Kaggle-specific: Ensure we can download data if not present
print(f"PyTorch Version: {torch.__version__}")
print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")

In [None]:
# -------------------------------
# Parameters
# -------------------------------
DATASET = 'omniglot'
N_WAY = 32
K_SHOT = 1
# 32-way 1-shot with 5 query means 32*(1+5) = 192 images per episode.
N_QUERY = 5 
NUM_EPOCHS = 500 
EPISODES_PER_EPOCH = 50 
TEST_EPISODES = 50 
OUTPUT_DIM = 1024 # Fixed to 1024 to support segment_size up to 1024

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# -------------------------------
# Distance Function (The Experiment)
# -------------------------------
def segmented_cosine_similarity(a, b, segment_size=2):
    """
    Segmented Cosine Similarity with Sliding Window (Circular).
    Segments are formed by sliding a window of size 'segment_size' over the vector with wrap-around.
    Number of segments equals the vector dimension.
    
    Formula:
    Sim = Sum_{i=1}^{N} ( (|q_i|*|s_i|) / (segment_size * |q|*|s|) * Cos(q_i, s_i) )
    """
    batch_q, dim = a.size()
    batch_s, _ = b.size()
    
    # 1. Circular Padding to handle sliding window wrap-around
    # If D=1, no padding needed, or pad 0 is fine.
    # If D=2, we need index [N-1, 0] so we pad index 0 at the end. (Pad 1)
    # If D=K, we need up to [N-1, 0, ..., K-2], so we pad K-1 elements.
    if segment_size > 1:
        pad_size = segment_size - 1
        # Pad the last 'pad_size' elements with the first 'pad_size' elements
        a_padded = torch.cat([a, a[:, :pad_size]], dim=1)
        b_padded = torch.cat([b, b[:, :pad_size]], dim=1)
    else:
        a_padded = a
        b_padded = b
        
    # 2. Unfold to get sliding windows
    # Shape: [Batch, Num_Segments, Segment_Size]
    # Num_Segments will be 'dim' because step=1
    a_seg = a_padded.unfold(dimension=1, size=segment_size, step=1)
    b_seg = b_padded.unfold(dimension=1, size=segment_size, step=1)
    
    # Check if we got 'dim' segments
    if a_seg.shape[1] != dim:
         # Fallback or correction if logic slightly off
         # Unfold: (Len - Size) / Step + 1. 
         # (Dim + D - 1 - D) / 1 + 1 = Dim / 1. Correct.
         pass

    # 3. Pairwise handling
    # a: [Q, 1, N, D], b: [1, S, N, D]
    a_exp = a_seg.unsqueeze(1)
    b_exp = b_seg.unsqueeze(0)
    
    # 4. Dot product per segment: [Q, S, N]
    # Sum over D dimension
    dot_prod = (a_exp * b_exp).sum(dim=-1)
    
    # 5. Norm per segment: [Q, S, N] (Broadcasted)
    norm_a_seg = torch.norm(a_exp, p=2, dim=-1) # [Q, 1, N]
    norm_b_seg = torch.norm(b_exp, p=2, dim=-1) # [1, S, N]
    
    # 6. Cosine per segment
    # Avoid division by zero
    cos_seg = dot_prod / (norm_a_seg * norm_b_seg + 1e-8)
    
    # 7. Weights: (|q_i| |s_i|) / (D * |q| |s|)
    # Global norms
    norm_a = torch.norm(a, p=2, dim=1).view(batch_q, 1) # [Q, 1]
    norm_b = torch.norm(b, p=2, dim=1).view(1, batch_s) # [1, S]
    
    numerator = norm_a_seg * norm_b_seg # [Q, S, N]
    denominator = segment_size * norm_a.unsqueeze(2) * norm_b.unsqueeze(2) + 1e-8
    
    weights = numerator / denominator
    
    # 8. Weighted sum over all N segments
    sim = (weights * cos_seg).sum(dim=2) # [Q, S]
    
    return sim

In [None]:
# -------------------------------
# Helper Functions
# -------------------------------
def quantize_to_int8(tensor, scale=None):
    if scale is None:
        abs_max = torch.max(torch.abs(tensor)).item()
        scale = 127.0 / (abs_max + 1e-8)
    scaled = tensor * scale
    quantized = scaled.detach().round().clamp(-128, 127) - scaled.detach() + scaled
    return quantized, scale

def dequantize_from_int8(quantized, scale):
    return quantized.float() / scale

def organize_by_class(dataset):
    data_by_class = {}
    for img, label in dataset:
        if label not in data_by_class:
            data_by_class[label] = []
        data_by_class[label].append(img)
    return data_by_class

def sample_episode(data_by_class, n_way, k_shot, n_query):
    selected_classes = random.sample(list(data_by_class.keys()), n_way)
    support_images, support_labels, query_images, query_labels = [], [], [], []
    for idx, cls in enumerate(selected_classes):
        images = data_by_class[cls]
        if len(images) < (k_shot + n_query):
            images = images * ((k_shot + n_query) // len(images) + 1)
        selected_imgs = random.sample(images, k_shot + n_query)
        support_images += selected_imgs[:k_shot]
        support_labels += [idx] * k_shot
        query_images += selected_imgs[k_shot:]
        query_labels += [idx] * n_query
    return (
        torch.stack(support_images),
        torch.tensor(support_labels),
        torch.stack(query_images),
        torch.tensor(query_labels),
    )

# -------------------------------
# Model (Matched to mix/code.py)
# -------------------------------
class MANN(nn.Module):
    def __init__(self, dataset, out_dim=1024, quantize=False):
        super().__init__()
        in_channels = 1 if dataset in ['omniglot', 'mnist'] else 3
        planes = [128, 128, 128, 128] if dataset in ['omniglot', 'mnist'] else [128, 256, 512, 1024]
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, planes[0], kernel_size=5, stride=2, padding=2, bias=False),
            nn.ReLU(),
            nn.Conv2d(planes[0], planes[1], kernel_size=5, stride=1, padding=2, bias=False),
            nn.ReLU(),
            nn.Conv2d(planes[1], planes[2], kernel_size=3, stride=2, padding=1, bias=False),
            nn.ReLU(),
            nn.Conv2d(planes[2], planes[3], kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
        )
        self.fc = nn.Linear(planes[3] * 7 * 7, out_dim)  # Adjusted for input size
        self.quantize = quantize
        self.scale = None  # 用於存儲縮放因子

    def forward(self, x, apply_quantize=None):
        if apply_quantize is None:
            apply_quantize = self.quantize
            
        emb = self.fc(self.model(x).view(x.size(0), -1))
        
        if apply_quantize:
            # 量化 embedding
            quantized_emb, self.scale = quantize_to_int8(emb, self.scale)
            # 立即反量化回浮點，以便後續處理
            emb = dequantize_from_int8(quantized_emb, self.scale)
        return emb

In [None]:
# -------------------------------
# Experiment logic
# -------------------------------

print(f"Setting up {DATASET} experiment")
print(f"N_WAY={N_WAY}, K_SHOT={K_SHOT}, OUTPUT_DIM={OUTPUT_DIM}")

# Data Setup
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: 1.0 - x)
])

# Use a local directory for data. In Kaggle, this will be /kaggle/working/data
# You can also look for input data in /kaggle/input if you have added the dataset
data_root = './data'
os.makedirs(data_root, exist_ok=True)

try:
    print(f"Attempting to download/load dataset to {data_root}...")
    train_dataset = torchvision.datasets.Omniglot(root=data_root, background=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.Omniglot(root=data_root, background=False, download=True, transform=transform)
    print("Dataset loaded successfully.")
except Exception as e:
    print(f"Download might have failed: {e}")
    print("If running on Kaggle without internet, please add the 'Omniglot' dataset to your notebook input.")
    # Fallback to checking /kaggle/input if possible, but torchvision structure is specific.
    # We will assume download works or user handles data.

train_data = organize_by_class(train_dataset)
test_data = organize_by_class(test_dataset)

# Sweep Parameters
segment_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]

results = []

for seg_size in segment_sizes:
    print(f"\n=========================================")
    print(f"Running Experiment with Segment Size D={seg_size}")
    print(f"=========================================")
    
    # 1. Initialize Model
    model = MANN(dataset=DATASET, out_dim=OUTPUT_DIM).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    # 2. Define Distance Function for this run
    dist_fn = functools.partial(segmented_cosine_similarity, segment_size=seg_size)
    
    # 3. Training Loop
    model.train()
    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0
        for ep in range(EPISODES_PER_EPOCH):
            support_imgs, support_labels, query_imgs, query_labels = sample_episode(
                train_data, N_WAY, K_SHOT, N_QUERY)
            
            support_imgs = support_imgs.to(DEVICE)
            support_labels = support_labels.to(DEVICE)
            query_imgs = query_imgs.to(DEVICE)
            query_labels = query_labels.to(DEVICE)
            
            optimizer.zero_grad()
            
            # Forward
            support_emb = model(support_imgs)
            query_emb = model(query_imgs)
            
            # Similarity
            logits = dist_fn(query_emb, support_emb)
            
            loss = criterion(logits, query_labels)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        if (epoch + 1) % 10 == 0:
            print(f"  Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {epoch_loss/EPISODES_PER_EPOCH:.4f}")

    # 4. Testing Loop
    model.eval()
    total_correct = 0
    total_total = 0
    with torch.no_grad():
        for ep in range(TEST_EPISODES):
            support_imgs, support_labels, query_imgs, query_labels = sample_episode(
                test_data, N_WAY, K_SHOT, N_QUERY)
                
            support_imgs = support_imgs.to(DEVICE)
            support_labels = support_labels.to(DEVICE)
            query_imgs = query_imgs.to(DEVICE)
            query_labels = query_labels.to(DEVICE)
            
            support_emb = model(support_imgs)
            query_emb = model(query_imgs)
            
            logits = dist_fn(query_emb, support_emb)
            
            preds = logits.argmax(dim=1)
            pred_labels = support_labels[preds]
            total_correct += (pred_labels == query_labels).sum().item()
            total_total += query_labels.size(0)
    
    accuracy = 100.0 * total_correct / total_total
    print(f"  Accuracy for D={seg_size}: {accuracy:.2f}%")
    
    results.append({
        "Segment_Size_D": seg_size,
        "Accuracy": accuracy
    })
    
    del model, optimizer
    torch.cuda.empty_cache()

In [None]:
# Save Results
df = pd.DataFrame(results)
df.to_csv("segcos_results.csv", index=False)
print("\nResults saved to segcos_results.csv")

# Plot
plt.figure(figsize=(10, 6))
plt.plot(df["Segment_Size_D"], df["Accuracy"], marker='o')
plt.xscale('log', base=2)
plt.xlabel('Segment Size (D)')
plt.ylabel('Accuracy (%)')
plt.title(f'{DATASET} {N_WAY}-way {K_SHOT}-shot Accuracy vs Segment Size')
plt.grid(True)
plt.savefig("segcos_plot.png")
print("Plot saved to segcos_plot.png")
plt.show()