In [1]:
import sys

sys.path.insert(0, '..')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft

from pathlib import Path
import numpy as np
from torch.utils.data import DataLoader
from core.dataset_multimodal import collate_fn, ASASSNVarStarDataset
from functools import partial
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import random
import numpy as np
import json
import os

In [234]:
datapath = Path('../data/asaasn')
ds = ASASSNVarStarDataset(datapath, mode='train', verbose=True, only_periodic=True,
                          merge_type='inner', recalc_period=False, prime=True, use_bands=['v', 'g'],
                          only_sources_with_spectra=True, return_phased=False, fill_value=0)

fig, axs = plt.subplots(nrows=5, ncols=2, figsize=(10, 20))

for i in range(5):
    v, g = ds[i]['lcs'][0]
    time_v, flux_v, flux_error_v = v[:, 0], v[:, 1], v[:, 2]
    time_g, flux_g, flux_error_g = g[:, 0], g[:, 1], g[:, 2]
  
    axs[i, 0].errorbar(time_v, flux_v, yerr=flux_error_v, fmt='.')
    axs[i, 0].set_title(f'Entry {i} - V Band')

    axs[i, 1].errorbar(time_g, flux_g, yerr=flux_error_g, fmt='.')
    axs[i, 1].set_title(f'Entry {i} - G Band')

plt.show()

In [235]:
fig, axs = plt.subplots(nrows=5, ncols=2, figsize=(10, 20))

for i in range(5):
    v, g = ds[i]['lcs'][0]
    time_v, flux_v, flux_error_v = v[:, 0], v[:, 1], v[:, 2]
    time_g, flux_g, flux_error_g = g[:, 0], g[:, 1], g[:, 2]
  
    axs[i, 0].plot(flux_v, '.')
    axs[i, 0].set_title(f'Entry {i} - V Band')

    axs[i, 1].plot(flux_g, '.')
    axs[i, 1].set_title(f'Entry {i} - G Band')

plt.show()

In [236]:
ds[0]['lcs'][0][0][:, 0]

In [237]:
no_spectra_data_keys = ['lcs', 'metadata', 'classes']
no_spectra_collate_fn = partial(collate_fn, data_keys=no_spectra_data_keys, fill_value=0)

train_dataloader = DataLoader(ds, batch_size=2, shuffle=False, collate_fn=no_spectra_collate_fn)

In [238]:
batch, masks = next(iter(train_dataloader))
lcs, metadata, classes = batch

In [239]:
lcs.shape

In [257]:
fluxes = lcs[:, :, 1, :]
times = lcs[:, :, 0, :]

In [258]:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))

for i in range(B):
    flux_v, flux_g = fluxes[i, 0, :], fluxes[i, 1, :]
    time_v, time_g = times[i, 0, :], times[i, 1, :]

    flux_v = flux_v[flux_v != 0]
    flux_g = flux_g[flux_g != 0]
    time_v = time_v[time_v != 0]
    time_g = time_g[time_g != 0]
  
    axs[i, 0].plot(time_v, flux_v)
    axs[i, 0].set_title(f'Entry {i} - V Band')

    axs[i, 1].plot(time_g, flux_g)
    axs[i, 1].set_title(f'Entry {i} - G Band')
    
plt.show()

In [242]:
# Sorting along the last dimension and getting the indices
sorted_values, indices = torch.sort(lcs[:, :, 0, :], dim=-1)

# Using the indices to sort the other corresponding slices
sorted_lcs_1 = torch.gather(lcs[:, :, 1, :], -1, indices)
sorted_lcs_2 = torch.gather(lcs[:, :, 2, :], -1, indices)

# Reconstructing the sorted version of the original tensor
sorted_lcs = torch.stack([sorted_values, sorted_lcs_1, sorted_lcs_2], dim=2)

In [256]:
sorted_fluxes = sorted_lcs[:, :, 1, :]
sorted_times = sorted_lcs[:, :, 0, :]

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))

for i in range(B):
    sorted_flux_v, sorted_flux_g = sorted_fluxes[i, 0, :], sorted_fluxes[i, 1, :]
    sorted_time_v, sorted_time_g = sorted_times[i, 0, :], sorted_times[i, 1, :]

    sorted_flux_v = sorted_flux_v[sorted_flux_v != 0]
    sorted_flux_g = sorted_flux_g[sorted_flux_g != 0]
    sorted_time_v = sorted_time_v[sorted_time_v != 0]
    sorted_time_g = sorted_time_g[sorted_time_g != 0]
    
    axs[i, 0].plot(sorted_time_v, sorted_flux_v)
    axs[i, 0].set_title(f'Entry {i} - V Band')

    axs[i, 1].plot(sorted_time_g, sorted_flux_g)
    axs[i, 1].set_title(f'Entry {i} - G Band')

plt.show()

In [261]:
def FFT_for_Period(x, k=2):
    # [B, T, C]
    xf = torch.fft.rfft(x, dim=1)
    # find period by amplitudes
    frequency_list = abs(xf).mean(0).mean(-1)
    frequency_list[0] = 0
    _, top_list = torch.topk(frequency_list, k)
    top_list = top_list.detach().cpu().numpy()
    period = x.shape[1] // top_list
    return period, abs(xf).mean(-1)[:, top_list]

In [262]:
sorted_fluxes.shape

In [263]:
sorted_fluxes = sorted_fluxes.transpose(1, 2)
sorted_fluxes.shape

In [266]:
period_list, period_weight = FFT_for_Period(sorted_fluxes, k=5)

In [267]:
period_list, period_weight

In [197]:
period = period_list[0]

In [192]:
B, T, N = fluxes.shape

In [200]:
out = fluxes.reshape(B, T // period, period, N).permute(0, 3, 1, 2).contiguous()

In [201]:
out.shape