In [1]:
import sys

sys.path.insert(0, '..')

In [2]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.functional as F
from transformers import TimeSeriesTransformerConfig, TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel 
from transformers.models.time_series_transformer.modeling_time_series_transformer import TimeSeriesTransformerEncoder
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy import stats
import json 
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import seaborn as sns
from datetime import datetime
from torch.utils.data import Dataset, DataLoader

from pathlib import Path
from core.multimodal.dataset import collate_fn, ASASSNVarStarDataset
from functools import partial
import matplotlib.pyplot as plt
import pandas as pd
from astropy.io import fits

from core.spectra.dataset import SpectraVDataset
from core.spectra.model import GalSpecNet

In [55]:
random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True

In [56]:
data_root = '/home/mariia/AstroML/data/asassn'
lamost_spec_file='Spectra/lamost_spec.csv'
lamost_spec_dir='Spectra/v2'
# v_file = 'asassn_catalog_full.csv'
v_file = 'v.csv'
spectra_v_file = 'spectra_v_merged.csv'

In [5]:
CLASSES = ['CWA', 'CWB', 'DCEP', 'DCEPS', 'DSCT', 'EA', 'EB', 'EW',
           'HADS', 'M', 'ROT', 'RRAB', 'RRC', 'RRD', 'RVA', 'SR']

In [7]:
def train_epoch():
    model.train()
    total_loss = []
    total_correct_predictions = 0
    total_predictions = 0

    for fluxes, label in tqdm(train_dataloader):
        fluxes, label = fluxes.to(device), label.to(device)

        optimizer.zero_grad()

        logits = model(fluxes)
        loss = criterion(logits, label)
        total_loss.append(loss.item())

        probabilities = torch.nn.functional.softmax(logits, dim=1)
        _, predicted_labels = torch.max(probabilities, dim=1)
        correct_predictions = (predicted_labels == label).sum().item()

        total_correct_predictions += correct_predictions
        total_predictions += label.size(0)

        loss.backward()
        optimizer.step()

    return sum(total_loss) / len(total_loss), total_correct_predictions / total_predictions
    
def val_epoch():
    model.eval()
    total_loss = []
    total_correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for fluxes, label in tqdm(val_dataloader):
            fluxes, label = fluxes.to(device), label.to(device)

            logits = model(fluxes)
            loss = criterion(logits, label)
            total_loss.append(loss.item())

            probabilities = torch.nn.functional.softmax(logits, dim=1)
            _, predicted_labels = torch.max(probabilities, dim=1)
            correct_predictions = (predicted_labels == label).sum().item()
    
            total_correct_predictions += correct_predictions
            total_predictions += label.size(0)

    return sum(total_loss) / len(total_loss), total_correct_predictions / total_predictions

In [20]:
class SpectraVDataset(Dataset):
    def __init__(self, data_root, lamost_spec_dir, spectra_v_file, split='train', classes=None, z_corr=False):
        self.data_root = data_root
        self.lamost_spec_dir = os.path.join(data_root, lamost_spec_dir)
        self.spectra_v_file = os.path.join(data_root, spectra_v_file)
        self.split = split
        self.z_corr = z_corr

        self.df = pd.read_csv(self.spectra_v_file)
        self.df = self.df[['edr3_source_id', 'variable_type', 'spec_filename']]

        if classes:
            self.df = self.df[self.df['variable_type'].isin(classes)]

        self._split()

        self.id2target = {i: x for i, x in enumerate(sorted(self.df['variable_type'].unique()))}
        self.target2id = {v: k for k, v in self.id2target.items()}
        self.num_classes = len(self.id2target)

    def _split(self):
        total_size = len(self.df)
        train_size = int(total_size * 0.7)
        val_size = int(total_size * 0.15)

        shuffled_df = self.df.sample(frac=1, random_state=42)

        if self.split == 'train':
            self.df = shuffled_df[:train_size]
        elif self.split == 'val':
            self.df = shuffled_df[train_size:train_size + val_size]
        elif self.split == 'test':
            self.df = shuffled_df[train_size + val_size:]
        else:
            self.df = shuffled_df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        el = self.df.iloc[idx]
        variable_type, spec_filename = el['variable_type'], el['spec_filename']
        label = self.target2id[variable_type]

        # read spectra
        spectra = self.readLRSFits(os.path.join(self.lamost_spec_dir, spec_filename))
        original_wavelengths, fluxes = spectra[:, 0], spectra[:, 1]

        # interpolate
        wavelengths = np.arange(3850, 9000, 2)
        fluxes = np.interp(wavelengths, original_wavelengths, fluxes)

        # normalize
        # fluxes = (fluxes - fluxes.mean()) / fluxes.std()
        N = np.sum(fluxes ** 2)
        fluxes = fluxes / np.sqrt(N)

        # reshape so batches are [N, 1, (9000-3850)//2]
        fluxes = fluxes.reshape(1, -1).astype(np.float32)

        return fluxes, label

    def readLRSFits(self, filename):
        """
        Read LAMOST fits file
          adapted from https://github.com/fandongwei/pylamost

        Parameters:
        -----------
        filename: str
          name of the fits file
        z_corr: bool.
          if True, correct for measured radial velocity of star

        Returns:
        --------
        spec: numpy array
          wavelength, flux, inverse variance
        """

        hdulist = fits.open(filename)
        len_list = len(hdulist)

        if len_list == 1:
            head = hdulist[0].header
            scidata = hdulist[0].data
            coeff0 = head['COEFF0']
            coeff1 = head['COEFF1']
            pixel_num = head['NAXIS1']
            specflux = scidata[0,]
            ivar = scidata[1,]
            wavelength = np.linspace(0, pixel_num - 1, pixel_num)
            wavelength = np.power(10, (coeff0 + wavelength * coeff1))
            hdulist.close()
        elif len_list == 2:
            head = hdulist[0].header
            scidata = hdulist[1].data
            wavelength = scidata[0][2]
            ivar = scidata[0][1]
            specflux = scidata[0][0]
        else:
            raise ValueError(f'Wrong number of fits files. {len_list} should be 1 or 2')

        if self.z_corr:
            try:
                # correct for radial velocity of star
                redshift = head['Z']
            except Exception as e:
                print(e, 'Setting redshift to zero')
                redshift = 0.0

            wavelength = wavelength - redshift * wavelength

        return np.vstack((wavelength, specflux, ivar)).T

In [43]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, stride=3, short_cut=True):
        super(ResNetBlock, self).__init__()
        
        self.short_cut = short_cut
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=2)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=stride)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.pool = nn.MaxPool1d(kernel_size=kernel_size, stride=1, padding=2)
        
        if self.short_cut:
            self.shortcut_conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride)

    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)
        out = self.pool(out)
        
        if self.short_cut:
            identity = self.shortcut_conv(identity)
        
        out += identity
        
        return out

class ResNet(nn.Module):
    def __init__(self, num_classes, dropout=0.5):
        super(ResNet, self).__init__()

        self.model = nn.Sequential(
            ResNetBlock(1, 16),
            ResNetBlock(16, 32),
            ResNetBlock(32, 64),
            # ResNetBlock(64, 128),
            # ResNetBlock(128, 256),
            # ResNetBlock(256, 512)
        )
        self.dropout = nn.Dropout(dropout)
        self.mlp = nn.Sequential(
            nn.Linear(6016, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.model(x)
        x = x.view(x.shape[0], -1)
        x = self.dropout(x)
        x = self.mlp(x)

        return x

In [50]:

train_dataset = SpectraVDataset(data_root, lamost_spec_dir, spectra_v_file, classes=None, split='train')
val_dataset = SpectraVDataset(data_root, lamost_spec_dir, spectra_v_file, classes=None, split='val')

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)

In [52]:
train_dataset.df['variable_type'].value_counts()

In [45]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print('Using', device)

model = ResNet(train_dataset.num_classes)
model = model.to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

train_losses, val_losses = [], []
train_accs, val_accs = [], []

In [46]:
for i in range(20):
    print(f'Epoch {i}')
    
    train_loss, train_acc = train_epoch()
    print(f'Train Loss: {round(train_loss, 3)} Acc: {round(train_acc, 2)}')
    
    val_loss, val_acc = val_epoch()
    print(f'Val Loss: {round(val_loss, 3)} Acc: {round(val_acc, 2)}')

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

In [52]:
dataset = SpectraVDataset(data_root, lamost_spec_dir, spectra_v_file, classes=CLASSES, split='all')
dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

In [54]:
for i in tqdm(range(len(dataset))):
    X, y = dataset[i]
    
    if X.std() < 0.01:
        print(i)

In [53]:
plt.plot(dataset[7984][0])

In [50]:
stats.median_abs_deviation(dataset[7984][0])

In [48]:
dataset[873][0].std()

In [14]:
model = GalSpecNet(dataset.num_classes)
criterion = nn.CrossEntropyLoss()

In [15]:
X, y = next(iter(dataloader))

In [27]:
for X, y in tqdm(dataloader):
    try:
        out = model(X)
        loss = criterion(out, y)
    except:
        break

In [8]:
for i in tqdm(range(len(dataset))):
    try:
        dataset[i]
    except:
        print(i)

In [12]:
dataset[3]

In [57]:
spec_df = pd.read_csv(os.path.join(data_root, lamost_spec_file), index_col=0)
v_df = pd.read_csv(os.path.join(data_root, v_file))

v_df = v_df.drop_duplicates(subset=['edr3_source_id'])
spec_df = spec_df.drop_duplicates(subset=['edr3_source_id'])

df = pd.merge(v_df, spec_df, on='edr3_source_id', how='inner')

In [328]:
source_ids = []

for i in range(len(df)):
    el = df.iloc[i]
    path = el['spec_filename']
    
    if not os.path.exists(os.path.join(data_root, lamost_spec_dir, path)):
        source_ids.append(el['edr3_source_id'])
        print(i, el['edr3_source_id'])

In [329]:
df = df[~df['edr3_source_id'].isin(source_ids)]

In [333]:
weird_sources = ['EDR3 3714273187707121920', 'EDR3 3222213829875076096', 'EDR3 601653935246445696']
df = df[~df['edr3_source_id'].isin(weird_sources)]

In [334]:
df.to_csv(os.path.join(data_root, spectra_v_file), index=False)

In [330]:
spectra = train_dataset2[3609]
print(train_dataset2.df.iloc[3609]['edr3_source_id'])
plt.plot(spectra[:, 0], spectra[:, 1])

In [331]:
spectra = train_dataset2[3772]
print(train_dataset2.df.iloc[3772]['edr3_source_id'])
plt.plot(spectra[:, 0], spectra[:, 1])

In [332]:
spectra = train_dataset2[8568]
print(train_dataset2.df.iloc[8568]['edr3_source_id'])
plt.plot(spectra[:, 0], spectra[:, 1])

In [326]:
spectra = train_dataset2[8569]
plt.plot(spectra[:, 0], spectra[:, 1])

In [19]:
class SpectraVDataset(Dataset):
    def __init__(self, data_root, lamost_spec_dir, spectra_v_file, split='train', classes=None, z_corr=False):
        self.data_root = data_root
        self.lamost_spec_dir = os.path.join(data_root, lamost_spec_dir)
        self.spectra_v_file = os.path.join(data_root, spectra_v_file)
        self.split = split
        self.z_corr = z_corr

        self.df = pd.read_csv(self.spectra_v_file)
        self.df = self.df[['edr3_source_id', 'variable_type', 'spec_filename']]

        if classes:
            self.df = self.df[self.df['variable_type'].isin(classes)]
            
        self._split()

        self.id2target = {i: x for i, x in enumerate(sorted(self.df['variable_type'].unique()))}
        self.target2id = {v: k for k, v in self.id2target.items()}
        self.num_classes = len(self.id2target)

    def _split(self):
        total_size = len(self.df)
        train_size = int(total_size * 0.7)
        val_size = int(total_size * 0.15)

        shuffled_df = self.df.sample(frac=1, random_state=42)

        if self.split == 'train':
            self.df = shuffled_df[:train_size]
        elif self.split == 'val':
            self.df = shuffled_df[train_size:train_size + val_size]
        elif self.split == 'test':
            self.df = shuffled_df[train_size + val_size:]
        else:
            self.df = shuffled_df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        el = self.df.iloc[idx]
        variable_type, spec_filename = el['variable_type'], el['spec_filename']
        label = self.target2id[variable_type]

        # read spectra
        spectra = self.readLRSFits(os.path.join(self.lamost_spec_dir, spec_filename))
        original_wavelengths, fluxes = spectra[:, 0], spectra[:, 1]
        
        # interpolate
        wavelengths = np.arange(3850, 9000, 2)
        fluxes = np.interp(wavelengths, original_wavelengths, fluxes)

        # normalize
        # fluxes = fluxes / fluxes.mean()
        # fluxes = (fluxes - fluxes.min()) / (fluxes.max() - fluxes.min())
        mean = fluxes.mean()
        mad = stats.median_abs_deviation(fluxes)
        fluxes = (fluxes - mean) / mad

        # reshape so batches are [N, 1, (9000-3850)//2]
        fluxes = fluxes.reshape(1, -1).astype(np.float32)

        return fluxes, label

        
    def readLRSFits(self, filename):
        """
        Read LAMOST fits file
          adapted from https://github.com/fandongwei/pylamost
    
        Parameters:
        -----------
        filename: str
          name of the fits file
        z_corr: bool.
          if True, correct for measured radial velocity of star
    
        Returns:
        --------
        spec: numpy array
          wavelength, flux, inverse variance
        """
    
        hdulist = fits.open(filename)
        len_list = len(hdulist)
        
        if len_list == 1:
            head = hdulist[0].header
            scidata = hdulist[0].data
            coeff0 = head['COEFF0']
            coeff1 = head['COEFF1']
            pixel_num = head['NAXIS1']
            specflux = scidata[0,]
            ivar = scidata[1,]
            wavelength = np.linspace(0, pixel_num - 1, pixel_num)
            wavelength = np.power(10, (coeff0 + wavelength * coeff1))
            hdulist.close()
            
        elif len_list == 2:
            head = hdulist[0].header
            scidata = hdulist[1].data
            wavelength = scidata[0][2]
            ivar = scidata[0][1]
            specflux = scidata[0][0]
    
        if self.z_corr:
            try:
                # correct for radial velocity of star
                redshift = head['Z']
            except Exception:
                redshift = 0.0
                
            wavelength = wavelength - redshift * wavelength
    
        return np.vstack((wavelength, specflux, ivar)).T

In [20]:
train_dataset = SpectraVDataset(data_root, lamost_spec_dir, spectra_v_file, classes=CLASSES, split='train')
val_dataset = SpectraVDataset(data_root, lamost_spec_dir, spectra_v_file, classes=CLASSES, split='val')

In [21]:
len(train_dataset), len(val_dataset)

In [37]:
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=4)

In [38]:
flux, label = next(iter(train_dataloader))

In [47]:
plt.plot(train_dataset.readLRSFits(os.path.join(data_root, lamost_spec_dir, train_dataset.df.iloc[1]['spec_filename']))[:, 1])

In [39]:
plt.plot(flux[0, 0, :])

In [43]:
plt.plot(flux[1, 0, :])

In [44]:
plt.plot(flux[2, 0, :])

In [45]:
plt.plot(flux[3, 0, :])

In [46]:
plt.plot(flux[4, 0, :])

In [230]:
conv1 = nn.Sequential(nn.Conv1d(1, 64, kernel_size=3), nn.ReLU())
mp1 = nn.MaxPool1d(kernel_size=4)

x1 = conv1(X)
print(x1.shape)

x2 = mp1(x1)
print(x2.shape)

In [231]:
conv2 = nn.Sequential(nn.Conv1d(64, 64, kernel_size=3), nn.ReLU())
mp2 = nn.MaxPool1d(kernel_size=4)

x3 = conv2(x2)
print(x3.shape)

x4 = mp2(x3)
print(x4.shape)

In [234]:
conv3 = nn.Sequential(nn.Conv1d(64, 32, kernel_size=3), nn.ReLU())
mp3 = nn.MaxPool1d(kernel_size=4)

x5 = conv3(x4)
print(x5.shape)

x6 = mp3(x5)
print(x6.shape)

In [239]:
conv4 = nn.Sequential(nn.Conv1d(32, 32, kernel_size=3), nn.ReLU())
x7 = conv4(x6)
x7.shape

In [240]:
x8 = x7.view(x7.shape[0], -1)
x8.shape

In [244]:
mlp = nn.Sequential(
    nn.Linear(32 * 37, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, train_dataset.num_classes)
)

In [245]:
x9 = mlp(x8)
x9.shape

In [251]:
class GalSpecNet(nn.Module):
    
    def __init__(self, num_classes):
        super(GalSpecNet, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv1d(1, 64, kernel_size=3), nn.ReLU())
        self.mp1 = nn.MaxPool1d(kernel_size=4)
        self.conv2 = nn.Sequential(nn.Conv1d(64, 64, kernel_size=3), nn.ReLU())
        self.mp2 = nn.MaxPool1d(kernel_size=4)
        self.conv3 = nn.Sequential(nn.Conv1d(64, 32, kernel_size=3), nn.ReLU())
        self.mp3 = nn.MaxPool1d(kernel_size=4)
        self.conv4 = nn.Sequential(nn.Conv1d(32, 32, kernel_size=3), nn.ReLU())

        self.dropout = nn.Dropout(0.2)
        self.mlp = nn.Sequential(
            nn.Linear(32 * 37, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, num_classes)
        )
            

    def forward(self, x):
        x = self.conv1(x)
        x = self.mp1(x)
        x = self.conv2(x)
        x = self.mp2(x)
        x = self.conv3(x)
        x = self.mp3(x)
        x = self.conv4(x)
        
        x = x.view(x.shape[0], -1)
        x = self.dropout(x)
        x = self.mlp(x)

        return x

In [56]:
model = ResNet(train_dataset.num_classes)

In [72]:
1

In [73]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot training and validation losses on the left side
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

# Plot training and validation accuracies on the right side
ax2.plot(train_accs, label='Train Accuracy')
ax2.plot(val_accs, label='Validation Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

fig.suptitle('Only 1 batch norm')
plt.show()

In [55]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot training and validation losses on the left side
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

# Plot training and validation accuracies on the right side
ax2.plot(train_accs, label='Train Accuracy')
ax2.plot(val_accs, label='Validation Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

fig.suptitle('Dont remember lol')
plt.show()

In [16]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot training and validation losses on the left side
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

# Plot training and validation accuracies on the right side
ax2.plot(train_accs, label='Train Accuracy')
ax2.plot(val_accs, label='Validation Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

fig.suptitle('Weight decay 0.1 + Res Connections')
plt.show()

In [24]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot training and validation losses on the left side
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

# Plot training and validation accuracies on the right side
ax2.plot(train_accs, label='Train Accuracy')
ax2.plot(val_accs, label='Validation Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

fig.suptitle('No max pool')
plt.show()

In [91]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot training and validation losses on the left side
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

# Plot training and validation accuracies on the right side
ax2.plot(train_accs, label='Train Accuracy')
ax2.plot(val_accs, label='Validation Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

fig.suptitle('MaxPool + Batch Norm but less layers and mlp instead of fc')
plt.show()

In [85]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot training and validation losses on the left side
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

# Plot training and validation accuracies on the right side
ax2.plot(train_accs, label='Train Accuracy')
ax2.plot(val_accs, label='Validation Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

fig.suptitle('MaxPool + Batch Norm but less layers')
plt.show()

In [75]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot training and validation losses on the left side
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

# Plot training and validation accuracies on the right side
ax2.plot(train_accs, label='Train Accuracy')
ax2.plot(val_accs, label='Validation Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

fig.suptitle('MaxPool + Batch Norm')
plt.show()

In [61]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot training and validation losses on the left side
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

# Plot training and validation accuracies on the right side
ax2.plot(train_accs, label='Train Accuracy')
ax2.plot(val_accs, label='Validation Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

fig.suptitle('6 Conv1d (1>64, 64>64, 64>64, 64>128, 128>128, 128>256), 2 MP1d, ReLU, DO(0.8), Flatten, 2 Linear (17920>1024>nc)')
plt.show()

In [56]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot training and validation losses on the left side
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

# Plot training and validation accuracies on the right side
ax2.plot(train_accs, label='Train Accuracy')
ax2.plot(val_accs, label='Validation Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

fig.suptitle('6 Conv1d (1>64, 64>64, 64>64, 64>128, 128>128, 128>256), 2 MP1d, ReLU, DO(0.5), Flatten, 2 Linear (17920>1024>nc)')
plt.show()

In [48]:
plt.plot(train_losses)
plt.plot(val_losses)

In [49]:
plt.plot(train_accs)
plt.plot(val_accs)

In [47]:
for i in range(10, 50):
    print(f'Epoch {i}')
    
    train_loss, train_acc = train_epoch()
    print(f'Train Loss: {round(train_loss, 3)} Acc: {round(train_acc, 2)}')
    
    val_loss, val_acc = val_epoch()
    print(f'Val Loss: {round(val_loss, 3)} Acc: {round(val_acc, 2)}')

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

In [7]:
datapath = Path('/home/mariia/AstroML/data/asassn')
ds_train = ASASSNVarStarDataset(datapath, mode='train', verbose=True, only_periodic=True, recalc_period=False, 
                                prime=True, use_bands=['v', 'g'], only_sources_with_spectra=True, return_phased=True, 
                                fill_value=0)

In [9]:
ds_val = ASASSNVarStarDataset(datapath, mode='val', verbose=True, only_periodic=True, recalc_period=False, 
                              prime=True, use_bands=['v', 'g'], only_sources_with_spectra=True, return_phased=True, 
                              fill_value=0)

In [None]:
def preprocess_batch(batch, masks):
    lcs, classes = batch
    lcs_mask, classes_mask = masks

    # shape now [128, 1, 3, 759], make [128, 3, 759] 
    X = lcs[:, 0, :, :]
    
    # change axises, shape now [128, 3, 759], make [128, 759, 3]
    X = X.transpose(1, 2)
    
    # since mask is the same for time flux and flux err we can make it 2D
    mask = lcs_mask[:, 0, 0, :]

    # context length 200, crop X and MASK if longer, pad if shorter
    if X.shape[1] < context_length:
        X_padding = (0, 0, 0, context_length - X.shape[1], 0, 0)
        mask_padding = (0, context_length - X.shape[1])
        X = F.pad(X, X_padding)
        mask = F.pad(mask, mask_padding, value=True)
    else:
        X = X[:, :context_length, :]
        mask = mask[:, :context_length]

    # the last dimention is (time, flux, flux_err), sort it based on time
    sort_indices = torch.argsort(X[:, :, 0], dim=1)
    sorted_X = torch.zeros_like(X)
    
    for i in range(X.shape[0]):
        sorted_X[i] = X[i, sort_indices[i]]
    
    # rearange indexes for masks as well
    sorted_mask = torch.zeros_like(mask)
    
    for i in range(mask.shape[0]):
        sorted_mask[i] = mask[i, sort_indices[i]]

    # mask should be 1 for values that are observed and 0 for values that are missing
    sorted_mask = 1 - sorted_mask.int()

    # read scales
    with open('scales.json', 'r') as f:
        scales = json.load(f)
        mean, std = scales['v']['mean'], scales['v']['std']

    # scale X
    sorted_X[:, :, 1] = (sorted_X[:, :, 1] - mean) / std
    sorted_X[:, :, 2] = sorted_X[:, :, 2] / std

    # reshape classes to be 1D vector and convert from float to int
    classes = classes[:, 0]
    classes = classes.long()
    
    return sorted_X, sorted_mask, classes

In [12]:
train_dataset = SpectraDataset(ds_train)
val_dataset = SpectraDataset(ds_val)

In [13]:
train_dataset[0]

In [14]:
train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=8)
val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=False)

In [15]:
class GalSpecNet(nn.Module):
    """https://academic.oup.com/mnras/article/527/1/1163/7283157"""
    
    def __init__(self, num_classes):
        super(GalSpecNet, self).__init__()
        
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=3, padding=0)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, padding=0)
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=3, padding=0)
        self.conv4 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, padding=0)

        self.mp1 = nn.MaxPool1d(kernel_size=4)
        self.mp2 = nn.MaxPool1d(kernel_size=4)
        self.mp3 = nn.MaxPool1d(kernel_size=4)

        self.fc1 = nn.Linear(2496, 512)
        self.fc2 = nn.Linear(512, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, num_classes)
        
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = F.relu(self.mp1(self.conv1(x)))
        x = F.relu(self.mp2(self.conv2(x)))
        x = F.relu(self.mp3(self.conv3(x)))
        x = F.relu(self.conv4(x))        
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(self.dropout(x)))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        
        return x

In [18]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print('Using', device)

In [19]:
model = GalSpecNet(len(ds_train.target_lookup))
model = model.to(device)

In [20]:
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [22]:
def train_epoch():
    model.train()
    total_loss = []
    total_correct_predictions = 0
    total_predictions = 0

    for fluxes, label in tqdm(train_dataloader):
        fluxes, label = fluxes.to(device), label.to(device)

        optimizer.zero_grad()

        logits = model(fluxes)
        loss = criterion(logits, label)
        total_loss.append(loss.item())

        probabilities = torch.nn.functional.softmax(logits, dim=1)
        _, predicted_labels = torch.max(probabilities, dim=1)
        correct_predictions = (predicted_labels == label).sum().item()

        total_correct_predictions += correct_predictions
        total_predictions += label.size(0)

        loss.backward()
        optimizer.step()

    return sum(total_loss) / len(total_loss), total_correct_predictions / total_predictions

In [23]:
def val_epoch():
    model.eval()
    total_loss = []
    total_correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for fluxes, label in tqdm(val_dataloader):
            fluxes, label = fluxes.to(device), label.to(device)

            logits = model(fluxes)
            loss = criterion(logits, label)
            total_loss.append(loss.item())

            probabilities = torch.nn.functional.softmax(logits, dim=1)
            _, predicted_labels = torch.max(probabilities, dim=1)
            correct_predictions = (predicted_labels == label).sum().item()
    
            total_correct_predictions += correct_predictions
            total_predictions += label.size(0)

    return sum(total_loss) / len(total_loss), total_correct_predictions / total_predictions

In [24]:
for i in range(10):
    print(f'Epoch {i}')
    
    train_loss, train_acc = train_epoch()
    print(f'Train Loss: {round(train_loss, 3)} Acc: {round(train_acc, 2)}')
    
    val_loss, val_acc = val_epoch()
    print(f'Val Loss: {round(val_loss, 3)} Acc: {round(val_acc, 2)}')

In [25]:
for i in range(10, 100):
    print(f'Epoch {i}')
    
    train_loss, train_acc = train_epoch()
    print(f'Train Loss: {round(train_loss, 3)} Acc: {round(train_acc, 2)}')
    
    val_loss, val_acc = val_epoch()
    print(f'Val Loss: {round(val_loss, 3)} Acc: {round(val_acc, 2)}')

In [111]:
torch.save(model.state_dict(), f'weights-10.pth')

In [112]:
for i in range(10, 100):
    print(f'Epoch {i}')
    print('Train', train_epoch())
    print('Val', val_epoch())

    if i % 10 == 0:
        torch.save(model.state_dict(), f'weights-{i}.pth')

In [122]:
for i in tqdm(range(len(ds_train))):
    if ds_train[i]['spectra'][0][0][:, 1].dtype != np.float32:
        print(i)

In [None]:
for i in tqdm(range(len(ds_train))):
    try:
        spectra = ds_train[i]['spectra'][0][0]
    except:
        print(i)
        break

In [176]:
spec_filename = ds_train.spec_df[ds_train.spec_df['edr3_source_id'] == 'EDR3 1314026659889734400']['spec_filename'].iloc[0]

In [185]:
remove_filenames = []

for el in ds_train.spec_df['spec_filename']:
    filename = (ds_train.data_root / ds_train.lamost_spec_dir / el)

    if not os.path.exists(filename):
        remove_filenames.append(str(filename).split('/')[-1])

In [187]:
ds_train.spec_df[ds_train.spec_df['spec_filename'].isin(remove_filenames)].index

In [177]:
filename = (ds_train.data_root / ds_train.lamost_spec_dir / spec_filename)

if os.path.exists(filename):
    row_spectra.append(self._readLRSFits(filename))

In [178]:
os.path.exists(filename)

In [94]:
fig, axs = plt.subplots(nrows=5, ncols=2, figsize=(15, 20))

for i in range(5):
    spectra = ds_train[i]['spectra'][0][0]
    y = ds_train[i]['classes'][0][0]

    wavelengths, fluxes = spectra[:, 0], spectra[:, 1]
    regular_wavelengths = np.arange(3850, 9002, 2)
    interpolated_fluxes = np.interp(regular_wavelengths, wavelengths, fluxes)

    axs[i, 0].plot(wavelengths, fluxes)
    axs[i, 0].set_title(f'Class {ds_train.target_lookup[y]} Before Interpolation')
    axs[i, 1].plot(regular_wavelengths, interpolated_fluxes)
    axs[i, 1].set_title(f'Class {ds_train.target_lookup[y]} After Interpolation')

plt.show()

In [28]:
ds_train[i]['spectra'][0][0][:, 0].min(), ds_train[i]['spectra'][0][0][:, 0].max()

In [30]:
ds_train.df['variable_type'].value_counts()