## Training a Masked Autoencoder

In [1]:
from astropy.table import Table
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
plt.rcParams['text.usetex'] = True
plt.rcParams['font.size'] = 14
plt.rcParams['legend.fontsize'] = 14
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.major.size'] = 5.0
plt.rcParams['xtick.minor.size'] = 3.0
plt.rcParams['ytick.major.size'] = 5.0
plt.rcParams['ytick.minor.size'] = 3.0
plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True
from time import time
from sklearn.preprocessing import StandardScaler
import pandas as pd
from gaiaxpy import generate, PhotometricSystem

Converting to GPU if available

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


Checking directory

In [3]:
%%bash
cd /scratch/
pwd

/scratch


In [4]:
# defining the Dataset class
class data_set(Dataset):
    '''
    Main way to access the .h5 file.
    '''
    def __init__(self,file,train=True,valid=False,test=False,noscale=False):
        fn = h5py.File(file, 'r')
        self.f = fn
        
        # get data
        if train:
            dset = self.f['group_1']['data']
            d = dset[:]
            dat = np.array([
                metscaler.fit_transform(d[[0]].T).flatten(),
                logscaler.fit_transform(d[[1]].T).flatten(),
                tefscaler.fit_transform(d[[2]].T).flatten(),
                amscaler.fit_transform(d[[3]].T).flatten(), # comment out if not
            ])
            self.l = dat.shape[1]
            self.x = torch.Tensor(dat.T)
        elif valid:
            dset = self.f['group_2']['data']
            d = dset[:]
            dat = np.array([
                metscaler.transform(d[[0]].T).flatten(),
                logscaler.transform(d[[1]].T).flatten(),
                tefscaler.transform(d[[2]].T).flatten(),
                amscaler.transform(d[[3]].T).flatten(), # comment out if not
            ])
            self.l = dat.shape[1]
            self.x = torch.Tensor(dat.T)
        elif test:
            dset = self.f['group_3']['data']
            d = dset[:]
            dat = np.array([
                metscaler.transform(d[[0]].T).flatten(),
                logscaler.transform(d[[1]].T).flatten(),
                tefscaler.transform(d[[2]].T).flatten(),
                amscaler.transform(d[[3]].T).flatten(), # comment out if not
            ])
            self.l = dat.shape[1]
            self.x = torch.Tensor(dat.T)
        elif noscale:
            dset = self.f['group_3']['data']
            d = dset[:]
            self.l = d.shape[1]
            self.x = torch.Tensor(d.T)
        
        # get label
        if train:
            ydset = self.f['group_1']['label']
            ydat = ydset[:]
            self.y = torch.Tensor(ydat[:].T) # torch.from_numpy(y[index]) does not work since y is doubles and not floats.
        elif valid:
            ydset = self.f['group_2']['label']
            ydat = ydset[:]
            self.y = torch.Tensor(ydat[:].T)
        elif test:
            ydset = self.f['group_3']['label']
            ydat = ydset[:]
            self.y = torch.Tensor(ydat[:].T)
        elif noscale:
            ydset = self.f['group_3']['label']
            ydat = ydset[:]
            self.y = torch.Tensor(ydat.T)
        
        # get error in label # comment out for non-error label runs
        if train:
            errdset = self.f['group_1']['e_label']
            self.err = torch.Tensor(errdset[:].T)
        elif valid:
            errdset = self.f['group_2']['e_label']
            self.err = torch.Tensor(errdset[:].T)
        elif test or noscale:
            errdset = self.f['group_3']['e_label']
            self.err = torch.Tensor(errdset[:].T)
            
        if train:
            gdset = self.f['group_1']['gmag']
            self.g = torch.Tensor(gdset[:].T)
        elif valid:
            gdset = self.f['group_2']['gmag']
            self.g = torch.Tensor(gdset[:].T)
        elif test or noscale:
            gdset = self.f['group_3']['gmag']
            self.g = torch.Tensor(gdset[:].T)
        
    def __len__(self):
        return self.l
  
    def __getitem__(self, index):
        xg = self.x[index]
        yg = self.y[index]
        gg = self.g[index]
        errg = self.err[index]
        return (xg,yg,gg,errg)

In [6]:
# chatgpt TabularViT
# this will be the encoder
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from einops.layers.torch import Rearrange

# Define the TabularViT model
class TabularViT(nn.Module):
    def __init__(self, input_dim, output_dim, patch_dim=64, num_patches=16, dim=256, depth=6, heads=8, mlp_dim=512):
        super().__init__()
        self.patch_dim = patch_dim
        self.num_patches = num_patches
        self.to_patch_embedding = nn.Linear(input_dim, patch_dim)
        self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim), num_layers=depth)
        self.layer_norm = nn.LayerNorm(dim)
        self.fc = nn.Linear(dim, output_dim)

    def forward(self, x):
        x = self.to_patch_embedding(x).transpose(1, 2)
        x = self._add_positional_encoding(x)
        x = self.transformer(x)
        x = self.layer_norm(x.mean(dim=1))
        x = self.fc(x)
        return x

    def _add_positional_encoding(self, x):
        b, n, _ = x.shape
        position_embeddings = self.position_embeddings[:, :(n + 1)]
        return (x + position_embeddings).permute(1, 0, 2)

# Define the TabularDataset class to load the data
class TabularDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Load the data and split into train and validation sets
# data = pd.read_csv('data.csv')
# X = data.drop(columns=['target'])
# y = data['target']
# X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# # Create the data loaders
# train_dataset = TabularDataset(X_train.values, y_train.values)
# val_dataset = TabularDataset(X_val.values, y_val.values)
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Define the training loop
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    train_loss = 0.0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

    train_loss /= len(train_loader.dataset)
    accuracy = 100.0 * correct / len(train_loader.dataset)

    return train_loss, accuracy

# Define the validation loop
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0


In [None]:
# MAE from chatgpt
# switch encoder for above tab_vit
class MaskedAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, mask):
        super(MaskedAutoencoder, self).__init__()
        self.mask = torch.tensor(mask, dtype=torch.float32)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        
    def forward(self, x):
        masked_x = x * self.mask
        z = self.encoder(masked_x)
        x_hat = self.decoder(z)
        return x_hat, z
