In [None]:
from utils import *

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
df = generate_recordings_csv(
    "/media/daniel/Data3/Gotcha_data/recordings/first_drone_exp/recording_f2"
)

In [None]:
# Matplotlib plots


def plot_histograms(df):
    """Plot histograms for Range, snr, RCS, Angle, and Elevation."""
    plt.figure(figsize=(15, 10))

    # Range Histogram
    plt.subplot(3, 3, 1)
    sns.histplot(df["range"].dropna(), bins=30, kde=True, color="skyblue")
    plt.title("Distribution of range")

    # snr Histogram
    plt.subplot(3, 3, 2)
    sns.histplot(df["snr"].dropna(), bins=30, kde=True, color="salmon")
    plt.title("Distribution of snr")

    # RCS Histogram
    plt.subplot(3, 3, 3)
    sns.histplot(df["rcs"].dropna(), bins=30, kde=True, color="limegreen")
    plt.title("Distribution of rcs")

    # Angle Histogram
    plt.subplot(3, 3, 4)
    sns.histplot(df["angle"].dropna(), bins=30, kde=True, color="orchid")
    plt.title("Distribution of angle_d")

    # Elevation Histogram
    plt.subplot(3, 3, 5)
    sns.histplot(df["elevation"].dropna(), bins=30, kde=True, color="gold")
    plt.title("Distribution of elevation_d")

    # Elevation Histogram
    plt.subplot(3, 3, 6)
    sns.histplot(df["dopInd"].dropna(), bins=30, kde=True, color="green")
    plt.title("Distribtion of dop_ind")


def plot_scatter_plots(df):
    """Plot scatter plots to visualize relationships between variables."""
    plt.figure(figsize=(15, 10))

    # Range vs snr
    plt.subplot(2, 2, 1)
    sns.scatterplot(data=df, x="range", y="snr", palette="deep")
    plt.title("Range vs snr")

    # Range vs RCS
    plt.subplot(2, 2, 2)
    sns.scatterplot(data=df, x="range", y="rcs", palette="muted")
    plt.title("Range vs rcs")

    # Angle vs snr
    plt.subplot(2, 2, 3)
    sns.scatterplot(data=df, x="angle", y="snr", palette="bright")
    plt.title("Angle vs snr")

    # Elevation vs Range
    plt.subplot(2, 2, 4)
    sns.scatterplot(data=df, x="elevation", y="range", palette="dark")
    plt.title("Elevation vs range")

    plt.tight_layout()
    plt.show()


def plot_correlations(df):
    """Plot a heatmap of correlations between numerical variables."""
    plt.figure(figsize=(12, 10))
    corr = df.select_dtypes(include=["float64", "int64"]).corr()
    sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f")
    plt.title("Correlation Heatmap")
    plt.show()


def statistics(df, limit=10):
    detections = []

    for idx, row in tqdm(df.sample(n=limit).iterrows()):
        try:
            a = np.load(row["full_path"])
        except:
            print(f"Can`t open {row['full_path']}")
            continue

        detections += detections_map_to_dicts(a)

    det_df = pd.DataFrame(detections)

    px.histogram(df, y="numDetections", x="timestamp").show()

    plot_histograms(det_df)
    plot_correlations(det_df)
    plot_scatter_plots(det_df)

    pass


# Dry statistics
statistics(df, limit=10)

In [None]:
def rd_crops(ranges, angles, elevations, n_time=10, start_time=0):

    sss = []

    for idx, row in df.iloc[start_time : start_time + n_time].iterrows():
        try:
            det_map = np.load(row["full_path"])
        except:
            print(f"Can`t open {row['full_path']}")
            continue

        radar3 = get_radar3(det_map)
        sss.append(radar3[ranges, angles, elevations, :].sum(axis=(1, 2)).T)

    return np.array(sss)

In [None]:
s1 = rd_crops(
    ranges=slice(0, 600),
    angles=slice(0, 120),
    elevations=slice(0, 120),
    n_time=100,
    start_time=1000,
)

In [None]:
px.imshow(np.fft.fftshift(s1, axes=1), animation_frame=0, title="RD map").show()

In [None]:
s2 = rd_crops(
    ranges=slice(0, 150), angles=slice(55, 65), elevations=slice(55, 65), n_time=50, start_time=100
)

In [None]:
# px.imshow(s2, animation_frame=0,title='RD map',range_color=[s2.min(),s2.max()]).show()
px.imshow(s2, animation_frame=0, title="RD map").show()

In [None]:
s3 = rd_crops(
    ranges=slice(0, 150), angles=slice(60, 61), elevations=slice(60, 61), n_time=50, start_time=100
)

In [None]:
px.imshow(s3, animation_frame=0, title="RD map").show()

In [None]:
def timeDoppler(df, n_time=10, start_time=0):
    sss = []
    df_ff = df

    for idx, row in df.iloc[start_time : start_time + n_time].iterrows():
        try:
            det_map = np.load(row["full_path"])
        except:
            print(f"Can`t open {row['full_path']}")
            continue
        try:
            # dop_row = det_map[np.random.randint(1,det_map.shape[0]),6:]
            dop_row = det_map[np.argmax(det_map[:, 4]), 6:]
        except:
            dop_row = np.zeros(det_map.shape[1] - 6)

        # dop_row = dop_row/dop_row.max()
        # dop_row = det_map[:,6:].sum(axis=0)/det_map.shape[0]

        sss.append(dop_row)

    td = np.array(sss).T

    return td

In [None]:
# path = '/home/daniel/Documents/Gotcha/recordings/first_drone_exp/'
path = "/media/daniel/Data3/Gotcha_data/recordings/drone_exp_2_28_11_24/"
for f in os.listdir(path):
    n_time = -1
    start_time = 0
    df = generate_recordings_csv(os.path.join(path, f))
    td = timeDoppler(df, n_time=n_time, start_time=start_time)
    # td = np.fft.fftshift(td,axes=0)

    regions = detect_continuous_velocity_regions(
        td, blur_kernel=(5, 5), canny_threshold1=50, canny_threshold2=150, min_width=10
    )
    print("Continuous Time Intervals:", regions)
    # px.imshow(td,title='TD map').show()

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np


def detect_continuous_velocity_regions(
    heatmap, blur_kernel=(5, 5), canny_threshold1=50, canny_threshold2=150, min_width=10
):
    """
    Detects continuous velocity regions in a 2D heatmap using image processing techniques.

    Parameters:
    - heatmap: 2D NumPy array representing the velocity-time heatmap
    - blur_kernel: Tuple representing the Gaussian blur kernel size
    - canny_threshold1: Lower threshold for the Canny edge detection
    - canny_threshold2: Upper threshold for the Canny edge detection
    - min_width: Minimum width of continuous regions to consider

    Returns:
    - continuous_intervals: List of tuples representing continuous time intervals
    """

    # Step 1: Normalize the heatmap to the range 0-255 for better processing
    normalized_heatmap = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)

    # Step 2: Apply a Gaussian blur to smooth the heatmap
    blurred_heatmap = cv2.GaussianBlur(normalized_heatmap, blur_kernel, 0)

    # Step 3: Use the Canny edge detection algorithm to find edges
    edges = cv2.Canny(blurred_heatmap, threshold1=canny_threshold1, threshold2=canny_threshold2)

    # Step 4: Find contours from the edges
    contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Step 5: Identify continuous regions based on contour analysis
    continuous_intervals = []

    for contour in contours:
        # Get the bounding box of the contour
        x, y, w, h = cv2.boundingRect(contour)

        # Only consider contours that span a significant portion of the time axis
        if w > min_width:  # Adjust this threshold based on your data
            continuous_intervals.append((x, x + w))

    # Step 6: Plot the original and processed heatmaps
    plt.figure(figsize=(12, 6))

    # Original heatmap
    plt.subplot(2, 1, 1)
    plt.imshow(heatmap, aspect="auto", cmap="viridis", origin="lower")
    plt.colorbar(label="Velocity")
    plt.xlabel("Time")
    plt.ylabel("Velocity")
    plt.title("Original Heatmap")

    # Heatmap with continuous regions highlighted
    mask = np.zeros(td.shape)
    for x1, x2 in continuous_intervals:
        mask[:, x1:x2] = 1

    plt.subplot(2, 1, 2)
    plt.imshow(heatmap * mask, aspect="auto", cmap="viridis", origin="lower")
    # for interval in continuous_intervals:
    #     plt.axvspan(interval[0], interval[1], color='red', alpha=0.3)
    plt.colorbar(label="Velocity")
    plt.xlabel("Time")
    plt.ylabel("Velocity")
    plt.title("Heatmap with Continuous Regions Highlighted")

    plt.tight_layout()
    plt.show()

    return continuous_intervals


# Example usage with sample heatmap data (replace with your actual heatmap)
heatmap = td  # Example data
print("Continuous Time Intervals:", regions)

In [None]:
n_time = -1
start_time = 0
df = generate_recordings_csv("/home/daniel/Documents/Gotcha/first_drone_exp_filtered")
td = timeDoppler(df, n_time=n_time, start_time=start_time)
regions = detect_continuous_velocity_regions(
    td, blur_kernel=(5, 5), canny_threshold1=50, canny_threshold2=150, min_width=10
)

px.imshow(td[:, :1000], title="TD map").show()

In [None]:
# plot map
try:
    det_map = np.load(df.iloc[0]["full_path"])
except:
    print("Can`t open file")
    pass
print(det_map.shape)
px.imshow(det_map[:, 6:].T, title="Det map").show()

In [None]:
# Sort by range
det_map = det_map[det_map[:, 0].argsort()]

px.imshow(det_map[:, 6:].T, title="Det map").show()

In [None]:
from sklearn.manifold import TSNE


def tsne_plot(df, sss=None, n_time=10, start_time=0, prep=30, model=None):
    Ldrone = []
    if sss == None:
        sss = []

        for idx, row in df.iloc[start_time : start_time + n_time].iterrows():
            try:
                det_map = np.load(row["full_path"])
                if model != None:
                    Ldrone.append(model(det_map))

            except:
                print(f"Can`t open {row['full_path']}")
                continue

            # sss.append(det_map[np.argmax(det_map[:,4]),6:])
            sss.append(det_map[:, 6:])

        sss = np.concatenate(sss, axis=0)
        print(f"Len of sss: {len(sss)}")
    # Apply t-SNE to reduce dimensions to 2
    tsne = TSNE(n_components=2, random_state=42, perplexity=prep)
    reduced_vectors = tsne.fit_transform(np.array(sss))

    # Plot the reduced vectors
    plt.figure(figsize=(8, 6))
    plt.scatter(reduced_vectors[:, 0], reduced_vectors[:, 1], s=30, alpha=0.7)
    plt.title("t-SNE Visualization ")
    plt.xlabel("t-SNE Dimension 1")
    plt.ylabel("t-SNE Dimension 2")
    plt.grid(True)
    plt.show()

    return reduced_vectors, sss, Ldrone

In [None]:
# df = generate_recordings_csv('/home/daniel/Documents/Gotcha/first_drone_exp_filtered')
df = generate_recordings_csv(
    "/home/daniel/Documents/Gotcha/recordings/first_drone_exp/recording_f2"
)
sss = None
rv, sss, ldrone = tsne_plot(df, sss, 100, 200, prep=128)

In [None]:
# px.imshow(np.array(sss).T,title='Det map').show()

thr_rule = rv[:, 0] > 30
# thr_rule = ((rv[:,1])**2 + (rv[:,0])**2) < 900


d1 = np.array(sss)
# d1[thr_rule] = np.zeros(d1.shape[1])
d1 = d1[thr_rule]
px.imshow(d1.T, title="Det map").show()
d1 = np.array(sss)
# d1[~thr_rule] = np.zeros(d1.shape[1])
d1 = d1[~thr_rule]
px.imshow(d1.T, title="Det map").show()

In [None]:
rv[:, 0] > 25

In [None]:
import torch

ckpt_path = "test2.ckpt"
ckpt = load_shitty_ckpt(ckpt_path)
model = Simplest1DCnn()
model.load_state_dict(ckpt)
model.eval()


sss = None
rv, sss, ldrone = tsne_plot(df, sss, 10, 0, prep=128, model=model)

In [None]:
df