## Training a Masked Autoencoder

In [1]:
import numpy as np
import pandas as pd
from astropy.table import Table
from time import time
import h5py

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler

from gaiaxpy import generate, PhotometricSystem

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
plt.rcParams['text.usetex'] = True
plt.rcParams['font.size'] = 14
plt.rcParams['legend.fontsize'] = 14
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.major.size'] = 5.0
plt.rcParams['xtick.minor.size'] = 3.0
plt.rcParams['ytick.major.size'] = 5.0
plt.rcParams['ytick.minor.size'] = 3.0
plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True

Converting to GPU if available

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


Checking directory

In [3]:
%%bash
cd /scratch/
pwd

/scratch


In [4]:
# scalers for dataloading
metscaler = StandardScaler(); logscaler = StandardScaler(); tefscaler = StandardScaler()
# extscaler = StandardScaler(); parscaler = StandardScaler()
scale = 'standard_scale'

batchlen = 32
lr = 1e-4
epochs = 10
optimize = 'Adam'
datafname = "/arc/home/aydanmckay/mae_tab/lamost_pristine_bprp_gmag.h5"
datashort = 'ViT_MAE_v1'
lossname = 'L2'

In [5]:
# defining the Dataset class
class data_set(Dataset):
    '''
    Main way to access the .h5 file.
    '''
    def __init__(self,file,train=True,valid=False,test=False,noscale=False):
        fn = h5py.File(file, 'r')
        self.f = fn
        
        # get data
        if train:
            name = 'group_1'
        elif valid:
            name = 'group_2'
        elif test or noscale:
            name = 'group_3'
        
        dset = self.f[name]['theta']
        dl = dset[:]
        if noscale:
            self.l = dl.shape[1]
            self.t = torch.Tensor(dl.T)
        else:
            dat = np.array([
                metscaler.fit_transform(dl[[0]].T).flatten(),
                logscaler.fit_transform(dl[[1]].T).flatten(),
                tefscaler.fit_transform(dl[[2]].T).flatten(),
            ])
            self.l = dat.shape[1]
            self.x = torch.Tensor(dat.T)

        ydset = self.f[name]['bprp']
        ydat = ydset[:]
        self.y = torch.Tensor(ydat[:].T)

        errdset = self.f[name]['e_bprp']
        self.err = torch.Tensor(errdset[:].T)
        
        mdset = self.f[name]['mags']
        self.m = torch.Tensor(mdset[:].T)
        
        ddset = self.f[name]['dist']
        self.d = torch.Tensor(ddset[:].T)
        
        edset = self.f[name]['ext']
        self.e = torch.Tensor(edset[:].T)
        
    def __len__(self):
        return self.l
  
    def __getitem__(self, index):
        tg = self.t[index]
        yg = self.y[index]
        mg = self.m[index]
        errg = self.err[index]
        eg = self.e[index]
        dg = self.d[index]
        return (tg,yg,errg,mg,eg,dg)

In [6]:
# chatgpt TabularViT
# this will be the encoder
from einops.layers.torch import Rearrange

# Define the TabularViT model
class TabularViT(nn.Module):
    def __init__(self, input_dim, output_dim, patch_dim=64, num_patches=16, dim=256, depth=6, heads=8, mlp_dim=512):
        super().__init__()
        self.patch_dim = patch_dim
        self.num_patches = num_patches
        self.to_patch_embedding = nn.Linear(input_dim, patch_dim)
        self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim), num_layers=depth)
        self.layer_norm = nn.LayerNorm(dim)
        self.fc = nn.Linear(dim, output_dim)

    def forward(self, x):
        x = self.to_patch_embedding(x).transpose(1, 2)
        x = self._add_positional_encoding(x)
        x = self.transformer(x)
        x = self.layer_norm(x.mean(dim=1))
        x = self.fc(x)
        return x

    def _add_positional_encoding(self, x):
        b, n, _ = x.shape
        position_embeddings = self.position_embeddings[:, :(n + 1)]
        return (x + position_embeddings).permute(1, 0, 2)

# Define the training loop
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    train_loss = 0.0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

    train_loss /= len(train_loader.dataset)
    accuracy = 100.0 * correct / len(train_loader.dataset)

    return train_loss, accuracy

# Define the validation loop
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0


In [7]:
# MAE from chatgpt
# switch encoder for above tab_vit
class MaskedAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, mask):
        super(MaskedAutoencoder, self).__init__()
        self.mask = torch.tensor(mask, dtype=torch.float32)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        
    def forward(self, x):
        masked_x = x * self.mask
        z = self.encoder(masked_x)
        x_hat = self.decoder(z)
        return x_hat, z


In [8]:
training_data = data_set(datafname)
valid_data = data_set(datafname,train=False,valid=True)

In [9]:
train_dataloader = DataLoader(
    training_data,
    batch_size=batchlen,
    shuffle=True,
    num_workers=0
)
valid_dataloader = DataLoader(
    valid_data,
    batch_size=batchlen,
    shuffle=True,
    num_workers=0
)

In [10]:
model = MAE()
model = model.to(device)

NameError: name 'MAE' is not defined

In [None]:
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import torchvision

from lightly.data import LightlyDataset
from lightly.data.multi_view_collate import MultiViewCollate
from lightly.models import utils
from lightly.models.modules import masked_autoencoder
from lightly.transforms.mae_transform import MAETransform


class MAE(nn.Module):
    def __init__(self, vit):
        super().__init__()

        decoder_dim = 512
        self.mask_ratio = 0.75
        self.patch_size = vit.patch_size
        self.sequence_length = vit.seq_length
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
        self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit)
        self.decoder = masked_autoencoder.MAEDecoder(
            seq_length=vit.seq_length,
            num_layers=1,
            num_heads=16,
            embed_input_dim=vit.hidden_dim,
            hidden_dim=decoder_dim,
            mlp_dim=decoder_dim * 4,
            out_dim=vit.patch_size**2 * 3,
            dropout=0,
            attention_dropout=0,
        )

    def forward_encoder(self, images, idx_keep=None):
        return self.backbone.encode(images, idx_keep)

    def forward_decoder(self, x_encoded, idx_keep, idx_mask):
        # build decoder input
        batch_size = x_encoded.shape[0]
        x_decode = self.decoder.embed(x_encoded)
        x_masked = utils.repeat_token(
            self.mask_token, (batch_size, self.sequence_length)
        )
        x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))

        # decoder forward pass
        x_decoded = self.decoder.decode(x_masked)

        # predict pixel values for masked tokens
        x_pred = utils.get_at_index(x_decoded, idx_mask)
        x_pred = self.decoder.predict(x_pred)
        return x_pred

    def forward(self, images):
        batch_size = images.shape[0]
        idx_keep, idx_mask = utils.random_token_mask(
            size=(batch_size, self.sequence_length),
            mask_ratio=self.mask_ratio,
            device=images.device,
        )
        x_encoded = self.forward_encoder(images, idx_keep)
        x_pred = self.forward_decoder(x_encoded, idx_keep, idx_mask)

        # get image patches for masked tokens
        patches = utils.patchify(images, self.patch_size)
        # must adjust idx_mask for missing class token
        target = utils.get_at_index(patches, idx_mask - 1)
        return x_pred, target


vit = torchvision.models.vit_b_32(pretrained=False)
model = MAE(vit)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# we ignore object detection annotations by setting target_transform to return 0
pascal_voc = torchvision.datasets.VOCDetection(
    "datasets/pascal_voc", download=True, target_transform=lambda t: 0
)
transform = MAETransform()
dataset = LightlyDataset.from_torch_dataset(pascal_voc, transform=transform)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

collate_fn = MultiViewCollate()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)

In [None]:
for images, _, _ in dataloader:
    print(images)
    break

In [None]:
print("Starting Training")
for epoch in range(10):
    total_loss = 0
    for images, _, _ in dataloader:
        images = images[0].to(device)  # images is a list containing only one view
        predictions, targets = model(images)
        loss = criterion(predictions, targets)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")