In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt

In [3]:
def remove_end_line(x, y, tol=1e-6):
    x = np.asarray(x)
    y = np.asarray(y)
    
    if len(x) < 3:
        return x, y
    
    last_point = np.array([x[-1], y[-1]])
    second_last_point = np.array([x[-2], y[-2]])
    line_vec = last_point - second_last_point
    
    for i in range(len(x) - 2, 0, -1):
        point = np.array([x[i], y[i]])
        vec = point - second_last_point
        cross = line_vec[0] * vec[1] - line_vec[1] * vec[0]
        if abs(cross) > tol:
            return x[:i+1], y[:i+1]
    
    return x, y

def resample_trajectory(r, theta):
    """
    Resample the trajectory defined by (r, theta) to new_theta values using linear interpolation.
    
    Parameters:
    - r: array-like, original radius values
    - theta: array-like, original theta values
    
    Returns:
    - new_r: array-like, interpolated radius values corresponding to new_theta
    """
    max_theta = theta[-1]
    steps_per_rotation = 100
    theta_linespace = np.linspace(0, max_theta, int(steps_per_rotation * (max_theta // (2 * np.pi))))
    
    r = np.asarray(r)
    theta = np.asarray(theta)
    
    new_r = np.interp(theta_linespace, theta, r)
    return new_r, theta_linespace

def plot_segmented_spiral(
    gen_theta,
    gen_r,
    num_segments=10,
    num_points=500,
    colors=None,
    figsize=(12, 6),
    plot_title=None,
    axes=None,
    original_data=None,
):
    """
    Generates and plots a spiral with segments of different colors.

    Parameters:
    - num_segments (int): The number of colored segments to divide the plot into.
    - num_points (int): The total number of points to generate for the spiral. Default is 500.
    - theta_end (int): The number of full rotations (2*pi) for the spiral. Default is 10.
    - noise_range (tuple): A tuple (min, max) for adding random noise to the radius.
                           Default is (-10, 10).
    - colors (list or None): A list of colors to use for the segments. If None,
                             Matplotlib's default color cycle will be used.
    - figsize (tuple): A tuple (width, height) for the figure size. Default is (12, 6).
    - plot_title (str or None): Title for the entire plot. If None, 'Spiral Generation with {num_segments} Segments' is set.
    - axes (tuple or None): A tuple of Matplotlib axes to plot on. If None, new axes will be created.
    """
    
    points_per_segment = num_points // num_segments
    segments_theta = [
        gen_theta[i * points_per_segment : (i + 1) * points_per_segment + 1]
        for i in range(num_segments)
    ]
    segments_r = [
        gen_r[i * points_per_segment : (i + 1) * points_per_segment + 1]
        for i in range(num_segments)
    ]
    
    # Use default colors if none are provided
    if colors is None:
        cmap = plt.colormaps['tab20']
        colors = [cmap(i) for i in range(num_segments)]
    colors = [colors[i % len(colors)] for i in range(num_segments)]

    # 3. Create the plots
    if axes is None:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, gridspec_kw={'wspace': 0.2})
        ax2 = fig.add_subplot(1, 2, 2, projection='polar')
    else:
        ax1, ax2 = axes

    # Cartesian subplot (Radius vs Theta)
    if original_data is not None:
            orig_theta, orig_r = original_data
            ax1.plot(np.unwrap(orig_theta), orig_r, color='gray', alpha=0.4, linestyle='--', label='Original Data')
    for i in range(num_segments):
        ax1.plot(np.unwrap(segments_theta[i]), segments_r[i], color=colors[i])
    
    ax1.set_title("Radius vs Theta")
    ax1.grid(True)
    
    theta_end = int(gen_theta[-1] / np.pi)
    if original_data is not None:
        orig_theta, _ = original_data
        orig_theta = np.unwrap(orig_theta)
        orig_theta_end = int(orig_theta[-1] / np.pi)
        if orig_theta_end > theta_end:
            theta_end = orig_theta_end
    tick_positions = np.linspace(0, theta_end * np.pi, theta_end + 1)
    tick_labels = [f'{i}Ï€' for i in range(theta_end + 1)]
    
    ax1.set_xticks(tick_positions, tick_labels)
    ax1.set_xlabel("Theta (radians)")
    ax1.set_ylabel("Radius")
    if original_data is not None:
        ax1.legend()

    # Polar subplot (Spiral)
    if original_data is not None:
        orig_theta, orig_r = original_data
        ax2.plot(orig_theta, orig_r, color='gray', alpha=0.4, linestyle='--', label='Original Data')
    for i in range(num_segments):
        ax2.plot(segments_theta[i], segments_r[i], color=colors[i])
    
    ax2.set_title("Generated spiral")
    ax2.grid(False)
    ax2.axis('off')
    
    if plot_title:
        plt.suptitle(plot_title)
    else:
        plt.suptitle(f"Spiral Generation with {num_segments} Segments")
    
    if axes is None:
        plt.show()

    return ax1, ax2
    
def plot_losses(train_losses, val_losses=None, figsize=(8, 6)):
    plt.figure(figsize=figsize)
    plt.plot(train_losses, label='Training Loss')
    if val_losses is not None:
        plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.show()

In [4]:
class SpiralDataset(Dataset):
    def __init__(self, folder, preload=True):
        self.files = list(Path(folder).glob("*.csv"))
        self.preload = preload
        if preload:
            self.data = [pd.read_csv(f)[['r', 'theta']].values.astype('float32') for f in self.files]
        else:
            self.data = None

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if self.preload:
            arr = self.data[idx]
        else:
            df = pd.read_csv(self.files[idx])
            arr = df[['r', 'theta']].values.astype('float32')
        return torch.tensor(arr)

In [None]:
class SpiralAutoencoderCNN(nn.Module):
    def __init__(self, input_len, latent_dim=16, hidden_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_len, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_len)
        )
    
    def forward(self, x):
        # x: (B, T, 2)
        B, T, D = x.size()
        x_flat = x.view(B, T*D)        # flatten spiral
        latent = self.encoder(x_flat)  # single vector per spiral
        out_flat = self.decoder(latent)
        out = out_flat.view(B, T, D)
        return out, latent