## Training a Masked Autoencoder

In [1]:
import numpy as np
import pandas as pd
from astropy.table import Table
from time import time
import h5py

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler

from gaiaxpy import generate, PhotometricSystem

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
plt.rcParams['text.usetex'] = True
plt.rcParams['font.size'] = 14
plt.rcParams['legend.fontsize'] = 14
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.major.size'] = 5.0
plt.rcParams['xtick.minor.size'] = 3.0
plt.rcParams['ytick.major.size'] = 5.0
plt.rcParams['ytick.minor.size'] = 3.0
plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True

Converting to GPU if available

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


Checking directory

In [3]:
%%bash
cd /scratch/
pwd

/scratch


In [4]:
# scalers for dataloading
metscaler = StandardScaler(); logscaler = StandardScaler(); tefscaler = StandardScaler()
# extscaler = StandardScaler(); parscaler = StandardScaler()
scale = 'standard_scale'

batchlen = 32
lr = 1e-4
epochs = 10
optimize = 'Adam'
datafname = "/arc/home/aydanmckay/mae_tab/lamost_pristine_bprp_gmag.h5"
datashort = 'ViT_MAE_v1'
lossname = 'L2'

In [5]:
# defining the Dataset class
class data_set(Dataset):
    '''
    Main way to access the .h5 file.
    '''
    def __init__(self,file,train=True,valid=False,test=False,noscale=False):
        fn = h5py.File(file, 'r')
        self.f = fn
        
        # get data
        if train:
            name = 'group_1'
        elif valid:
            name = 'group_2'
        elif test or noscale:
            name = 'group_3'
        
        dset = self.f[name]['theta']
        dl = dset[:]
        if noscale:
            self.l = dl.shape[1]
            self.t = torch.Tensor(dl.T)
        else:
            dat = np.array([
                metscaler.fit_transform(dl[[0]].T).flatten(),
                logscaler.fit_transform(dl[[1]].T).flatten(),
                tefscaler.fit_transform(dl[[2]].T).flatten(),
            ])
            self.l = dat.shape[1]
            self.x = torch.Tensor(dat.T)

        ydset = self.f[name]['bprp']
        ydat = ydset[:]
        self.y = torch.Tensor(ydat[:].T)

        errdset = self.f[name]['e_bprp']
        self.err = torch.Tensor(errdset[:].T)
        
        mdset = self.f[name]['mags']
        self.m = torch.Tensor(mdset[:].T)
        
        ddset = self.f[name]['dist']
        self.d = torch.Tensor(ddset[:].T)
        
        edset = self.f[name]['ext']
        self.e = torch.Tensor(edset[:].T)
        
    def __len__(self):
        return self.l
  
    def __getitem__(self, index):
        tg = self.t[index]
        yg = self.y[index]
        mg = self.m[index]
        errg = self.err[index]
        eg = self.e[index]
        dg = self.d[index]
        return (tg,yg,errg,mg,eg,dg)

In [7]:
# MAE from chatgpt
# switch encoder for above tab_vit
class MaskedAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, mask):
        super(MaskedAutoencoder, self).__init__()
        self.mask = torch.tensor(mask, dtype=torch.float32)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        
    def forward(self, x):
        masked_x = x * self.mask
        z = self.encoder(masked_x)
        x_hat = self.decoder(z)
        return x_hat, z


In [8]:
training_data = data_set(datafname)
valid_data = data_set(datafname,train=False,valid=True)

In [9]:
train_dataloader = DataLoader(
    training_data,
    batch_size=batchlen,
    shuffle=True,
    num_workers=0
)
valid_dataloader = DataLoader(
    valid_data,
    batch_size=batchlen,
    shuffle=True,
    num_workers=0
)

In [10]:
model = MAE()
model = model.to(device)

NameError: name 'MAE' is not defined

In [None]:
class Net(nn.Module):
    '''
    The autoencoder to be used to reduce the images of the galaxies down
    to a latent space, from which the original images are to be reconstructed.
    The improvements are as follows: using both batch normalization and maxpool
    layers in the encoder, having a better decoder output layer, in which the
    kernel size is not large, having a MLP in the model, so that the algorithm
    can train on both reconstruction error and prediction error simultaneously.
    '''
    def __init__(self):
        super(Net, self).__init__()
        # The encoder, reducing the images down from a shape of (batchsize, 3, 200, 200)
        # to a shape of (batchsize, 1024)
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1), # shape (batchsize, 16, 200, 200)
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2, stride=2), # shape (batchsize, 16, 100, 100)
            nn.Conv2d(16, 32, kernel_size=3, padding=1), # shape (batchsize, 32, 100, 100)
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, stride=2), # shape (batchsize, 32, 50, 50)
            nn.Conv2d(32, 64, kernel_size=3, padding=1), # shape (batchsize, 64, 50, 50)
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, stride=2), # shape (batchsize, 64, 25, 25)
            nn.Conv2d(64, 128, kernel_size=3, padding=1), # shape (batchsize, 128, 25, 25)
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, stride=2), # shape (batchsize, 128, 12, 12)
            nn.Conv2d(128, 256, kernel_size=3, padding=1),# shape (batchsize, 256, 12, 12)
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2, stride=2), # shape (batchsize, 256, 6, 6)
            nn.Conv2d(256, 512, kernel_size=3, padding=1),# shape (batchsize, 512, 6, 6)
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(2, stride=2), # shape (batchsize, 512, 3, 3)
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),# shape (batchsize, 1024, 3, 3)
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.MaxPool2d(2, stride=2), # shape (batchsize, 1024, 1, 1)
            nn.Flatten(),
            nn.Linear(1024, 1024)
        )
        # the decoder, reconstucting the images to the original shape of (batchsize, 3, 424, 424)
        # from the latent vector of 1024
        self.decoder = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Unflatten(-1, (1024, 1, 1)),
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, output_padding=1), # shape (batchsize, 512, 3, 3)
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), # shape (batchsize, 256, 6, 6)
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # shape (batchsize, 128, 12, 12)
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=1), # shape (batchsize, 64, 25, 25)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # shape (batchsize, 32, 50, 50)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # shape (batchsize, 16, 100, 100)
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1), # shape (batchsize, 3, 200, 200)
            # nn.Sigmoid(),
        )
        # the classifier, a simple 3-layer MLP to predict the morphology of the galaxies
        self.mlp = nn.Sequential(
            nn.Linear(1024, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            # sigmoid to make predictions in between 0 and 1
            nn.Sigmoid()
        )

    def forward(self, x):
        # the forward function, encoding then decoding the data
        # and predicting from the latent space
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        predicted = self.mlp(encoded)
        return (decoded,predicted)

In [None]:
# defining the hyperparameters
lr = 0.001
batch_size = 32
num_epochs = 10

# initializing the autoencoder
net = Net()
net.to(device)

# setting the optimization algorithm
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

# defining the reconstruction loss function
recon_loss = nn.MSELoss()

# defining the prediction loss function (binary output)
pred_loss = nn.BCELoss()

# defining new training and validation functions
def newtrain(model, dataloader, recon_loss, pred_loss, optimizer):
    '''
    The training algorithm of the autoencoder. This algorithm now
    combines the losses of both reconstruction and prediction to
    train both simultaneously.
    '''
    model.train()
    train_loss = 0.0
    running_recon_loss = 0.0
    running_pred_loss = 0.0
    for batch_imgs,y in dataloader:
        y = y.to(device)
        optimizer.zero_grad()

        # predicting
        outputs,preds = model(batch_imgs)

        # computing the individual losses
        recon_error = recon_loss(outputs, batch_imgs)
        pred_error = pred_loss(preds,y.reshape(-1,1).float())

        # computing the total loss
        loss = recon_error + pred_error

        # backward pass and optimization
        loss.backward()
        optimizer.step()

        # updating the running loss
        running_recon_loss += recon_error.item() * batch_imgs.size(0)
        running_pred_loss += pred_error.item() * batch_imgs.size(0)
        train_loss += loss.item() * batch_imgs.size(0)

    # printing the two losses separately
    print("Epoch %d Reconstruction Loss: %.3f Prediction Loss: %.3f" % (epoch+1,running_recon_loss/len(dataloader.dataset),running_pred_loss/len(dataloader.dataset)))
    return train_loss / len(dataloader.dataset)

def newvalidate(model, dataloader, recon_loss, pred_loss):
    '''
    The validation algorithm for the autoencoder. Similar to the
    training algorithm function, the diffence is that backprop and
    optimization is not done during the evaluation step.
    '''
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch_imgs,y in dataloader:
            y = y.to(device)

            # predicting
            outputs,preds = model(batch_imgs)

            # computing the losses
            recon_error = recon_loss(outputs, batch_imgs)
            pred_error = pred_loss(preds,y.reshape(-1,1).float())

            # computing the total loss
            loss = recon_error + pred_error
            val_loss += loss.item() * batch_imgs.size(0)
    return val_loss / len(dataloader.dataset)

In [None]:
# training the algorithm
ntlosses = []
nvlosses = []
for epoch in range(num_epochs):
    train_loss = newtrain(net, new_train_loader, recon_loss, pred_loss, optimizer)
    val_loss = newvalidate(net, new_val_loader, recon_loss, pred_loss)
    print(f"Epoch {epoch + 1}/{num_epochs}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
    ntlosses.append(train_loss)
    nvlosses.append(val_loss)

In [None]:
# plotting the loss
plt.plot(range(1,num_epochs+1),ntlosses,label='Train Loss')
plt.plot(range(1,num_epochs+1),nvlosses,label='Valid Loss')
plt.xlabel('Epoch')
plt.ylabel('Combined Loss')
plt.legend(fancybox=True)
# plt.savefig('/content/drive/MyDrive/lossplotnet.png')

In [None]:
# saving the new algorithm
# torch.save(net.state_dict(), '/content/drive/MyDrive/gpumodel2.pth')

In [None]:
# loading the new algorithm
net.load_state_dict(torch.load('/content/drive/MyDrive/gpumodel2.pth'))
net.eval()