In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import itertools
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
import h5py
import os
import yaml
import torchvision
import EIANN.plot as plot

In [None]:
N = 2000
K = 4
sigma = 0.16

t = torch.linspace(0, 1, N) # Generate linearly spaced values from 0 to 1, used as the parameter that varies along the length of the spiral
X = torch.zeros(K*N, 2)
y = torch.zeros(K*N)
idx = torch.arange(K*N)

In [None]:
# arm_index is the offset or phase shift, allowing the function to generate points for each arm in the spiral
for arm_index in range(K):
    X[arm_index*N:(arm_index+1)*N, 0] = t*(torch.sin(2*np.pi/K*(2*t+arm_index)) + sigma*torch.randn(N))
    X[arm_index*N:(arm_index+1)*N, 1] = t*(torch.cos(2*np.pi/K*(2*t+arm_index)) + sigma*torch.randn(N))
    y[arm_index*N:(arm_index+1)*N] = arm_index

In [None]:
plt.scatter(range(len(y)), y, c=y, cmap='viridis', marker='o')

In [None]:
all_data = []
targets = []

for index, (data, label) in enumerate(zip(X, y)):
    print(f"Index: {index}, Label: {label}") 
    target = torch.eye(K)[int(label)]
    print(f"Target: {target}") 
    targets.append(target)

    all_data.append((index, data, target))

# Print the first few targets to verify
print("First few targets:")
for i in range(min(10, len(targets))):
    print(targets[i])
print("Last few targets:")
for i in range(max(0, len(targets)-10), len(targets)):
    print(targets[i])

In [None]:
num_samples = len(all_data)
num_samples

In [None]:
# Split data into train/test
test_size = int(0.15 * num_samples)
val_size = int(0.15 * num_samples)
train_size = num_samples - (test_size + val_size)

test_end_idx = test_size
val_end_idx = test_size + val_size

test_data = all_data[:test_end_idx]
val_data = all_data[test_end_idx:val_end_idx]
train_data = all_data[val_end_idx:]

In [None]:
data_generator = torch.Generator()
batch_size = 1

In [None]:
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, generator=data_generator)
val_loader = DataLoader(val_data, batch_size=len(val_data), shuffle=False, num_workers=0)
test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False, num_workers=0)

In [None]:
def generate_spiral_data(N = 2000, K = 4, sigma = 0.16):
    t = torch.linspace(0, 1, N) # Generate linearly spaced values from 0 to 1, used as the parameter that varies along the length of the spiral
    X = torch.zeros(K*N, 2)
    y = torch.zeros(K*N)

    # arm_index is the offset or phase shift, allowing the function to generate points for each arm in the spiral
    for arm_index in range(K):
        X[arm_index*N:(arm_index+1)*N, 0] = t*(torch.sin(2*np.pi/K*(2*t+arm_index)) + sigma*torch.randn(N))
        X[arm_index*N:(arm_index+1)*N, 1] = t*(torch.cos(2*np.pi/K*(2*t+arm_index)) + sigma*torch.randn(N))
        y[arm_index*N:(arm_index+1)*N] = arm_index    

    all_data = []
    for index, (data, label) in enumerate(zip(X, y)):
        target = torch.eye(K)[int(label)]
        all_data.append((index, data, target))

    return all_data

In [None]:
train_data = generate_spiral_data(N=1400)
val_data = generate_spiral_data(N=300)
test_data = generate_spiral_data(N=300)

In [None]:
data_generator = torch.Generator()
batch_size = 1

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, generator=data_generator)
val_loader = DataLoader(val_data, batch_size=len(val_data), shuffle=False, num_workers=0)
test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False, num_workers=0)