In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
import os
import functools

# Kaggle-specific: Ensure we can download data if not present
print(f"PyTorch Version: {torch.__version__}")
print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")

In [None]:
# -------------------------------
# Parameters
# -------------------------------
DATASET = 'omniglot'
N_WAY = 32
K_SHOT = 1
# 32-way 1-shot with 5 query means 32*(1+5) = 192 images per episode.
N_QUERY = 5 
NUM_EPOCHS = 500 
EPISODES_PER_EPOCH = 50 
TEST_EPISODES = 50 
OUTPUT_DIM = 1024 # Fixed to 1024 to support segment_size up to 1024

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# -------------------------------
# Distance Function (The Experiment)
# -------------------------------
def segmented_cosine_similarity(a, b, segment_size=2):
    """
    Segmented Cosine Similarity with Sliding Window (Circular).
    Segments are formed by sliding a window of size 'segment_size' over the vector with wrap-around.
    Number of segments equals the vector dimension.
    
    Formula:
    Sim = Sum_{i=1}^{N} ( (|q_i|*|s_i|) / (segment_size * |q|*|s|) * Cos(q_i, s_i) )
    
    Memory Efficient Version: Calculates in chunks to avoid OOM.
    """
    batch_q, dim = a.size()
    batch_s, _ = b.size()
    
    # 1. Circular Padding
    if segment_size > 1:
        pad_size = segment_size - 1
        a_padded = torch.cat([a, a[:, :pad_size]], dim=1)
        b_padded = torch.cat([b, b[:, :pad_size]], dim=1)
    else:
        a_padded = a
        b_padded = b
        
    # 2. Unfold to get sliding windows
    # Shape: [Batch, Num_Segments, Segment_Size]
    a_seg = a_padded.unfold(dimension=1, size=segment_size, step=1)
    b_seg = b_padded.unfold(dimension=1, size=segment_size, step=1)
    
    if a_seg.shape[1] != dim:
         pass

    # Global norms: [Batch, 1]
    norm_a = torch.norm(a, p=2, dim=1).view(batch_q, 1) 
    norm_b = torch.norm(b, p=2, dim=1).view(1, batch_s) 
    
    # Final Similarity Matrix: [Q, S]
    # Initialize on device
    final_sim = torch.zeros(batch_q, batch_s, device=a.device)
    
    # --- Chunking Strategy ---
    # a_seg: [Q, N, D]
    # b_seg: [S, N, D]
    # Instead of broadcasting to [Q, S, N, D] which is Q*S*N*D floats.
    # N=1024, Q=160, S=32. 160*32*1024*D. 
    # If D is large? No, the issue is Q*S*N term in broadcast.
    # 160*32*1024 = 5,242,880 elements. 
    # But PyTorch broadcast creates temporary tensors. 
    # The traceback says 20GB. Likely due to large internal expansion.
    
    # Let's iterate over Query Batch to reduce memory usage.
    # Or iterate over Segments (N) if we sum them up?
    # Sim = Sum_i ( ... )
    # It allows us to compute similarity segment by segment (or chunks of segments) and accumulate.
    
    # Chunk over Num_Segments (dim)
    # This avoids storing the huge [Q, S, N] tensor of norms/dots.
    
    chunk_size = 64 # Process 64 segments at a time
    num_segments = dim
    
    for i in range(0, num_segments, chunk_size):
        end = min(i + chunk_size, num_segments)
        
        # Slices: [Q, Chunk, D]
        sub_a = a_seg[:, i:end, :] 
        # Slices: [S, Chunk, D]
        sub_b = b_seg[:, i:end, :]
        
        # Expand for pairwise: [Q, S, Chunk, D]
        # Still might be big if Chunk is large.
        # Let's calculate parts directly.

        # a: [Q, 1, C, D], b: [1, S, C, D]
        sub_a_exp = sub_a.unsqueeze(1)
        sub_b_exp = sub_b.unsqueeze(0)
        
        # Dot: [Q, S, C]
        dot_prod = (sub_a_exp * sub_b_exp).sum(dim=-1)
        
        # Norms: [Q, S, C]
        norm_a_seg = torch.norm(sub_a_exp, p=2, dim=-1)
        norm_b_seg = torch.norm(sub_b_exp, p=2, dim=-1)
        
        # Cosine: [Q, S, C]
        cos_seg = dot_prod / (norm_a_seg * norm_b_seg + 1e-8)
        
        # Weights: [Q, S, C]
        # (|q_i| |s_i|) / (D * |q| |s|)
        # numerator = norm_a_seg * norm_b_seg
        # Using global norms which are [Q, 1] and [1, S]
        # Need to align dimensions to [Q, S, C]
        
        global_norm_denom = (segment_size * norm_a.unsqueeze(2) * norm_b.unsqueeze(2)) + 1e-8
        
        weights = (norm_a_seg * norm_b_seg) / global_norm_denom
        
        # Weighted sum contribution
        # sim_chunk: [Q, S] (summed over C)
        sim_chunk = (weights * cos_seg).sum(dim=2)
        
        final_sim += sim_chunk
        
        # Clear intermediates
        del sub_a_exp, sub_b_exp, dot_prod, norm_a_seg, norm_b_seg, cos_seg, weights
    
    return final_sim

In [None]:
# -------------------------------
# Helper Functions
# -------------------------------
def quantize_to_int8(tensor, scale=None):
    if scale is None:
        abs_max = torch.max(torch.abs(tensor)).item()
        scale = 127.0 / (abs_max + 1e-8)
    scaled = tensor * scale
    quantized = scaled.detach().round().clamp(-128, 127) - scaled.detach() + scaled
    return quantized, scale

def dequantize_from_int8(quantized, scale):
    return quantized.float() / scale

def organize_by_class(dataset):
    data_by_class = {}
    for img, label in dataset:
        if label not in data_by_class:
            data_by_class[label] = []
        data_by_class[label].append(img)
    return data_by_class

def sample_episode(data_by_class, n_way, k_shot, n_query):
    selected_classes = random.sample(list(data_by_class.keys()), n_way)
    support_images, support_labels, query_images, query_labels = [], [], [], []
    for idx, cls in enumerate(selected_classes):
        images = data_by_class[cls]
        if len(images) < (k_shot + n_query):
            images = images * ((k_shot + n_query) // len(images) + 1)
        selected_imgs = random.sample(images, k_shot + n_query)
        support_images += selected_imgs[:k_shot]
        support_labels += [idx] * k_shot
        query_images += selected_imgs[k_shot:]
        query_labels += [idx] * n_query
    return (
        torch.stack(support_images),
        torch.tensor(support_labels),
        torch.stack(query_images),
        torch.tensor(query_labels),
    )

# -------------------------------
# Model (Matched to mix/code.py)
# -------------------------------
class MANN(nn.Module):
    def __init__(self, dataset, out_dim=1024, quantize=False):
        super().__init__()
        in_channels = 1 if dataset in ['omniglot', 'mnist'] else 3
        planes = [128, 128, 128, 128] if dataset in ['omniglot', 'mnist'] else [128, 256, 512, 1024]
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, planes[0], kernel_size=5, stride=2, padding=2, bias=False),
            nn.ReLU(),
            nn.Conv2d(planes[0], planes[1], kernel_size=5, stride=1, padding=2, bias=False),
            nn.ReLU(),
            nn.Conv2d(planes[1], planes[2], kernel_size=3, stride=2, padding=1, bias=False),
            nn.ReLU(),
            nn.Conv2d(planes[2], planes[3], kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(),
        )
        self.fc = nn.Linear(planes[3] * 7 * 7, out_dim)  # Adjusted for input size
        self.quantize = quantize
        self.scale = None  # 用於存儲縮放因子

    def forward(self, x, apply_quantize=None):
        if apply_quantize is None:
            apply_quantize = self.quantize
            
        emb = self.fc(self.model(x).view(x.size(0), -1))
        
        if apply_quantize:
            # 量化 embedding
            quantized_emb, self.scale = quantize_to_int8(emb, self.scale)
            # 立即反量化回浮點，以便後續處理
            emb = dequantize_from_int8(quantized_emb, self.scale)
        return emb

In [None]:
# -------------------------------
# Experiment logic
# -------------------------------

print(f"Setting up {DATASET} experiment")
print(f"N_WAY={N_WAY}, K_SHOT={K_SHOT}, OUTPUT_DIM={OUTPUT_DIM}")

# Data Setup
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: 1.0 - x)
])

# Use a local directory for data. In Kaggle, this will be /kaggle/working/data
data_root = './data'
os.makedirs(data_root, exist_ok=True)

try:
    print(f"Attempting to download/load dataset to {data_root}...")
    train_dataset = torchvision.datasets.Omniglot(root=data_root, background=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.Omniglot(root=data_root, background=False, download=True, transform=transform)
    print("Dataset loaded successfully.")
except Exception as e:
    print(f"Download might have failed: {e}")
    # Fallback/Handling logic...

train_data = organize_by_class(train_dataset)
test_data = organize_by_class(test_dataset)

# Parameters
segment_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
NUM_RUNS = 10  # Train 10 models to get statistics
results_per_size = {d: [] for d in segment_sizes}

print(f"\n=========================================")
print(f"Starting Experiment: {NUM_RUNS} runs")
print(f"Each run trains ONE model with Standard Cosine (D=1), then tests on ALL segment sizes.")
print(f"=========================================")

for run in range(NUM_RUNS):
    print(f"\n--- Run {run+1}/{NUM_RUNS} ---")
    
    # 1. Initialize Model
    model = MANN(dataset=DATASET, out_dim=OUTPUT_DIM).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    # 2. Train with Standard Cosine Similarity (equivalent to D=1)
    # Using D=1 during training as per request for "general cosine similarity"
    train_dist_fn = functools.partial(segmented_cosine_similarity, segment_size=1)
    
    model.train()
    print("  Training model...")
    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0
        for ep in range(EPISODES_PER_EPOCH):
            support_imgs, support_labels, query_imgs, query_labels = sample_episode(
                train_data, N_WAY, K_SHOT, N_QUERY)
            
            support_imgs = support_imgs.to(DEVICE)
            support_labels = support_labels.to(DEVICE)
            query_imgs = query_imgs.to(DEVICE)
            query_labels = query_labels.to(DEVICE)
            
            optimizer.zero_grad()
            
            support_emb = model(support_imgs)
            query_emb = model(query_imgs)
            
            logits = train_dist_fn(query_emb, support_emb)
            loss = criterion(logits, query_labels)
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
        if (epoch + 1) % 50 == 0 or (epoch + 1) == NUM_EPOCHS:
             print(f"    Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {epoch_loss/EPISODES_PER_EPOCH:.4f}")

    # 3. Test on ALL segment sizes using the SAME trained model
    print("  Testing on all segment sizes...")
    model.eval()
    
    # Pre-load test episodes for consistent testing across D? 
    # Or sample fresh for each D? Standard is sample fresh or fixed set.
    # Given the loop structure, sampling fresh per D is fine, but sampling SAME episodes for comparing D is better.
    # Let's sample a fixed set of episodes for this run's evaluation to be fair across D.
    
    test_episodes_data = []
    for _ in range(TEST_EPISODES):
        test_episodes_data.append(sample_episode(test_data, N_WAY, K_SHOT, N_QUERY))
        
    with torch.no_grad():
        for seg_size in segment_sizes:
            test_dist_fn = functools.partial(segmented_cosine_similarity, segment_size=seg_size)
            
            total_correct = 0
            total_total = 0
            
            for episode in test_episodes_data:
                support_imgs, support_labels, query_imgs, query_labels = episode
                
                support_imgs = support_imgs.to(DEVICE)
                support_labels = support_labels.to(DEVICE)
                query_imgs = query_imgs.to(DEVICE)
                query_labels = query_labels.to(DEVICE)
                
                support_emb = model(support_imgs)
                query_emb = model(query_imgs)
                
                logits = test_dist_fn(query_emb, support_emb)
                
                preds = logits.argmax(dim=1)
                pred_labels = support_labels[preds]
                total_correct += (pred_labels == query_labels).sum().item()
                total_total += query_labels.size(0)
            
            accuracy = 100.0 * total_correct / total_total
            results_per_size[seg_size].append(accuracy)
            # print(f"    D={seg_size}: {accuracy:.2f}%")
            
    # Clean up
    del model, optimizer
    torch.cuda.empty_cache()

# Aggregate Results
final_results = []
print("\n=== Final Aggregated Results ===")
for seg_size in segment_sizes:
    accs = results_per_size[seg_size]
    mean_acc = np.mean(accs)
    std_acc = np.std(accs)
    print(f"D={seg_size}: {mean_acc:.2f}% ± {std_acc:.2f}%")
    
    final_results.append({
        "Segment_Size_D": seg_size,
        "Mean_Accuracy": mean_acc,
        "Std_Dev": std_acc,
        "All_Accuracies": str(accs)
    })

In [None]:
# Save Results
df = pd.DataFrame(final_results)
df.to_csv("segcos_results.csv", index=False)
print("\nResults saved to segcos_results.csv")

# Plot with Error Bars
plt.figure(figsize=(10, 6))
plt.errorbar(df["Segment_Size_D"], df["Mean_Accuracy"], yerr=df["Std_Dev"], 
             marker='o', capsize=5, linestyle='-', linewidth=2, markersize=8)

plt.xscale('log', base=2)
plt.xlabel('Segment Size (D)')
plt.ylabel('Accuracy (%)')
plt.title(f'{DATASET} {N_WAY}-way {K_SHOT}-shot Accuracy vs Segment Size\n(Trained with Cosine, Tested with SegCos)')
plt.grid(True, which="both", ls="-", alpha=0.2)
plt.savefig("segcos_plot.png")
print("Plot saved to segcos_plot.png")
plt.show()