In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
from torchvision.models import ResNet50_Weights
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from glob import glob
import os
import re
import pandas as pd
from PIL import Image


from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

In [None]:
model = models.resnet50()
train, val = get_graph_node_names(model)
val

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet50_Weights
from torchvision.models.feature_extraction import create_feature_extractor
import matplotlib.pyplot as plt
import numpy as np
import time

class ExponentialDecayLayer(nn.Module):
    def __init__(self, decay_rate=0.1, delay_ms=400):
        super(ExponentialDecayLayer, self).__init__()
        self.decay_rate = torch.tensor(decay_rate)
        self.delay_ms = delay_ms
        
    def forward(self, x):
        # Convert delay_ms to seconds and apply final decay value
        t = torch.tensor(self.delay_ms / 1000.0)  # convert to seconds
        decay = torch.exp(-self.decay_rate * t)
        print(f"Applying decay factor: {decay.item():.3f} for delay {self.delay_ms}ms")
        return x * decay

def plot_multiple_decay_curves(decay_rates, max_delay_ms=1200, num_points=100, save_path=None):
    """
    Plot multiple decay curves with different decay rates on the same plot.
    
    Args:
        decay_rates (list): List of decay rates to plot
        max_delay_ms (int): Maximum delay in milliseconds
        num_points (int): Number of points to plot
        save_path (str or Path, optional): Path to save the plot image
    """
    # Define a color palette
    colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
    
    times = np.linspace(0, max_delay_ms, num_points)
    
    # Create square figure
    plt.figure(figsize=(8, 8))
    
    # Plot each decay rate with a different color
    for i, decay_rate in enumerate(decay_rates):
        decay_values = np.exp(-decay_rate * (times / 1000.0))
        color = colors[i % len(colors)]  # Cycle through colors if more decay rates than colors
        plt.plot(times, decay_values, color=color, linestyle='-', 
                label=f'Decay rate = {decay_rate:.1f}', linewidth=2)
    
    plt.xlabel('Delay (ms)', fontsize=12)
    plt.ylabel('Decay Factor', fontsize=12)
    plt.title('Comparison of Different Exponential Decay Rates', fontsize=14)
    plt.grid(False)  # Turn off grid
    plt.legend(fontsize=10)
    
    # Add key timepoints markers
    key_times = [0, 100, 200, 300, 400, 800, 1200]
    plt.xticks(key_times, labels=key_times, fontsize=10)
    
    # Remove top and right spines
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nPlot saved to: {save_path}")
    
    plt.show()
    
    # Print decay values at key timepoints for each rate
    print("\nDecay values at key timepoints:")
    for decay_rate in decay_rates:
        print(f"\nDecay rate = {decay_rate:.1f}")
        for t in key_times:
            decay = np.exp(-decay_rate * (t / 1000.0))
            print(f"At {t}ms: {decay:.3f}")

# Example usage with multiple decay rates
decay_rates = [0.1, 1.0, 1.5, 2.5, 10.0]
plot_multiple_decay_curves(decay_rates, save_path="combined_decay_rates.png")

In [None]:
def extract_number(filename):
    # Extract the number from the filename
    match = re.search(r'im(\d+)\.png', filename)
    if match:
        return int(match.group(1))
    return 0  # Return 0 if no number is found

def sort_filenames(filenames):
    # Sort the filenames based on the extracted number
    return sorted(filenames, key=extract_number)

class HVM200SequenceDataset(Dataset):
    def __init__(self, image_path, transform=None):
        self.image_path = image_path
        self.num_timesteps = num_timesteps
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        
        # Sort image files
        self.images = sort_filenames(glob(os.path.join(image_path, "*.png")))
        
        # Load metadata
        meta_data_path = os.path.join(image_path, "working_memory_images_labels.csv")
        self.meta_data = pd.read_csv(meta_data_path)
        self.meta_data["img_path"] = self.images
        
        # Create a mapping of unique objects to integer labels
        self.label_to_idx = {obj: i for i, obj in enumerate(self.meta_data["object"].unique())}
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)
        
        label = self.meta_data.loc[self.meta_data["img_path"] == img_path, "object"].values[0]
        label_idx = self.label_to_idx[label]
        
        return image, label_idx
    
    def hvm200_to_coco1600(self, label):
        hvm200_to_coco1600 = {
            "bear": "bear",
            "elephant": "ELEPHANT_M",
            "person": "face0001",
            "car": "alfa155",
            "dog": "breed_pug",
            "apple": "Apple_Fruit_obj",
            "chair": "_001",
            "plane": "f16",
            "bird": "lo_poly_animal_CHICKDEE",
            "zebra": "zebra"
        }
        return hvm200_to_coco1600.get(label, label)

# Usage example:
image_path = "../hypothesis_wm_experiments/data/hvm200/"

num_timesteps = 4  # 100ms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Grayscale(3),
    transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
])

dataset = HVM200SequenceDataset(image_path, transform=transform)
img, label = dataset[0]
img.shape

In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self, decay_rate=0.1, delay_ms=400):
        super(FeatureExtractor, self).__init__()
        
        self.base_model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.feature_extractor = create_feature_extractor(
            self.base_model, 
            return_nodes={
                "layer4.1.bn1": "pre_decay",
            }
        )
        
        self.decay = ExponentialDecayLayer(decay_rate=decay_rate, delay_ms=delay_ms)
        
    def forward(self, x):
        # Get features before decay
        features = self.feature_extractor(x)
        pre_decay_features = features["pre_decay"]
        
        # Apply decay
        post_decay_features = self.decay(pre_decay_features)
        
        return pre_decay_features, post_decay_features


def extract_features(dataset, model, device, batch_size=32):
    """Extract features for all images in the dataset"""
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    pre_decay_features = []
    post_decay_features = []
    labels = []
    
    model.eval()
    with torch.no_grad():
        for batch_imgs, batch_labels in dataloader:
            batch_imgs = batch_imgs.to(device)
            pre_decay, post_decay = model(batch_imgs)
            
            pre_decay_features.append(pre_decay.cpu())
            post_decay_features.append(post_decay.cpu())
            labels.append(batch_labels)
            
    pre_decay_features = torch.cat(pre_decay_features, dim=0)
    post_decay_features = torch.cat(post_decay_features, dim=0)
    labels = torch.cat(labels, dim=0)
    
    return pre_decay_features, post_decay_features, labels

In [1]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset parameters
image_path = "../hypothesis_wm_experiments/data/hvm200/"
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Grayscale(3),
    transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
])

# Create dataset
dataset = HVM200SequenceDataset(image_path, transform=transform)
# Create model for different delays
delays = [0, 100, 400, 800, 1200]  # milliseconds
decay_rate = 0.1


# Dictionary to store features
all_features = {}

for delay in delays:
    print(f"\nExtracting features for {delay}ms delay...")
    
    # Create model
    model = FeatureExtractor(decay_rate=decay_rate, delay_ms=delay)
    model = model.to(device)
    model.eval()
    
    # Extract features
    pre_decay, post_decay, labels = extract_features(dataset, model, device)
    
    # Store features
    all_features[delay] = {
        'pre_decay': pre_decay,
        'post_decay': post_decay,
        'labels': labels
    }
    
    print(f"Extracted features shapes:")
    print(f"Pre-decay: {pre_decay.shape}")
    print(f"Post-decay: {post_decay.shape}")
    
    # Save features
    save_dir = f"extracted_features/{decay_rate}decay"
    os.makedirs(save_dir, exist_ok=True)
    
    torch.save({
        'pre_decay': pre_decay,
        'post_decay': post_decay,
        'labels': labels,
        'decay_rate': decay_rate,
        'delay_ms': delay
    }, os.path.join(save_dir, f'features_delay{delay}ms.pt'))
    
    # Calculate average decay effect
    decay_effect = (post_decay / (pre_decay + 1e-6)).mean().item()
    print(f"Average decay factor: {decay_effect:.3f}")

NameError: name 'torch' is not defined